Skip to content

Commit 5efa871

Browse files
committed
Update answer
1 parent 4520da9 commit 5efa871

File tree

6 files changed

+111
-128
lines changed

6 files changed

+111
-128
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from autogen_agentchat.conditions import (
44
TextMentionTermination,
55
MaxMessageTermination,
6+
SourceMatchTermination,
67
)
78
from autogen_agentchat.teams import SelectorGroupChat
89
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
@@ -11,9 +12,6 @@
1112
from autogen_text_2_sql.custom_agents.parallel_query_solving_agent import (
1213
ParallelQuerySolvingAgent,
1314
)
14-
from autogen_text_2_sql.custom_agents.sources_agent import (
15-
SourcesAgent,
16-
)
1715
from autogen_agentchat.agents import UserProxyAgent
1816
from autogen_agentchat.messages import TextMessage
1917
import json
@@ -66,16 +64,14 @@ def get_all_agents(self):
6664

6765
self.answer_agent = LLMAgentCreator.create("answer_agent")
6866

69-
self.sources_agent = SourcesAgent()
70-
7167
# Auto-responding UserProxyAgent
7268
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")
7369

7470
agents = [
7571
self.user_proxy,
7672
self.query_rewrite_agent,
7773
self.parallel_query_solving_agent,
78-
self.sources_agent,
74+
self.answer_agent,
7975
]
8076

8177
return agents
@@ -85,8 +81,9 @@ def termination_condition(self):
8581
"""Define the termination condition for the chat."""
8682
termination = (
8783
TextMentionTermination("TERMINATE")
88-
| (TextMentionTermination("answer") & TextMentionTermination("sources"))
89-
| MaxMessageTermination(20)
84+
| SourceMatchTermination("answer_agent")
85+
| TextMentionTermination("requires_user_information_request")
86+
| MaxMessageTermination(5)
9087
)
9188
return termination
9289

@@ -105,8 +102,6 @@ def unified_selector(self, messages):
105102
# Handle transition after parallel query solving
106103
elif current_agent == "parallel_query_solving_agent":
107104
decision = "answer_agent"
108-
elif current_agent == "answer_agent":
109-
decision = "sources_agent"
110105

111106
if decision:
112107
logging.info(f"Agent transition: {current_agent} -> {decision}")
@@ -127,6 +122,45 @@ def agentic_flow(self):
127122
)
128123
return flow
129124

125+
def extract_sources(self, messages: list) -> AnswerWithSources:
126+
"""Extract the sources from the answer."""
127+
128+
answer = messages[-1].content
129+
130+
sql_query_results = messages[-2].content
131+
132+
try:
133+
sql_query_results = json.loads(sql_query_results)
134+
135+
logging.info("SQL Query Results: %s", sql_query_results)
136+
137+
sources = []
138+
139+
for question, sql_query_result_list in sql_query_results["results"].items():
140+
logging.info(
141+
"SQL Query Result for question '%s': %s",
142+
question,
143+
sql_query_result_list,
144+
)
145+
146+
for sql_query_result in sql_query_result_list:
147+
logging.info("SQL Query Result: %s", sql_query_result)
148+
sources.append(
149+
{
150+
"sql_query": sql_query_result["sql_query"],
151+
"sql_rows": sql_query_result["sql_rows"],
152+
}
153+
)
154+
155+
except json.JSONDecodeError:
156+
logging.error("Could not load message: %s", sql_query_results)
157+
raise ValueError("Could not load message")
158+
159+
return AnswerWithSources(
160+
answer=answer,
161+
sources=sources,
162+
)
163+
130164
async def process_question(
131165
self,
132166
question: str,
@@ -160,8 +194,7 @@ async def process_question(
160194
agent_input[f"chat_{idx}"] = chat
161195

162196
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
163-
logging.info("Message: %s", message)
164-
logging.info("Message type: %s", type(message))
197+
logging.debug("Message: %s", message)
165198

166199
payload = None
167200

@@ -184,17 +217,19 @@ async def process_question(
184217

185218
elif isinstance(message, TaskResult):
186219
# Now we need to return the final answer or the disambiguation request
220+
logging.info("TaskResult: %s", message)
187221

188-
if message.source == "answer_agent":
222+
if message.messages[-1].source == "answer_agent":
189223
# If the message is from the answer_agent, we need to return the final answer
190-
payload = AnswerWithSources(
191-
**json.loads(message.content),
192-
)
193-
else:
224+
payload = self.extract_sources(message.messages)
225+
elif message.messages[-1].source == "parallel_query_solving_agent":
194226
payload = UserInformationRequest(
195-
**json.loads(message.content),
227+
**json.loads(message.messages[-1].content),
196228
)
229+
else:
230+
logging.error("Unexpected TaskResult: %s", message)
231+
raise ValueError("Unexpected TaskResult")
197232

198233
if payload is not None:
199-
logging.info("Payload: %s", payload)
234+
logging.debug("Payload: %s", payload)
200235
yield payload

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,14 @@
33
from typing import AsyncGenerator, List, Sequence
44

55
from autogen_agentchat.agents import BaseChatAgent
6-
from autogen_agentchat.base import Response
7-
from autogen_agentchat.messages import (
8-
AgentMessage,
9-
ChatMessage,
10-
TextMessage,
11-
)
6+
from autogen_agentchat.base import Response, TaskResult
7+
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
128
from autogen_core import CancellationToken
139
import json
1410
import logging
1511
from autogen_text_2_sql.inner_autogen_text_2_sql import InnerAutoGenText2Sql
16-
1712
from aiostream import stream
13+
from json import JSONDecodeError
1814

1915

2016
class ParallelQuerySolvingAgent(BaseChatAgent):
@@ -59,7 +55,7 @@ async def on_messages_stream(
5955
logging.info(f"Query Rewrites: {query_rewrites}")
6056

6157
async def consume_inner_messages_from_agentic_flow(
62-
agentic_flow, identifier, complete_inner_messages
58+
agentic_flow, identifier, database_results
6359
):
6460
"""
6561
Consume the inner messages and append them to the specified list.
@@ -71,14 +67,43 @@ async def consume_inner_messages_from_agentic_flow(
7167
"""
7268
async for inner_message in agentic_flow:
7369
# Add message to results dictionary, tagged by the function name
74-
if identifier not in complete_inner_messages:
75-
complete_inner_messages[identifier] = []
76-
complete_inner_messages[identifier].append(inner_message)
70+
if identifier not in database_results:
71+
database_results[identifier] = []
72+
73+
logging.info(f"Checking Inner Message: {inner_message}")
74+
75+
if isinstance(inner_message, TaskResult) is False:
76+
try:
77+
inner_message = json.loads(inner_message.content)
78+
logging.info(f"Loaded: {inner_message}")
79+
80+
# Search for specific message types and add them to the final output object
81+
if (
82+
"type" in inner_message
83+
and inner_message["type"] == "query_execution_with_limit"
84+
):
85+
database_results[identifier].append(
86+
{
87+
"sql_query": inner_message["sql_query"].replace(
88+
"\n", " "
89+
),
90+
"sql_rows": inner_message["sql_rows"],
91+
}
92+
)
93+
94+
except (JSONDecodeError, TypeError) as e:
95+
logging.error("Could not load message: %s", inner_message)
96+
logging.warning(f"Error processing message: {e}")
97+
98+
except Exception as e:
99+
logging.error("Could not load message: %s", inner_message)
100+
logging.error(f"Error processing message: {e}")
101+
raise e
77102

78103
yield inner_message
79104

80105
inner_solving_generators = []
81-
complete_inner_messages = {}
106+
database_results = {}
82107

83108
# Start processing sub-queries
84109
for query_rewrite in query_rewrites["sub_queries"]:
@@ -95,32 +120,33 @@ async def consume_inner_messages_from_agentic_flow(
95120
question=query_rewrite, parameters=user_parameters
96121
),
97122
query_rewrite,
98-
complete_inner_messages,
123+
database_results,
99124
)
100125
)
101126

102-
logging.info("Created %i Inner Solving Generators", inner_solving_generators)
127+
logging.info(
128+
"Created %i Inner Solving Generators", len(inner_solving_generators)
129+
)
103130
logging.info("Starting Inner Solving Generators")
104131
combined_message_streams = stream.merge(*inner_solving_generators)
105132

106133
async with combined_message_streams.stream() as streamer:
107134
async for inner_message in streamer:
108-
logging.info(f"Inner Solving Message: {inner_message}")
109-
yield inner_message
135+
if isinstance(inner_message, TextMessage):
136+
logging.debug(f"Inner Solving Message: {inner_message}")
137+
yield inner_message
110138

111139
# Log final results for debugging or auditing
112-
logging.info(f"Formatted Results: {complete_inner_messages}")
140+
logging.info(f"Database Results: {database_results}")
113141

114-
# TODO: Trim out unnecessary information from the final response
115142
# Final response
116143
yield Response(
117144
chat_message=TextMessage(
118-
content=json.dumps(complete_inner_messages), source=self.name
145+
content=json.dumps(
146+
{"contains_results": True, "results": database_results}
147+
),
148+
source=self.name,
119149
),
120-
inner_messages=[
121-
complete_inner_message["message"]
122-
for complete_inner_message in complete_inner_messages
123-
],
124150
)
125151

126152
async def on_reset(self, cancellation_token: CancellationToken) -> None:

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sources_agent.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def unified_selector(self, messages):
150150
correction_result = json.loads(messages[-1].content)
151151
if isinstance(correction_result, dict):
152152
if "answer" in correction_result and "sources" in correction_result:
153-
decision = "answer_and_sources_agent"
153+
decision = "user_proxy"
154154
elif "corrected_query" in correction_result:
155155
if correction_result.get("executing", False):
156156
decision = "sql_query_correction_agent"
@@ -166,8 +166,6 @@ def unified_selector(self, messages):
166166
decision = "sql_query_generation_agent"
167167
except json.JSONDecodeError:
168168
decision = "sql_query_generation_agent"
169-
elif current_agent == "answer_and_sources_agent":
170-
decision = "user_proxy" # Let user_proxy send TERMINATE
171169

172170
if decision:
173171
logging.info(f"Agent transition: {current_agent} -> {decision}")

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
class Source(BaseModel):
55
sql_query: str
66
sql_rows: list[dict]
7-
markdown_table: str
87

98

109
class AnswerWithSources(BaseModel):
1110
answer: str
12-
sources: list[str] = Field(default_factory=list)
11+
sources: list[Source] = Field(default_factory=list)

text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/answer_agent.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,9 @@ system_message: |
44
<role_and_objective>
55
You are a helpful AI Assistant specializing in answering a user's question.
66
</role_and_objective>
7+
8+
Use the information obtained to generate a response to the user's question. The question has been broken down into a series of SQL queries and you need to generate a response based on the results of these queries.
9+
10+
Do not use any external resources to generate the response. The response should be based solely on the information provided in the SQL queries and their results.
11+
12+
You can use Markdown and Markdown tables to format the response.

0 commit comments

Comments
 (0)