Skip to content

Commit 1132cd0

Browse files
committed
Parallel solving
1 parent e716b8d commit 1132cd0

File tree

3 files changed

+277
-107
lines changed

3 files changed

+277
-107
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 14 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
"""
2-
Copyright (c) Microsoft Corporation.
3-
Licensed under the MIT License.
4-
"""
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
53
from autogen_agentchat.conditions import (
64
TextMentionTermination,
75
MaxMessageTermination,
@@ -10,11 +8,8 @@
108
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
119
from autogen_text_2_sql.creators.llm_agent_creator import LLMAgentCreator
1210
import logging
13-
from autogen_text_2_sql.custom_agents.sql_query_cache_agent import (
14-
SqlQueryCacheAgent,
15-
)
16-
from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import (
17-
SqlSchemaSelectionAgent,
11+
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
12+
ParallelQuerySolvingAgent,
1813
)
1914
from autogen_text_2_sql.custom_agents.answer_and_sources_agent import (
2015
AnswerAndSourcesAgent,
@@ -45,23 +40,9 @@ async def on_messages_stream(self, messages, sender=None, config=None):
4540

4641
class AutoGenText2Sql:
4742
def __init__(self, engine_specific_rules: str, **kwargs: dict):
48-
self.pre_run_query_cache = False
4943
self.target_engine = os.environ["Text2Sql__DatabaseEngine"].upper()
5044
self.engine_specific_rules = engine_specific_rules
5145
self.kwargs = kwargs
52-
self.set_mode()
53-
54-
def set_mode(self):
55-
"""Set the mode of the plugin based on the environment variables."""
56-
self.pre_run_query_cache = (
57-
os.environ.get("Text2Sql__PreRunQueryCache", "True").lower() == "true"
58-
)
59-
self.use_column_value_store = (
60-
os.environ.get("Text2Sql__UseColumnValueStore", "True").lower() == "true"
61-
)
62-
self.use_query_cache = (
63-
os.environ.get("Text2Sql__UseQueryCache", "True").lower() == "true"
64-
)
6546

6647
def get_all_agents(self):
6748
"""Get all agents for the complete flow."""
@@ -72,43 +53,8 @@ def get_all_agents(self):
7253
"query_rewrite_agent", current_datetime=current_datetime
7354
)
7455

75-
self.sql_query_generation_agent = LLMAgentCreator.create(
76-
"sql_query_generation_agent",
77-
target_engine=self.target_engine,
78-
engine_specific_rules=self.engine_specific_rules,
79-
**self.kwargs,
80-
)
81-
82-
# If relationship_paths not provided, use a generic template
83-
if "relationship_paths" not in self.kwargs:
84-
self.kwargs[
85-
"relationship_paths"
86-
] = """
87-
Common relationship paths to consider:
88-
- Transaction → Related Dimensions (for basic analysis)
89-
- Geographic → Location hierarchies (for geographic analysis)
90-
- Temporal → Date hierarchies (for time-based analysis)
91-
- Entity → Attributes (for entity-specific analysis)
92-
"""
93-
94-
self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
95-
target_engine=self.target_engine,
96-
engine_specific_rules=self.engine_specific_rules,
97-
**self.kwargs,
98-
)
99-
100-
self.sql_query_correction_agent = LLMAgentCreator.create(
101-
"sql_query_correction_agent",
102-
target_engine=self.target_engine,
103-
engine_specific_rules=self.engine_specific_rules,
104-
**self.kwargs,
105-
)
106-
107-
self.sql_disambiguation_agent = LLMAgentCreator.create(
108-
"sql_disambiguation_agent",
109-
target_engine=self.target_engine,
110-
engine_specific_rules=self.engine_specific_rules,
111-
**self.kwargs,
56+
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(
57+
engine_specific_rules=self.engine_specific_rules, **self.kwargs
11258
)
11359

11460
self.answer_and_sources_agent = AnswerAndSourcesAgent()
@@ -119,17 +65,10 @@ def get_all_agents(self):
11965
agents = [
12066
self.user_proxy,
12167
self.query_rewrite_agent,
122-
self.sql_query_generation_agent,
123-
self.sql_schema_selection_agent,
124-
self.sql_query_correction_agent,
125-
self.sql_disambiguation_agent,
68+
self.parallel_query_solving_agent,
12669
self.answer_and_sources_agent,
12770
]
12871

129-
if self.use_query_cache:
130-
self.query_cache_agent = SqlQueryCacheAgent()
131-
agents.append(self.query_cache_agent)
132-
13372
return agents
13473

13574
@property
@@ -149,51 +88,19 @@ def unified_selector(self, messages):
14988
decision = None
15089

15190
# If this is the first message start with query_rewrite_agent
152-
if len(messages) == 1:
91+
if current_agent == "start":
15392
decision = "query_rewrite_agent"
15493
# Handle transition after query rewriting
15594
elif current_agent == "query_rewrite_agent":
156-
decision = (
157-
"sql_query_cache_agent"
158-
if self.use_query_cache
159-
else "sql_schema_selection_agent"
160-
)
161-
# Handle subsequent agent transitions
162-
elif current_agent == "sql_query_cache_agent":
163-
# Always go through schema selection after cache check
164-
decision = "sql_schema_selection_agent"
165-
elif current_agent == "sql_schema_selection_agent":
166-
decision = "sql_disambiguation_agent"
167-
elif current_agent == "sql_disambiguation_agent":
168-
decision = "sql_query_generation_agent"
169-
elif current_agent == "sql_query_generation_agent":
170-
decision = "sql_query_correction_agent"
171-
elif current_agent == "sql_query_correction_agent":
172-
try:
173-
correction_result = json.loads(messages[-1].content)
174-
if isinstance(correction_result, dict):
175-
if "answer" in correction_result and "sources" in correction_result:
176-
decision = "answer_and_sources_agent"
177-
elif "corrected_query" in correction_result:
178-
if correction_result.get("executing", False):
179-
decision = "sql_query_correction_agent"
180-
else:
181-
decision = "sql_query_generation_agent"
182-
elif "error" in correction_result:
183-
decision = "sql_query_generation_agent"
184-
elif isinstance(correction_result, list) and len(correction_result) > 0:
185-
if "requested_fix" in correction_result[0]:
186-
decision = "sql_query_generation_agent"
187-
188-
if decision is None:
189-
decision = "sql_query_generation_agent"
190-
except json.JSONDecodeError:
191-
decision = "sql_query_generation_agent"
192-
elif current_agent == "answer_and_sources_agent":
193-
decision = "user_proxy" # Let user_proxy send TERMINATE
95+
decision = "parallel_query_solving_agent"
96+
# Handle transition after parallel query solving
97+
elif current_agent == "parallel_query_solving_agent":
98+
decision = "answer_and_sources_agent"
19499

195100
if decision:
196101
logging.info(f"Agent transition: {current_agent} -> {decision}")
102+
else:
103+
logging.info(f"No agent transition defined from {current_agent}")
197104

198105
return decision
199106

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from typing import AsyncGenerator, List, Sequence
4+
5+
from autogen_agentchat.agents import BaseChatAgent
6+
from autogen_agentchat.base import Response
7+
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
8+
from autogen_core import CancellationToken
9+
import json
10+
import logging
11+
import asyncio
12+
from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql
13+
14+
15+
class ParallelQuerySolvingAgent(BaseChatAgent):
16+
def __init__(self, engine_specific_rules: str, **kwargs: dict):
17+
super().__init__(
18+
"parallel_query_solving_agent",
19+
"An agent that solves each query in parallel.",
20+
)
21+
22+
self.engine_specific_rules = engine_specific_rules
23+
self.kwargs = kwargs
24+
25+
@property
26+
def produced_message_types(self) -> List[type[ChatMessage]]:
27+
return [TextMessage]
28+
29+
async def on_messages(
30+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
31+
) -> Response:
32+
# Calls the on_messages_stream.
33+
response: Response | None = None
34+
async for message in self.on_messages_stream(messages, cancellation_token):
35+
if isinstance(message, Response):
36+
response = message
37+
assert response is not None
38+
return response
39+
40+
async def on_messages_stream(
41+
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
42+
) -> AsyncGenerator[AgentMessage | Response, None]:
43+
last_response = messages[-1].content
44+
45+
# Load the json of the last message to populate the final output object
46+
query_rewrites = json.loads(last_response)
47+
48+
logging.info(f"Query Rewrite: {query_rewrites}")
49+
50+
inner_solving_tasks = []
51+
52+
for query_rewrite in query_rewrites:
53+
# Create an instance of the InnerAutoGenText2Sql class
54+
inner_autogen_text_2_sql = InnerAutoGenText2Sql(
55+
self.engine_specific_rules, **self.kwargs
56+
)
57+
58+
inner_solving_tasks.append(
59+
inner_autogen_text_2_sql.run_stream(task=query_rewrite)
60+
)
61+
62+
# Wait for all the inner solving tasks to complete
63+
inner_solving_results = await asyncio.gather(*inner_solving_tasks)
64+
65+
logging.info(f"Inner Solving Results: {inner_solving_results}")
66+
67+
yield Response(
68+
chat_message=TextMessage(
69+
content=json.dumps(inner_solving_results), source=self.name
70+
)
71+
)
72+
73+
async def on_reset(self, cancellation_token: CancellationToken) -> None:
74+
pass

0 commit comments

Comments
 (0)