Skip to content

Commit 1eb6197

Browse files
committed
Update query rewrite agent
1 parent a4775eb commit 1eb6197

File tree

4 files changed

+62
-56
lines changed

4 files changed

+62
-56
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def consume_inner_messages_from_agentic_flow(
7575
if isinstance(inner_message, TaskResult) is False:
7676
try:
7777
inner_message = json.loads(inner_message.content)
78-
logging.info(f"Loaded: {inner_message}")
78+
logging.info(f"Inner Loaded: {inner_message}")
7979

8080
# Search for specific message types and add them to the final output object
8181
if (
@@ -91,6 +91,21 @@ async def consume_inner_messages_from_agentic_flow(
9191
}
9292
)
9393

94+
if ("contains_pre_run_results" in inner_message) and (
95+
inner_message["contains_pre_run_results"] is True
96+
):
97+
for pre_run_sql_query, pre_run_result in inner_message[
98+
"cached_questions_and_schemas"
99+
].items():
100+
database_results[identifier].append(
101+
{
102+
"sql_query": pre_run_sql_query.replace(
103+
"\n", " "
104+
),
105+
"sql_rows": pre_run_result["sql_rows"],
106+
}
107+
)
108+
94109
except (JSONDecodeError, TypeError) as e:
95110
logging.error("Could not load message: %s", inner_message)
96111
logging.warning(f"Error processing message: {e}")
@@ -113,13 +128,15 @@ async def consume_inner_messages_from_agentic_flow(
113128
self.engine_specific_rules, **self.kwargs
114129
)
115130

131+
identifier = ", ".join(query_rewrite)
132+
116133
# Launch tasks for each sub-query
117134
inner_solving_generators.append(
118135
consume_inner_messages_from_agentic_flow(
119136
inner_autogen_text_2_sql.process_question(
120137
question=query_rewrite, injected_parameters=injected_parameters
121138
),
122-
query_rewrite,
139+
identifier,
123140
database_results,
124141
)
125142
)

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,55 +39,46 @@ async def on_messages_stream(
3939
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4040
) -> AsyncGenerator[AgentMessage | Response, None]:
4141
# Get the decomposed questions from the query_rewrite_agent
42-
parameter_input = messages[0].content
43-
last_response = messages[-1].content
4442
try:
45-
user_questions = json.loads(last_response)
46-
injected_parameters = json.loads(parameter_input)["injected_parameters"]
43+
request_details = json.loads(messages[0].content)
44+
injected_parameters = request_details["injected_parameters"]
45+
user_questions = request_details["question"]
4746
logging.info(f"Processing questions: {user_questions}")
4847
logging.info(f"Input Parameters: {injected_parameters}")
48+
except json.JSONDecodeError:
49+
# If not JSON array, process as single question
50+
raise ValueError("Could not load message")
4951

50-
# Initialize results dictionary
51-
cached_results = {
52-
"cached_questions_and_schemas": [],
53-
"contains_pre_run_results": False,
54-
}
55-
56-
# Process each question sequentially
57-
for question in user_questions:
58-
# Fetch the queries from the cache based on the question
59-
logging.info(f"Fetching queries from cache for question: {question}")
60-
cached_query = await self.sql_connector.fetch_queries_from_cache(
61-
question, injected_parameters=injected_parameters
62-
)
52+
# Initialize results dictionary
53+
cached_results = {
54+
"cached_questions_and_schemas": [],
55+
"contains_pre_run_results": False,
56+
}
6357

64-
# If any question has pre-run results, set the flag
65-
if cached_query.get("contains_pre_run_results", False):
66-
cached_results["contains_pre_run_results"] = True
58+
# Process each question sequentially
59+
for question in user_questions:
60+
# Fetch the queries from the cache based on the question
61+
logging.info(f"Fetching queries from cache for question: {question}")
62+
cached_query = await self.sql_connector.fetch_queries_from_cache(
63+
question, injected_parameters=injected_parameters
64+
)
6765

68-
# Add the cached results for this question
69-
if cached_query.get("cached_questions_and_schemas"):
70-
cached_results["cached_questions_and_schemas"].extend(
71-
cached_query["cached_questions_and_schemas"]
72-
)
66+
# If any question has pre-run results, set the flag
67+
if cached_query.get("contains_pre_run_results", False):
68+
cached_results["contains_pre_run_results"] = True
7369

74-
logging.info(f"Final cached results: {cached_results}")
75-
yield Response(
76-
chat_message=TextMessage(
77-
content=json.dumps(cached_results), source=self.name
78-
)
79-
)
80-
except json.JSONDecodeError:
81-
# If not JSON array, process as single question
82-
logging.info(f"Processing single question: {last_response}")
83-
cached_queries = await self.sql_connector.fetch_queries_from_cache(
84-
last_response
85-
)
86-
yield Response(
87-
chat_message=TextMessage(
88-
content=json.dumps(cached_queries), source=self.name
70+
# Add the cached results for this question
71+
if cached_query.get("cached_questions_and_schemas"):
72+
cached_results["cached_questions_and_schemas"].extend(
73+
cached_query["cached_questions_and_schemas"]
8974
)
75+
76+
logging.info(f"Final cached results: {cached_results}")
77+
yield Response(
78+
chat_message=TextMessage(
79+
content=json.dumps(cached_results), source=self.name
9080
)
81+
)
9182

9283
async def on_reset(self, cancellation_token: CancellationToken) -> None:
9384
pass

text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ async def fetch_queries_from_cache(
228228

229229
for sql_query, sql_result in zip(sql_queries, sql_results):
230230
query_result_store[sql_query["SqlQuery"]] = {
231-
"result": sql_result,
231+
"sql_rows": sql_result,
232232
"schemas": sql_query["Schemas"],
233233
}
234234

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ system_message: |
4444
- Determine if breaking down would simplify processing
4545
4646
3. Break Down Complex Queries:
47-
- Create independent sub-queries that can be processed separately
47+
- Create independent sub-queries that can be processed separately.
48+
- Each sub-query should be a simple, focused task.
49+
- Group dependent sub-queries together for sequential processing.
4850
- Ensure each sub-query is simple and focused
4951
- Include clear combination instructions
5052
- Preserve all necessary context in each sub-query
@@ -71,8 +73,8 @@ system_message: |
7173
Return a JSON object with sub-queries and combination instructions:
7274
{
7375
"sub_queries": [
74-
"<sub_query_1>",
75-
"<sub_query_2>",
76+
["<sub_query_1>"],
77+
["<sub_query_2>"],
7678
...
7779
],
7880
"combination_logic": "<instructions for combining results>",
@@ -87,9 +89,7 @@ system_message: |
8789
Output:
8890
{
8991
"sub_queries": [
90-
"Calculate quarterly sales totals by product category for 2008",
91-
"Identify categories with positive growth each quarter",
92-
"For these categories, find their top selling products in 2008"
92+
["Calculate quarterly sales totals by product category for 2008", "For these categories, find their top selling products in 2008"]
9393
],
9494
"combination_logic": "First identify growing categories from quarterly analysis, then find their best-selling products",
9595
"query_type": "complex"
@@ -100,7 +100,7 @@ system_message: |
100100
Output:
101101
{
102102
"sub_queries": [
103-
"How many orders did we have in 2008?"
103+
["How many orders did we have in 2008?"]
104104
],
105105
"combination_logic": "Direct count query, no combination needed",
106106
"query_type": "simple"
@@ -111,13 +111,11 @@ system_message: |
111111
Output:
112112
{
113113
"sub_queries": [
114-
"Get total sales by product in European countries",
115-
"Get total sales by product in North American countries",
116-
"Calculate total market size for each region",
117-
"Find top 5 products by sales in each region",
118-
"Calculate market share percentages for these products"
114+
["Get total sales by product in European countries"],
115+
["Get total sales by product in North American countries"],
116+
["Calculate total market size for each region", "Find top 5 products by sales in each region"],
119117
],
120-
"combination_logic": "First identify top products in each region, then calculate and compare their market shares",
118+
"combination_logic": "First identify top products in each region, then calculate and compare their market shares. Questions that depend on the result of each sub-query are combined.",
121119
"query_type": "complex"
122120
}
123121
</examples>

0 commit comments

Comments
 (0)