@@ -68,57 +68,48 @@ def get_all_agents(self):
6868 # Get current datetime for the Query Rewrite Agent
6969 current_datetime = datetime .now ()
7070
71- QUERY_REWRITE_AGENT = LLMAgentCreator .create (
71+ self . query_rewrite_agent = LLMAgentCreator .create (
7272 "query_rewrite_agent" , current_datetime = current_datetime
7373 )
7474
75- SQL_QUERY_GENERATION_AGENT = LLMAgentCreator .create (
75+ self . sql_query_generation_agent = LLMAgentCreator .create (
7676 "sql_query_generation_agent" ,
7777 target_engine = self .target_engine ,
7878 engine_specific_rules = self .engine_specific_rules ,
7979 ** self .kwargs ,
8080 )
8181
82- SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent (
83- target_engine = self .target_engine ,
84- engine_specific_rules = self .engine_specific_rules ,
85- ** self .kwargs ,
86- )
82+ self .sql_schema_selection_agent = SqlSchemaSelectionAgent ()
8783
88- SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator .create (
84+ self . sql_query_correction_agent = LLMAgentCreator .create (
8985 "sql_query_correction_agent" ,
9086 target_engine = self .target_engine ,
9187 engine_specific_rules = self .engine_specific_rules ,
9288 ** self .kwargs ,
9389 )
9490
95- SQL_DISAMBIGUATION_AGENT = LLMAgentCreator .create (
91+ self . sql_disambiguation_agent = LLMAgentCreator .create (
9692 "sql_disambiguation_agent" ,
9793 target_engine = self .target_engine ,
9894 engine_specific_rules = self .engine_specific_rules ,
9995 ** self .kwargs ,
10096 )
10197
102- QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator .create (
103- "question_decomposition_agent"
104- )
105-
10698 # Auto-responding UserProxyAgent
107- USER_PROXY = EmptyResponseUserProxyAgent (name = "user_proxy" )
99+ self . user_proxy = EmptyResponseUserProxyAgent (name = "user_proxy" )
108100
109101 agents = [
110- USER_PROXY ,
111- QUERY_REWRITE_AGENT ,
112- SQL_QUERY_GENERATION_AGENT ,
113- SQL_SCHEMA_SELECTION_AGENT ,
114- SQL_QUERY_CORRECTION_AGENT ,
115- QUESTION_DECOMPOSITION_AGENT ,
116- SQL_DISAMBIGUATION_AGENT ,
102+ self .user_proxy ,
103+ self .query_rewrite_agent ,
104+ self .sql_query_generation_agent ,
105+ self .sql_schema_selection_agent ,
106+ self .sql_query_correction_agent ,
107+ self .sql_disambiguation_agent ,
117108 ]
118109
119110 if self .use_query_cache :
120- SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent ()
121- agents .append (SQL_QUERY_CACHE_AGENT )
111+ self . query_cache_agent = SqlQueryCacheAgent (** self . kwargs )
112+ agents .append (self . query_cache_agent )
122113
123114 return agents
124115
@@ -195,13 +186,30 @@ def agentic_flow(self):
195186 )
196187 return flow
197188
198- async def process_question (self , task : str , chat_history : list [str ] = None ):
199- """Process the complete question through the unified system."""
189+ async def process_question (
190+ self , task : str , chat_history : list [str ] = None , parameters : dict = None
191+ ):
192+ """Process the complete question through the unified system.
193+
194+ Args:
195+ ----
196+ task (str): The user question to process.
197+ chat_history (list[str], optional): The chat history. Defaults to None.
198+ parameters (dict, optional): The parameters to pass to the agents. Defaults to None.
199+
200+ Returns:
201+ -------
202+ dict: The response from the system.
203+ """
200204
201205 logging .info ("Processing question: %s" , task )
202206 logging .info ("Chat history: %s" , chat_history )
203207
204- agent_input = {"user_question" : task , "chat_history" : {}}
208+ agent_input = {
209+ "user_question" : task ,
210+ "chat_history" : {},
211+ "parameters" : parameters ,
212+ }
205213
206214 if chat_history is not None :
207215 # Update input
0 commit comments