1919import re
2020
2121from text_2_sql_core .payloads .interaction_payloads import (
22- QuestionPayload ,
22+ UserMessagePayload ,
2323 AnswerWithSourcesPayload ,
24- DismabiguationRequestPayload ,
24+ DismabiguationRequestsPayload ,
2525 ProcessingUpdatePayload ,
2626 InteractionPayload ,
2727 PayloadType ,
@@ -40,16 +40,16 @@ def get_all_agents(self):
4040 # Get current datetime for the Query Rewrite Agent
4141 current_datetime = datetime .now ()
4242
43- self .question_rewrite_agent = LLMAgentCreator .create (
44- "question_rewrite_agent " , current_datetime = current_datetime
43+ self .user_message_rewrite_agent = LLMAgentCreator .create (
44+ "user_message_rewrite_agent " , current_datetime = current_datetime
4545 )
4646
4747 self .parallel_query_solving_agent = ParallelQuerySolvingAgent (** self .kwargs )
4848
4949 self .answer_agent = LLMAgentCreator .create ("answer_agent" )
5050
5151 agents = [
52- self .question_rewrite_agent ,
52+ self .user_message_rewrite_agent ,
5353 self .parallel_query_solving_agent ,
5454 self .answer_agent ,
5555 ]
@@ -62,7 +62,7 @@ def termination_condition(self):
6262 termination = (
6363 TextMentionTermination ("TERMINATE" )
6464 | SourceMatchTermination ("answer_agent" )
65- | TextMentionTermination ("requires_user_information_request " )
65+ | TextMentionTermination ("contains_disambiguation_requests " )
6666 | MaxMessageTermination (5 )
6767 )
6868 return termination
@@ -73,11 +73,11 @@ def unified_selector(self, messages):
7373 current_agent = messages [- 1 ].source if messages else "user"
7474 decision = None
7575
76- # If this is the first message start with question_rewrite_agent
76+ # If this is the first message start with user_message_rewrite_agent
7777 if current_agent == "user" :
78- decision = "question_rewrite_agent "
78+ decision = "user_message_rewrite_agent "
7979 # Handle transition after query rewriting
80- elif current_agent == "question_rewrite_agent " :
80+ elif current_agent == "user_message_rewrite_agent " :
8181 decision = "parallel_query_solving_agent"
8282 # Handle transition after parallel query solving
8383 elif current_agent == "parallel_query_solving_agent" :
@@ -102,15 +102,6 @@ def agentic_flow(self):
102102 )
103103 return flow
104104
105- def extract_disambiguation_request (
106- self , messages : list
107- ) -> DismabiguationRequestPayload :
108- """Extract the disambiguation request from the answer."""
109- disambiguation_request = messages [- 1 ].content
110- return DismabiguationRequestPayload (
111- disambiguation_request = disambiguation_request ,
112- )
113-
114105 def parse_message_content (self , content ):
115106 """Parse different message content formats into a dictionary."""
116107 if isinstance (content , (list , dict )):
@@ -134,6 +125,49 @@ def parse_message_content(self, content):
134125 # If all parsing attempts fail, return the content as-is
135126 return content
136127
128+ def extract_decomposed_user_messages (self , messages : list ) -> list [list [str ]]:
129+ """Extract the decomposed messages from the answer."""
130+ # Only load sub-message results if we have a database result
131+ sub_message_results = self .parse_message_content (messages [1 ].content )
132+ logging .info ("Decomposed Results: %s" , sub_message_results )
133+
134+ decomposed_user_messages = sub_message_results .get (
135+ "decomposed_user_messages" , []
136+ )
137+
138+ logging .debug (
139+ "Returning decomposed_user_messages: %s" , decomposed_user_messages
140+ )
141+
142+ return decomposed_user_messages
143+
144+ def extract_disambiguation_request (
145+ self , messages : list
146+ ) -> DismabiguationRequestsPayload :
147+ """Extract the disambiguation request from the answer."""
148+ all_disambiguation_requests = self .parse_message_content (messages [- 1 ].content )
149+
150+ decomposed_user_messages = self .extract_decomposed_user_messages (messages )
151+ request_payload = DismabiguationRequestsPayload (
152+ decomposed_user_messages = decomposed_user_messages
153+ )
154+
155+ for per_question_disambiguation_request in all_disambiguation_requests [
156+ "disambiguation_requests"
157+ ].values ():
158+ for disambiguation_request in per_question_disambiguation_request :
159+ logging .info (
160+ "Disambiguation Request Identified: %s" , disambiguation_request
161+ )
162+
163+ request = DismabiguationRequestsPayload .Body .DismabiguationRequest (
164+ agent_question = disambiguation_request ["agent_question" ],
165+ user_choices = disambiguation_request ["user_choices" ],
166+ )
167+ request_payload .body .disambiguation_requests .append (request )
168+
169+ return request_payload
170+
137171 def extract_answer_payload (self , messages : list ) -> AnswerWithSourcesPayload :
138172 """Extract the sources from the answer."""
139173 answer = messages [- 1 ].content
@@ -145,41 +179,35 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
145179 except json .JSONDecodeError :
146180 logging .warning ("Unable to read SQL query results: %s" , sql_query_results )
147181 sql_query_results = {}
148- sub_question_results = {}
149- else :
150- # Only load sub-question results if we have a database result
151- sub_question_results = self .parse_message_content (messages [1 ].content )
152- logging .info ("Sub-Question Results: %s" , sub_question_results )
153182
154183 try :
155- sub_questions = [
156- sub_question
157- for sub_question_group in sub_question_results .get ("sub_questions" , [])
158- for sub_question in sub_question_group
159- ]
184+ decomposed_user_messages = self .extract_decomposed_user_messages (messages )
160185
161186 logging .info ("SQL Query Results: %s" , sql_query_results )
162187 payload = AnswerWithSourcesPayload (
163- answer = answer , sub_questions = sub_questions
188+ answer = answer , decomposed_user_messages = decomposed_user_messages
164189 )
165190
166191 if not isinstance (sql_query_results , dict ):
167192 logging .error (f"Expected dict, got { type (sql_query_results )} " )
168193 return payload
169194
170- if "results " not in sql_query_results :
195+ if "database_results " not in sql_query_results :
171196 logging .error ("No 'results' key in sql_query_results" )
172197 return payload
173198
174- for question , sql_query_result_list in sql_query_results ["results" ].items ():
199+ for message , sql_query_result_list in sql_query_results [
200+ "database_results"
201+ ].items ():
175202 if not sql_query_result_list : # Check if list is empty
176- logging .warning (f"No results for question : { question } " )
203+ logging .warning (f"No results for message : { message } " )
177204 continue
178205
179206 for sql_query_result in sql_query_result_list :
180207 if not isinstance (sql_query_result , dict ):
181208 logging .error (
182- f"Expected dict for sql_query_result, got { type (sql_query_result )} "
209+ "Expected dict for sql_query_result, got %s" ,
210+ type (sql_query_result ),
183211 )
184212 continue
185213
@@ -208,47 +236,47 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
208236 answer = f"{ answer } \n Error processing results: { str (e )} "
209237 )
210238
211- async def process_question (
239+ async def process_user_message (
212240 self ,
213- question_payload : QuestionPayload ,
241+ message_payload : UserMessagePayload ,
214242 chat_history : list [InteractionPayload ] = None ,
215243 ) -> AsyncGenerator [InteractionPayload , None ]:
216- """Process the complete question through the unified system.
244+ """Process the complete message through the unified system.
217245
218246 Args:
219247 ----
220- task (str): The user question to process.
248+ task (str): The user message to process.
221249 chat_history (list[str], optional): The chat history. Defaults to None.
222250 injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
223251
224252 Returns:
225253 -------
226254 dict: The response from the system.
227255 """
228- logging .info ("Processing question : %s" , question_payload .body .question )
256+ logging .info ("Processing message : %s" , message_payload .body .user_message )
229257 logging .info ("Chat history: %s" , chat_history )
230258
231259 agent_input = {
232- "question " : question_payload .body .question ,
260+ "message " : message_payload .body .user_message ,
233261 "chat_history" : {},
234- "injected_parameters" : question_payload .body .injected_parameters ,
262+ "injected_parameters" : message_payload .body .injected_parameters ,
235263 }
236264
237265 if chat_history is not None :
238266 # Update input
239267 for idx , chat in enumerate (chat_history ):
240- if chat .root .payload_type == PayloadType .QUESTION :
268+ if chat .root .payload_type == PayloadType .USER_MESSAGE :
241269 # For now only consider the user query
242270 chat_history_key = f"chat_{ idx } "
243- agent_input [chat_history_key ] = chat .root .body .question
271+ agent_input [chat_history_key ] = chat .root .body .user_message
244272
245273 async for message in self .agentic_flow .run_stream (task = json .dumps (agent_input )):
246274 logging .debug ("Message: %s" , message )
247275
248276 payload = None
249277
250278 if isinstance (message , TextMessage ):
251- if message .source == "question_rewrite_agent " :
279+ if message .source == "user_message_rewrite_agent " :
252280 payload = ProcessingUpdatePayload (
253281 message = "Rewriting the query..." ,
254282 )
@@ -271,10 +299,10 @@ async def process_question(
271299 elif message .messages [- 1 ].source == "parallel_query_solving_agent" :
272300 # Load into disambiguation request
273301 payload = self .extract_disambiguation_request (message .messages )
274- elif message .messages [- 1 ].source == "question_rewrite_agent " :
302+ elif message .messages [- 1 ].source == "user_message_rewrite_agent " :
275303 # Load into empty response
276304 payload = AnswerWithSourcesPayload (
277- answer = "Apologies, I cannot answer that question as it is not relevant. Please try another question or rephrase your current question ."
305+ answer = "Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message ."
278306 )
279307 else :
280308 logging .error ("Unexpected TaskResult: %s" , message )
0 commit comments