Skip to content

Commit c85985e

Browse files
committed
update
1 parent 912f677 commit c85985e

File tree

9 files changed

+229
-5
lines changed

9 files changed

+229
-5
lines changed

deployment/kustomizations/base/cm.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ data:
172172
llm: litellm_llm.default
173173
- name: data_assistance
174174
llm: litellm_llm.default
175+
- name: data_exploration_assistance
176+
llm: litellm_llm.default
175177
- name: sql_pairs_indexing
176178
document_store: qdrant
177179
embedder: litellm_embedder.default

docker/config.example.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ pipes:
122122
llm: litellm_llm.default
123123
- name: data_assistance
124124
llm: litellm_llm.default
125+
- name: data_exploration_assistance
126+
llm: litellm_llm.default
125127
- name: sql_pairs_indexing
126128
document_store: qdrant
127129
embedder: litellm_embedder.default

wren-ai-service/src/globals.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ def create_service_container(
146146
**pipe_components["followup_sql_generation"],
147147
),
148148
"sql_functions_retrieval": _sql_functions_retrieval_pipeline,
149+
"sql_executor": retrieval.SQLExecutor(
150+
**pipe_components["sql_executor"],
151+
engine_timeout=settings.engine_timeout,
152+
),
153+
"data_exploration_assistance": generation.DataExplorationAssistance(
154+
**pipe_components["data_exploration_assistance"],
155+
),
149156
},
150157
allow_intent_classification=settings.allow_intent_classification,
151158
allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,

wren-ai-service/src/pipelines/generation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .chart_adjustment import ChartAdjustment
22
from .chart_generation import ChartGeneration
33
from .data_assistance import DataAssistance
4+
from .data_exploration_assistance import DataExplorationAssistance
45
from .followup_sql_generation import FollowUpSQLGeneration
56
from .followup_sql_generation_reasoning import FollowUpSQLGenerationReasoning
67
from .intent_classification import IntentClassification
@@ -36,4 +37,5 @@
3637
"FollowUpSQLGenerationReasoning",
3738
"MisleadingAssistance",
3839
"SQLTablesExtraction",
40+
"DataExplorationAssistance",
3941
]
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import asyncio
2+
import logging
3+
import sys
4+
from typing import Any, Optional
5+
6+
from hamilton import base
7+
from hamilton.async_driver import AsyncDriver
8+
from haystack.components.builders.prompt_builder import PromptBuilder
9+
from langfuse.decorators import observe
10+
11+
from src.core.pipeline import BasicPipeline
12+
from src.core.provider import LLMProvider
13+
14+
logger = logging.getLogger("wren-ai-service")
15+
16+
17+
data_exploration_assistance_system_prompt = """
18+
You are a great data analyst good at exploring data.
19+
You are given a user question and a sql data.
20+
You need to understand the user question and the sql data, and then answer the user question.
21+
### INSTRUCTIONS ###
22+
1. Your answer should be in the same language as the language user provided.
23+
2. You must follow the sql data to answer the user question.
24+
3. You should provide your answer in Markdown format.
25+
4. You have the following skills:
26+
- explain the data in a easy to understand manner
27+
- provide insights and trends in the data
28+
- find out anomalies and outliers in the data
29+
5. You only need to use the skills required to answer the user question based on the user question and the sql data.
30+
### OUTPUT FORMAT ###
31+
Please provide your response in proper Markdown format without ```markdown``` tags.
32+
"""
33+
34+
data_exploration_assistance_user_prompt_template = """
35+
User Question: {{query}}
36+
Language: {{language}}
37+
SQL Data:
38+
{{ sql_data }}
39+
Please think step by step.
40+
"""
41+
42+
43+
## Start of Pipeline
44+
@observe(capture_input=False)
45+
def prompt(
46+
query: str,
47+
language: str,
48+
sql_data: dict,
49+
prompt_builder: PromptBuilder,
50+
) -> dict:
51+
return prompt_builder.run(
52+
query=query,
53+
language=language,
54+
sql_data=sql_data,
55+
)
56+
57+
58+
@observe(as_type="generation", capture_input=False)
59+
async def data_exploration_assistance(
60+
prompt: dict, generator: Any, query_id: str
61+
) -> dict:
62+
return await generator(prompt=prompt.get("prompt"), query_id=query_id)
63+
64+
65+
## End of Pipeline
66+
67+
68+
class DataExplorationAssistance(BasicPipeline):
69+
def __init__(
70+
self,
71+
llm_provider: LLMProvider,
72+
**kwargs,
73+
):
74+
self._user_queues = {}
75+
self._components = {
76+
"generator": llm_provider.get_generator(
77+
system_prompt=data_exploration_assistance_system_prompt,
78+
streaming_callback=self._streaming_callback,
79+
),
80+
"prompt_builder": PromptBuilder(
81+
template=data_exploration_assistance_user_prompt_template
82+
),
83+
}
84+
85+
super().__init__(
86+
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
87+
)
88+
89+
def _streaming_callback(self, chunk, query_id):
90+
if query_id not in self._user_queues:
91+
self._user_queues[
92+
query_id
93+
] = asyncio.Queue() # Create a new queue for the user if it doesn't exist
94+
# Put the chunk content into the user's queue
95+
asyncio.create_task(self._user_queues[query_id].put(chunk.content))
96+
if chunk.meta.get("finish_reason"):
97+
asyncio.create_task(self._user_queues[query_id].put("<DONE>"))
98+
99+
async def get_streaming_results(self, query_id):
100+
async def _get_streaming_results(query_id):
101+
return await self._user_queues[query_id].get()
102+
103+
if query_id not in self._user_queues:
104+
self._user_queues[query_id] = asyncio.Queue()
105+
106+
while True:
107+
try:
108+
# Wait for an item from the user's queue
109+
self._streaming_results = await asyncio.wait_for(
110+
_get_streaming_results(query_id), timeout=120
111+
)
112+
if (
113+
self._streaming_results == "<DONE>"
114+
): # Check for end-of-stream signal
115+
del self._user_queues[query_id]
116+
break
117+
if self._streaming_results: # Check if there are results to yield
118+
yield self._streaming_results
119+
self._streaming_results = "" # Clear after yielding
120+
except TimeoutError:
121+
break
122+
123+
@observe(name="Data Exploration Assistance")
124+
async def run(
125+
self,
126+
query: str,
127+
sql_data: dict,
128+
language: str,
129+
query_id: Optional[str] = None,
130+
):
131+
logger.info("Data Exploration Assistance pipeline is running...")
132+
return await self._pipe.execute(
133+
["data_exploration_assistance"],
134+
inputs={
135+
"query": query,
136+
"language": language,
137+
"query_id": query_id or "",
138+
"sql_data": sql_data,
139+
**self._components,
140+
},
141+
)

wren-ai-service/src/pipelines/generation/intent_classification.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
intent_classification_system_prompt = """
2626
### Task ###
27-
You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema. Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
27+
You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema or sql data if provided.
28+
Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
2829
2930
### Instructions ###
3031
- **Follow the user's previous questions:** If there are previous questions, try to understand the user's current question as following the previous questions.
@@ -39,6 +40,19 @@
3940
4041
### Intent Definitions ###
4142
43+
<DATA_EXPLORATION>
44+
**When to Use:**
45+
- The user's question is about data exploration such as asking for data details, asking for explanation of the data, asking for insights, asking for recommendations, asking for comparison, etc.
46+
**Requirements:**
47+
- SQL DATA is provided and the user's question is about exploring the data.
48+
- The user's question can be answered by the SQL DATA.
49+
- The row size of the SQL DATA is less than 500.
50+
**Examples:**
51+
- "Show me the part where the data appears abnormal"
52+
- "Please explain the data in the table"
53+
- "What's the trend of the data?"
54+
</DATA_EXPLORATION>
55+
4256
<TEXT_TO_SQL>
4357
**When to Use:**
4458
- The user's inputs are about modifying SQL from previous questions.
@@ -51,6 +65,7 @@
5165
- Must have complete filter criteria, specific values, or clear references to previous context.
5266
- Include specific table and column names from the schema in your reasoning or modifying SQL from previous questions.
5367
- Reference phrases from the user's inputs that clearly relate to the schema.
68+
- The SQL DATA is not provided or SQL DATA cannot answer the user's question, and the user's question can be answered given the database schema.
5469
5570
**Examples:**
5671
- "What is the total sales for last quarter?"
@@ -111,7 +126,7 @@
111126
{
112127
"rephrased_question": "<rephrased question in full standalone question if there are previous questions, otherwise the original question>",
113128
"reasoning": "<brief chain-of-thought reasoning (max 20 words)>",
114-
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" | "USER_GUIDE"
129+
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "DATA_EXPLORATION" |"GENERAL" | "USER_GUIDE"
115130
}
116131
"""
117132

@@ -143,6 +158,12 @@
143158
- {{doc.path}}: {{doc.content}}
144159
{% endfor %}
145160
161+
{% if sql_data %}
162+
### SQL DATA ###
163+
{{ sql_data }}
164+
row size of SQL DATA: {{ sql_data_size }}
165+
{% endif %}
166+
146167
### INPUT ###
147168
{% if histories %}
148169
User's previous questions:
@@ -275,6 +296,7 @@ def prompt(
275296
sql_samples: Optional[list[dict]] = None,
276297
instructions: Optional[list[dict]] = None,
277298
configuration: Configuration | None = None,
299+
sql_data: Optional[dict] = None,
278300
) -> dict:
279301
_prompt = prompt_builder.run(
280302
query=query,
@@ -286,6 +308,8 @@ def prompt(
286308
instructions=instructions,
287309
),
288310
docs=wren_ai_docs,
311+
sql_data=sql_data,
312+
sql_data_size=len(sql_data.get("data", [])),
289313
)
290314
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}
291315

@@ -320,7 +344,13 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict
320344

321345
class IntentClassificationResult(BaseModel):
322346
rephrased_question: str
323-
results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL", "USER_GUIDE"]
347+
results: Literal[
348+
"MISLEADING_QUERY",
349+
"TEXT_TO_SQL",
350+
"GENERAL",
351+
"DATA_EXPLORATION",
352+
"USER_GUIDE",
353+
]
324354
reasoning: str
325355

326356

@@ -383,6 +413,7 @@ async def run(
383413
sql_samples: Optional[list[dict]] = None,
384414
instructions: Optional[list[dict]] = None,
385415
configuration: Configuration = Configuration(),
416+
sql_data: Optional[dict] = None,
386417
):
387418
logger.info("Intent Classification pipeline is running...")
388419
return await self._pipe.execute(
@@ -394,6 +425,7 @@ async def run(
394425
"sql_samples": sql_samples or [],
395426
"instructions": instructions or [],
396427
"configuration": configuration,
428+
"sql_data": sql_data or {},
397429
**self._components,
398430
**self._configs,
399431
},

wren-ai-service/src/web/v1/services/ask.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ class _AskResultResponse(BaseModel):
8383
trace_id: Optional[str] = None
8484
is_followup: bool = False
8585
general_type: Optional[
86-
Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE"]
86+
Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"]
8787
] = None
8888

8989

9090
class AskResultResponse(_AskResultResponse):
9191
is_followup: Optional[bool] = Field(False, exclude=True)
9292
general_type: Optional[
93-
Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE"]
93+
Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"]
9494
] = Field(None, exclude=True)
9595

9696

@@ -227,6 +227,16 @@ async def ask(
227227
)
228228

229229
if self._allow_intent_classification:
230+
last_sql_data = None
231+
if histories:
232+
if last_sql := histories[-1].sql:
233+
last_sql_data = (
234+
await self._pipelines["sql_executor"].run(
235+
sql=last_sql,
236+
project_id=ask_request.project_id,
237+
)
238+
)["execute_sql"]["results"]
239+
230240
intent_classification_result = (
231241
await self._pipelines["intent_classification"].run(
232242
query=user_query,
@@ -235,6 +245,7 @@ async def ask(
235245
instructions=instructions,
236246
project_id=ask_request.project_id,
237247
configuration=ask_request.configurations,
248+
sql_data=last_sql_data,
238249
)
239250
).get("post_process", {})
240251
intent = intent_classification_result.get("intent")
@@ -317,6 +328,27 @@ async def ask(
317328
)
318329
results["metadata"]["type"] = "GENERAL"
319330
return results
331+
elif intent == "DATA_EXPLORATION":
332+
asyncio.create_task(
333+
self._pipelines["data_exploration_assistance"].run(
334+
query=user_query,
335+
sql_data=last_sql_data,
336+
language=ask_request.configurations.language,
337+
query_id=ask_request.query_id,
338+
)
339+
)
340+
341+
self._ask_results[query_id] = AskResultResponse(
342+
status="finished",
343+
type="GENERAL",
344+
rephrased_question=rephrased_question,
345+
intent_reasoning=intent_reasoning,
346+
trace_id=trace_id,
347+
is_followup=True if histories else False,
348+
general_type="DATA_EXPLORATION",
349+
)
350+
results["metadata"]["type"] = "GENERAL"
351+
return results
320352
else:
321353
self._ask_results[query_id] = AskResultResponse(
322354
status="understanding",
@@ -639,6 +671,8 @@ async def get_ask_streaming_result(
639671
_pipeline_name = "data_assistance"
640672
elif self._ask_results.get(query_id).general_type == "MISLEADING_QUERY":
641673
_pipeline_name = "misleading_assistance"
674+
elif self._ask_results.get(query_id).general_type == "DATA_EXPLORATION":
675+
_pipeline_name = "data_exploration_assistance"
642676
elif self._ask_results.get(query_id).status == "planning":
643677
if self._ask_results.get(query_id).is_followup:
644678
_pipeline_name = "followup_sql_generation_reasoning"

wren-ai-service/tools/config/config.example.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ pipes:
135135
llm: litellm_llm.default
136136
- name: data_assistance
137137
llm: litellm_llm.default
138+
- name: data_exploration_assistance
139+
llm: litellm_llm.default
138140
- name: sql_pairs_indexing
139141
document_store: qdrant
140142
embedder: litellm_embedder.default

wren-ai-service/tools/config/config.full.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ pipes:
135135
llm: litellm_llm.default
136136
- name: data_assistance
137137
llm: litellm_llm.default
138+
- name: data_exploration_assistance
139+
llm: litellm_llm.default
138140
- name: sql_pairs_indexing
139141
document_store: qdrant
140142
embedder: litellm_embedder.default

0 commit comments

Comments
 (0)