Skip to content

Commit ad57b59

Browse files
Fix issues with model flow between agents, add structured output modes and add follow up questions (#158)
* Update interaction payloads * Structured output branch * Update prompts * Update output * Update steps * Add buffered history * Update validation * Update * Update logging * Add sanitizier * Update rewrite * Update interactions
1 parent 9a2d5b6 commit ad57b59

24 files changed

+550
-270
lines changed

text_2_sql/.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Text2Sql__DatabaseEngine=<DatabaseEngine> # TSQL or Postgres or Snowflake or Dat
55
Text2Sql__UseQueryCache=<Determines if the Query Cache will be used to speed up query generation. Defaults to True.> # True or False
66
Text2Sql__PreRunQueryCache=<Determines if the results from the Query Cache will be pre-run to speed up answer generation. Defaults to True.> # True or False
77
Text2Sql__UseColumnValueStore=<Determines if the Column Value Store will be used for schema selection Defaults to True.> # True or False
8+
Text2Sql__GenerateFollowUpQuestions=<Determines if follow up questions will be generated. Defaults to True.> # True or False
89

910
# Open AI Connection Details
1011
OpenAI__CompletionDeployment=<openAICompletionDeploymentId. Used for data dictionary creator>

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,31 @@ def __init__(self, state_store: StateStore, **kwargs):
4848

4949
self._agentic_flow = None
5050

51+
self._generate_follow_up_questions = (
52+
os.environ.get("Text2Sql__GenerateFollowUpQuestions", "True").lower()
53+
== "true"
54+
)
55+
5156
def get_all_agents(self):
5257
"""Get all agents for the complete flow."""
5358

54-
self.user_message_rewrite_agent = LLMAgentCreator.create(
59+
user_message_rewrite_agent = LLMAgentCreator.create(
5560
"user_message_rewrite_agent", **self.kwargs
5661
)
5762

58-
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
63+
parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
5964

60-
self.answer_agent = LLMAgentCreator.create("answer_agent", **self.kwargs)
65+
if self._generate_follow_up_questions:
66+
answer_agent = LLMAgentCreator.create(
67+
"answer_with_follow_up_questions_agent", **self.kwargs
68+
)
69+
else:
70+
answer_agent = LLMAgentCreator.create("answer_agent", **self.kwargs)
6171

6272
agents = [
63-
self.user_message_rewrite_agent,
64-
self.parallel_query_solving_agent,
65-
self.answer_agent,
73+
user_message_rewrite_agent,
74+
parallel_query_solving_agent,
75+
answer_agent,
6676
]
6777

6878
return agents
@@ -71,9 +81,16 @@ def get_all_agents(self):
7181
def termination_condition(self):
7282
"""Define the termination condition for the chat."""
7383
termination = (
74-
TextMentionTermination("TERMINATE")
75-
| SourceMatchTermination("answer_agent")
76-
| TextMentionTermination("contains_disambiguation_requests")
84+
SourceMatchTermination("answer_agent")
85+
| SourceMatchTermination("answer_with_follow_up_questions_agent")
86+
# | TextMentionTermination(
87+
# "[]",
88+
# sources=["user_message_rewrite_agent"],
89+
# )
90+
| TextMentionTermination(
91+
"contains_disambiguation_requests",
92+
sources=["parallel_query_solving_agent"],
93+
)
7794
| MaxMessageTermination(5)
7895
)
7996
return termination
@@ -91,6 +108,11 @@ def unified_selector(self, messages):
91108
elif current_agent == "user_message_rewrite_agent":
92109
decision = "parallel_query_solving_agent"
93110
# Handle transition after parallel query solving
111+
elif (
112+
current_agent == "parallel_query_solving_agent"
113+
and self._generate_follow_up_questions
114+
):
115+
decision = "answer_with_follow_up_questions_agent"
94116
elif current_agent == "parallel_query_solving_agent":
95117
decision = "answer_agent"
96118

@@ -142,32 +164,35 @@ def parse_message_content(self, content):
142164
# If all parsing attempts fail, return the content as-is
143165
return content
144166

145-
def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]:
146-
"""Extract the decomposed messages from the answer."""
147-
# Only load sub-message results if we have a database result
148-
sub_message_results = self.parse_message_content(messages[1].content)
149-
logging.info("Decomposed Results: %s", sub_message_results)
167+
def last_message_by_agent(self, messages: list, agent_name: str) -> TextMessage:
168+
"""Get the last message by a specific agent."""
169+
for message in reversed(messages):
170+
if message.source == agent_name:
171+
return message.content
172+
return None
150173

151-
decomposed_user_messages = sub_message_results.get(
152-
"decomposed_user_messages", []
174+
def extract_steps(self, messages: list) -> list[list[str]]:
175+
"""Extract the steps messages from the answer."""
176+
# Only load sub-message results if we have a database result
177+
sub_message_results = json.loads(
178+
self.last_message_by_agent(messages, "user_message_rewrite_agent")
153179
)
180+
logging.info("Steps Results: %s", sub_message_results)
154181

155-
logging.debug(
156-
"Returning decomposed_user_messages: %s", decomposed_user_messages
157-
)
182+
steps = sub_message_results.get("steps", [])
183+
184+
logging.debug("Returning steps: %s", steps)
158185

159-
return decomposed_user_messages
186+
return steps
160187

161188
def extract_disambiguation_request(
162189
self, messages: list
163190
) -> DismabiguationRequestsPayload:
164191
"""Extract the disambiguation request from the answer."""
165192
all_disambiguation_requests = self.parse_message_content(messages[-1].content)
166193

167-
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
168-
request_payload = DismabiguationRequestsPayload(
169-
decomposed_user_messages=decomposed_user_messages
170-
)
194+
steps = self.extract_steps(messages)
195+
request_payload = DismabiguationRequestsPayload(steps=steps)
171196

172197
for per_question_disambiguation_request in all_disambiguation_requests[
173198
"disambiguation_requests"
@@ -187,23 +212,27 @@ def extract_disambiguation_request(
187212

188213
def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
189214
"""Extract the sources from the answer."""
190-
answer = messages[-1].content
191-
sql_query_results = self.parse_message_content(messages[-2].content)
215+
answer_payload = json.loads(messages[-1].content)
216+
217+
logging.info("Answer Payload: %s", answer_payload)
218+
sql_query_results = self.last_message_by_agent(
219+
messages, "parallel_query_solving_agent"
220+
)
192221

193222
try:
194223
if isinstance(sql_query_results, str):
195224
sql_query_results = json.loads(sql_query_results)
225+
elif sql_query_results is None:
226+
sql_query_results = {}
196227
except json.JSONDecodeError:
197228
logging.warning("Unable to read SQL query results: %s", sql_query_results)
198229
sql_query_results = {}
199230

200231
try:
201-
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
232+
steps = self.extract_steps(messages)
202233

203234
logging.info("SQL Query Results: %s", sql_query_results)
204-
payload = AnswerWithSourcesPayload(
205-
answer=answer, decomposed_user_messages=decomposed_user_messages
206-
)
235+
payload = AnswerWithSourcesPayload(**answer_payload, steps=steps)
207236

208237
if not isinstance(sql_query_results, dict):
209238
logging.error(f"Expected dict, got {type(sql_query_results)}")
@@ -248,10 +277,9 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
248277

249278
except Exception as e:
250279
logging.error("Error processing results: %s", str(e))
280+
251281
# Return payload with error context instead of empty
252-
return AnswerWithSourcesPayload(
253-
answer=f"{answer}\nError processing results: {str(e)}"
254-
)
282+
return AnswerWithSourcesPayload(**answer_payload)
255283

256284
async def process_user_message(
257285
self,
@@ -295,7 +323,10 @@ async def process_user_message(
295323
payload = ProcessingUpdatePayload(
296324
message="Solving the query...",
297325
)
298-
elif message.source == "answer_agent":
326+
elif (
327+
message.source == "answer_agent"
328+
or message.source == "answer_with_follow_up_questions_agent"
329+
):
299330
payload = ProcessingUpdatePayload(
300331
message="Generating the answer...",
301332
)
@@ -304,7 +335,11 @@ async def process_user_message(
304335
# Now we need to return the final answer or the disambiguation request
305336
logging.info("TaskResult: %s", message)
306337

307-
if message.messages[-1].source == "answer_agent":
338+
if (
339+
message.messages[-1].source == "answer_agent"
340+
or message.messages[-1].source
341+
== "answer_with_follow_up_questions_agent"
342+
):
308343
# If the message is from the answer_agent, we need to return the final answer
309344
payload = self.extract_answer_payload(message.messages)
310345
elif message.messages[-1].source == "parallel_query_solving_agent":

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
88
from jinja2 import Template
99
import logging
10+
from text_2_sql_core.structured_outputs import (
11+
AnswerAgentOutput,
12+
AnswerWithFollowUpQuestionsAgentOutput,
13+
UserMessageRewriteAgentOutput,
14+
)
15+
from autogen_core.model_context import BufferedChatCompletionContext
1016

1117

1218
class LLMAgentCreator:
@@ -106,10 +112,22 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
106112
for tool in agent_file["tools"]:
107113
tools.append(cls.get_tool(sql_helper, tool))
108114

115+
structured_output = None
116+
if agent_file.get("structured_output", False):
117+
# Import the structured output agent
118+
if name == "answer_agent":
119+
structured_output = AnswerAgentOutput
120+
elif name == "answer_with_follow_up_questions_agent":
121+
structured_output = AnswerWithFollowUpQuestionsAgentOutput
122+
elif name == "user_message_rewrite_agent":
123+
structured_output = UserMessageRewriteAgentOutput
124+
109125
agent = AssistantAgent(
110126
name=name,
111127
tools=tools,
112-
model_client=LLMModelCreator.get_model(agent_file["model"]),
128+
model_client=LLMModelCreator.get_model(
129+
agent_file["model"], structured_output=structured_output
130+
),
113131
description=cls.get_property_and_render_parameters(
114132
agent_file, "description", kwargs
115133
),
@@ -118,4 +136,9 @@ def create(cls, name: str, **kwargs) -> AssistantAgent:
118136
),
119137
)
120138

139+
if "context_size" in agent_file:
140+
agent.model_context = BufferedChatCompletionContext(
141+
buffer_size=agent_file["context_size"]
142+
)
143+
121144
return agent

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_model_creator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
class LLMModelCreator:
1414
@classmethod
15-
def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
15+
def get_model(
16+
cls, model_name: str, structured_output=None
17+
) -> AzureOpenAIChatCompletionClient:
1618
"""Retrieves the model based on the model name.
1719
1820
Args:
@@ -22,9 +24,9 @@ def get_model(cls, model_name: str) -> AzureOpenAIChatCompletionClient:
2224
Returns:
2325
AzureOpenAIChatCompletionClient: The model client."""
2426
if model_name == "4o-mini":
25-
return cls.gpt_4o_mini_model()
27+
return cls.gpt_4o_mini_model(structured_output=structured_output)
2628
elif model_name == "4o":
27-
return cls.gpt_4o_model()
29+
return cls.gpt_4o_model(structured_output=structured_output)
2830
else:
2931
raise ValueError(f"Model {model_name} not found")
3032

@@ -46,7 +48,9 @@ def get_authentication_properties(cls) -> dict:
4648
return token_provider, api_key
4749

4850
@classmethod
49-
def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
51+
def gpt_4o_mini_model(
52+
cls, structured_output=None
53+
) -> AzureOpenAIChatCompletionClient:
5054
token_provider, api_key = cls.get_authentication_properties()
5155
return AzureOpenAIChatCompletionClient(
5256
azure_deployment=os.environ["OpenAI__MiniCompletionDeployment"],
@@ -61,10 +65,11 @@ def gpt_4o_mini_model(cls) -> AzureOpenAIChatCompletionClient:
6165
"json_output": True,
6266
},
6367
temperature=0,
68+
response_format=structured_output,
6469
)
6570

6671
@classmethod
67-
def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
72+
def gpt_4o_model(cls, structured_output=None) -> AzureOpenAIChatCompletionClient:
6873
token_provider, api_key = cls.get_authentication_properties()
6974
return AzureOpenAIChatCompletionClient(
7075
azure_deployment=os.environ["OpenAI__CompletionDeployment"],
@@ -79,4 +84,5 @@ def gpt_4o_model(cls) -> AzureOpenAIChatCompletionClient:
7984
"json_output": True,
8085
},
8186
temperature=0,
87+
response_format=structured_output,
8288
)

0 commit comments

Comments
 (0)