diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 728d835c91..d41d543144 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -213,13 +213,17 @@ def create_service_container( ), "db_schema_retrieval": _db_schema_retrieval_pipeline, "sql_generation": generation.SQLGeneration( - **pipe_components["question_recommendation_sql_generation"], + **pipe_components[ + "question_recommendation_sql_generation" + ], ), "sql_pairs_retrieval": _sql_pair_retrieval_pipeline, "instructions_retrieval": _instructions_retrieval_pipeline, "sql_functions_retrieval": _sql_functions_retrieval_pipeline, + "sql_correction": _sql_correction_pipeline, }, allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval, + max_sql_correction_retries=settings.max_sql_correction_retries, **query_cache, ), sql_pairs_service=services.SqlPairsService( diff --git a/wren-ai-service/src/web/v1/services/question_recommendation.py b/wren-ai-service/src/web/v1/services/question_recommendation.py index 24df1e3e62..41c69ca58d 100644 --- a/wren-ai-service/src/web/v1/services/question_recommendation.py +++ b/wren-ai-service/src/web/v1/services/question_recommendation.py @@ -31,6 +31,7 @@ def __init__( self, pipelines: Dict[str, BasicPipeline], allow_sql_functions_retrieval: bool = True, + max_sql_correction_retries: int = 3, maxsize: int = 1_000_000, ttl: int = 120, ): @@ -39,6 +40,7 @@ def __init__( maxsize=maxsize, ttl=ttl ) self._allow_sql_functions_retrieval = allow_sql_functions_retrieval + self._max_sql_correction_retries = max_sql_correction_retries def _handle_exception( self, @@ -127,10 +129,45 @@ async def _instructions_retrieval() -> list[dict]: post_process = generated_sql["post_process"] + # If initial generation fails, try correction loop similar to ask flow if len(post_process["valid_generation_result"]) == 0: - return post_process + failed_dry_run_result = post_process.get( + "invalid_generation_result" + ) + current_sql_correction_retries = 0 + + while ( + failed_dry_run_result + and current_sql_correction_retries + < self._max_sql_correction_retries + ): + current_sql_correction_retries += 1 + + sql_correction_results = await self._pipelines[ + "sql_correction" + ].run( + contexts=table_ddls, + invalid_generation_result=failed_dry_run_result, + instructions=instructions, + project_id=project_id, + ) + + post_process = sql_correction_results["post_process"] + if valid_generation_result := post_process.get( + "valid_generation_result" + ): + valid_sql = valid_generation_result["sql"] + break + + failed_dry_run_result = post_process.get( + "invalid_generation_result" + ) + else: + # Still no valid SQL after corrections + return post_process - valid_sql = post_process["valid_generation_result"]["sql"] + else: + valid_sql = post_process["valid_generation_result"]["sql"] # Partial update the resource current = self._cache[request_id]