1717import os
1818from datetime import datetime
1919
20- from text_2_sql_core .payloads .agent_request_response_pair import (
21- AgentRequestResponsePair ,
22- AgentRequestBody ,
23- AnswerWithSources ,
20+ from text_2_sql_core .payloads .interaction_payloads import (
21+ QuestionPayload ,
22+ AnswerWithSourcesPayload ,
23+ AnswerWithSourcesBody ,
2424 Source ,
25- DismabiguationRequests ,
26- )
27- from text_2_sql_core .payloads .chat_history import ChatHistoryItem
28- from text_2_sql_core .payloads .processing_update import (
25+ DismabiguationRequestPayload ,
2926 ProcessingUpdateBody ,
30- ProcessingUpdate ,
27+ ProcessingUpdatePayload ,
28+ InteractionPayload ,
29+ PayloadType ,
3130)
3231from autogen_agentchat .base import TaskResult
3332from typing import AsyncGenerator
@@ -108,17 +107,19 @@ def agentic_flow(self):
108107 )
109108 return flow
110109
111- def extract_disambiguation_request (self , messages : list ) -> DismabiguationRequests :
110+ def extract_disambiguation_request (
111+ self , messages : list
112+ ) -> DismabiguationRequestPayload :
112113 """Extract the disambiguation request from the answer."""
113114
114115 disambiguation_request = messages [- 1 ].content
115116
116117 # TODO: Properly extract the disambiguation request
117- return DismabiguationRequests (
118+ return DismabiguationRequestPayload (
118119 disambiguation_request = disambiguation_request ,
119120 )
120121
121- def extract_sources (self , messages : list ) -> AnswerWithSources :
122+ def extract_sources (self , messages : list ) -> AnswerWithSourcesPayload :
122123 """Extract the sources from the answer."""
123124
124125 answer = messages [- 1 ].content
@@ -152,16 +153,18 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
152153 logging .error ("Could not load message: %s" , sql_query_results )
153154 raise ValueError ("Could not load message" )
154155
155- return AnswerWithSources (
156- answer = answer ,
157- sources = sources ,
156+ return AnswerWithSourcesPayload (
157+ body = AnswerWithSourcesBody (
158+ answer = answer ,
159+ sources = sources ,
160+ )
158161 )
159162
160163 async def process_question (
161164 self ,
162- request : AgentRequestBody ,
163- chat_history : list [ChatHistoryItem ] = None ,
164- ) -> AsyncGenerator [AgentRequestResponsePair | ProcessingUpdate , None ]:
165+ question_payload : QuestionPayload ,
166+ chat_history : list [InteractionPayload ] = None ,
167+ ) -> AsyncGenerator [InteractionPayload , None ]:
165168 """Process the complete question through the unified system.
166169
167170 Args:
@@ -174,23 +177,22 @@ async def process_question(
174177 -------
175178 dict: The response from the system.
176179 """
177- logging .info ("Processing question: %s" , request .question )
180+ logging .info ("Processing question: %s" , question_payload . body .question )
178181 logging .info ("Chat history: %s" , chat_history )
179182
180183 agent_input = {
181- "question" : request .question ,
184+ "question" : question_payload . body .question ,
182185 "chat_history" : {},
183- "injected_parameters" : request .injected_parameters ,
186+ "injected_parameters" : question_payload . body .injected_parameters ,
184187 }
185188
186189 if chat_history is not None :
187190 # Update input
188191 for idx , chat in enumerate (chat_history ):
189- # For now only consider the user query
190- chat_history_key = f"chat_{ idx } "
191- agent_input [
192- chat_history_key
193- ] = chat .request_response_pair .request .question
192+ if chat .root .payload_type == PayloadType .QUESTION :
193+ # For now only consider the user query
194+ chat_history_key = f"chat_{ idx } "
195+ agent_input [chat_history_key ] = chat .root .body .question
194196
195197 async for message in self .agentic_flow .run_stream (task = json .dumps (agent_input )):
196198 logging .debug ("Message: %s" , message )
@@ -213,27 +215,24 @@ async def process_question(
213215 )
214216
215217 if processing_update is not None :
216- payload = ProcessingUpdate (
217- processing_update = processing_update ,
218+ payload = ProcessingUpdatePayload (
219+ body = processing_update ,
218220 )
219221
220222 elif isinstance (message , TaskResult ):
221223 # Now we need to return the final answer or the disambiguation request
222224 logging .info ("TaskResult: %s" , message )
223225
224- response = None
225226 if message .messages [- 1 ].source == "answer_agent" :
226227 # If the message is from the answer_agent, we need to return the final answer
227- response = self .extract_sources (message .messages )
228+ payload = self .extract_sources (message .messages )
228229 elif message .messages [- 1 ].source == "parallel_query_solving_agent" :
229230 # Load into disambiguation request
230- response = self .extract_disambiguation_request (message .messages )
231+ payload = self .extract_disambiguation_request (message .messages )
231232 else :
232233 logging .error ("Unexpected TaskResult: %s" , message )
233234 raise ValueError ("Unexpected TaskResult" )
234235
235- payload = AgentRequestResponsePair (request = request , response = response )
236-
237236 if payload is not None :
238237 logging .debug ("Payload: %s" , payload )
239238 yield payload
0 commit comments