Skip to content

Commit 8af6ce9

Browse files
committed
update
1 parent 0586f41 commit 8af6ce9

File tree

7 files changed

+183
-6
lines changed

7 files changed

+183
-6
lines changed

deployment/kustomizations/base/cm.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ data:
174174
llm: litellm_llm.default
175175
- name: data_exploration_assistance
176176
llm: litellm_llm.default
177+
- name: user_clarification_assistance
178+
llm: litellm_llm.default
177179
- name: sql_pairs_indexing
178180
document_store: qdrant
179181
embedder: litellm_embedder.default

docker/config.example.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ pipes:
124124
llm: litellm_llm.default
125125
- name: data_exploration_assistance
126126
llm: litellm_llm.default
127+
- name: user_clarification_assistance
128+
llm: litellm_llm.default
127129
- name: sql_pairs_indexing
128130
document_store: qdrant
129131
embedder: litellm_embedder.default

wren-ai-service/src/pipelines/generation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .sql_question import SQLQuestion
1717
from .sql_regeneration import SQLRegeneration
1818
from .sql_tables_extraction import SQLTablesExtraction
19+
from .user_clarification_assistance import UserClarificationAssistance
1920
from .user_guide_assistance import UserGuideAssistance
2021

2122
__all__ = [
@@ -38,4 +39,5 @@
3839
"MisleadingAssistance",
3940
"SQLTablesExtraction",
4041
"DataExplorationAssistance",
42+
"UserClarificationAssistance",
4143
]
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import asyncio
2+
import logging
3+
import sys
4+
from typing import Any, Optional
5+
6+
from hamilton import base
7+
from hamilton.async_driver import AsyncDriver
8+
from haystack.components.builders.prompt_builder import PromptBuilder
9+
from langfuse.decorators import observe
10+
11+
from src.core.pipeline import BasicPipeline
12+
from src.core.provider import LLMProvider
13+
from src.pipelines.common import clean_up_new_lines
14+
from src.utils import trace_cost
15+
from src.web.v1.services.ask import AskHistory
16+
17+
logger = logging.getLogger("wren-ai-service")
18+
19+
20+
user_clarification_assistance_system_prompt = """
21+
"""
22+
23+
user_clarification_assistance_user_prompt_template = """
24+
"""
25+
26+
27+
## Start of Pipeline
28+
@observe(capture_input=False)
29+
def prompt(
30+
query: str,
31+
db_schemas: list[str],
32+
language: str,
33+
histories: list[AskHistory],
34+
prompt_builder: PromptBuilder,
35+
custom_instruction: str,
36+
) -> dict:
37+
previous_query_summaries = (
38+
[history.question for history in histories] if histories else []
39+
)
40+
query = "\n".join(previous_query_summaries) + "\n" + query
41+
42+
_prompt = prompt_builder.run(
43+
query=query,
44+
db_schemas=db_schemas,
45+
language=language,
46+
custom_instruction=custom_instruction,
47+
)
48+
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}
49+
50+
51+
@observe(as_type="generation", capture_input=False)
52+
@trace_cost
53+
async def user_clarification_assistance(
54+
prompt: dict, generator: Any, query_id: str, generator_name: str
55+
) -> dict:
56+
return await generator(
57+
prompt=prompt.get("prompt"),
58+
query_id=query_id,
59+
), generator_name
60+
61+
62+
## End of Pipeline
63+
64+
65+
class UserClarificationAssistance(BasicPipeline):
66+
def __init__(
67+
self,
68+
llm_provider: LLMProvider,
69+
**kwargs,
70+
):
71+
self._user_queues = {}
72+
self._components = {
73+
"generator": llm_provider.get_generator(
74+
system_prompt=user_clarification_assistance_system_prompt,
75+
streaming_callback=self._streaming_callback,
76+
),
77+
"generator_name": llm_provider.get_model(),
78+
"prompt_builder": PromptBuilder(
79+
template=user_clarification_assistance_user_prompt_template
80+
),
81+
}
82+
83+
super().__init__(
84+
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
85+
)
86+
87+
def _streaming_callback(self, chunk, query_id):
88+
if query_id not in self._user_queues:
89+
self._user_queues[
90+
query_id
91+
] = asyncio.Queue() # Create a new queue for the user if it doesn't exist
92+
# Put the chunk content into the user's queue
93+
asyncio.create_task(self._user_queues[query_id].put(chunk.content))
94+
if chunk.meta.get("finish_reason"):
95+
asyncio.create_task(self._user_queues[query_id].put("<DONE>"))
96+
97+
async def get_streaming_results(self, query_id):
98+
async def _get_streaming_results(query_id):
99+
return await self._user_queues[query_id].get()
100+
101+
if query_id not in self._user_queues:
102+
self._user_queues[
103+
query_id
104+
] = asyncio.Queue() # Ensure the user's queue exists
105+
while True:
106+
try:
107+
# Wait for an item from the user's queue
108+
self._streaming_results = await asyncio.wait_for(
109+
_get_streaming_results(query_id), timeout=120
110+
)
111+
if (
112+
self._streaming_results == "<DONE>"
113+
): # Check for end-of-stream signal
114+
del self._user_queues[query_id]
115+
break
116+
if self._streaming_results: # Check if there are results to yield
117+
yield self._streaming_results
118+
self._streaming_results = "" # Clear after yielding
119+
except TimeoutError:
120+
break
121+
122+
@observe(name="User Clarification Assistance")
123+
async def run(
124+
self,
125+
query: str,
126+
db_schemas: list[str],
127+
language: str,
128+
query_id: Optional[str] = None,
129+
histories: Optional[list[AskHistory]] = None,
130+
custom_instruction: Optional[str] = None,
131+
):
132+
logger.info("User Clarification Assistance pipeline is running...")
133+
return await self._pipe.execute(
134+
["user_clarification_assistance"],
135+
inputs={
136+
"query": query,
137+
"db_schemas": db_schemas,
138+
"language": language,
139+
"query_id": query_id or "",
140+
"histories": histories or [],
141+
"custom_instruction": custom_instruction or "",
142+
**self._components,
143+
},
144+
)

wren-ai-service/src/web/v1/services/ask.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,7 @@ async def ask(
291291
self._pipelines["misleading_assistance"].run(
292292
query=user_query,
293293
histories=histories,
294-
db_schemas=intent_classification_result.get(
295-
"db_schemas"
296-
),
294+
db_schemas=table_ddls,
297295
language=ask_request.configurations.language,
298296
query_id=ask_request.query_id,
299297
custom_instruction=ask_request.custom_instruction,
@@ -316,9 +314,7 @@ async def ask(
316314
self._pipelines["data_assistance"].run(
317315
query=user_query,
318316
histories=histories,
319-
db_schemas=intent_classification_result.get(
320-
"db_schemas"
321-
),
317+
db_schemas=table_ddls,
322318
language=ask_request.configurations.language,
323319
query_id=ask_request.query_id,
324320
custom_instruction=ask_request.custom_instruction,
@@ -364,6 +360,7 @@ async def ask(
364360
sql_data=last_sql_data,
365361
language=ask_request.configurations.language,
366362
query_id=ask_request.query_id,
363+
custom_instruction=ask_request.custom_instruction,
367364
)
368365
)
369366

@@ -378,6 +375,28 @@ async def ask(
378375
)
379376
results["metadata"]["type"] = "GENERAL"
380377
return results
378+
elif intent == "USER_CLARIFICATION":
379+
asyncio.create_task(
380+
self._pipelines["user_clarification_assistance"].run(
381+
query=user_query,
382+
db_schemas=table_ddls,
383+
language=ask_request.configurations.language,
384+
query_id=ask_request.query_id,
385+
custom_instruction=ask_request.custom_instruction,
386+
)
387+
)
388+
389+
self._ask_results[query_id] = AskResultResponse(
390+
status="finished",
391+
type="GENERAL",
392+
rephrased_question=rephrased_question,
393+
intent_reasoning=intent_reasoning,
394+
trace_id=trace_id,
395+
is_followup=True if histories else False,
396+
general_type="USER_CLARIFICATION",
397+
)
398+
results["metadata"]["type"] = "GENERAL"
399+
return results
381400
else:
382401
self._ask_results[query_id] = AskResultResponse(
383402
status="understanding",
@@ -689,6 +708,10 @@ async def get_ask_streaming_result(
689708
_pipeline_name = "misleading_assistance"
690709
elif self._ask_results.get(query_id).general_type == "DATA_EXPLORATION":
691710
_pipeline_name = "data_exploration_assistance"
711+
elif (
712+
self._ask_results.get(query_id).general_type == "USER_CLARIFICATION"
713+
):
714+
_pipeline_name = "user_clarification_assistance"
692715
elif self._ask_results.get(query_id).status == "planning":
693716
if self._ask_results.get(query_id).is_followup:
694717
_pipeline_name = "followup_sql_generation_reasoning"

wren-ai-service/tools/config/config.example.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ pipes:
137137
llm: litellm_llm.default
138138
- name: data_exploration_assistance
139139
llm: litellm_llm.default
140+
- name: user_clarification_assistance
141+
llm: litellm_llm.default
140142
- name: sql_pairs_indexing
141143
document_store: qdrant
142144
embedder: litellm_embedder.default

wren-ai-service/tools/config/config.full.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ pipes:
137137
llm: litellm_llm.default
138138
- name: data_exploration_assistance
139139
llm: litellm_llm.default
140+
- name: user_clarification_assistance
141+
llm: litellm_llm.default
140142
- name: sql_pairs_indexing
141143
document_store: qdrant
142144
embedder: litellm_embedder.default

0 commit comments

Comments
 (0)