1- """
2- Copyright (c) Microsoft Corporation.
3- Licensed under the MIT License.
4- """
1+ # Copyright (c) Microsoft Corporation.
2+ # Licensed under the MIT License.
53from autogen_agentchat .conditions import (
64 TextMentionTermination ,
75 MaxMessageTermination ,
108from autogen_text_2_sql .creators .llm_model_creator import LLMModelCreator
119from autogen_text_2_sql .creators .llm_agent_creator import LLMAgentCreator
1210import logging
13- from autogen_text_2_sql .custom_agents .sql_query_cache_agent import (
14- SqlQueryCacheAgent ,
15- )
16- from autogen_text_2_sql .custom_agents .sql_schema_selection_agent import (
17- SqlSchemaSelectionAgent ,
11+ from autogen_text_2_sql .custom_agents .parallel_query_solving_agent import (
12+ ParallelQuerySolvingAgent ,
1813)
1914from autogen_text_2_sql .custom_agents .answer_and_sources_agent import (
2015 AnswerAndSourcesAgent ,
@@ -45,23 +40,9 @@ async def on_messages_stream(self, messages, sender=None, config=None):
4540
4641class AutoGenText2Sql :
4742 def __init__ (self , engine_specific_rules : str , ** kwargs : dict ):
48- self .pre_run_query_cache = False
4943 self .target_engine = os .environ ["Text2Sql__DatabaseEngine" ].upper ()
5044 self .engine_specific_rules = engine_specific_rules
5145 self .kwargs = kwargs
52- self .set_mode ()
53-
54- def set_mode (self ):
55- """Set the mode of the plugin based on the environment variables."""
56- self .pre_run_query_cache = (
57- os .environ .get ("Text2Sql__PreRunQueryCache" , "True" ).lower () == "true"
58- )
59- self .use_column_value_store = (
60- os .environ .get ("Text2Sql__UseColumnValueStore" , "True" ).lower () == "true"
61- )
62- self .use_query_cache = (
63- os .environ .get ("Text2Sql__UseQueryCache" , "True" ).lower () == "true"
64- )
6546
6647 def get_all_agents (self ):
6748 """Get all agents for the complete flow."""
@@ -72,43 +53,8 @@ def get_all_agents(self):
7253 "query_rewrite_agent" , current_datetime = current_datetime
7354 )
7455
75- self .sql_query_generation_agent = LLMAgentCreator .create (
76- "sql_query_generation_agent" ,
77- target_engine = self .target_engine ,
78- engine_specific_rules = self .engine_specific_rules ,
79- ** self .kwargs ,
80- )
81-
82- # If relationship_paths not provided, use a generic template
83- if "relationship_paths" not in self .kwargs :
84- self .kwargs [
85- "relationship_paths"
86- ] = """
87- Common relationship paths to consider:
88- - Transaction → Related Dimensions (for basic analysis)
89- - Geographic → Location hierarchies (for geographic analysis)
90- - Temporal → Date hierarchies (for time-based analysis)
91- - Entity → Attributes (for entity-specific analysis)
92- """
93-
94- self .sql_schema_selection_agent = SqlSchemaSelectionAgent (
95- target_engine = self .target_engine ,
96- engine_specific_rules = self .engine_specific_rules ,
97- ** self .kwargs ,
98- )
99-
100- self .sql_query_correction_agent = LLMAgentCreator .create (
101- "sql_query_correction_agent" ,
102- target_engine = self .target_engine ,
103- engine_specific_rules = self .engine_specific_rules ,
104- ** self .kwargs ,
105- )
106-
107- self .sql_disambiguation_agent = LLMAgentCreator .create (
108- "sql_disambiguation_agent" ,
109- target_engine = self .target_engine ,
110- engine_specific_rules = self .engine_specific_rules ,
111- ** self .kwargs ,
56+ self .parallel_query_solving_agent = ParallelQuerySolvingAgent (
57+ engine_specific_rules = self .engine_specific_rules , ** self .kwargs
11258 )
11359
11460 self .answer_and_sources_agent = AnswerAndSourcesAgent ()
@@ -119,17 +65,10 @@ def get_all_agents(self):
11965 agents = [
12066 self .user_proxy ,
12167 self .query_rewrite_agent ,
122- self .sql_query_generation_agent ,
123- self .sql_schema_selection_agent ,
124- self .sql_query_correction_agent ,
125- self .sql_disambiguation_agent ,
68+ self .parallel_query_solving_agent ,
12669 self .answer_and_sources_agent ,
12770 ]
12871
129- if self .use_query_cache :
130- self .query_cache_agent = SqlQueryCacheAgent ()
131- agents .append (self .query_cache_agent )
132-
13372 return agents
13473
13574 @property
@@ -149,51 +88,19 @@ def unified_selector(self, messages):
14988 decision = None
15089
15190 # If this is the first message start with query_rewrite_agent
152- if len ( messages ) == 1 :
91+ if current_agent == "start" :
15392 decision = "query_rewrite_agent"
15493 # Handle transition after query rewriting
15594 elif current_agent == "query_rewrite_agent" :
156- decision = (
157- "sql_query_cache_agent"
158- if self .use_query_cache
159- else "sql_schema_selection_agent"
160- )
161- # Handle subsequent agent transitions
162- elif current_agent == "sql_query_cache_agent" :
163- # Always go through schema selection after cache check
164- decision = "sql_schema_selection_agent"
165- elif current_agent == "sql_schema_selection_agent" :
166- decision = "sql_disambiguation_agent"
167- elif current_agent == "sql_disambiguation_agent" :
168- decision = "sql_query_generation_agent"
169- elif current_agent == "sql_query_generation_agent" :
170- decision = "sql_query_correction_agent"
171- elif current_agent == "sql_query_correction_agent" :
172- try :
173- correction_result = json .loads (messages [- 1 ].content )
174- if isinstance (correction_result , dict ):
175- if "answer" in correction_result and "sources" in correction_result :
176- decision = "answer_and_sources_agent"
177- elif "corrected_query" in correction_result :
178- if correction_result .get ("executing" , False ):
179- decision = "sql_query_correction_agent"
180- else :
181- decision = "sql_query_generation_agent"
182- elif "error" in correction_result :
183- decision = "sql_query_generation_agent"
184- elif isinstance (correction_result , list ) and len (correction_result ) > 0 :
185- if "requested_fix" in correction_result [0 ]:
186- decision = "sql_query_generation_agent"
187-
188- if decision is None :
189- decision = "sql_query_generation_agent"
190- except json .JSONDecodeError :
191- decision = "sql_query_generation_agent"
192- elif current_agent == "answer_and_sources_agent" :
193- decision = "user_proxy" # Let user_proxy send TERMINATE
95+ decision = "parallel_query_solving_agent"
96+ # Handle transition after parallel query solving
97+ elif current_agent == "parallel_query_solving_agent" :
98+ decision = "answer_and_sources_agent"
19499
195100 if decision :
196101 logging .info (f"Agent transition: { current_agent } -> { decision } " )
102+ else :
103+ logging .info (f"No agent transition defined from { current_agent } " )
197104
198105 return decision
199106
0 commit comments