diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 72b0678bc9..93f9b29744 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -172,6 +172,10 @@ data: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/docker/config.example.yaml b/docker/config.example.yaml index f8f56c7bb3..9ec5b35a4d 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -122,6 +122,10 @@ pipes: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 728d835c91..015e6f08e8 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -125,6 +125,12 @@ def create_service_container( **pipe_components["user_guide_assistance"], wren_ai_docs=wren_ai_docs, ), + "data_exploration_assistance": generation.DataExplorationAssistance( + **pipe_components["data_exploration_assistance"], + ), + "user_clarification_assistance": generation.UserClarificationAssistance( + **pipe_components["user_clarification_assistance"], + ), "db_schema_retrieval": _db_schema_retrieval_pipeline, "historical_question": retrieval.HistoricalQuestionRetrieval( **pipe_components["historical_question_retrieval"], @@ -146,6 +152,7 @@ def create_service_container( **pipe_components["followup_sql_generation"], ), "sql_functions_retrieval": _sql_functions_retrieval_pipeline, + "sql_executor": _sql_executor_pipeline, }, allow_intent_classification=settings.allow_intent_classification, allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning, diff --git a/wren-ai-service/src/pipelines/generation/__init__.py b/wren-ai-service/src/pipelines/generation/__init__.py index 6940643217..8dd59cd4aa 100644 --- a/wren-ai-service/src/pipelines/generation/__init__.py +++ b/wren-ai-service/src/pipelines/generation/__init__.py @@ -1,6 +1,7 @@ from .chart_adjustment import ChartAdjustment from .chart_generation import ChartGeneration from .data_assistance import DataAssistance +from .data_exploration_assistance import DataExplorationAssistance from .followup_sql_generation import FollowUpSQLGeneration from .followup_sql_generation_reasoning import FollowUpSQLGenerationReasoning from .intent_classification import IntentClassification @@ -15,6 +16,7 @@ from .sql_question import SQLQuestion from .sql_regeneration import SQLRegeneration from .sql_tables_extraction import SQLTablesExtraction +from .user_clarification_assistance import UserClarificationAssistance from .user_guide_assistance import UserGuideAssistance __all__ = [ @@ -36,4 +38,6 @@ "FollowUpSQLGenerationReasoning", "MisleadingAssistance", "SQLTablesExtraction", + "DataExplorationAssistance", + "UserClarificationAssistance", ] diff --git a/wren-ai-service/src/pipelines/generation/data_assistance.py b/wren-ai-service/src/pipelines/generation/data_assistance.py index 51b91197f9..0dec79db3d 100644 --- a/wren-ai-service/src/pipelines/generation/data_assistance.py +++ b/wren-ai-service/src/pipelines/generation/data_assistance.py @@ -20,7 +20,7 @@ data_assistance_system_prompt = """ ### TASK ### You are a data analyst great at answering user's questions about given database schema. -Please carefully read user's question and database schema to answer it in easy to understand manner +Please carefully read user's question, intent for the question, and database schema to answer it in easy to understand manner using the Markdown format. Your goal is to help guide user understand its database! ### INSTRUCTIONS ### @@ -41,8 +41,16 @@ {{ db_schema }} {% endfor %} +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + ### INPUT ### User's question: {{query}} +Intent for user's question: {{intent_reasoning}} Language: {{language}} Custom Instruction: {{ custom_instruction }} @@ -55,19 +63,17 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, db_schemas: list[str], language: str, histories: list[AskHistory], prompt_builder: PromptBuilder, custom_instruction: str, ) -> dict: - previous_query_summaries = ( - [history.question for history in histories] if histories else [] - ) - query = "\n".join(previous_query_summaries) + "\n" + query - _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, + histories=histories, db_schemas=db_schemas, language=language, custom_instruction=custom_instruction, @@ -150,6 +156,7 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, db_schemas: list[str], language: str, query_id: Optional[str] = None, @@ -161,6 +168,7 @@ async def run( ["data_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "db_schemas": db_schemas, "language": language, "query_id": query_id or "", diff --git a/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py new file mode 100644 index 0000000000..ca450f8c54 --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/data_exploration_assistance.py @@ -0,0 +1,172 @@ +import asyncio +import logging +import sys +from typing import Any, Optional + +from hamilton import base +from hamilton.async_driver import AsyncDriver +from haystack.components.builders.prompt_builder import PromptBuilder +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline +from src.core.provider import LLMProvider +from src.pipelines.common import clean_up_new_lines +from src.utils import trace_cost +from src.web.v1.services.ask import AskHistory + +logger = logging.getLogger("wren-ai-service") + + +data_exploration_assistance_system_prompt = """ +You are a great data analyst good at exploring data. +You are given a user question, an intent for the question, and a sql data. +You need to understand the user question, the intent for the question, and the sql data, and then answer the user question. + +### INSTRUCTIONS ### +1. Your answer should be in the same language as the language user provided. +2. You must follow the sql data to answer the user question. +3. You should provide your answer in Markdown format. +4. You have the following skills: +- explain the data in a easy to understand manner +- provide insights and trends in the data +- find out anomalies and outliers in the data +5. You only need to use the skills required to answer the user question based on the user question and the sql data. + +### OUTPUT FORMAT ### +Please provide your response in proper Markdown format without ```markdown``` tags. +""" + +data_exploration_assistance_user_prompt_template = """ +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + +### INPUT ### +User Question: {{query}} +Intent for user's question: {{intent_reasoning}} +Language: {{language}} +SQL Data: +{{ sql_data }} + +Custom Instruction: {{ custom_instruction }} + +Please think step by step. +""" + + +## Start of Pipeline +@observe(capture_input=False) +def prompt( + query: str, + intent_reasoning: str, + histories: list[AskHistory], + language: str, + sql_data: dict, + prompt_builder: PromptBuilder, + custom_instruction: str, +) -> dict: + _prompt = prompt_builder.run( + query=query, + intent_reasoning=intent_reasoning, + language=language, + sql_data=sql_data, + histories=histories, + custom_instruction=custom_instruction, + ) + return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} + + +@observe(as_type="generation", capture_input=False) +@trace_cost +async def data_exploration_assistance( + prompt: dict, generator: Any, query_id: str +) -> dict: + return await generator(prompt=prompt.get("prompt"), query_id=query_id) + + +## End of Pipeline + + +class DataExplorationAssistance(BasicPipeline): + def __init__( + self, + llm_provider: LLMProvider, + **kwargs, + ): + self._user_queues = {} + self._components = { + "generator": llm_provider.get_generator( + system_prompt=data_exploration_assistance_system_prompt, + streaming_callback=self._streaming_callback, + ), + "prompt_builder": PromptBuilder( + template=data_exploration_assistance_user_prompt_template + ), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _streaming_callback(self, chunk, query_id): + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Create a new queue for the user if it doesn't exist + # Put the chunk content into the user's queue + asyncio.create_task(self._user_queues[query_id].put(chunk.content)) + if chunk.meta.get("finish_reason"): + asyncio.create_task(self._user_queues[query_id].put("")) + + async def get_streaming_results(self, query_id): + async def _get_streaming_results(query_id): + return await self._user_queues[query_id].get() + + if query_id not in self._user_queues: + self._user_queues[query_id] = asyncio.Queue() + + while True: + try: + # Wait for an item from the user's queue + self._streaming_results = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if ( + self._streaming_results == "" + ): # Check for end-of-stream signal + del self._user_queues[query_id] + break + if self._streaming_results: # Check if there are results to yield + yield self._streaming_results + self._streaming_results = "" # Clear after yielding + except TimeoutError: + break + + @observe(name="Data Exploration Assistance") + async def run( + self, + query: str, + intent_reasoning: str, + sql_data: dict, + language: str, + query_id: Optional[str] = None, + histories: Optional[list[AskHistory]] = None, + custom_instruction: Optional[str] = None, + ): + logger.info("Data Exploration Assistance pipeline is running...") + return await self._pipe.execute( + ["data_exploration_assistance"], + inputs={ + "query": query, + "intent_reasoning": intent_reasoning, + "language": language, + "query_id": query_id or "", + "sql_data": sql_data, + "histories": histories or [], + "custom_instruction": custom_instruction or "", + **self._components, + }, + ) diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py index 42b28c5b8f..bb6a9917ed 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation_reasoning.py @@ -49,8 +49,8 @@ {% for history in histories %} Question: {{ history.question }} -SQL: -{{ history.sql }} +Response: +{{ history.response }} {% endfor %} ### QUESTION ### diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 4d6cd313cd..b92a741666 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -1,4 +1,3 @@ -import ast import logging import sys from typing import Any, Literal, Optional @@ -6,14 +5,13 @@ import orjson from hamilton import base from hamilton.async_driver import AsyncDriver -from haystack import Document from haystack.components.builders.prompt_builder import PromptBuilder from langfuse.decorators import observe from pydantic import BaseModel from src.core.pipeline import BasicPipeline from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider -from src.pipelines.common import build_table_ddl, clean_up_new_lines +from src.pipelines.common import clean_up_new_lines from src.pipelines.generation.utils.sql import construct_instructions from src.utils import trace_cost from src.web.v1.services import Configuration @@ -24,7 +22,8 @@ intent_classification_system_prompt = """ ### Task ### -You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema. Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification. +You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema or sql data if provided. +Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, `USER_GUIDE`, or `USER_CLARIFICATION`. Additionally, provide a concise reasoning (maximum 20 words) for your classification. ### Instructions ### - **Follow the user's previous questions:** If there are previous questions, try to understand the user's current question as following the previous questions. @@ -33,12 +32,26 @@ - **Rephrase Question:** Rewrite follow-up questions into full standalone questions using prior conversation context. - **Concise Reasoning:** The reasoning must be clear, concise, and limited to 20 words. - **Language Consistency:** Use the same language as specified in the user's output language for the rephrased question and reasoning. -- **Vague Queries:** If the question is vague or does not related to a table or property from the schema, classify it as `MISLEADING_QUERY`. -- **Incomplete Queries:** If the question is related to the database schema but references unspecified values (e.g., "the following", "these", "those") without providing them, classify as `GENERAL`. +- **Vague Queries:** If the question does not related to the database schema, classify it as `MISLEADING_QUERY`. +- **User Clarification:** If the question is related to the database schema, but missing some details in order to answer the question, classify it as `USER_CLARIFICATION`. +- **Incomplete Queries:** If the question is related to the database schema but references unspecified values (e.g., "the following", "these", "those") without providing them, classify as `USER_CLARIFICATION`. - **Time-related Queries:** Don't rephrase time-related information in the user's question. ### Intent Definitions ### + +**When to Use:** +- The user's question is about data exploration such as asking for data details, asking for explanation of the data, asking for insights, asking for recommendations, asking for comparison, etc. +**Requirements:** +- SQL DATA is provided and the user's question is about exploring the data. +- The user's question can be answered by the SQL DATA. +- The row size of the SQL DATA is less than 500. +**Examples:** +- "Show me the part where the data appears abnormal" +- "Please explain the data in the table" +- "What's the trend of the data?" + + **When to Use:** - The user's inputs are about modifying SQL from previous questions. @@ -51,6 +64,7 @@ - Must have complete filter criteria, specific values, or clear references to previous context. - Include specific table and column names from the schema in your reasoning or modifying SQL from previous questions. - Reference phrases from the user's inputs that clearly relate to the schema. +- The SQL DATA is not provided or SQL DATA cannot answer the user's question, and the user's question can be answered given the database schema. **Examples:** - "What is the total sales for last quarter?" @@ -58,9 +72,9 @@ - "List the top 10 products by revenue." - -**When to Use:** -- The user seeks general information about the database schema or its overall capabilities. + +**When to Use:** +- The user's question is related to the database schema, but missing some details in order to answer the question. - The query references **missing information** (e.g., "the following items" without listing them). - The query contains **placeholder references** that cannot be resolved from context. - The query is **incomplete for SQL generation** despite mentioning database concepts. @@ -70,11 +84,18 @@ - Identify missing parameters, unspecified references, or incomplete filter criteria. **Examples:** -- "What is the dataset about?" -- "Tell me more about the database." - "How can I analyze customer behavior with this data?" - "Show me orders for these products" (without specifying which products) - "Filter by the criteria I mentioned" (without previous context defining criteria) + + + +**When to Use:** +- The user seeks general information about the database schema or its overall capabilities + +**Examples:** +- "What is the dataset about?" +- "Tell me more about the database." @@ -111,7 +132,7 @@ { "rephrased_question": "", "reasoning": "", - "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE" + "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "DATA_EXPLORATION" | "GENERAL" | "USER_GUIDE" | "USER_CLARIFICATION" } """ @@ -143,14 +164,20 @@ - {{doc.path}}: {{doc.content}} {% endfor %} +{% if sql_data %} +### SQL DATA ### +{{ sql_data }} +row size of SQL DATA: {{ sql_data_size }} +{% endif %} + ### INPUT ### {% if histories %} User's previous questions: {% for history in histories %} Question: {{ history.question }} -SQL: -{{ history.sql }} +Response: +{{ history.response }} {% endfor %} {% endif %} @@ -162,130 +189,30 @@ ## Start of Pipeline -@observe(capture_input=False, capture_output=False) -async def embedding(query: str, embedder: Any, histories: list[AskHistory]) -> dict: - previous_query_summaries = ( - [history.question for history in histories] if histories else [] - ) - - query = "\n".join(previous_query_summaries) + "\n" + query - - return await embedder.run(query) - - -@observe(capture_input=False) -async def table_retrieval( - embedding: dict, project_id: str, table_retriever: Any -) -> dict: - filters = { - "operator": "AND", - "conditions": [ - {"field": "type", "operator": "==", "value": "TABLE_DESCRIPTION"}, - ], - } - - if project_id: - filters["conditions"].append( - {"field": "project_id", "operator": "==", "value": project_id} - ) - - return await table_retriever.run( - query_embedding=embedding.get("embedding"), - filters=filters, - ) - - -@observe(capture_input=False) -async def dbschema_retrieval( - table_retrieval: dict, embedding: dict, project_id: str, dbschema_retriever: Any -) -> list[Document]: - tables = table_retrieval.get("documents", []) - table_names = [] - for table in tables: - content = ast.literal_eval(table.content) - table_names.append(content["name"]) - - logger.info(f"dbschema_retrieval with table_names: {table_names}") - - table_name_conditions = [ - {"field": "name", "operator": "==", "value": table_name} - for table_name in table_names - ] - - filters = { - "operator": "AND", - "conditions": [ - {"field": "type", "operator": "==", "value": "TABLE_SCHEMA"}, - {"operator": "OR", "conditions": table_name_conditions}, - ], - } - - if project_id: - filters["conditions"].append( - {"field": "project_id", "operator": "==", "value": project_id} - ) - - results = await dbschema_retriever.run( - query_embedding=embedding.get("embedding"), filters=filters - ) - return results["documents"] - - -@observe() -def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[str]: - db_schemas = {} - for document in dbschema_retrieval: - content = ast.literal_eval(document.content) - if content["type"] == "TABLE": - if document.meta["name"] not in db_schemas: - db_schemas[document.meta["name"]] = content - else: - db_schemas[document.meta["name"]] = { - **content, - "columns": db_schemas[document.meta["name"]].get("columns", []), - } - elif content["type"] == "TABLE_COLUMNS": - if document.meta["name"] not in db_schemas: - db_schemas[document.meta["name"]] = {"columns": content["columns"]} - else: - if "columns" not in db_schemas[document.meta["name"]]: - db_schemas[document.meta["name"]]["columns"] = content["columns"] - else: - db_schemas[document.meta["name"]]["columns"] += content["columns"] - - # remove incomplete schemas - db_schemas = {k: v for k, v in db_schemas.items() if "type" in v and "columns" in v} - - db_schemas_in_ddl = [] - for table_schema in list(db_schemas.values()): - if table_schema["type"] == "TABLE": - ddl, _, _ = build_table_ddl(table_schema) - db_schemas_in_ddl.append(ddl) - - return db_schemas_in_ddl - - @observe(capture_input=False) def prompt( query: str, wren_ai_docs: list[dict], - construct_db_schemas: list[str], + db_schemas: list[str], histories: list[AskHistory], prompt_builder: PromptBuilder, sql_samples: Optional[list[dict]] = None, instructions: Optional[list[dict]] = None, configuration: Configuration | None = None, + sql_data: Optional[dict] = None, ) -> dict: _prompt = prompt_builder.run( query=query, language=configuration.language, - db_schemas=construct_db_schemas, + db_schemas=db_schemas, histories=histories, sql_samples=sql_samples, instructions=construct_instructions( instructions=instructions, ), docs=wren_ai_docs, + sql_data=sql_data, + sql_data_size=len(sql_data.get("data", [])), ) return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} @@ -297,21 +224,19 @@ async def classify_intent(prompt: dict, generator: Any, generator_name: str) -> @observe(capture_input=False) -def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict: +def post_process(classify_intent: dict) -> dict: try: results = orjson.loads(classify_intent.get("replies")[0]) return { "rephrased_question": results["rephrased_question"], "intent": results["results"], "reasoning": results["reasoning"], - "db_schemas": construct_db_schemas, } except Exception: return { "rephrased_question": "", "intent": "TEXT_TO_SQL", "reasoning": "", - "db_schemas": construct_db_schemas, } @@ -320,7 +245,14 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict class IntentClassificationResult(BaseModel): rephrased_question: str - results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL", "USER_GUIDE"] + results: Literal[ + "MISLEADING_QUERY", + "TEXT_TO_SQL", + "GENERAL", + "DATA_EXPLORATION", + "USER_GUIDE", + "USER_CLARIFICATION", + ] reasoning: str @@ -378,22 +310,26 @@ def __init__( async def run( self, query: str, + db_schemas: list[str], project_id: Optional[str] = None, histories: Optional[list[AskHistory]] = None, sql_samples: Optional[list[dict]] = None, instructions: Optional[list[dict]] = None, configuration: Configuration = Configuration(), + sql_data: Optional[dict] = None, ): logger.info("Intent Classification pipeline is running...") return await self._pipe.execute( ["post_process"], inputs={ "query": query, + "db_schemas": db_schemas, "project_id": project_id or "", "histories": histories or [], "sql_samples": sql_samples or [], "instructions": instructions or [], "configuration": configuration, + "sql_data": sql_data or {}, **self._components, **self._configs, }, diff --git a/wren-ai-service/src/pipelines/generation/misleading_assistance.py b/wren-ai-service/src/pipelines/generation/misleading_assistance.py index a35738ecf5..4f5ea2edeb 100644 --- a/wren-ai-service/src/pipelines/generation/misleading_assistance.py +++ b/wren-ai-service/src/pipelines/generation/misleading_assistance.py @@ -19,8 +19,8 @@ misleading_assistance_system_prompt = """ ### TASK ### -You are a helpful assistant that can help users understand their data better. Currently, you are given a user's question that is potentially misleading. -Your goal is to help guide user understand its data better and suggest few better questions to ask. +You are a helpful assistant that can help users understand their data better. Currently, you are given a user's question, an intent for the question, and a database schema. +Your goal is to help guide user understand its data better and suggest few better questions to ask based on the intent for the question and the database schema. ### INSTRUCTIONS ### @@ -41,8 +41,16 @@ {{ db_schema }} {% endfor %} +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + ### INPUT ### User's question: {{query}} +Intent for user's question: {{intent_reasoning}} Language: {{language}} Custom Instruction: {{ custom_instruction }} @@ -55,19 +63,17 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, db_schemas: list[str], language: str, histories: list[AskHistory], prompt_builder: PromptBuilder, custom_instruction: str, ) -> dict: - previous_query_summaries = ( - [history.question for history in histories] if histories else [] - ) - query = "\n".join(previous_query_summaries) + "\n" + query - _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, + histories=histories, db_schemas=db_schemas, language=language, custom_instruction=custom_instruction, @@ -150,6 +156,7 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, db_schemas: list[str], language: str, query_id: Optional[str] = None, @@ -161,6 +168,7 @@ async def run( ["misleading_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "db_schemas": db_schemas, "language": language, "query_id": query_id or "", diff --git a/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py new file mode 100644 index 0000000000..1d35bea872 --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/user_clarification_assistance.py @@ -0,0 +1,175 @@ +import asyncio +import logging +import sys +from typing import Any, Optional + +from hamilton import base +from hamilton.async_driver import AsyncDriver +from haystack.components.builders.prompt_builder import PromptBuilder +from langfuse.decorators import observe + +from src.core.pipeline import BasicPipeline +from src.core.provider import LLMProvider +from src.pipelines.common import clean_up_new_lines +from src.utils import trace_cost +from src.web.v1.services.ask import AskHistory + +logger = logging.getLogger("wren-ai-service") + + +user_clarification_assistance_system_prompt = """ +You are a helpful assistant that can help users understand their data better. Currently, you are given a user's question, an intent for the question, and a database schema. +You should tell user why the question is not clear enough or vague. + +### INSTRUCTIONS ### +1. Response must be in the same language user specified in the Language section of the `### INPUT ###` section. +2. There should be proper line breaks, whitespace, and Markdown formatting(headers, lists, tables, etc.) in your response. +3. MUST NOT add SQL code in your response. +4. MUST consider database schema when suggesting better questions. +5. The maximum response length is 50 words. +6. If the user provides a custom instruction, it should be followed strictly and you should use it to change the style of response. + +### OUTPUT FORMAT ### +Please provide your response in proper Markdown format without ```markdown``` tags. +""" + +user_clarification_assistance_user_prompt_template = """ +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + +### DATABASE SCHEMA ### +{% for db_schema in db_schemas %} + {{ db_schema }} +{% endfor %} + +### INPUT ### +User Question: {{query}} +Intent for user's question: {{intent_reasoning}} +Language: {{language}} + +Custom Instruction: {{ custom_instruction }} +""" + + +## Start of Pipeline +@observe(capture_input=False) +def prompt( + query: str, + intent_reasoning: str, + db_schemas: list[str], + language: str, + histories: list[AskHistory], + prompt_builder: PromptBuilder, + custom_instruction: str, +) -> dict: + _prompt = prompt_builder.run( + query=query, + intent_reasoning=intent_reasoning, + histories=histories, + db_schemas=db_schemas, + language=language, + custom_instruction=custom_instruction, + ) + return {"prompt": clean_up_new_lines(_prompt.get("prompt"))} + + +@observe(as_type="generation", capture_input=False) +@trace_cost +async def user_clarification_assistance( + prompt: dict, generator: Any, query_id: str, generator_name: str +) -> dict: + return await generator( + prompt=prompt.get("prompt"), + query_id=query_id, + ), generator_name + + +## End of Pipeline + + +class UserClarificationAssistance(BasicPipeline): + def __init__( + self, + llm_provider: LLMProvider, + **kwargs, + ): + self._user_queues = {} + self._components = { + "generator": llm_provider.get_generator( + system_prompt=user_clarification_assistance_system_prompt, + streaming_callback=self._streaming_callback, + ), + "generator_name": llm_provider.get_model(), + "prompt_builder": PromptBuilder( + template=user_clarification_assistance_user_prompt_template + ), + } + + super().__init__( + AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult()) + ) + + def _streaming_callback(self, chunk, query_id): + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Create a new queue for the user if it doesn't exist + # Put the chunk content into the user's queue + asyncio.create_task(self._user_queues[query_id].put(chunk.content)) + if chunk.meta.get("finish_reason"): + asyncio.create_task(self._user_queues[query_id].put("")) + + async def get_streaming_results(self, query_id): + async def _get_streaming_results(query_id): + return await self._user_queues[query_id].get() + + if query_id not in self._user_queues: + self._user_queues[ + query_id + ] = asyncio.Queue() # Ensure the user's queue exists + while True: + try: + # Wait for an item from the user's queue + self._streaming_results = await asyncio.wait_for( + _get_streaming_results(query_id), timeout=120 + ) + if ( + self._streaming_results == "" + ): # Check for end-of-stream signal + del self._user_queues[query_id] + break + if self._streaming_results: # Check if there are results to yield + yield self._streaming_results + self._streaming_results = "" # Clear after yielding + except TimeoutError: + break + + @observe(name="User Clarification Assistance") + async def run( + self, + query: str, + intent_reasoning: str, + db_schemas: list[str], + language: str, + query_id: Optional[str] = None, + histories: Optional[list[AskHistory]] = None, + custom_instruction: Optional[str] = None, + ): + logger.info("User Clarification Assistance pipeline is running...") + return await self._pipe.execute( + ["user_clarification_assistance"], + inputs={ + "query": query, + "intent_reasoning": intent_reasoning, + "db_schemas": db_schemas, + "language": language, + "query_id": query_id or "", + "histories": histories or [], + "custom_instruction": custom_instruction or "", + **self._components, + }, + ) diff --git a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py index be437f883b..a1ce19b446 100644 --- a/wren-ai-service/src/pipelines/generation/user_guide_assistance.py +++ b/wren-ai-service/src/pipelines/generation/user_guide_assistance.py @@ -12,14 +12,15 @@ from src.core.provider import LLMProvider from src.pipelines.common import clean_up_new_lines from src.utils import trace_cost +from src.web.v1.services.ask import AskHistory logger = logging.getLogger("wren-ai-service") user_guide_assistance_system_prompt = """ You are a helpful assistant that can help users understand Wren AI. -You are given a user question and a user guide. -You need to understand the user question and the user guide, and then answer the user question. +You are given a user question, an intent for the question, and a user guide. +You need to understand the user question, the intent reasoning, and the user guide, and then answer the user question. ### INSTRUCTIONS ### 1. Your answer should be in the same language as the language user provided. @@ -34,7 +35,16 @@ """ user_guide_assistance_user_prompt_template = """ +{% if histories %} +### PREVIOUS QUESTIONS ### +{% for history in histories %} + {{ history.question }} +{% endfor %} +{% endif %} + +### INPUT ### User Question: {{query}} +Intent for user's question: {{intent_reasoning}} Language: {{language}} User Guide: {% for doc in docs %} @@ -51,13 +61,17 @@ @observe(capture_input=False) def prompt( query: str, + intent_reasoning: str, language: str, wren_ai_docs: list[dict], + histories: list[AskHistory], prompt_builder: PromptBuilder, custom_instruction: str, ) -> dict: _prompt = prompt_builder.run( query=query, + intent_reasoning=intent_reasoning, + histories=histories, language=language, docs=wren_ai_docs, custom_instruction=custom_instruction, @@ -142,8 +156,10 @@ async def _get_streaming_results(query_id): async def run( self, query: str, + intent_reasoning: str, language: str, query_id: Optional[str] = None, + histories: Optional[list[AskHistory]] = None, custom_instruction: Optional[str] = None, ): logger.info("User Guide Assistance pipeline is running...") @@ -151,8 +167,10 @@ async def run( ["user_guide_assistance"], inputs={ "query": query, + "intent_reasoning": intent_reasoning, "language": language, "query_id": query_id or "", + "histories": histories or [], "custom_instruction": custom_instruction or "", **self._components, **self._configs, diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py index b40528c870..ffdd11001e 100644 --- a/wren-ai-service/src/pipelines/generation/utils/sql.py +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -515,7 +515,9 @@ def construct_ask_history_messages( ) messages.append( ChatMessage.from_assistant( - history.sql if hasattr(history, "sql") else history["sql"] + history.response + if hasattr(history, "response") + else history["response"] ) ) return messages diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 5e36c8a8ad..06b512c779 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -14,7 +14,8 @@ class AskHistory(BaseModel): - sql: str + response: str = Field(alias=AliasChoices("response", "sql")) + type: Literal["text", "sql"] = "sql" question: str @@ -83,14 +84,26 @@ class _AskResultResponse(BaseModel): trace_id: Optional[str] = None is_followup: bool = False general_type: Optional[ - Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE"] + Literal[ + "MISLEADING_QUERY", + "DATA_ASSISTANCE", + "USER_GUIDE", + "DATA_EXPLORATION", + "USER_CLARIFICATION", + ] ] = None class AskResultResponse(_AskResultResponse): is_followup: Optional[bool] = Field(False, exclude=True) general_type: Optional[ - Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE"] + Literal[ + "MISLEADING_QUERY", + "DATA_ASSISTANCE", + "USER_GUIDE", + "DATA_EXPLORATION", + "USER_CLARIFICATION", + ] ] = Field(None, exclude=True) @@ -206,7 +219,11 @@ async def ask( sql_generation_reasoning = "" else: # Run both pipeline operations concurrently - sql_samples_task, instructions_task = await asyncio.gather( + ( + sql_samples_task, + instructions_task, + db_schema_retrieval_task, + ) = await asyncio.gather( self._pipelines["sql_pairs_retrieval"].run( query=user_query, project_id=ask_request.project_id, @@ -216,6 +233,12 @@ async def ask( project_id=ask_request.project_id, scope="sql", ), + self._pipelines["db_schema_retrieval"].run( + query=user_query, + histories=histories, + project_id=ask_request.project_id, + enable_column_pruning=enable_column_pruning, + ), ) # Extract results from completed tasks @@ -225,16 +248,36 @@ async def ask( instructions = instructions_task["formatted_output"].get( "documents", [] ) + _retrieval_result = db_schema_retrieval_task.get( + "construct_retrieval_results", {} + ) + documents = _retrieval_result.get("retrieval_results", []) + table_names = [document.get("table_name") for document in documents] + table_ddls = [document.get("table_ddl") for document in documents] if self._allow_intent_classification: + last_sql_data = None + if histories: + if (last_response := histories[-1].response) and histories[ + -1 + ].type == "sql": + last_sql_data = ( + await self._pipelines["sql_executor"].run( + sql=last_response, + project_id=ask_request.project_id, + ) + )["execute_sql"]["results"] + intent_classification_result = ( await self._pipelines["intent_classification"].run( query=user_query, + db_schemas=table_ddls, histories=histories, sql_samples=sql_samples, instructions=instructions, project_id=ask_request.project_id, configuration=ask_request.configurations, + sql_data=last_sql_data, ) ).get("post_process", {}) intent = intent_classification_result.get("intent") @@ -250,10 +293,9 @@ async def ask( asyncio.create_task( self._pipelines["misleading_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, histories=histories, - db_schemas=intent_classification_result.get( - "db_schemas" - ), + db_schemas=table_ddls, language=ask_request.configurations.language, query_id=ask_request.query_id, custom_instruction=ask_request.custom_instruction, @@ -275,10 +317,9 @@ async def ask( asyncio.create_task( self._pipelines["data_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, histories=histories, - db_schemas=intent_classification_result.get( - "db_schemas" - ), + db_schemas=table_ddls, language=ask_request.configurations.language, query_id=ask_request.query_id, custom_instruction=ask_request.custom_instruction, @@ -300,6 +341,8 @@ async def ask( asyncio.create_task( self._pipelines["user_guide_assistance"].run( query=user_query, + intent_reasoning=intent_reasoning, + histories=histories, language=ask_request.configurations.language, query_id=ask_request.query_id, custom_instruction=ask_request.custom_instruction, @@ -317,6 +360,54 @@ async def ask( ) results["metadata"]["type"] = "GENERAL" return results + elif intent == "DATA_EXPLORATION": + asyncio.create_task( + self._pipelines["data_exploration_assistance"].run( + query=user_query, + intent_reasoning=intent_reasoning, + histories=histories, + sql_data=last_sql_data, + language=ask_request.configurations.language, + query_id=ask_request.query_id, + custom_instruction=ask_request.custom_instruction, + ) + ) + + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="GENERAL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + trace_id=trace_id, + is_followup=True if histories else False, + general_type="DATA_EXPLORATION", + ) + results["metadata"]["type"] = "GENERAL" + return results + elif intent == "USER_CLARIFICATION": + asyncio.create_task( + self._pipelines["user_clarification_assistance"].run( + query=user_query, + intent_reasoning=intent_reasoning, + histories=histories, + db_schemas=table_ddls, + language=ask_request.configurations.language, + query_id=ask_request.query_id, + custom_instruction=ask_request.custom_instruction, + ) + ) + + self._ask_results[query_id] = AskResultResponse( + status="finished", + type="GENERAL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + trace_id=trace_id, + is_followup=True if histories else False, + general_type="USER_CLARIFICATION", + ) + results["metadata"]["type"] = "GENERAL" + return results else: self._ask_results[query_id] = AskResultResponse( status="understanding", @@ -336,19 +427,6 @@ async def ask( is_followup=True if histories else False, ) - retrieval_result = await self._pipelines["db_schema_retrieval"].run( - query=user_query, - histories=histories, - project_id=ask_request.project_id, - enable_column_pruning=enable_column_pruning, - ) - _retrieval_result = retrieval_result.get( - "construct_retrieval_results", {} - ) - documents = _retrieval_result.get("retrieval_results", []) - table_names = [document.get("table_name") for document in documents] - table_ddls = [document.get("table_ddl") for document in documents] - if not documents: logger.exception(f"ask pipeline - NO_RELEVANT_DATA: {user_query}") if not self._is_stopped(query_id, self._ask_results): @@ -639,6 +717,12 @@ async def get_ask_streaming_result( _pipeline_name = "data_assistance" elif self._ask_results.get(query_id).general_type == "MISLEADING_QUERY": _pipeline_name = "misleading_assistance" + elif self._ask_results.get(query_id).general_type == "DATA_EXPLORATION": + _pipeline_name = "data_exploration_assistance" + elif ( + self._ask_results.get(query_id).general_type == "USER_CLARIFICATION" + ): + _pipeline_name = "user_clarification_assistance" elif self._ask_results.get(query_id).status == "planning": if self._ask_results.get(query_id).is_followup: _pipeline_name = "followup_sql_generation_reasoning" diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index f2c01e95a5..d0bbe64d75 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -135,6 +135,10 @@ pipes: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index bb688dcfd8..ed579ad1a0 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -135,6 +135,10 @@ pipes: llm: litellm_llm.default - name: data_assistance llm: litellm_llm.default + - name: data_exploration_assistance + llm: litellm_llm.default + - name: user_clarification_assistance + llm: litellm_llm.default - name: sql_pairs_indexing document_store: qdrant embedder: litellm_embedder.default