@@ -48,21 +48,31 @@ def __init__(self, state_store: StateStore, **kwargs):
4848
4949 self ._agentic_flow = None
5050
51+ self ._generate_follow_up_questions = (
52+ os .environ .get ("Text2Sql__GenerateFollowUpQuestions" , "True" ).lower ()
53+ == "true"
54+ )
55+
5156 def get_all_agents (self ):
5257 """Get all agents for the complete flow."""
5358
54- self . user_message_rewrite_agent = LLMAgentCreator .create (
59+ user_message_rewrite_agent = LLMAgentCreator .create (
5560 "user_message_rewrite_agent" , ** self .kwargs
5661 )
5762
58- self . parallel_query_solving_agent = ParallelQuerySolvingAgent (** self .kwargs )
63+ parallel_query_solving_agent = ParallelQuerySolvingAgent (** self .kwargs )
5964
60- self .answer_agent = LLMAgentCreator .create ("answer_agent" , ** self .kwargs )
65+ if self ._generate_follow_up_questions :
66+ answer_agent = LLMAgentCreator .create (
67+ "answer_with_follow_up_questions_agent" , ** self .kwargs
68+ )
69+ else :
70+ answer_agent = LLMAgentCreator .create ("answer_agent" , ** self .kwargs )
6171
6272 agents = [
63- self . user_message_rewrite_agent ,
64- self . parallel_query_solving_agent ,
65- self . answer_agent ,
73+ user_message_rewrite_agent ,
74+ parallel_query_solving_agent ,
75+ answer_agent ,
6676 ]
6777
6878 return agents
@@ -71,9 +81,16 @@ def get_all_agents(self):
7181 def termination_condition (self ):
7282 """Define the termination condition for the chat."""
7383 termination = (
74- TextMentionTermination ("TERMINATE" )
75- | SourceMatchTermination ("answer_agent" )
76- | TextMentionTermination ("contains_disambiguation_requests" )
84+ SourceMatchTermination ("answer_agent" )
85+ | SourceMatchTermination ("answer_with_follow_up_questions_agent" )
86+ # | TextMentionTermination(
87+ # "[]",
88+ # sources=["user_message_rewrite_agent"],
89+ # )
90+ | TextMentionTermination (
91+ "contains_disambiguation_requests" ,
92+ sources = ["parallel_query_solving_agent" ],
93+ )
7794 | MaxMessageTermination (5 )
7895 )
7996 return termination
@@ -91,6 +108,11 @@ def unified_selector(self, messages):
91108 elif current_agent == "user_message_rewrite_agent" :
92109 decision = "parallel_query_solving_agent"
93110 # Handle transition after parallel query solving
111+ elif (
112+ current_agent == "parallel_query_solving_agent"
113+ and self ._generate_follow_up_questions
114+ ):
115+ decision = "answer_with_follow_up_questions_agent"
94116 elif current_agent == "parallel_query_solving_agent" :
95117 decision = "answer_agent"
96118
@@ -142,32 +164,35 @@ def parse_message_content(self, content):
142164 # If all parsing attempts fail, return the content as-is
143165 return content
144166
145- def extract_decomposed_user_messages (self , messages : list ) -> list [list [str ]]:
146- """Extract the decomposed messages from the answer."""
147- # Only load sub-message results if we have a database result
148- sub_message_results = self .parse_message_content (messages [1 ].content )
149- logging .info ("Decomposed Results: %s" , sub_message_results )
167+ def last_message_by_agent (self , messages : list , agent_name : str ) -> TextMessage :
168+ """Get the last message by a specific agent."""
169+ for message in reversed (messages ):
170+ if message .source == agent_name :
171+ return message .content
172+ return None
150173
151- decomposed_user_messages = sub_message_results .get (
152- "decomposed_user_messages" , []
174+ def extract_steps (self , messages : list ) -> list [list [str ]]:
175+ """Extract the steps messages from the answer."""
176+ # Only load sub-message results if we have a database result
177+ sub_message_results = json .loads (
178+ self .last_message_by_agent (messages , "user_message_rewrite_agent" )
153179 )
180+ logging .info ("Steps Results: %s" , sub_message_results )
154181
155- logging . debug (
156- "Returning decomposed_user_messages: %s" , decomposed_user_messages
157- )
182+ steps = sub_message_results . get ( "steps" , [])
183+
184+ logging . debug ( "Returning steps: %s" , steps )
158185
159- return decomposed_user_messages
186+ return steps
160187
161188 def extract_disambiguation_request (
162189 self , messages : list
163190 ) -> DismabiguationRequestsPayload :
164191 """Extract the disambiguation request from the answer."""
165192 all_disambiguation_requests = self .parse_message_content (messages [- 1 ].content )
166193
167- decomposed_user_messages = self .extract_decomposed_user_messages (messages )
168- request_payload = DismabiguationRequestsPayload (
169- decomposed_user_messages = decomposed_user_messages
170- )
194+ steps = self .extract_steps (messages )
195+ request_payload = DismabiguationRequestsPayload (steps = steps )
171196
172197 for per_question_disambiguation_request in all_disambiguation_requests [
173198 "disambiguation_requests"
@@ -187,23 +212,27 @@ def extract_disambiguation_request(
187212
188213 def extract_answer_payload (self , messages : list ) -> AnswerWithSourcesPayload :
189214 """Extract the sources from the answer."""
190- answer = messages [- 1 ].content
191- sql_query_results = self .parse_message_content (messages [- 2 ].content )
215+ answer_payload = json .loads (messages [- 1 ].content )
216+
217+ logging .info ("Answer Payload: %s" , answer_payload )
218+ sql_query_results = self .last_message_by_agent (
219+ messages , "parallel_query_solving_agent"
220+ )
192221
193222 try :
194223 if isinstance (sql_query_results , str ):
195224 sql_query_results = json .loads (sql_query_results )
225+ elif sql_query_results is None :
226+ sql_query_results = {}
196227 except json .JSONDecodeError :
197228 logging .warning ("Unable to read SQL query results: %s" , sql_query_results )
198229 sql_query_results = {}
199230
200231 try :
201- decomposed_user_messages = self .extract_decomposed_user_messages (messages )
232+ steps = self .extract_steps (messages )
202233
203234 logging .info ("SQL Query Results: %s" , sql_query_results )
204- payload = AnswerWithSourcesPayload (
205- answer = answer , decomposed_user_messages = decomposed_user_messages
206- )
235+ payload = AnswerWithSourcesPayload (** answer_payload , steps = steps )
207236
208237 if not isinstance (sql_query_results , dict ):
209238 logging .error (f"Expected dict, got { type (sql_query_results )} " )
@@ -248,10 +277,9 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
248277
249278 except Exception as e :
250279 logging .error ("Error processing results: %s" , str (e ))
280+
251281 # Return payload with error context instead of empty
252- return AnswerWithSourcesPayload (
253- answer = f"{ answer } \n Error processing results: { str (e )} "
254- )
282+ return AnswerWithSourcesPayload (** answer_payload )
255283
256284 async def process_user_message (
257285 self ,
@@ -295,7 +323,10 @@ async def process_user_message(
295323 payload = ProcessingUpdatePayload (
296324 message = "Solving the query..." ,
297325 )
298- elif message .source == "answer_agent" :
326+ elif (
327+ message .source == "answer_agent"
328+ or message .source == "answer_with_follow_up_questions_agent"
329+ ):
299330 payload = ProcessingUpdatePayload (
300331 message = "Generating the answer..." ,
301332 )
@@ -304,7 +335,11 @@ async def process_user_message(
304335 # Now we need to return the final answer or the disambiguation request
305336 logging .info ("TaskResult: %s" , message )
306337
307- if message .messages [- 1 ].source == "answer_agent" :
338+ if (
339+ message .messages [- 1 ].source == "answer_agent"
340+ or message .messages [- 1 ].source
341+ == "answer_with_follow_up_questions_agent"
342+ ):
308343 # If the message is from the answer_agent, we need to return the final answer
309344 payload = self .extract_answer_payload (message .messages )
310345 elif message .messages [- 1 ].source == "parallel_query_solving_agent" :
0 commit comments