Skip to content

Commit 6a5fd8e

Browse files
committed
update
1 parent 8af6ce9 commit 6a5fd8e

File tree

6 files changed

+58
-16
lines changed

6 files changed

+58
-16
lines changed

wren-ai-service/src/pipelines/generation/data_assistance.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
{{ db_schema }}
4242
{% endfor %}
4343
44+
{% if histories %}
45+
### PREVIOUS QUESTIONS ###
46+
{% for history in histories %}
47+
{{ history.question }}
48+
{% endfor %}
49+
{% endif %}
50+
4451
### INPUT ###
4552
User's question: {{query}}
4653
Language: {{language}}
@@ -61,13 +68,9 @@ def prompt(
6168
prompt_builder: PromptBuilder,
6269
custom_instruction: str,
6370
) -> dict:
64-
previous_query_summaries = (
65-
[history.question for history in histories] if histories else []
66-
)
67-
query = "\n".join(previous_query_summaries) + "\n" + query
68-
6971
_prompt = prompt_builder.run(
7072
query=query,
73+
histories=histories,
7174
db_schemas=db_schemas,
7275
language=language,
7376
custom_instruction=custom_instruction,

wren-ai-service/src/pipelines/generation/data_exploration_assistance.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
from src.core.pipeline import BasicPipeline
1212
from src.core.provider import LLMProvider
13+
from src.pipelines.common import clean_up_new_lines
14+
from src.utils import trace_cost
15+
from src.web.v1.services.ask import AskHistory
1316

1417
logger = logging.getLogger("wren-ai-service")
1518

@@ -32,10 +35,21 @@
3235
"""
3336

3437
data_exploration_assistance_user_prompt_template = """
38+
{% if histories %}
39+
### PREVIOUS QUESTIONS ###
40+
{% for history in histories %}
41+
{{ history.question }}
42+
{% endfor %}
43+
{% endif %}
44+
45+
### INPUT ###
3546
User Question: {{query}}
3647
Language: {{language}}
3748
SQL Data:
3849
{{ sql_data }}
50+
51+
Custom Instruction: {{ custom_instruction }}
52+
3953
Please think step by step.
4054
"""
4155

@@ -44,18 +58,24 @@
4458
@observe(capture_input=False)
4559
def prompt(
4660
query: str,
61+
histories: list[AskHistory],
4762
language: str,
4863
sql_data: dict,
4964
prompt_builder: PromptBuilder,
65+
custom_instruction: str,
5066
) -> dict:
51-
return prompt_builder.run(
67+
_prompt = prompt_builder.run(
5268
query=query,
5369
language=language,
5470
sql_data=sql_data,
71+
histories=histories,
72+
custom_instruction=custom_instruction,
5573
)
74+
return {"prompt": clean_up_new_lines(_prompt.get("prompt"))}
5675

5776

5877
@observe(as_type="generation", capture_input=False)
78+
@trace_cost
5979
async def data_exploration_assistance(
6080
prompt: dict, generator: Any, query_id: str
6181
) -> dict:
@@ -127,6 +147,8 @@ async def run(
127147
sql_data: dict,
128148
language: str,
129149
query_id: Optional[str] = None,
150+
histories: Optional[list[AskHistory]] = None,
151+
custom_instruction: Optional[str] = None,
130152
):
131153
logger.info("Data Exploration Assistance pipeline is running...")
132154
return await self._pipe.execute(
@@ -136,6 +158,8 @@ async def run(
136158
"language": language,
137159
"query_id": query_id or "",
138160
"sql_data": sql_data,
161+
"histories": histories or [],
162+
"custom_instruction": custom_instruction or "",
139163
**self._components,
140164
},
141165
)

wren-ai-service/src/pipelines/generation/misleading_assistance.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
{{ db_schema }}
4242
{% endfor %}
4343
44+
{% if histories %}
45+
### PREVIOUS QUESTIONS ###
46+
{% for history in histories %}
47+
{{ history.question }}
48+
{% endfor %}
49+
{% endif %}
50+
4451
### INPUT ###
4552
User's question: {{query}}
4653
Language: {{language}}
@@ -61,13 +68,9 @@ def prompt(
6168
prompt_builder: PromptBuilder,
6269
custom_instruction: str,
6370
) -> dict:
64-
previous_query_summaries = (
65-
[history.question for history in histories] if histories else []
66-
)
67-
query = "\n".join(previous_query_summaries) + "\n" + query
68-
6971
_prompt = prompt_builder.run(
7072
query=query,
73+
histories=histories,
7174
db_schemas=db_schemas,
7275
language=language,
7376
custom_instruction=custom_instruction,

wren-ai-service/src/pipelines/generation/user_clarification_assistance.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,9 @@ def prompt(
3434
prompt_builder: PromptBuilder,
3535
custom_instruction: str,
3636
) -> dict:
37-
previous_query_summaries = (
38-
[history.question for history in histories] if histories else []
39-
)
40-
query = "\n".join(previous_query_summaries) + "\n" + query
41-
4237
_prompt = prompt_builder.run(
4338
query=query,
39+
histories=histories,
4440
db_schemas=db_schemas,
4541
language=language,
4642
custom_instruction=custom_instruction,

wren-ai-service/src/pipelines/generation/user_guide_assistance.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from src.core.provider import LLMProvider
1313
from src.pipelines.common import clean_up_new_lines
1414
from src.utils import trace_cost
15+
from src.web.v1.services.ask import AskHistory
1516

1617
logger = logging.getLogger("wren-ai-service")
1718

@@ -34,6 +35,14 @@
3435
"""
3536

3637
user_guide_assistance_user_prompt_template = """
38+
{% if histories %}
39+
### PREVIOUS QUESTIONS ###
40+
{% for history in histories %}
41+
{{ history.question }}
42+
{% endfor %}
43+
{% endif %}
44+
45+
### INPUT ###
3746
User Question: {{query}}
3847
Language: {{language}}
3948
User Guide:
@@ -53,11 +62,13 @@ def prompt(
5362
query: str,
5463
language: str,
5564
wren_ai_docs: list[dict],
65+
histories: list[AskHistory],
5666
prompt_builder: PromptBuilder,
5767
custom_instruction: str,
5868
) -> dict:
5969
_prompt = prompt_builder.run(
6070
query=query,
71+
histories=histories,
6172
language=language,
6273
docs=wren_ai_docs,
6374
custom_instruction=custom_instruction,
@@ -144,6 +155,7 @@ async def run(
144155
query: str,
145156
language: str,
146157
query_id: Optional[str] = None,
158+
histories: Optional[list[AskHistory]] = None,
147159
custom_instruction: Optional[str] = None,
148160
):
149161
logger.info("User Guide Assistance pipeline is running...")
@@ -153,6 +165,7 @@ async def run(
153165
"query": query,
154166
"language": language,
155167
"query_id": query_id or "",
168+
"histories": histories or [],
156169
"custom_instruction": custom_instruction or "",
157170
**self._components,
158171
**self._configs,

wren-ai-service/src/web/v1/services/ask.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ async def ask(
336336
asyncio.create_task(
337337
self._pipelines["user_guide_assistance"].run(
338338
query=user_query,
339+
histories=histories,
339340
language=ask_request.configurations.language,
340341
query_id=ask_request.query_id,
341342
custom_instruction=ask_request.custom_instruction,
@@ -357,6 +358,7 @@ async def ask(
357358
asyncio.create_task(
358359
self._pipelines["data_exploration_assistance"].run(
359360
query=user_query,
361+
histories=histories,
360362
sql_data=last_sql_data,
361363
language=ask_request.configurations.language,
362364
query_id=ask_request.query_id,
@@ -379,6 +381,7 @@ async def ask(
379381
asyncio.create_task(
380382
self._pipelines["user_clarification_assistance"].run(
381383
query=user_query,
384+
histories=histories,
382385
db_schemas=table_ddls,
383386
language=ask_request.configurations.language,
384387
query_id=ask_request.query_id,

0 commit comments

Comments
 (0)