1717import os
1818from datetime import datetime
1919
20- from text_2_sql_core .payloads .agent_request_response_pair import (
21- AgentRequestResponsePair ,
22- AgentRequestBody ,
23- AnswerWithSources ,
24- Source ,
25- DismabiguationRequests ,
26- )
27- from text_2_sql_core .payloads .chat_history import ChatHistoryItem
28- from text_2_sql_core .payloads .processing_update import (
29- ProcessingUpdateBody ,
30- ProcessingUpdate ,
20+ from text_2_sql_core .payloads .interaction_payloads import (
21+ QuestionPayload ,
22+ AnswerWithSourcesPayload ,
23+ DismabiguationRequestPayload ,
24+ ProcessingUpdatePayload ,
25+ InteractionPayload ,
26+ PayloadType ,
3127)
3228from autogen_agentchat .base import TaskResult
3329from typing import AsyncGenerator
@@ -108,17 +104,19 @@ def agentic_flow(self):
108104 )
109105 return flow
110106
111- def extract_disambiguation_request (self , messages : list ) -> DismabiguationRequests :
107+ def extract_disambiguation_request (
108+ self , messages : list
109+ ) -> DismabiguationRequestPayload :
112110 """Extract the disambiguation request from the answer."""
113111
114112 disambiguation_request = messages [- 1 ].content
115113
116114 # TODO: Properly extract the disambiguation request
117- return DismabiguationRequests (
115+ return DismabiguationRequestPayload (
118116 disambiguation_request = disambiguation_request ,
119117 )
120118
121- def extract_sources (self , messages : list ) -> AnswerWithSources :
119+ def extract_sources (self , messages : list ) -> AnswerWithSourcesPayload :
122120 """Extract the sources from the answer."""
123121
124122 answer = messages [- 1 ].content
@@ -130,7 +128,7 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
130128
131129 logging .info ("SQL Query Results: %s" , sql_query_results )
132130
133- sources = []
131+ payload = AnswerWithSourcesPayload ( answer = answer )
134132
135133 for question , sql_query_result_list in sql_query_results ["results" ].items ():
136134 logging .info (
@@ -141,27 +139,24 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
141139
142140 for sql_query_result in sql_query_result_list :
143141 logging .info ("SQL Query Result: %s" , sql_query_result )
144- sources .append (
145- Source (
146- sql_query = sql_query_result ["sql_query" ],
147- sql_rows = sql_query_result ["sql_rows" ],
148- )
142+ # Instantiate Source and append to the payload's sources list
143+ source = AnswerWithSourcesPayload .Body .Source (
144+ sql_query = sql_query_result ["sql_query" ],
145+ sql_rows = sql_query_result ["sql_rows" ],
149146 )
147+ payload .body .sources .append (source )
148+
149+ return payload
150150
151151 except json .JSONDecodeError :
152152 logging .error ("Could not load message: %s" , sql_query_results )
153153 raise ValueError ("Could not load message" )
154154
155- return AnswerWithSources (
156- answer = answer ,
157- sources = sources ,
158- )
159-
160155 async def process_question (
161156 self ,
162- request : AgentRequestBody ,
163- chat_history : list [ChatHistoryItem ] = None ,
164- ) -> AsyncGenerator [AgentRequestResponsePair | ProcessingUpdate , None ]:
157+ question_payload : QuestionPayload ,
158+ chat_history : list [InteractionPayload ] = None ,
159+ ) -> AsyncGenerator [InteractionPayload , None ]:
165160 """Process the complete question through the unified system.
166161
167162 Args:
@@ -174,65 +169,55 @@ async def process_question(
174169 -------
175170 dict: The response from the system.
176171 """
177- logging .info ("Processing question: %s" , request .question )
172+ logging .info ("Processing question: %s" , question_payload . body .question )
178173 logging .info ("Chat history: %s" , chat_history )
179174
180175 agent_input = {
181- "question" : request .question ,
176+ "question" : question_payload . body .question ,
182177 "chat_history" : {},
183- "injected_parameters" : request .injected_parameters ,
178+ "injected_parameters" : question_payload . body .injected_parameters ,
184179 }
185180
186181 if chat_history is not None :
187182 # Update input
188183 for idx , chat in enumerate (chat_history ):
189- # For now only consider the user query
190- chat_history_key = f"chat_{ idx } "
191- agent_input [
192- chat_history_key
193- ] = chat .request_response_pair .request .question
184+ if chat .root .payload_type == PayloadType .QUESTION :
185+ # For now only consider the user query
186+ chat_history_key = f"chat_{ idx } "
187+ agent_input [chat_history_key ] = chat .root .body .question
194188
195189 async for message in self .agentic_flow .run_stream (task = json .dumps (agent_input )):
196190 logging .debug ("Message: %s" , message )
197191
198192 payload = None
199193
200194 if isinstance (message , TextMessage ):
201- processing_update = None
202195 if message .source == "query_rewrite_agent" :
203- processing_update = ProcessingUpdateBody (
196+ payload = ProcessingUpdatePayload (
204197 message = "Rewriting the query..." ,
205198 )
206199 elif message .source == "parallel_query_solving_agent" :
207- processing_update = ProcessingUpdateBody (
200+ payload = ProcessingUpdatePayload (
208201 message = "Solving the query..." ,
209202 )
210203 elif message .source == "answer_agent" :
211- processing_update = ProcessingUpdateBody (
204+ payload = ProcessingUpdatePayload (
212205 message = "Generating the answer..." ,
213206 )
214207
215- if processing_update is not None :
216- payload = ProcessingUpdate (
217- processing_update = processing_update ,
218- )
219-
220208 elif isinstance (message , TaskResult ):
221209 # Now we need to return the final answer or the disambiguation request
222210 logging .info ("TaskResult: %s" , message )
223211
224- response = None
225212 if message .messages [- 1 ].source == "answer_agent" :
226213 # If the message is from the answer_agent, we need to return the final answer
227- response = self .extract_sources (message .messages )
214+ payload = self .extract_sources (message .messages )
228215 elif message .messages [- 1 ].source == "parallel_query_solving_agent" :
229216 # Load into disambiguation request
230- response = self .extract_disambiguation_request (message .messages )
231- else :
232- logging .error ("Unexpected TaskResult: %s" , message )
233- raise ValueError ("Unexpected TaskResult" )
234-
235- payload = AgentRequestResponsePair (request = request , response = response )
217+ payload = self .extract_disambiguation_request (message .messages )
218+ else :
219+ logging .error ("Unexpected TaskResult: %s" , message )
220+ raise ValueError ("Unexpected TaskResult" )
236221
237222 if payload is not None :
238223 logging .debug ("Payload: %s" , payload )
0 commit comments