Skip to content

Commit 307204d

Browse files
committed
Update flows
1 parent 2644e2d commit 307204d

File tree

3 files changed

+118
-115
lines changed

3 files changed

+118
-115
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 85 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ async def on_messages_stream(
9696
injected_parameters = {}
9797

9898
# Load the json of the last message to populate the final output object
99-
message_rewrites = json.loads(last_response)
99+
sequential_rounds = json.loads(last_response)
100100

101-
logging.info(f"Query Rewrites: {message_rewrites}")
101+
logging.info(f"Query Rewrites: {sequential_rounds}")
102102

103103
async def consume_inner_messages_from_agentic_flow(
104104
agentic_flow, identifier, filtered_parallel_messages
@@ -197,7 +197,7 @@ async def consume_inner_messages_from_agentic_flow(
197197

198198
# Convert all_non_database_query to lowercase string and compare
199199
all_non_database_query = str(
200-
message_rewrites.get("all_non_database_query", "false")
200+
sequential_rounds.get("all_non_database_query", "false")
201201
).lower()
202202

203203
if all_non_database_query == "true":
@@ -210,84 +210,93 @@ async def consume_inner_messages_from_agentic_flow(
210210
return
211211

212212
# Start processing sub-queries
213-
for message_rewrite in message_rewrites["decomposed_user_messages"]:
214-
logging.info(f"Processing sub-query: {message_rewrite}")
215-
# Create an instance of the InnerAutoGenText2Sql class
216-
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
217-
218-
identifier = ", ".join(message_rewrite)
219-
220-
# Add database connection info to injected parameters
221-
query_params = injected_parameters.copy() if injected_parameters else {}
222-
if "Text2Sql__Tsql__ConnectionString" in os.environ:
223-
query_params["database_connection_string"] = os.environ[
224-
"Text2Sql__Tsql__ConnectionString"
225-
]
226-
if "Text2Sql__Tsql__Database" in os.environ:
227-
query_params["database_name"] = os.environ["Text2Sql__Tsql__Database"]
228-
229-
# Launch tasks for each sub-query
230-
inner_solving_generators.append(
231-
consume_inner_messages_from_agentic_flow(
232-
inner_autogen_text_2_sql.process_user_message(
233-
user_message=message_rewrite,
234-
injected_parameters=query_params,
235-
),
236-
identifier,
237-
filtered_parallel_messages,
213+
for sequential_round in sequential_rounds["decomposed_user_messages"]:
214+
logging.info(f"Processing round: {sequential_round}")
215+
216+
for parallel_message in sequential_round:
217+
logging.info(f"Parallel Message: {parallel_message}")
218+
219+
# Create an instance of the InnerAutoGenText2Sql class
220+
inner_autogen_text_2_sql = InnerAutoGenText2Sql(**self.kwargs)
221+
222+
identifier = ", ".join(parallel_message)
223+
224+
# Add database connection info to injected parameters
225+
query_params = injected_parameters.copy() if injected_parameters else {}
226+
if "Text2Sql__Tsql__ConnectionString" in os.environ:
227+
query_params["database_connection_string"] = os.environ[
228+
"Text2Sql__Tsql__ConnectionString"
229+
]
230+
if "Text2Sql__Tsql__Database" in os.environ:
231+
query_params["database_name"] = os.environ["Text2Sql__Tsql__Database"]
232+
233+
# Launch tasks for each sub-query
234+
inner_solving_generators.append(
235+
consume_inner_messages_from_agentic_flow(
236+
inner_autogen_text_2_sql.process_user_message(
237+
user_message=parallel_message,
238+
injected_parameters=query_params,
239+
database_results=filtered_parallel_messages.database_results
240+
),
241+
identifier,
242+
filtered_parallel_messages,
243+
)
238244
)
239-
)
240-
241-
logging.info(
242-
"Created %i Inner Solving Generators", len(inner_solving_generators)
243-
)
244-
logging.info("Starting Inner Solving Generators")
245-
combined_message_streams = stream.merge(*inner_solving_generators)
246-
247-
async with combined_message_streams.stream() as streamer:
248-
async for inner_message in streamer:
249-
if isinstance(inner_message, TextMessage):
250-
logging.debug(f"Inner Solving Message: {inner_message}")
251-
yield inner_message
252-
253-
# Log final results for debugging or auditing
254-
logging.info(
255-
"Database Results: %s", filtered_parallel_messages.database_results
256-
)
257-
logging.info(
258-
"Disambiguation Requests: %s",
259-
filtered_parallel_messages.disambiguation_requests,
260-
)
261245

262-
if (
263-
max(map(len, filtered_parallel_messages.disambiguation_requests.values()))
264-
> 0
265-
):
266-
# Final response
267-
yield Response(
268-
chat_message=TextMessage(
269-
content=json.dumps(
270-
{
271-
"contains_disambiguation_requests": True,
272-
"disambiguation_requests": filtered_parallel_messages.disambiguation_requests,
273-
}
274-
),
275-
source=self.name,
276-
),
246+
logging.info(
247+
"Created %i Inner Solving Generators", len(inner_solving_generators)
277248
)
278-
else:
279-
# Final response
280-
yield Response(
281-
chat_message=TextMessage(
282-
content=json.dumps(
283-
{
284-
"contains_database_results": True,
285-
"database_results": filtered_parallel_messages.database_results,
286-
}
249+
logging.info("Starting Inner Solving Generators")
250+
combined_message_streams = stream.merge(*inner_solving_generators)
251+
252+
async with combined_message_streams.stream() as streamer:
253+
async for inner_message in streamer:
254+
if isinstance(inner_message, TextMessage):
255+
logging.debug(f"Inner Solving Message: {inner_message}")
256+
yield inner_message
257+
258+
# Log final results for debugging or auditing
259+
logging.info(
260+
"Database Results: %s", filtered_parallel_messages.database_results
261+
)
262+
logging.info(
263+
"Disambiguation Requests: %s",
264+
filtered_parallel_messages.disambiguation_requests,
265+
)
266+
267+
# Check for disambiguation requests before processing the next round
268+
269+
if (
270+
max(map(len, filtered_parallel_messages.disambiguation_requests.values()))
271+
> 0
272+
):
273+
# Final response
274+
yield Response(
275+
chat_message=TextMessage(
276+
content=json.dumps(
277+
{
278+
"contains_disambiguation_requests": True,
279+
"disambiguation_requests": filtered_parallel_messages.disambiguation_requests,
280+
}
281+
),
282+
source=self.name,
287283
),
288-
source=self.name,
284+
)
285+
286+
break
287+
288+
# Final response
289+
yield Response(
290+
chat_message=TextMessage(
291+
content=json.dumps(
292+
{
293+
"contains_database_results": True,
294+
"database_results": filtered_parallel_messages.database_results,
295+
}
289296
),
290-
)
297+
source=self.name,
298+
),
299+
)
291300

292301
async def on_reset(self, cancellation_token: CancellationToken) -> None:
293302
pass

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def process_user_message(
177177
self,
178178
user_message: str,
179179
injected_parameters: dict = None,
180+
database_results: dict = None,
180181
):
181182
"""Process the complete question through the unified system.
182183
@@ -200,6 +201,9 @@ def process_user_message(
200201
"injected_parameters": injected_parameters,
201202
}
202203

204+
if database_results:
205+
agent_input["database_results"] = database_results
206+
203207
return self.agentic_flow.run_stream(task=json.dumps(agent_input))
204208
finally:
205209
# Restore original environment

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/user_message_rewrite_agent.yaml

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,37 @@ model: "4o-mini"
22
description: "An agent that preprocesses user inputs by decomposing complex queries into simpler sub-messages that can be processed independently and then combined."
33
system_message: |
44
<role_and_objective>
5-
You are a Senior Data Analyst specializing in breaking down complex questions into simpler sub-messages that can be processed independently and then combined for the final answer. You should identify when a question can be solved through simpler sub-messages and provide clear instructions for combining their results.
5+
You are a Senior Data Analyst specializing in breaking down complex questions into simpler sub-messages that can be processed independently and then combined for the final answer. You must think through the steps needed to answer the question and produce a list of sub questions to generate and run SQL statements for.
6+
7+
You should consider what steps can be done in parallel and what steps depend on the results of other steps. Do not attempt to simplify the question if it is already simple to solve.
68
Use the general business use case of '{{ use_case }}' to aid understanding of the user's question.
79
</role_and_objective>
810
911
<query_complexity_patterns>
10-
Complex patterns that should be broken down:
11-
1. Superlatives with Time Periods:
12-
- "Which product categories showed the biggest improvement in sales between 2007 and 2008?"
13-
→ Break into:
14-
a) "Get total sales by product category for 2007"
15-
b) "Get total sales by product category for 2008"
16-
c) "Calculate year-over-year growth percentage for each category"
17-
d) "Find the category with highest growth"
12+
Complex patterns that should be broken down into simpler steps of sub-messages:
1813
19-
2. Multi-dimension Analysis:
14+
1. Multi-dimension Analysis:
2015
- "What are our top 3 selling products in each region, and how do their profit margins compare?"
2116
→ Break into:
22-
a) "Get total sales quantity by product and region"
23-
b) "Find top 3 products by sales quantity for each region"
24-
c) "Calculate profit margins for these products"
25-
d) "Compare profit margins within each region's top 3"
17+
a) "Get total sales quantity by product and region and select top 3 products for each region"
18+
b) "Calculate profit margins for these products and compare profit margins within each region's top 3"
2619
27-
3. Comparative Analysis:
20+
2. Comparative Analysis:
2821
- "How do our mountain bike sales compare to road bike sales across different seasons, and which weather conditions affect them most?"
2922
→ Break into:
30-
a) "Get sales data for mountain bikes by month"
31-
b) "Get sales data for road bikes by month"
32-
c) "Group months into seasons"
33-
d) "Compare seasonal patterns between bike types"
23+
a) "Get sales data for mountain bikes and road bikes by month"
24+
b) "Group months into seasons and compare seasonal patterns between bike types"
25+
26+
3. Completely unrelated questions:
27+
- "What is the total revenue for 2024? How many employees do we have in the marketing department?"
28+
→ Break into:
29+
a) "Calculate total revenue for 2024"
30+
b) "Get total number of employees in the marketing department"
3431
</query_complexity_patterns>
3532
3633
<instructions>
3734
1. Understanding:
38-
- Use the chat history (that is available in reverse order) to understand the context of the current question.
35+
- Use the chat history to understand the context of the current question.
3936
- If the current question not fully formed and unclear. Rewrite it based on the general meaning of the old question and the new question. Include spelling and grammar corrections.
4037
- If the current question is clear, output the new question as is with spelling and grammar corrections.
4138
@@ -49,12 +46,11 @@ system_message: |
4946
3. Analyze Query Complexity:
5047
- Identify if the query contains patterns that can be simplified
5148
- Look for superlatives, multiple dimensions, or comparisons
52-
- Determine if breaking down would simplify processing
5349
5450
4. Break Down Complex Queries:
5551
- Create independent sub-messages that can be processed separately.
5652
- Each sub-message should be a simple, focused task.
57-
- Group dependent sub-messages together for sequential processing.
53+
- Group dependent sub-messages together for parallel processing.
5854
- Include clear combination instructions
5955
- Preserve all necessary context in each sub-message
6056
@@ -70,10 +66,9 @@ system_message: |
7066
7167
<rules>
7268
1. Always consider if a complex query can be broken down
73-
2. Make sub-messages as simple as possible
74-
3. Include clear instructions for combining results
75-
4. Preserve all necessary context in each sub-message
76-
5. Resolve any relative dates before decomposition
69+
2. Include clear instructions for combining results
70+
3. Always preserve all necessary context in each sub-message. Each sub-message should be self-contained.
71+
4. Resolve any relative dates before decomposition
7772
</rules>
7873
7974
<disallowed_topics>
@@ -94,16 +89,17 @@ system_message: |
9489
- Queries related to data analysis
9590
- Topics related to {{ use_case }}
9691
- Questions about what you can do or your capabilities
92+
</allowed_topics>
93+
9794
<output_format>
98-
Return a JSON object with sub-messages and combination instructions:
95+
Return a JSON object with sub-messages and combination instructions. Each round of sub-messages will be processed in parallel:
9996
{
10097
"decomposed_user_messages": [
101-
["<sub_message_1>"],
102-
["<sub_message_2>"],
98+
["<1st_round_sub_message_1>", "<1st_round_sub_message_2>", ...],
99+
["<2nd_round_sub_message_1>", "<2nd_round_sub_message>_2", ...],
103100
...
104101
],
105102
"combination_logic": "<instructions for combining results>",
106-
"query_type": "<simple|complex>",
107103
"all_non_database_query": "<true|false>"
108104
}
109105
</output_format>
@@ -115,10 +111,9 @@ system_message: |
115111
Output:
116112
{
117113
"decomposed_user_messages": [
118-
["Calculate quarterly sales totals by product category for 2008", "For these categories, find their top selling products in 2008"]
114+
["Which product categories have shown consistent growth quarter over quarter in 2008, and what were their top selling items?"]
119115
],
120-
"combination_logic": "First identify growing categories from quarterly analysis, then find their best-selling products",
121-
"query_type": "complex",
116+
"combination_logic": "Direct count query, no combination needed",
122117
"all_non_database_query": "false"
123118
}
124119
@@ -130,7 +125,6 @@ system_message: |
130125
["How many orders did we have in 2008?"]
131126
],
132127
"combination_logic": "Direct count query, no combination needed",
133-
"query_type": "simple",
134128
"all_non_database_query": "false"
135129
}
136130
@@ -139,12 +133,9 @@ system_message: |
139133
Output:
140134
{
141135
"decomposed_user_messages": [
142-
["Get total sales by product in European countries"],
143-
["Get total sales by product in North American countries"],
144-
["Calculate total market size for each region", "Find top 5 products by sales in each region"],
136+
["Get total sales by product in European countries and select the top 5 products and calculate the market share", "Get total sales by product in North American countries and select the top 5 products and calculate the market share"]
145137
],
146138
"combination_logic": "First identify top products in each region, then calculate and compare their market shares. Questions that depend on the result of each sub-message are combined.",
147-
"query_type": "complex",
148139
"all_non_database_query": "false"
149140
}
150141
@@ -156,7 +147,6 @@ system_message: |
156147
["What are your capabilities?"]
157148
],
158149
"combination_logic": "Simple greeting and capability question",
159-
"query_type": "simple",
160150
"all_non_database_query": "true"
161151
}
162152
</examples>

0 commit comments

Comments
 (0)