From c85985e6c77072b4e2bd1a70924a9cc9e71f0a4e Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 14 Aug 2025 13:57:42 +0800 Subject: [PATCH 01/11] update --- deployment/kustomizations/base/cm.yaml | 2 + docker/config.example.yaml | 2 + wren-ai-service/src/globals.py | 7 + .../src/pipelines/generation/__init__.py | 2 + .../generation/data_exploration_assistance.py | 141 ++++++++++++++++++ .../generation/intent_classification.py | 38 ++++- wren-ai-service/src/web/v1/services/ask.py | 38 ++++- .../tools/config/config.example.yaml | 2 + wren-ai-service/tools/config/config.full.yaml | 2 + 9 files changed, 229 insertions(+), 5 deletions(-) create mode 100644 wren-ai-service/src/pipelines/generation/data_exploration_assistance.py diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 72b0678bc9..e561841930 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -172,6 +172,8 @@ data: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/docker/config.example.yaml b/docker/config.example.yaml index f8f56c7bb3..48eeded0a8 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -122,6 +122,8 @@ pipes: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 728d835c91..b8e8269e56 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -146,6 +146,13 @@ def create_service_container( **pipe_components["followup_sql_generation"], ), "sql_functions_retrieval": _sql_functions_retrieval_pipeline, + "sql_executor": retrieval.SQLExecutor( + **pipe_components["sql_executor"], + engine_timeout=settings.engine_timeout, + ), + "data_exploration_assistance": generation.DataExplorationAssistance( + **pipe_components["data_exploration_assistance"], + ), }, allow_intent_classification=settings.allow_intent_classification, allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning, diff --git a/wren-ai-service/src/pipelines/generation/__init__.py b/wren-ai-service/src/pipelines/generation/__init__.py index 6940643217..e4c155d9cc 100644 --- a/wren-ai-service/src/pipelines/generation/__init__.py +++ b/wren-ai-service/src/pipelines/generation/__init__.py @@ -1,6 +1,7 @@ from .chart_adjustment import ChartAdjustment from .chart_generation import ChartGeneration from .data_assistance import DataAssistance +from .data_exploration_assistance import DataExplorationAssistance from .followup_sql_generation import FollowUpSQLGeneration from .followup_sql_generation_reasoning import FollowUpSQLGenerationReasoning from .intent_classification import IntentClassification @@ -36,4 +37,5 @@ "FollowUpSQLGenerationReasoning", "MisleadingAssistance", "SQLTablesExtraction", + "DataExplorationAssistance", ] diff --git a/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py new file mode 100644 index 0000000000..9b6c868507 --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py @@ -0,0 +1,141 @@ +import asyncio +import logging +import sys +from typing import Any, Optional + +from hamilton import base +from hamilton.async_driver import AsyncDriver +from haystack.components.builders.prompt_builder import PromptBuilder +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline +from src.core.provider import LLMProvider + +logger = logging.getLogger("wren-ai-service") + + +data_exploration_assistance_system_prompt = """ +You are a great data analyst good at exploring data. +You are given a user question and a sql data. +You need to understand the user question and the sql data, and then answer the user question. +### INSTRUCTIONS ### +1. Your answer should be in the same language as the language user provided. +2. You must follow the sql data to answer the user question. +3. You should provide your answer in Markdown format. +4. You have the following skills: +- explain the data in a easy to understand manner +- provide insights and trends in the data +- find out anomalies and outliers in the data +5. You only need to use the skills required to answer the user question based on the user question and the sql data. +### OUTPUT FORMAT ### +Please provide your response in proper Markdown format without ```markdown``` tags. +""" + +data_exploration_assistance_user_prompt_template = """ +User Question: {{query}} +Language: {{language}} +SQL Data: +{{ sql_data }} +Please think step by step. +""" + + +## Start of Pipeline +@observe(capture_input=False) +def prompt( + query: str, + language: str, + sql_data: dict, + prompt_builder: PromptBuilder, +) -> dict: + return prompt_builder.run( + query=query, + language=language, + sql_data=sql_data, + ) + + +@observe(as_type="generation", capture_input=False) +async def data_exploration_assistance( + prompt: dict, generator: Any, query_id: str +) -> dict: + return await generator(prompt=prompt.get("prompt"), query_id=query_id) + + +## End of Pipeline + + +class DataExplorationAssistance(BasicPipeline): + def __init__( + self, + llm_provider: LLMProvider, + **kwargs, + ): + self._user_queues = {} + self._components = { + "generator": llm_provider.get_generator( + system_prompt=data_exploration_assistance_system_prompt, + streaming_callback=self._streaming_callback, + ), + "prompt_builder": PromptBuilder( + template=data_exploration_assistance_user_prompt_template + ), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _streaming_callback(self, chunk, query_id): + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Create a new queue for the user if it doesn't exist + # Put the chunk content into the user's queue + asyncio.create_task(self._user_queues[query_id].put(chunk.content)) + if chunk.meta.get("finish_reason"): + asyncio.create_task(self._user_queues[query_id].put("")) + + async def get_streaming_results(self, query_id): + async def _get_streaming_results(query_id): + return await self._user_queues[query_id].get() + + if query_id not in self._user_queues: + self._user_queues[query_id] = asyncio.Queue() + + while True: + try: + # Wait for an item from the user's queue + self._streaming_results = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if ( + self._streaming_results == "" + ): # Check for end-of-stream signal + del self._user_queues[query_id] + break + if self._streaming_results: # Check if there are results to yield + yield self._streaming_results + self._streaming_results = "" # Clear after yielding + except TimeoutError: + break + + @observe(name="Data Exploration Assistance") + async def run( + self, + query: str, + sql_data: dict, + language: str, + query_id: Optional[str] = None, + ): + logger.info("Data Exploration Assistance pipeline is running...") + return await self._pipe.execute( + ["data_exploration_assistance"], + inputs={ + "query": query, + "language": language, + "query_id": query_id or "", + "sql_data": sql_data, + **self._components, + }, + ) diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 4d6cd313cd..79783d472a 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -24,7 +24,8 @@ intent_classification_system_prompt = """ ### Task ### -You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema. Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification. +You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema or sql data if provided. +Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification. ### Instructions ### - **Follow the user's previous questions:** If there are previous questions, try to understand the user's current question as following the previous questions. @@ -39,6 +40,19 @@ ### Intent Definitions ### + +**When to Use:** +- The user's question is about data exploration such as asking for data details, asking for explanation of the data, asking for insights, asking for recommendations, asking for comparison, etc. +**Requirements:** +- SQL DATA is provided and the user's question is about exploring the data. +- The user's question can be answered by the SQL DATA. +- The row size of the SQL DATA is less than 500. +**Examples:** +- "Show me the part where the data appears abnormal" +- "Please explain the data in the table" +- "What's the trend of the data?" + + **When to Use:** - The user's inputs are about modifying SQL from previous questions. @@ -51,6 +65,7 @@ - Must have complete filter criteria, specific values, or clear references to previous context. - Include specific table and column names from the schema in your reasoning or modifying SQL from previous questions. - Reference phrases from the user's inputs that clearly relate to the schema. +- The SQL DATA is not provided or SQL DATA cannot answer the user's question, and the user's question can be answered given the database schema. **Examples:** - "What is the total sales for last quarter?" @@ -111,7 +126,7 @@ { "rephrased_question": "", "reasoning": "", - "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE" + "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "DATA_EXPLORATION" |"GENERAL" | "USER_GUIDE" } """ @@ -143,6 +158,12 @@ - {{doc.path}}: {{doc.content}} {% endfor %} +{% if sql_data %} +### SQL DATA ### +{{ sql_data }} +row size of SQL DATA: {{ sql_data_size }} +{% endif %} + ### INPUT ### {% if histories %} User's previous questions: @@ -275,6 +296,7 @@ def prompt( sql_samples: Optional[list[dict]] = None, instructions: Optional[list[dict]] = None, configuration: Configuration | None = None, + sql_data: Optional[dict] = None, ) -> dict: _prompt = prompt_builder.run( query=query, @@ -286,6 +308,8 @@ def prompt( instructions=instructions, ), docs=wren_ai_docs, + sql_data=sql_data, + sql_data_size=len(sql_data.get("data", [])), ) return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} @@ -320,7 +344,13 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict class IntentClassificationResult(BaseModel): rephrased_question: str - results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL", "USER_GUIDE"] + results: Literal[ + "MISLEADING_QUERY", + "TEXT_TO_SQL", + "GENERAL", + "DATA_EXPLORATION", + "USER_GUIDE", + ] reasoning: str @@ -383,6 +413,7 @@ async def run( sql_samples: Optional[list[dict]] = None, instructions: Optional[list[dict]] = None, configuration: Configuration = Configuration(), + sql_data: Optional[dict] = None, ): logger.info("Intent Classification pipeline is running...") return await self._pipe.execute( @@ -394,6 +425,7 @@ async def run( "sql_samples": sql_samples or [], "instructions": instructions or [], "configuration": configuration, + "sql_data": sql_data or {}, **self._components, **self._configs, }, diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 5e36c8a8ad..62e4115add 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -83,14 +83,14 @@ class _AskResultResponse(BaseModel): trace_id: Optional[str] = None is_followup: bool = False general_type: Optional[ - Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE"] + Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"] ] = None class AskResultResponse(_AskResultResponse): is_followup: Optional[bool] = Field(False, exclude=True) general_type: Optional[ - Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE"] + Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"] ] = Field(None, exclude=True) @@ -227,6 +227,16 @@ async def ask( ) if self._allow_intent_classification: + last_sql_data = None + if histories: + if last_sql := histories[-1].sql: + last_sql_data = ( + await self._pipelines["sql_executor"].run( + sql=last_sql, + project_id=ask_request.project_id, + ) + )["execute_sql"]["results"] + intent_classification_result = ( await self._pipelines["intent_classification"].run( query=user_query, @@ -235,6 +245,7 @@ async def ask( instructions=instructions, project_id=ask_request.project_id, configuration=ask_request.configurations, + sql_data=last_sql_data, ) ).get("post_process", {}) intent = intent_classification_result.get("intent") @@ -317,6 +328,27 @@ async def ask( ) results["metadata"]["type"] = "GENERAL" return results + elif intent == "DATA_EXPLORATION": + asyncio.create_task( + self._pipelines["data_exploration_assistance"].run( + query=user_query, + sql_data=last_sql_data, + language=ask_request.configurations.language, + query_id=ask_request.query_id, + ) + ) + + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="GENERAL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + trace_id=trace_id, + is_followup=True if histories else False, + general_type="DATA_EXPLORATION", + ) + results["metadata"]["type"] = "GENERAL" + return results else: self._ask_results[query_id] = AskResultResponse( status="understanding", @@ -639,6 +671,8 @@ async def get_ask_streaming_result( _pipeline_name = "data_assistance" elif self._ask_results.get(query_id).general_type == "MISLEADING_QUERY": _pipeline_name = "misleading_assistance" + elif self._ask_results.get(query_id).general_type == "DATA_EXPLORATION": + _pipeline_name = "data_exploration_assistance" elif self._ask_results.get(query_id).status == "planning": if self._ask_results.get(query_id).is_followup: _pipeline_name = "followup_sql_generation_reasoning" diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index f2c01e95a5..c89a67c42b 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -135,6 +135,8 @@ pipes: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index bb688dcfd8..d056bb29be 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -135,6 +135,8 @@ pipes: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default From c954d9ad0be5e941d1ad12ebfc332c34ffefd941 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Fri, 8 Aug 2025 08:46:38 +0800 Subject: [PATCH 02/11] update --- .../pipelines/generation/followup_sql_generation_reasoning.py | 4 ++-- .../src/pipelines/generation/intent_classification.py | 4 ++-- wren-ai-service/src/pipelines/generation/utils/sql.py | 4 +++- wren-ai-service/src/web/v1/services/ask.py | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index 42b28c5b8f..bb6a9917ed 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -49,8 +49,8 @@ {% for history in histories %} Question: {{ history.question }} -SQL: -{{ history.sql }} +Response: +{{ history.response }} {% endfor %} ### QUESTION ### diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 79783d472a..05192d2332 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -170,8 +170,8 @@ {% for history in histories %} Question: {{ history.question }} -SQL: -{{ history.sql }} +Response: +{{ history.response }} {% endfor %} {% endif %} diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index b40528c870..ffdd11001e 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -515,7 +515,9 @@ def construct_ask_history_messages( ) messages.append( ChatMessage.from_assistant( - history.sql if hasattr(history, "sql") else history["sql"] + history.response + if hasattr(history, "response") + else history["response"] ) ) return messages diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 62e4115add..cec7cbfd55 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -14,7 +14,7 @@ class AskHistory(BaseModel): - sql: str + response: str = Field(alias=AliasChoices("response", "sql")) question: str @@ -229,7 +229,7 @@ async def ask( if self._allow_intent_classification: last_sql_data = None if histories: - if last_sql := histories[-1].sql: + if last_sql := histories[-1].response: last_sql_data = ( await self._pipelines["sql_executor"].run( sql=last_sql, From a4b8477dac993c13250e67ce31d03275b690676e Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 14 Aug 2025 13:59:24 +0800 Subject: [PATCH 03/11] update --- wren-ai-service/src/globals.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index b8e8269e56..77a9742f2b 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -125,6 +125,9 @@ def create_service_container( **pipe_components["user_guide_assistance"], wren_ai_docs=wren_ai_docs, ), + "data_exploration_assistance": generation.DataExplorationAssistance( + **pipe_components["data_exploration_assistance"], + ), "db_schema_retrieval": _db_schema_retrieval_pipeline, "historical_question": retrieval.HistoricalQuestionRetrieval( **pipe_components["historical_question_retrieval"], @@ -148,10 +151,6 @@ def create_service_container( "sql_functions_retrieval": _sql_functions_retrieval_pipeline, "sql_executor": retrieval.SQLExecutor( **pipe_components["sql_executor"], - engine_timeout=settings.engine_timeout, - ), - "data_exploration_assistance": generation.DataExplorationAssistance( - **pipe_components["data_exploration_assistance"], ), }, allow_intent_classification=settings.allow_intent_classification, From 9b8ab01d37e26386eec9202b80ba0c570c9a2c39 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 12 Aug 2025 09:14:00 +0800 Subject: [PATCH 04/11] simplify --- .../generation/intent_classification.py | 144 +++--------------- wren-ai-service/src/web/v1/services/ask.py | 48 ++++-- 2 files changed, 56 insertions(+), 136 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 05192d2332..b92a741666 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -1,4 +1,3 @@ -import ast import logging import sys from typing import Any, Literal, Optional @@ -6,14 +5,13 @@ import orjson from hamilton import base from hamilton.async_driver import AsyncDriver -from haystack import Document from haystack.components.builders.prompt_builder import PromptBuilder from langfuse.decorators import observe from pydantic import BaseModel from src.core.pipeline import BasicPipeline from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider -from src.pipelines.common import build_table_ddl, clean_up_new_lines +from src.pipelines.common import clean_up_new_lines from src.pipelines.generation.utils.sql import construct_instructions from src.utils import trace_cost from src.web.v1.services import Configuration @@ -25,7 +23,7 @@ intent_classification_system_prompt = """ ### Task ### You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema or sql data if provided. -Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification. +Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, `USER_GUIDE`, or `USER_CLARIFICATION`. Additionally, provide a concise reasoning (maximum 20 words) for your classification. ### Instructions ### - **Follow the user's previous questions:** If there are previous questions, try to understand the user's current question as following the previous questions. @@ -34,8 +32,9 @@ - **Rephrase Question:** Rewrite follow-up questions into full standalone questions using prior conversation context. - **Concise Reasoning:** The reasoning must be clear, concise, and limited to 20 words. - **Language Consistency:** Use the same language as specified in the user's output language for the rephrased question and reasoning. -- **Vague Queries:** If the question is vague or does not related to a table or property from the schema, classify it as `MISLEADING_QUERY`. -- **Incomplete Queries:** If the question is related to the database schema but references unspecified values (e.g., "the following", "these", "those") without providing them, classify as `GENERAL`. +- **Vague Queries:** If the question does not related to the database schema, classify it as `MISLEADING_QUERY`. +- **User Clarification:** If the question is related to the database schema, but missing some details in order to answer the question, classify it as `USER_CLARIFICATION`. +- **Incomplete Queries:** If the question is related to the database schema but references unspecified values (e.g., "the following", "these", "those") without providing them, classify as `USER_CLARIFICATION`. - **Time-related Queries:** Don't rephrase time-related information in the user's question. ### Intent Definitions ### @@ -73,9 +72,9 @@ - "List the top 10 products by revenue." - -**When to Use:** -- The user seeks general information about the database schema or its overall capabilities. + +**When to Use:** +- The user's question is related to the database schema, but missing some details in order to answer the question. - The query references **missing information** (e.g., "the following items" without listing them). - The query contains **placeholder references** that cannot be resolved from context. - The query is **incomplete for SQL generation** despite mentioning database concepts. @@ -85,11 +84,18 @@ - Identify missing parameters, unspecified references, or incomplete filter criteria. **Examples:** -- "What is the dataset about?" -- "Tell me more about the database." - "How can I analyze customer behavior with this data?" - "Show me orders for these products" (without specifying which products) - "Filter by the criteria I mentioned" (without previous context defining criteria) + + + +**When to Use:** +- The user seeks general information about the database schema or its overall capabilities + +**Examples:** +- "What is the dataset about?" +- "Tell me more about the database." @@ -126,7 +132,7 @@ { "rephrased_question": "", "reasoning": "", - "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "DATA_EXPLORATION" |"GENERAL" | "USER_GUIDE" + "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "DATA_EXPLORATION" | "GENERAL" | "USER_GUIDE" | "USER_CLARIFICATION" } """ @@ -183,114 +189,11 @@ ## Start of Pipeline -@observe(capture_input=False, capture_output=False) -async def embedding(query: str, embedder: Any, histories: list[AskHistory]) -> dict: - previous_query_summaries = ( - [history.question for history in histories] if histories else [] - ) - - query = "\n".join(previous_query_summaries) + "\n" + query - - return await embedder.run(query) - - -@observe(capture_input=False) -async def table_retrieval( - embedding: dict, project_id: str, table_retriever: Any -) -> dict: - filters = { - "operator": "AND", - "conditions": [ - {"field": "type", "operator": "==", "value": "TABLE_DESCRIPTION"}, - ], - } - - if project_id: - filters["conditions"].append( - {"field": "project_id", "operator": "==", "value": project_id} - ) - - return await table_retriever.run( - query_embedding=embedding.get("embedding"), - filters=filters, - ) - - -@observe(capture_input=False) -async def dbschema_retrieval( - table_retrieval: dict, embedding: dict, project_id: str, dbschema_retriever: Any -) -> list[Document]: - tables = table_retrieval.get("documents", []) - table_names = [] - for table in tables: - content = ast.literal_eval(table.content) - table_names.append(content["name"]) - - logger.info(f"dbschema_retrieval with table_names: {table_names}") - - table_name_conditions = [ - {"field": "name", "operator": "==", "value": table_name} - for table_name in table_names - ] - - filters = { - "operator": "AND", - "conditions": [ - {"field": "type", "operator": "==", "value": "TABLE_SCHEMA"}, - {"operator": "OR", "conditions": table_name_conditions}, - ], - } - - if project_id: - filters["conditions"].append( - {"field": "project_id", "operator": "==", "value": project_id} - ) - - results = await dbschema_retriever.run( - query_embedding=embedding.get("embedding"), filters=filters - ) - return results["documents"] - - -@observe() -def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[str]: - db_schemas = {} - for document in dbschema_retrieval: - content = ast.literal_eval(document.content) - if content["type"] == "TABLE": - if document.meta["name"] not in db_schemas: - db_schemas[document.meta["name"]] = content - else: - db_schemas[document.meta["name"]] = { - **content, - "columns": db_schemas[document.meta["name"]].get("columns", []), - } - elif content["type"] == "TABLE_COLUMNS": - if document.meta["name"] not in db_schemas: - db_schemas[document.meta["name"]] = {"columns": content["columns"]} - else: - if "columns" not in db_schemas[document.meta["name"]]: - db_schemas[document.meta["name"]]["columns"] = content["columns"] - else: - db_schemas[document.meta["name"]]["columns"] += content["columns"] - - # remove incomplete schemas - db_schemas = {k: v for k, v in db_schemas.items() if "type" in v and "columns" in v} - - db_schemas_in_ddl = [] - for table_schema in list(db_schemas.values()): - if table_schema["type"] == "TABLE": - ddl, _, _ = build_table_ddl(table_schema) - db_schemas_in_ddl.append(ddl) - - return db_schemas_in_ddl - - @observe(capture_input=False) def prompt( query: str, wren_ai_docs: list[dict], - construct_db_schemas: list[str], + db_schemas: list[str], histories: list[AskHistory], prompt_builder: PromptBuilder, sql_samples: Optional[list[dict]] = None, @@ -301,7 +204,7 @@ def prompt( _prompt = prompt_builder.run( query=query, language=configuration.language, - db_schemas=construct_db_schemas, + db_schemas=db_schemas, histories=histories, sql_samples=sql_samples, instructions=construct_instructions( @@ -321,21 +224,19 @@ async def classify_intent(prompt: dict, generator: Any, generator_name: str) -> @observe(capture_input=False) -def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict: +def post_process(classify_intent: dict) -> dict: try: results = orjson.loads(classify_intent.get("replies")[0]) return { "rephrased_question": results["rephrased_question"], "intent": results["results"], "reasoning": results["reasoning"], - "db_schemas": construct_db_schemas, } except Exception: return { "rephrased_question": "", "intent": "TEXT_TO_SQL", "reasoning": "", - "db_schemas": construct_db_schemas, } @@ -350,6 +251,7 @@ class IntentClassificationResult(BaseModel): "GENERAL", "DATA_EXPLORATION", "USER_GUIDE", + "USER_CLARIFICATION", ] reasoning: str @@ -408,6 +310,7 @@ def __init__( async def run( self, query: str, + db_schemas: list[str], project_id: Optional[str] = None, histories: Optional[list[AskHistory]] = None, sql_samples: Optional[list[dict]] = None, @@ -420,6 +323,7 @@ async def run( ["post_process"], inputs={ "query": query, + "db_schemas": db_schemas, "project_id": project_id or "", "histories": histories or [], "sql_samples": sql_samples or [], diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index cec7cbfd55..49a97b18e7 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -83,14 +83,26 @@ class _AskResultResponse(BaseModel): trace_id: Optional[str] = None is_followup: bool = False general_type: Optional[ - Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"] + Literal[ + "MISLEADING_QUERY", + "DATA_ASSISTANCE", + "USER_GUIDE", + "DATA_EXPLORATION", + "USER_CLARIFICATION", + ] ] = None class AskResultResponse(_AskResultResponse): is_followup: Optional[bool] = Field(False, exclude=True) general_type: Optional[ - Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"] + Literal[ + "MISLEADING_QUERY", + "DATA_ASSISTANCE", + "USER_GUIDE", + "DATA_EXPLORATION", + "USER_CLARIFICATION", + ] ] = Field(None, exclude=True) @@ -206,7 +218,11 @@ async def ask( sql_generation_reasoning = "" else: # Run both pipeline operations concurrently - sql_samples_task, instructions_task = await asyncio.gather( + ( + sql_samples_task, + instructions_task, + db_schema_retrieval_task, + ) = await asyncio.gather( self._pipelines["sql_pairs_retrieval"].run( query=user_query, project_id=ask_request.project_id, @@ -216,6 +232,12 @@ async def ask( project_id=ask_request.project_id, scope="sql", ), + self._pipelines["db_schema_retrieval"].run( + query=user_query, + histories=histories, + project_id=ask_request.project_id, + enable_column_pruning=enable_column_pruning, + ), ) # Extract results from completed tasks @@ -225,6 +247,12 @@ async def ask( instructions = instructions_task["formatted_output"].get( "documents", [] ) + _retrieval_result = db_schema_retrieval_task.get( + "construct_retrieval_results", {} + ) + documents = _retrieval_result.get("retrieval_results", []) + table_names = [document.get("table_name") for document in documents] + table_ddls = [document.get("table_ddl") for document in documents] if self._allow_intent_classification: last_sql_data = None @@ -240,6 +268,7 @@ async def ask( intent_classification_result = ( await self._pipelines["intent_classification"].run( query=user_query, + db_schemas=table_ddls, histories=histories, sql_samples=sql_samples, instructions=instructions, @@ -368,19 +397,6 @@ async def ask( is_followup=True if histories else False, ) - retrieval_result = await self._pipelines["db_schema_retrieval"].run( - query=user_query, - histories=histories, - project_id=ask_request.project_id, - enable_column_pruning=enable_column_pruning, - ) - _retrieval_result = retrieval_result.get( - "construct_retrieval_results", {} - ) - documents = _retrieval_result.get("retrieval_results", []) - table_names = [document.get("table_name") for document in documents] - table_ddls = [document.get("table_ddl") for document in documents] - if not documents: logger.exception(f"ask pipeline - NO_RELEVANT_DATA: {user_query}") if not self._is_stopped(query_id, self._ask_results): From bf3eb42607baa5f3aee5bb51b1c04ee19ddc5c9b Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 12 Aug 2025 10:15:19 +0800 Subject: [PATCH 05/11] update --- deployment/kustomizations/base/cm.yaml | 2 + docker/config.example.yaml | 2 + .../src/pipelines/generation/__init__.py | 2 + .../user_clarification_assistance.py | 144 ++++++++++++++++++ wren-ai-service/src/web/v1/services/ask.py | 35 ++++- .../tools/config/config.example.yaml | 2 + wren-ai-service/tools/config/config.full.yaml | 2 + 7 files changed, 183 insertions(+), 6 deletions(-) create mode 100644 wren-ai-service/src/pipelines/generation/user_clarification_assistance.py diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index e561841930..93f9b29744 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -174,6 +174,8 @@ data: llm: litellm_llm.default - name: data_exploration_assistance llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 48eeded0a8..9ec5b35a4d 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -124,6 +124,8 @@ pipes: llm: litellm_llm.default - name: data_exploration_assistance llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/wren-ai-service/src/pipelines/generation/__init__.py b/wren-ai-service/src/pipelines/generation/__init__.py index e4c155d9cc..8dd59cd4aa 100644 --- a/wren-ai-service/src/pipelines/generation/__init__.py +++ b/wren-ai-service/src/pipelines/generation/__init__.py @@ -16,6 +16,7 @@ from .sql_question import SQLQuestion from .sql_regeneration import SQLRegeneration from .sql_tables_extraction import SQLTablesExtraction +from .user_clarification_assistance import UserClarificationAssistance from .user_guide_assistance import UserGuideAssistance __all__ = [ @@ -38,4 +39,5 @@ "MisleadingAssistance", "SQLTablesExtraction", "DataExplorationAssistance", + "UserClarificationAssistance", ] diff --git a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py new file mode 100644 index 0000000000..2062766d6a --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py @@ -0,0 +1,144 @@ +import asyncio +import logging +import sys +from typing import Any, Optional + +from hamilton import base +from hamilton.async_driver import AsyncDriver +from haystack.components.builders.prompt_builder import PromptBuilder +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline +from src.core.provider import LLMProvider +from src.pipelines.common import clean_up_new_lines +from src.utils import trace_cost +from src.web.v1.services.ask import AskHistory + +logger = logging.getLogger("wren-ai-service") + + +user_clarification_assistance_system_prompt = """ +""" + +user_clarification_assistance_user_prompt_template = """ +""" + + +## Start of Pipeline +@observe(capture_input=False) +def prompt( + query: str, + db_schemas: list[str], + language: str, + histories: list[AskHistory], + prompt_builder: PromptBuilder, + custom_instruction: str, +) -> dict: + previous_query_summaries = ( + [history.question for history in histories] if histories else [] + ) + query = "\n".join(previous_query_summaries) + "\n" + query + + _prompt = prompt_builder.run( + query=query, + db_schemas=db_schemas, + language=language, + custom_instruction=custom_instruction, + ) + return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} + + +@observe(as_type="generation", capture_input=False) +@trace_cost +async def user_clarification_assistance( + prompt: dict, generator: Any, query_id: str, generator_name: str +) -> dict: + return await generator( + prompt=prompt.get("prompt"), + query_id=query_id, + ), generator_name + + +## End of Pipeline + + +class UserClarificationAssistance(BasicPipeline): + def __init__( + self, + llm_provider: LLMProvider, + **kwargs, + ): + self._user_queues = {} + self._components = { + "generator": llm_provider.get_generator( + system_prompt=user_clarification_assistance_system_prompt, + streaming_callback=self._streaming_callback, + ), + "generator_name": llm_provider.get_model(), + "prompt_builder": PromptBuilder( + template=user_clarification_assistance_user_prompt_template + ), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _streaming_callback(self, chunk, query_id): + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Create a new queue for the user if it doesn't exist + # Put the chunk content into the user's queue + asyncio.create_task(self._user_queues[query_id].put(chunk.content)) + if chunk.meta.get("finish_reason"): + asyncio.create_task(self._user_queues[query_id].put("")) + + async def get_streaming_results(self, query_id): + async def _get_streaming_results(query_id): + return await self._user_queues[query_id].get() + + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Ensure the user's queue exists + while True: + try: + # Wait for an item from the user's queue + self._streaming_results = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if ( + self._streaming_results == "" + ): # Check for end-of-stream signal + del self._user_queues[query_id] + break + if self._streaming_results: # Check if there are results to yield + yield self._streaming_results + self._streaming_results = "" # Clear after yielding + except TimeoutError: + break + + @observe(name="User Clarification Assistance") + async def run( + self, + query: str, + db_schemas: list[str], + language: str, + query_id: Optional[str] = None, + histories: Optional[list[AskHistory]] = None, + custom_instruction: Optional[str] = None, + ): + logger.info("User Clarification Assistance pipeline is running...") + return await self._pipe.execute( + ["user_clarification_assistance"], + inputs={ + "query": query, + "db_schemas": db_schemas, + "language": language, + "query_id": query_id or "", + "histories": histories or [], + "custom_instruction": custom_instruction or "", + **self._components, + }, + ) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 49a97b18e7..efa447b7b2 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -291,9 +291,7 @@ async def ask( self._pipelines["misleading_assistance"].run( query=user_query, histories=histories, - db_schemas=intent_classification_result.get( - "db_schemas" - ), + db_schemas=table_ddls, language=ask_request.configurations.language, query_id=ask_request.query_id, custom_instruction=ask_request.custom_instruction, @@ -316,9 +314,7 @@ async def ask( self._pipelines["data_assistance"].run( query=user_query, histories=histories, - db_schemas=intent_classification_result.get( - "db_schemas" - ), + db_schemas=table_ddls, language=ask_request.configurations.language, query_id=ask_request.query_id, custom_instruction=ask_request.custom_instruction, @@ -364,6 +360,7 @@ async def ask( sql_data=last_sql_data, language=ask_request.configurations.language, query_id=ask_request.query_id, + custom_instruction=ask_request.custom_instruction, ) ) @@ -378,6 +375,28 @@ async def ask( ) results["metadata"]["type"] = "GENERAL" return results + elif intent == "USER_CLARIFICATION": + asyncio.create_task( + self._pipelines["user_clarification_assistance"].run( + query=user_query, + db_schemas=table_ddls, + language=ask_request.configurations.language, + query_id=ask_request.query_id, + custom_instruction=ask_request.custom_instruction, + ) + ) + + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="GENERAL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + trace_id=trace_id, + is_followup=True if histories else False, + general_type="USER_CLARIFICATION", + ) + results["metadata"]["type"] = "GENERAL" + return results else: self._ask_results[query_id] = AskResultResponse( status="understanding", @@ -689,6 +708,10 @@ async def get_ask_streaming_result( _pipeline_name = "misleading_assistance" elif self._ask_results.get(query_id).general_type == "DATA_EXPLORATION": _pipeline_name = "data_exploration_assistance" + elif ( + self._ask_results.get(query_id).general_type == "USER_CLARIFICATION" + ): + _pipeline_name = "user_clarification_assistance" elif self._ask_results.get(query_id).status == "planning": if self._ask_results.get(query_id).is_followup: _pipeline_name = "followup_sql_generation_reasoning" diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index c89a67c42b..d0bbe64d75 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -137,6 +137,8 @@ pipes: llm: litellm_llm.default - name: data_exploration_assistance llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index d056bb29be..ed579ad1a0 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -137,6 +137,8 @@ pipes: llm: litellm_llm.default - name: data_exploration_assistance llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default From 682934fb4bcbafc2a837b74653f24027193bf287 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 12 Aug 2025 13:43:38 +0800 Subject: [PATCH 06/11] update --- .../pipelines/generation/data_assistance.py | 13 ++++++---- .../generation/data_exploration_assistance.py | 26 ++++++++++++++++++- .../generation/misleading_assistance.py | 13 ++++++---- .../user_clarification_assistance.py | 6 +---- .../generation/user_guide_assistance.py | 13 ++++++++++ wren-ai-service/src/web/v1/services/ask.py | 3 +++ 6 files changed, 58 insertions(+), 16 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index 51b91197f9..a777f998f5 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -41,6 +41,13 @@ {{ db_schema }} {% endfor %} +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + ### INPUT ### User's question: {{query}} Language: {{language}} @@ -61,13 +68,9 @@ def prompt( prompt_builder: PromptBuilder, custom_instruction: str, ) -> dict: - previous_query_summaries = ( - [history.question for history in histories] if histories else [] - ) - query = "\n".join(previous_query_summaries) + "\n" + query - _prompt = prompt_builder.run( query=query, + histories=histories, db_schemas=db_schemas, language=language, custom_instruction=custom_instruction, diff --git a/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py index 9b6c868507..f16782a3b5 100644 --- a/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py @@ -10,6 +10,9 @@ from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider +from src.pipelines.common import clean_up_new_lines +from src.utils import trace_cost +from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") @@ -32,10 +35,21 @@ """ data_exploration_assistance_user_prompt_template = """ +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + +### INPUT ### User Question: {{query}} Language: {{language}} SQL Data: {{ sql_data }} + +Custom Instruction: {{ custom_instruction }} + Please think step by step. """ @@ -44,18 +58,24 @@ @observe(capture_input=False) def prompt( query: str, + histories: list[AskHistory], language: str, sql_data: dict, prompt_builder: PromptBuilder, + custom_instruction: str, ) -> dict: - return prompt_builder.run( + _prompt = prompt_builder.run( query=query, language=language, sql_data=sql_data, + histories=histories, + custom_instruction=custom_instruction, ) + return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} @observe(as_type="generation", capture_input=False) +@trace_cost async def data_exploration_assistance( prompt: dict, generator: Any, query_id: str ) -> dict: @@ -127,6 +147,8 @@ async def run( sql_data: dict, language: str, query_id: Optional[str] = None, + histories: Optional[list[AskHistory]] = None, + custom_instruction: Optional[str] = None, ): logger.info("Data Exploration Assistance pipeline is running...") return await self._pipe.execute( @@ -136,6 +158,8 @@ async def run( "language": language, "query_id": query_id or "", "sql_data": sql_data, + "histories": histories or [], + "custom_instruction": custom_instruction or "", **self._components, }, ) diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index a35738ecf5..ac27794004 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -41,6 +41,13 @@ {{ db_schema }} {% endfor %} +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + ### INPUT ### User's question: {{query}} Language: {{language}} @@ -61,13 +68,9 @@ def prompt( prompt_builder: PromptBuilder, custom_instruction: str, ) -> dict: - previous_query_summaries = ( - [history.question for history in histories] if histories else [] - ) - query = "\n".join(previous_query_summaries) + "\n" + query - _prompt = prompt_builder.run( query=query, + histories=histories, db_schemas=db_schemas, language=language, custom_instruction=custom_instruction, diff --git a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py index 2062766d6a..0c56de2c59 100644 --- a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py @@ -34,13 +34,9 @@ def prompt( prompt_builder: PromptBuilder, custom_instruction: str, ) -> dict: - previous_query_summaries = ( - [history.question for history in histories] if histories else [] - ) - query = "\n".join(previous_query_summaries) + "\n" + query - _prompt = prompt_builder.run( query=query, + histories=histories, db_schemas=db_schemas, language=language, custom_instruction=custom_instruction, diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index be437f883b..fa8edc6ba5 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -12,6 +12,7 @@ from src.core.provider import LLMProvider from src.pipelines.common import clean_up_new_lines from src.utils import trace_cost +from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") @@ -34,6 +35,14 @@ """ user_guide_assistance_user_prompt_template = """ +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + +### INPUT ### User Question: {{query}} Language: {{language}} User Guide: @@ -53,11 +62,13 @@ def prompt( query: str, language: str, wren_ai_docs: list[dict], + histories: list[AskHistory], prompt_builder: PromptBuilder, custom_instruction: str, ) -> dict: _prompt = prompt_builder.run( query=query, + histories=histories, language=language, docs=wren_ai_docs, custom_instruction=custom_instruction, @@ -144,6 +155,7 @@ async def run( query: str, language: str, query_id: Optional[str] = None, + histories: Optional[list[AskHistory]] = None, custom_instruction: Optional[str] = None, ): logger.info("User Guide Assistance pipeline is running...") @@ -153,6 +165,7 @@ async def run( "query": query, "language": language, "query_id": query_id or "", + "histories": histories or [], "custom_instruction": custom_instruction or "", **self._components, **self._configs, diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index efa447b7b2..c3837c27cf 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -336,6 +336,7 @@ async def ask( asyncio.create_task( self._pipelines["user_guide_assistance"].run( query=user_query, + histories=histories, language=ask_request.configurations.language, query_id=ask_request.query_id, custom_instruction=ask_request.custom_instruction, @@ -357,6 +358,7 @@ async def ask( asyncio.create_task( self._pipelines["data_exploration_assistance"].run( query=user_query, + histories=histories, sql_data=last_sql_data, language=ask_request.configurations.language, query_id=ask_request.query_id, @@ -379,6 +381,7 @@ async def ask( asyncio.create_task( self._pipelines["user_clarification_assistance"].run( query=user_query, + histories=histories, db_schemas=table_ddls, language=ask_request.configurations.language, query_id=ask_request.query_id, From 984f1b3c778bd63967a6b732879f6170bc9a3334 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 12 Aug 2025 13:54:29 +0800 Subject: [PATCH 07/11] update --- .../pipelines/generation/data_assistance.py | 7 +++- .../generation/data_exploration_assistance.py | 11 ++++-- .../generation/misleading_assistance.py | 9 +++-- .../user_clarification_assistance.py | 35 +++++++++++++++++++ .../generation/user_guide_assistance.py | 9 +++-- wren-ai-service/src/web/v1/services/ask.py | 5 +++ 6 files changed, 69 insertions(+), 7 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index a777f998f5..0dec79db3d 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -20,7 +20,7 @@ data_assistance_system_prompt = """ ### TASK ### You are a data analyst great at answering user's questions about given database schema. -Please carefully read user's question and database schema to answer it in easy to understand manner +Please carefully read user's question, intent for the question, and database schema to answer it in easy to understand manner using the Markdown format. Your goal is to help guide user understand its database! ### INSTRUCTIONS ### @@ -50,6 +50,7 @@ ### INPUT ### User's question: {{query}} +Intent for user's question: {{intent_reasoning}} Language: {{language}} Custom Instruction: {{ custom_instruction }} @@ -62,6 +63,7 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, db_schemas: list[str], language: str, histories: list[AskHistory], @@ -70,6 +72,7 @@ def prompt( ) -> dict: _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, histories=histories, db_schemas=db_schemas, language=language, @@ -153,6 +156,7 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, db_schemas: list[str], language: str, query_id: Optional[str] = None, @@ -164,6 +168,7 @@ async def run( ["data_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "db_schemas": db_schemas, "language": language, "query_id": query_id or "", diff --git a/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py index f16782a3b5..ca450f8c54 100644 --- a/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py @@ -19,8 +19,9 @@ data_exploration_assistance_system_prompt = """ You are a great data analyst good at exploring data. -You are given a user question and a sql data. -You need to understand the user question and the sql data, and then answer the user question. +You are given a user question, an intent for the question, and a sql data. +You need to understand the user question, the intent for the question, and the sql data, and then answer the user question. + ### INSTRUCTIONS ### 1. Your answer should be in the same language as the language user provided. 2. You must follow the sql data to answer the user question. @@ -30,6 +31,7 @@ - provide insights and trends in the data - find out anomalies and outliers in the data 5. You only need to use the skills required to answer the user question based on the user question and the sql data. + ### OUTPUT FORMAT ### Please provide your response in proper Markdown format without ```markdown``` tags. """ @@ -44,6 +46,7 @@ ### INPUT ### User Question: {{query}} +Intent for user's question: {{intent_reasoning}} Language: {{language}} SQL Data: {{ sql_data }} @@ -58,6 +61,7 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, histories: list[AskHistory], language: str, sql_data: dict, @@ -66,6 +70,7 @@ def prompt( ) -> dict: _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, language=language, sql_data=sql_data, histories=histories, @@ -144,6 +149,7 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, sql_data: dict, language: str, query_id: Optional[str] = None, @@ -155,6 +161,7 @@ async def run( ["data_exploration_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "language": language, "query_id": query_id or "", "sql_data": sql_data, diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index ac27794004..4f5ea2edeb 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -19,8 +19,8 @@ misleading_assistance_system_prompt = """ ### TASK ### -You are a helpful assistant that can help users understand their data better. Currently, you are given a user's question that is potentially misleading. -Your goal is to help guide user understand its data better and suggest few better questions to ask. +You are a helpful assistant that can help users understand their data better. Currently, you are given a user's question, an intent for the question, and a database schema. +Your goal is to help guide user understand its data better and suggest few better questions to ask based on the intent for the question and the database schema. ### INSTRUCTIONS ### @@ -50,6 +50,7 @@ ### INPUT ### User's question: {{query}} +Intent for user's question: {{intent_reasoning}} Language: {{language}} Custom Instruction: {{ custom_instruction }} @@ -62,6 +63,7 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, db_schemas: list[str], language: str, histories: list[AskHistory], @@ -70,6 +72,7 @@ def prompt( ) -> dict: _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, histories=histories, db_schemas=db_schemas, language=language, @@ -153,6 +156,7 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, db_schemas: list[str], language: str, query_id: Optional[str] = None, @@ -164,6 +168,7 @@ async def run( ["misleading_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "db_schemas": db_schemas, "language": language, "query_id": query_id or "", diff --git a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py index 0c56de2c59..1e709cd9cd 100644 --- a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py @@ -18,9 +18,40 @@ user_clarification_assistance_system_prompt = """ +You are a helpful assistant that can help users understand their data better. Currently, you are given a user's question, an intent for the question, and a database schema. +You should tell user why the question is not clear enough or vague and suggest a better question based on the intent for the question and the database schema. + +### INSTRUCTIONS ### +1. Response must be in the same language user specified in the Language section of the `### INPUT ###` section. +2. There should be proper line breaks, whitespace, and Markdown formatting(headers, lists, tables, etc.) in your response. +3. MUST NOT add SQL code in your response. +4. MUST consider database schema when suggesting better questions. +5. The maximum response length is 100 words. +6. If the user provides a custom instruction, it should be followed strictly and you should use it to change the style of response. + +### OUTPUT FORMAT ### +Please provide your response in proper Markdown format without ```markdown``` tags. """ user_clarification_assistance_user_prompt_template = """ +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + +### DATABASE SCHEMA ### +{% for db_schema in db_schemas %} + {{ db_schema }} +{% endfor %} + +### INPUT ### +User Question: {{query}} +Intent for user's question: {{intent_reasoning}} +Language: {{language}} + +Custom Instruction: {{ custom_instruction }} """ @@ -28,6 +59,7 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, db_schemas: list[str], language: str, histories: list[AskHistory], @@ -36,6 +68,7 @@ def prompt( ) -> dict: _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, histories=histories, db_schemas=db_schemas, language=language, @@ -119,6 +152,7 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, db_schemas: list[str], language: str, query_id: Optional[str] = None, @@ -130,6 +164,7 @@ async def run( ["user_clarification_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "db_schemas": db_schemas, "language": language, "query_id": query_id or "", diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index fa8edc6ba5..a1ce19b446 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -19,8 +19,8 @@ user_guide_assistance_system_prompt = """ You are a helpful assistant that can help users understand Wren AI. -You are given a user question and a user guide. -You need to understand the user question and the user guide, and then answer the user question. +You are given a user question, an intent for the question, and a user guide. +You need to understand the user question, the intent reasoning, and the user guide, and then answer the user question. ### INSTRUCTIONS ### 1. Your answer should be in the same language as the language user provided. @@ -44,6 +44,7 @@ ### INPUT ### User Question: {{query}} +Intent for user's question: {{intent_reasoning}} Language: {{language}} User Guide: {% for doc in docs %} @@ -60,6 +61,7 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, language: str, wren_ai_docs: list[dict], histories: list[AskHistory], @@ -68,6 +70,7 @@ def prompt( ) -> dict: _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, histories=histories, language=language, docs=wren_ai_docs, @@ -153,6 +156,7 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, language: str, query_id: Optional[str] = None, histories: Optional[list[AskHistory]] = None, @@ -163,6 +167,7 @@ async def run( ["user_guide_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "language": language, "query_id": query_id or "", "histories": histories or [], diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index c3837c27cf..feed48c514 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -290,6 +290,7 @@ async def ask( asyncio.create_task( self._pipelines["misleading_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, histories=histories, db_schemas=table_ddls, language=ask_request.configurations.language, @@ -313,6 +314,7 @@ async def ask( asyncio.create_task( self._pipelines["data_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, histories=histories, db_schemas=table_ddls, language=ask_request.configurations.language, @@ -336,6 +338,7 @@ async def ask( asyncio.create_task( self._pipelines["user_guide_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, histories=histories, language=ask_request.configurations.language, query_id=ask_request.query_id, @@ -358,6 +361,7 @@ async def ask( asyncio.create_task( self._pipelines["data_exploration_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, histories=histories, sql_data=last_sql_data, language=ask_request.configurations.language, @@ -381,6 +385,7 @@ async def ask( asyncio.create_task( self._pipelines["user_clarification_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, histories=histories, db_schemas=table_ddls, language=ask_request.configurations.language, From 7aa0e169de7b1eb38b55cbeca993006bc5a0b4c4 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 12 Aug 2025 14:53:45 +0800 Subject: [PATCH 08/11] fix --- wren-ai-service/src/globals.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 77a9742f2b..9f74efc9c3 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -128,6 +128,9 @@ def create_service_container( "data_exploration_assistance": generation.DataExplorationAssistance( **pipe_components["data_exploration_assistance"], ), + "user_clarification_assistance": generation.UserClarificationAssistance( + **pipe_components["user_clarification_assistance"], + ), "db_schema_retrieval": _db_schema_retrieval_pipeline, "historical_question": retrieval.HistoricalQuestionRetrieval( **pipe_components["historical_question_retrieval"], From 463b2b4b3bcf4a624c9309515ff9437d0d7a1a04 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 12 Aug 2025 21:13:11 +0800 Subject: [PATCH 09/11] update --- .../src/pipelines/generation/user_clarification_assistance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py index 1e709cd9cd..1d35bea872 100644 --- a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py @@ -19,14 +19,14 @@ user_clarification_assistance_system_prompt = """ You are a helpful assistant that can help users understand their data better. Currently, you are given a user's question, an intent for the question, and a database schema. -You should tell user why the question is not clear enough or vague and suggest a better question based on the intent for the question and the database schema. +You should tell user why the question is not clear enough or vague. ### INSTRUCTIONS ### 1. Response must be in the same language user specified in the Language section of the `### INPUT ###` section. 2. There should be proper line breaks, whitespace, and Markdown formatting(headers, lists, tables, etc.) in your response. 3. MUST NOT add SQL code in your response. 4. MUST consider database schema when suggesting better questions. -5. The maximum response length is 100 words. +5. The maximum response length is 50 words. 6. If the user provides a custom instruction, it should be followed strictly and you should use it to change the style of response. ### OUTPUT FORMAT ### From cb0657330a7788de7520811ef5e8d565918fd6ce Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Tue, 12 Aug 2025 21:20:28 +0800 Subject: [PATCH 10/11] update --- wren-ai-service/src/web/v1/services/ask.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index feed48c514..06b512c779 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -15,6 +15,7 @@ class AskHistory(BaseModel): response: str = Field(alias=AliasChoices("response", "sql")) + type: Literal["text", "sql"] = "sql" question: str @@ -257,10 +258,12 @@ async def ask( if self._allow_intent_classification: last_sql_data = None if histories: - if last_sql := histories[-1].response: + if (last_response := histories[-1].response) and histories[ + -1 + ].type == "sql": last_sql_data = ( await self._pipelines["sql_executor"].run( - sql=last_sql, + sql=last_response, project_id=ask_request.project_id, ) )["execute_sql"]["results"] From 9e789da7fcb5be1dbd997f5b90446f87a104ba68 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Thu, 14 Aug 2025 14:01:06 +0800 Subject: [PATCH 11/11] update --- wren-ai-service/src/globals.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 9f74efc9c3..015e6f08e8 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -152,9 +152,7 @@ def create_service_container( **pipe_components["followup_sql_generation"], ), "sql_functions_retrieval": _sql_functions_retrieval_pipeline, - "sql_executor": retrieval.SQLExecutor( - **pipe_components["sql_executor"], - ), + "sql_executor": _sql_executor_pipeline, }, allow_intent_classification=settings.allow_intent_classification, allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,