@@ -40,16 +40,16 @@ def get_all_agents(self):
4040 # Get current datetime for the Query Rewrite Agent
4141 current_datetime = datetime .now ()
4242
43- self .user_input_rewrite_agent = LLMAgentCreator .create (
44- "user_input_rewrite_agent " , current_datetime = current_datetime
43+ self .message_rewrite_agent = LLMAgentCreator .create (
44+ "message_rewrite_agent " , current_datetime = current_datetime
4545 )
4646
4747 self .parallel_query_solving_agent = ParallelQuerySolvingAgent (** self .kwargs )
4848
4949 self .answer_agent = LLMAgentCreator .create ("answer_agent" )
5050
5151 agents = [
52- self .user_input_rewrite_agent ,
52+ self .message_rewrite_agent ,
5353 self .parallel_query_solving_agent ,
5454 self .answer_agent ,
5555 ]
@@ -73,11 +73,11 @@ def unified_selector(self, messages):
7373 current_agent = messages [- 1 ].source if messages else "user"
7474 decision = None
7575
76- # If this is the first message start with user_input_rewrite_agent
76+ # If this is the first message start with message_rewrite_agent
7777 if current_agent == "user" :
78- decision = "user_input_rewrite_agent "
78+ decision = "message_rewrite_agent "
7979 # Handle transition after query rewriting
80- elif current_agent == "user_input_rewrite_agent " :
80+ elif current_agent == "message_rewrite_agent " :
8181 decision = "parallel_query_solving_agent"
8282 # Handle transition after parallel query solving
8383 elif current_agent == "parallel_query_solving_agent" :
@@ -145,24 +145,24 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
145145 except json .JSONDecodeError :
146146 logging .warning ("Unable to read SQL query results: %s" , sql_query_results )
147147 sql_query_results = {}
148- sub_user_input_results = {}
148+ sub_message_results = {}
149149 else :
150- # Only load sub-user_input results if we have a database result
151- sub_user_input_results = self .parse_message_content (messages [1 ].content )
152- logging .info ("Sub-user_input Results: %s" , sub_user_input_results )
150+ # Only load sub-message results if we have a database result
151+ sub_message_results = self .parse_message_content (messages [1 ].content )
152+ logging .info ("Sub-message Results: %s" , sub_message_results )
153153
154154 try :
155- sub_user_inputs = [
156- sub_user_input
157- for sub_user_input_group in sub_user_input_results .get (
158- "sub_user_inputs " , []
155+ decomposed_messages = [
156+ sub_message
157+ for sub_message_group in sub_message_results .get (
158+ "decomposed_messages " , []
159159 )
160- for sub_user_input in sub_user_input_group
160+ for sub_message in sub_message_group
161161 ]
162162
163163 logging .info ("SQL Query Results: %s" , sql_query_results )
164164 payload = AnswerWithSourcesPayload (
165- answer = answer , sub_user_inputs = sub_user_inputs
165+ answer = answer , decomposed_messages = decomposed_messages
166166 )
167167
168168 if not isinstance (sql_query_results , dict ):
@@ -173,11 +173,9 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
173173 logging .error ("No 'results' key in sql_query_results" )
174174 return payload
175175
176- for user_input , sql_query_result_list in sql_query_results [
177- "results"
178- ].items ():
176+ for message , sql_query_result_list in sql_query_results ["results" ].items ():
179177 if not sql_query_result_list : # Check if list is empty
180- logging .warning (f"No results for user_input : { user_input } " )
178+ logging .warning (f"No results for message : { message } " )
181179 continue
182180
183181 for sql_query_result in sql_query_result_list :
@@ -213,47 +211,47 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
213211 answer = f"{ answer } \n Error processing results: { str (e )} "
214212 )
215213
216- async def process_user_input (
214+ async def process_message (
217215 self ,
218- user_input_payload : UserInputPayload ,
216+ message_payload : UserInputPayload ,
219217 chat_history : list [InteractionPayload ] = None ,
220218 ) -> AsyncGenerator [InteractionPayload , None ]:
221- """Process the complete user_input through the unified system.
219+ """Process the complete message through the unified system.
222220
223221 Args:
224222 ----
225- task (str): The user user_input to process.
223+ task (str): The user message to process.
226224 chat_history (list[str], optional): The chat history. Defaults to None.
227225 injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
228226
229227 Returns:
230228 -------
231229 dict: The response from the system.
232230 """
233- logging .info ("Processing user_input : %s" , user_input_payload .body .user_input )
231+ logging .info ("Processing message : %s" , message_payload .body .message )
234232 logging .info ("Chat history: %s" , chat_history )
235233
236234 agent_input = {
237- "user_input " : user_input_payload .body .user_input ,
235+ "message " : message_payload .body .message ,
238236 "chat_history" : {},
239- "injected_parameters" : user_input_payload .body .injected_parameters ,
237+ "injected_parameters" : message_payload .body .injected_parameters ,
240238 }
241239
242240 if chat_history is not None :
243241 # Update input
244242 for idx , chat in enumerate (chat_history ):
245- if chat .root .payload_type == PayloadType .USER_INPUT :
243+ if chat .root .payload_type == PayloadType .message :
246244 # For now only consider the user query
247245 chat_history_key = f"chat_{ idx } "
248- agent_input [chat_history_key ] = chat .root .body .user_input
246+ agent_input [chat_history_key ] = chat .root .body .message
249247
250248 async for message in self .agentic_flow .run_stream (task = json .dumps (agent_input )):
251249 logging .debug ("Message: %s" , message )
252250
253251 payload = None
254252
255253 if isinstance (message , TextMessage ):
256- if message .source == "user_input_rewrite_agent " :
254+ if message .source == "message_rewrite_agent " :
257255 payload = ProcessingUpdatePayload (
258256 message = "Rewriting the query..." ,
259257 )
@@ -276,10 +274,10 @@ async def process_user_input(
276274 elif message .messages [- 1 ].source == "parallel_query_solving_agent" :
277275 # Load into disambiguation request
278276 payload = self .extract_disambiguation_request (message .messages )
279- elif message .messages [- 1 ].source == "user_input_rewrite_agent " :
277+ elif message .messages [- 1 ].source == "message_rewrite_agent " :
280278 # Load into empty response
281279 payload = AnswerWithSourcesPayload (
282- answer = "Apologies, I cannot answer that user_input as it is not relevant. Please try another user_input or rephrase your current user_input ."
280+ answer = "Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message ."
283281 )
284282 else :
285283 logging .error ("Unexpected TaskResult: %s" , message )
0 commit comments