Skip to content

Commit 9b8ab01

Browse files
committed
simplify
1 parent a4b8477 commit 9b8ab01

File tree

2 files changed

+56
-136
lines changed

2 files changed

+56
-136
lines changed

wren-ai-service/src/pipelines/generation/intent_classification.py

Lines changed: 24 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
import ast
21
import logging
32
import sys
43
from typing import Any, Literal, Optional
54

65
import orjson
76
from hamilton import base
87
from hamilton.async_driver import AsyncDriver
9-
from haystack import Document
108
from haystack.components.builders.prompt_builder import PromptBuilder
119
from langfuse.decorators import observe
1210
from pydantic import BaseModel
1311

1412
from src.core.pipeline import BasicPipeline
1513
from src.core.provider import DocumentStoreProvider, EmbedderProvider, LLMProvider
16-
from src.pipelines.common import build_table_ddl, clean_up_new_lines
14+
from src.pipelines.common import clean_up_new_lines
1715
from src.pipelines.generation.utils.sql import construct_instructions
1816
from src.utils import trace_cost
1917
from src.web.v1.services import Configuration
@@ -25,7 +23,7 @@
2523
intent_classification_system_prompt = """
2624
### Task ###
2725
You are an expert detective specializing in intent classification. Combine the user's current question and previous questions to determine their true intent based on the provided database schema or sql data if provided.
28-
Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, or `USER_GUIDE`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
26+
Classify the intent into one of these categories: `MISLEADING_QUERY`, `TEXT_TO_SQL`, `DATA_EXPLORATION`, `GENERAL`, `USER_GUIDE`, or `USER_CLARIFICATION`. Additionally, provide a concise reasoning (maximum 20 words) for your classification.
2927
3028
### Instructions ###
3129
- **Follow the user's previous questions:** If there are previous questions, try to understand the user's current question as following the previous questions.
@@ -34,8 +32,9 @@
3432
- **Rephrase Question:** Rewrite follow-up questions into full standalone questions using prior conversation context.
3533
- **Concise Reasoning:** The reasoning must be clear, concise, and limited to 20 words.
3634
- **Language Consistency:** Use the same language as specified in the user's output language for the rephrased question and reasoning.
37-
- **Vague Queries:** If the question is vague or does not related to a table or property from the schema, classify it as `MISLEADING_QUERY`.
38-
- **Incomplete Queries:** If the question is related to the database schema but references unspecified values (e.g., "the following", "these", "those") without providing them, classify as `GENERAL`.
35+
- **Vague Queries:** If the question does not related to the database schema, classify it as `MISLEADING_QUERY`.
36+
- **User Clarification:** If the question is related to the database schema, but missing some details in order to answer the question, classify it as `USER_CLARIFICATION`.
37+
- **Incomplete Queries:** If the question is related to the database schema but references unspecified values (e.g., "the following", "these", "those") without providing them, classify as `USER_CLARIFICATION`.
3938
- **Time-related Queries:** Don't rephrase time-related information in the user's question.
4039
4140
### Intent Definitions ###
@@ -73,9 +72,9 @@
7372
- "List the top 10 products by revenue."
7473
</TEXT_TO_SQL>
7574
76-
<GENERAL>
77-
**When to Use:**
78-
- The user seeks general information about the database schema or its overall capabilities.
75+
<USER_CLARIFICATION>
76+
**When to Use:**
77+
- The user's question is related to the database schema, but missing some details in order to answer the question.
7978
- The query references **missing information** (e.g., "the following items" without listing them).
8079
- The query contains **placeholder references** that cannot be resolved from context.
8180
- The query is **incomplete for SQL generation** despite mentioning database concepts.
@@ -85,11 +84,18 @@
8584
- Identify missing parameters, unspecified references, or incomplete filter criteria.
8685
8786
**Examples:**
88-
- "What is the dataset about?"
89-
- "Tell me more about the database."
9087
- "How can I analyze customer behavior with this data?"
9188
- "Show me orders for these products" (without specifying which products)
9289
- "Filter by the criteria I mentioned" (without previous context defining criteria)
90+
</USER_CLARIFICATION>
91+
92+
<GENERAL>
93+
**When to Use:**
94+
- The user seeks general information about the database schema or its overall capabilities
95+
96+
**Examples:**
97+
- "What is the dataset about?"
98+
- "Tell me more about the database."
9399
</GENERAL>
94100
95101
<USER_GUIDE>
@@ -126,7 +132,7 @@
126132
{
127133
"rephrased_question": "<rephrased question in full standalone question if there are previous questions, otherwise the original question>",
128134
"reasoning": "<brief chain-of-thought reasoning (max 20 words)>",
129-
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "DATA_EXPLORATION" |"GENERAL" | "USER_GUIDE"
135+
"results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "DATA_EXPLORATION" | "GENERAL" | "USER_GUIDE" | "USER_CLARIFICATION"
130136
}
131137
"""
132138

@@ -183,114 +189,11 @@
183189

184190

185191
## Start of Pipeline
186-
@observe(capture_input=False, capture_output=False)
187-
async def embedding(query: str, embedder: Any, histories: list[AskHistory]) -> dict:
188-
previous_query_summaries = (
189-
[history.question for history in histories] if histories else []
190-
)
191-
192-
query = "\n".join(previous_query_summaries) + "\n" + query
193-
194-
return await embedder.run(query)
195-
196-
197-
@observe(capture_input=False)
198-
async def table_retrieval(
199-
embedding: dict, project_id: str, table_retriever: Any
200-
) -> dict:
201-
filters = {
202-
"operator": "AND",
203-
"conditions": [
204-
{"field": "type", "operator": "==", "value": "TABLE_DESCRIPTION"},
205-
],
206-
}
207-
208-
if project_id:
209-
filters["conditions"].append(
210-
{"field": "project_id", "operator": "==", "value": project_id}
211-
)
212-
213-
return await table_retriever.run(
214-
query_embedding=embedding.get("embedding"),
215-
filters=filters,
216-
)
217-
218-
219-
@observe(capture_input=False)
220-
async def dbschema_retrieval(
221-
table_retrieval: dict, embedding: dict, project_id: str, dbschema_retriever: Any
222-
) -> list[Document]:
223-
tables = table_retrieval.get("documents", [])
224-
table_names = []
225-
for table in tables:
226-
content = ast.literal_eval(table.content)
227-
table_names.append(content["name"])
228-
229-
logger.info(f"dbschema_retrieval with table_names: {table_names}")
230-
231-
table_name_conditions = [
232-
{"field": "name", "operator": "==", "value": table_name}
233-
for table_name in table_names
234-
]
235-
236-
filters = {
237-
"operator": "AND",
238-
"conditions": [
239-
{"field": "type", "operator": "==", "value": "TABLE_SCHEMA"},
240-
{"operator": "OR", "conditions": table_name_conditions},
241-
],
242-
}
243-
244-
if project_id:
245-
filters["conditions"].append(
246-
{"field": "project_id", "operator": "==", "value": project_id}
247-
)
248-
249-
results = await dbschema_retriever.run(
250-
query_embedding=embedding.get("embedding"), filters=filters
251-
)
252-
return results["documents"]
253-
254-
255-
@observe()
256-
def construct_db_schemas(dbschema_retrieval: list[Document]) -> list[str]:
257-
db_schemas = {}
258-
for document in dbschema_retrieval:
259-
content = ast.literal_eval(document.content)
260-
if content["type"] == "TABLE":
261-
if document.meta["name"] not in db_schemas:
262-
db_schemas[document.meta["name"]] = content
263-
else:
264-
db_schemas[document.meta["name"]] = {
265-
**content,
266-
"columns": db_schemas[document.meta["name"]].get("columns", []),
267-
}
268-
elif content["type"] == "TABLE_COLUMNS":
269-
if document.meta["name"] not in db_schemas:
270-
db_schemas[document.meta["name"]] = {"columns": content["columns"]}
271-
else:
272-
if "columns" not in db_schemas[document.meta["name"]]:
273-
db_schemas[document.meta["name"]]["columns"] = content["columns"]
274-
else:
275-
db_schemas[document.meta["name"]]["columns"] += content["columns"]
276-
277-
# remove incomplete schemas
278-
db_schemas = {k: v for k, v in db_schemas.items() if "type" in v and "columns" in v}
279-
280-
db_schemas_in_ddl = []
281-
for table_schema in list(db_schemas.values()):
282-
if table_schema["type"] == "TABLE":
283-
ddl, _, _ = build_table_ddl(table_schema)
284-
db_schemas_in_ddl.append(ddl)
285-
286-
return db_schemas_in_ddl
287-
288-
289192
@observe(capture_input=False)
290193
def prompt(
291194
query: str,
292195
wren_ai_docs: list[dict],
293-
construct_db_schemas: list[str],
196+
db_schemas: list[str],
294197
histories: list[AskHistory],
295198
prompt_builder: PromptBuilder,
296199
sql_samples: Optional[list[dict]] = None,
@@ -301,7 +204,7 @@ def prompt(
301204
_prompt = prompt_builder.run(
302205
query=query,
303206
language=configuration.language,
304-
db_schemas=construct_db_schemas,
207+
db_schemas=db_schemas,
305208
histories=histories,
306209
sql_samples=sql_samples,
307210
instructions=construct_instructions(
@@ -321,21 +224,19 @@ async def classify_intent(prompt: dict, generator: Any, generator_name: str) ->
321224

322225

323226
@observe(capture_input=False)
324-
def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict:
227+
def post_process(classify_intent: dict) -> dict:
325228
try:
326229
results = orjson.loads(classify_intent.get("replies")[0])
327230
return {
328231
"rephrased_question": results["rephrased_question"],
329232
"intent": results["results"],
330233
"reasoning": results["reasoning"],
331-
"db_schemas": construct_db_schemas,
332234
}
333235
except Exception:
334236
return {
335237
"rephrased_question": "",
336238
"intent": "TEXT_TO_SQL",
337239
"reasoning": "",
338-
"db_schemas": construct_db_schemas,
339240
}
340241

341242

@@ -350,6 +251,7 @@ class IntentClassificationResult(BaseModel):
350251
"GENERAL",
351252
"DATA_EXPLORATION",
352253
"USER_GUIDE",
254+
"USER_CLARIFICATION",
353255
]
354256
reasoning: str
355257

@@ -408,6 +310,7 @@ def __init__(
408310
async def run(
409311
self,
410312
query: str,
313+
db_schemas: list[str],
411314
project_id: Optional[str] = None,
412315
histories: Optional[list[AskHistory]] = None,
413316
sql_samples: Optional[list[dict]] = None,
@@ -420,6 +323,7 @@ async def run(
420323
["post_process"],
421324
inputs={
422325
"query": query,
326+
"db_schemas": db_schemas,
423327
"project_id": project_id or "",
424328
"histories": histories or [],
425329
"sql_samples": sql_samples or [],

wren-ai-service/src/web/v1/services/ask.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,26 @@ class _AskResultResponse(BaseModel):
8383
trace_id: Optional[str] = None
8484
is_followup: bool = False
8585
general_type: Optional[
86-
Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"]
86+
Literal[
87+
"MISLEADING_QUERY",
88+
"DATA_ASSISTANCE",
89+
"USER_GUIDE",
90+
"DATA_EXPLORATION",
91+
"USER_CLARIFICATION",
92+
]
8793
] = None
8894

8995

9096
class AskResultResponse(_AskResultResponse):
9197
is_followup: Optional[bool] = Field(False, exclude=True)
9298
general_type: Optional[
93-
Literal["MISLEADING_QUERY", "DATA_ASSISTANCE", "USER_GUIDE", "DATA_EXPLORATION"]
99+
Literal[
100+
"MISLEADING_QUERY",
101+
"DATA_ASSISTANCE",
102+
"USER_GUIDE",
103+
"DATA_EXPLORATION",
104+
"USER_CLARIFICATION",
105+
]
94106
] = Field(None, exclude=True)
95107

96108

@@ -206,7 +218,11 @@ async def ask(
206218
sql_generation_reasoning = ""
207219
else:
208220
# Run both pipeline operations concurrently
209-
sql_samples_task, instructions_task = await asyncio.gather(
221+
(
222+
sql_samples_task,
223+
instructions_task,
224+
db_schema_retrieval_task,
225+
) = await asyncio.gather(
210226
self._pipelines["sql_pairs_retrieval"].run(
211227
query=user_query,
212228
project_id=ask_request.project_id,
@@ -216,6 +232,12 @@ async def ask(
216232
project_id=ask_request.project_id,
217233
scope="sql",
218234
),
235+
self._pipelines["db_schema_retrieval"].run(
236+
query=user_query,
237+
histories=histories,
238+
project_id=ask_request.project_id,
239+
enable_column_pruning=enable_column_pruning,
240+
),
219241
)
220242

221243
# Extract results from completed tasks
@@ -225,6 +247,12 @@ async def ask(
225247
instructions = instructions_task["formatted_output"].get(
226248
"documents", []
227249
)
250+
_retrieval_result = db_schema_retrieval_task.get(
251+
"construct_retrieval_results", {}
252+
)
253+
documents = _retrieval_result.get("retrieval_results", [])
254+
table_names = [document.get("table_name") for document in documents]
255+
table_ddls = [document.get("table_ddl") for document in documents]
228256

229257
if self._allow_intent_classification:
230258
last_sql_data = None
@@ -240,6 +268,7 @@ async def ask(
240268
intent_classification_result = (
241269
await self._pipelines["intent_classification"].run(
242270
query=user_query,
271+
db_schemas=table_ddls,
243272
histories=histories,
244273
sql_samples=sql_samples,
245274
instructions=instructions,
@@ -368,19 +397,6 @@ async def ask(
368397
is_followup=True if histories else False,
369398
)
370399

371-
retrieval_result = await self._pipelines["db_schema_retrieval"].run(
372-
query=user_query,
373-
histories=histories,
374-
project_id=ask_request.project_id,
375-
enable_column_pruning=enable_column_pruning,
376-
)
377-
_retrieval_result = retrieval_result.get(
378-
"construct_retrieval_results", {}
379-
)
380-
documents = _retrieval_result.get("retrieval_results", [])
381-
table_names = [document.get("table_name") for document in documents]
382-
table_ddls = [document.get("table_ddl") for document in documents]
383-
384400
if not documents:
385401
logger.exception(f"ask pipeline - NO_RELEVANT_DATA: {user_query}")
386402
if not self._is_stopped(query_id, self._ask_results):

0 commit comments

Comments
 (0)