Skip to content

Commit 087797b

Browse files
Adds data contract and reduces agent calls. (#119)
1 parent 66660ed commit 087797b

18 files changed

+306
-293
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"source": [
5151
"import dotenv\n",
5252
"import logging\n",
53-
"from autogen_text_2_sql import AutoGenText2Sql"
53+
"from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody"
5454
]
5555
},
5656
{
@@ -100,7 +100,7 @@
100100
"metadata": {},
101101
"outputs": [],
102102
"source": [
103-
"async for message in agentic_text_2_sql.process_question(question=\"What total number of orders in June 2008?\"):\n",
103+
"async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n",
104104
" logging.info(\"Received %s Message from Text2SQL System\", message)"
105105
]
106106
},
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
13
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4+
from text_2_sql_core.payloads.agent_response import AgentRequestBody
25

3-
__all__ = ["AutoGenText2Sql"]
6+
__all__ = ["AutoGenText2Sql", "AgentRequestBody"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@
1818
import os
1919
from datetime import datetime
2020

21-
from text_2_sql_core.payloads import (
21+
from text_2_sql_core.payloads.agent_response import (
22+
AgentResponse,
23+
AgentRequestBody,
2224
AnswerWithSources,
23-
UserInformationRequest,
25+
Source,
26+
DismabiguationRequests,
27+
)
28+
from text_2_sql_core.payloads.chat_history import ChatHistoryItem
29+
from text_2_sql_core.payloads.processing_update import (
30+
ProcessingUpdateBody,
2431
ProcessingUpdate,
25-
ChatHistoryItem,
2632
)
2733
from autogen_agentchat.base import Response, TaskResult
2834
from typing import AsyncGenerator
@@ -123,6 +129,16 @@ def agentic_flow(self):
123129
)
124130
return flow
125131

132+
def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests:
133+
"""Extract the disambiguation request from the answer."""
134+
135+
disambiguation_request = messages[-1].content
136+
137+
# TODO: Properly extract the disambiguation request
138+
return DismabiguationRequests(
139+
disambiguation_request=disambiguation_request,
140+
)
141+
126142
def extract_sources(self, messages: list) -> AnswerWithSources:
127143
"""Extract the sources from the answer."""
128144

@@ -147,10 +163,10 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
147163
for sql_query_result in sql_query_result_list:
148164
logging.info("SQL Query Result: %s", sql_query_result)
149165
sources.append(
150-
{
151-
"sql_query": sql_query_result["sql_query"],
152-
"sql_rows": sql_query_result["sql_rows"],
153-
}
166+
Source(
167+
sql_query=sql_query_result["sql_query"],
168+
sql_rows=sql_query_result["sql_rows"],
169+
)
154170
)
155171

156172
except json.JSONDecodeError:
@@ -164,10 +180,9 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
164180

165181
async def process_question(
166182
self,
167-
question: str,
183+
request: AgentRequestBody,
168184
chat_history: list[ChatHistoryItem] = None,
169-
injected_parameters: dict = None,
170-
) -> AsyncGenerator[AnswerWithSources | UserInformationRequest, None]:
185+
) -> AsyncGenerator[AgentResponse | ProcessingUpdate, None]:
171186
"""Process the complete question through the unified system.
172187
173188
Args:
@@ -180,58 +195,63 @@ async def process_question(
180195
-------
181196
dict: The response from the system.
182197
"""
183-
logging.info("Processing question: %s", question)
198+
logging.info("Processing question: %s", request.question)
184199
logging.info("Chat history: %s", chat_history)
185200

186201
agent_input = {
187-
"question": question,
202+
"question": request.question,
188203
"chat_history": {},
189-
"injected_parameters": injected_parameters,
204+
"injected_parameters": request.injected_parameters,
190205
}
191206

192207
if chat_history is not None:
193208
# Update input
194209
for idx, chat in enumerate(chat_history):
195210
# For now only consider the user query
196-
agent_input[f"chat_{idx}"] = chat.user_query
211+
agent_input[f"chat_{idx}"] = chat.request.question
197212

198213
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
199214
logging.debug("Message: %s", message)
200215

201216
payload = None
202217

203218
if isinstance(message, TextMessage):
219+
processing_update = None
204220
if message.source == "query_rewrite_agent":
205-
# If the message is from the query_rewrite_agent, we need to update the chat history
206-
payload = ProcessingUpdate(
221+
processing_update = ProcessingUpdateBody(
207222
message="Rewriting the query...",
208223
)
209224
elif message.source == "parallel_query_solving_agent":
210-
# If the message is from the parallel_query_solving_agent, we need to update the chat history
211-
payload = ProcessingUpdate(
225+
processing_update = ProcessingUpdateBody(
212226
message="Solving the query...",
213227
)
214228
elif message.source == "answer_agent":
215-
# If the message is from the answer_agent, we need to update the chat history
216-
payload = ProcessingUpdate(
229+
processing_update = ProcessingUpdateBody(
217230
message="Generating the answer...",
218231
)
219232

233+
if processing_update is not None:
234+
payload = ProcessingUpdate(
235+
processing_update=processing_update,
236+
)
237+
220238
elif isinstance(message, TaskResult):
221239
# Now we need to return the final answer or the disambiguation request
222240
logging.info("TaskResult: %s", message)
223241

242+
response = None
224243
if message.messages[-1].source == "answer_agent":
225244
# If the message is from the answer_agent, we need to return the final answer
226-
payload = self.extract_sources(message.messages)
245+
response = self.extract_sources(message.messages)
227246
elif message.messages[-1].source == "parallel_query_solving_agent":
228-
payload = UserInformationRequest(
229-
**json.loads(message.messages[-1].content),
230-
)
247+
# Load into disambiguation request
248+
response = self.extract_disambiguation_request(message.messages)
231249
else:
232250
logging.error("Unexpected TaskResult: %s", message)
233251
raise ValueError("Unexpected TaskResult")
234252

253+
payload = AgentResponse(request=request, response=response)
254+
235255
if payload is not None:
236256
logging.debug("Payload: %s", payload)
237257
yield payload

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,13 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
4242
elif tool_name == "sql_get_entity_schemas_tool":
4343
return FunctionToolAlias(
4444
sql_helper.get_entity_schemas,
45-
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
45+
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
4646
)
4747
elif tool_name == "sql_get_column_values_tool":
4848
return FunctionToolAlias(
4949
ai_search_helper.get_column_values,
5050
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
5151
)
52-
elif tool_name == "current_datetime_tool":
53-
return FunctionToolAlias(
54-
sql_helper.get_current_datetime,
55-
description="Gets the current date and time.",
56-
)
5752
else:
5853
raise ValueError(f"Tool {tool_name} not found")
5954

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ async def on_messages_stream(
4444
last_response = messages[-1].content
4545
parameter_input = messages[0].content
4646
try:
47-
user_parameters = json.loads(parameter_input)["parameters"]
47+
injected_parameters = json.loads(parameter_input)["injected_parameters"]
4848
except json.JSONDecodeError:
4949
logging.error("Error decoding the user parameters.")
50-
user_parameters = {}
50+
injected_parameters = {}
5151

5252
# Load the json of the last message to populate the final output object
5353
query_rewrites = json.loads(last_response)
@@ -75,7 +75,7 @@ async def consume_inner_messages_from_agentic_flow(
7575
if isinstance(inner_message, TaskResult) is False:
7676
try:
7777
inner_message = json.loads(inner_message.content)
78-
logging.info(f"Loaded: {inner_message}")
78+
logging.info(f"Inner Loaded: {inner_message}")
7979

8080
# Search for specific message types and add them to the final output object
8181
if (
@@ -91,6 +91,21 @@ async def consume_inner_messages_from_agentic_flow(
9191
}
9292
)
9393

94+
if ("contains_pre_run_results" in inner_message) and (
95+
inner_message["contains_pre_run_results"] is True
96+
):
97+
for pre_run_sql_query, pre_run_result in inner_message[
98+
"cached_questions_and_schemas"
99+
].items():
100+
database_results[identifier].append(
101+
{
102+
"sql_query": pre_run_sql_query.replace(
103+
"\n", " "
104+
),
105+
"sql_rows": pre_run_result["sql_rows"],
106+
}
107+
)
108+
94109
except (JSONDecodeError, TypeError) as e:
95110
logging.error("Could not load message: %s", inner_message)
96111
logging.warning(f"Error processing message: {e}")
@@ -113,13 +128,15 @@ async def consume_inner_messages_from_agentic_flow(
113128
self.engine_specific_rules, **self.kwargs
114129
)
115130

131+
identifier = ", ".join(query_rewrite)
132+
116133
# Launch tasks for each sub-query
117134
inner_solving_generators.append(
118135
consume_inner_messages_from_agentic_flow(
119136
inner_autogen_text_2_sql.process_question(
120-
question=query_rewrite, parameters=user_parameters
137+
question=query_rewrite, injected_parameters=injected_parameters
121138
),
122-
query_rewrite,
139+
identifier,
123140
database_results,
124141
)
125142
)

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,55 +39,46 @@ async def on_messages_stream(
3939
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4040
) -> AsyncGenerator[AgentMessage | Response, None]:
4141
# Get the decomposed questions from the query_rewrite_agent
42-
parameter_input = messages[0].content
43-
last_response = messages[-1].content
4442
try:
45-
user_questions = json.loads(last_response)
46-
injected_parameters = json.loads(parameter_input)["injected_parameters"]
43+
request_details = json.loads(messages[0].content)
44+
injected_parameters = request_details["injected_parameters"]
45+
user_questions = request_details["question"]
4746
logging.info(f"Processing questions: {user_questions}")
4847
logging.info(f"Input Parameters: {injected_parameters}")
48+
except json.JSONDecodeError:
49+
# If not JSON array, process as single question
50+
raise ValueError("Could not load message")
4951

50-
# Initialize results dictionary
51-
cached_results = {
52-
"cached_questions_and_schemas": [],
53-
"contains_pre_run_results": False,
54-
}
55-
56-
# Process each question sequentially
57-
for question in user_questions:
58-
# Fetch the queries from the cache based on the question
59-
logging.info(f"Fetching queries from cache for question: {question}")
60-
cached_query = await self.sql_connector.fetch_queries_from_cache(
61-
question, injected_parameters=injected_parameters
62-
)
52+
# Initialize results dictionary
53+
cached_results = {
54+
"cached_questions_and_schemas": [],
55+
"contains_pre_run_results": False,
56+
}
6357

64-
# If any question has pre-run results, set the flag
65-
if cached_query.get("contains_pre_run_results", False):
66-
cached_results["contains_pre_run_results"] = True
58+
# Process each question sequentially
59+
for question in user_questions:
60+
# Fetch the queries from the cache based on the question
61+
logging.info(f"Fetching queries from cache for question: {question}")
62+
cached_query = await self.sql_connector.fetch_queries_from_cache(
63+
question, injected_parameters=injected_parameters
64+
)
6765

68-
# Add the cached results for this question
69-
if cached_query.get("cached_questions_and_schemas"):
70-
cached_results["cached_questions_and_schemas"].extend(
71-
cached_query["cached_questions_and_schemas"]
72-
)
66+
# If any question has pre-run results, set the flag
67+
if cached_query.get("contains_pre_run_results", False):
68+
cached_results["contains_pre_run_results"] = True
7369

74-
logging.info(f"Final cached results: {cached_results}")
75-
yield Response(
76-
chat_message=TextMessage(
77-
content=json.dumps(cached_results), source=self.name
78-
)
79-
)
80-
except json.JSONDecodeError:
81-
# If not JSON array, process as single question
82-
logging.info(f"Processing single question: {last_response}")
83-
cached_queries = await self.sql_connector.fetch_queries_from_cache(
84-
last_response
85-
)
86-
yield Response(
87-
chat_message=TextMessage(
88-
content=json.dumps(cached_queries), source=self.name
70+
# Add the cached results for this question
71+
if cached_query.get("cached_questions_and_schemas"):
72+
cached_results["cached_questions_and_schemas"].extend(
73+
cached_query["cached_questions_and_schemas"]
8974
)
75+
76+
logging.info(f"Final cached results: {cached_results}")
77+
yield Response(
78+
chat_message=TextMessage(
79+
content=json.dumps(cached_results), source=self.name
9080
)
81+
)
9182

9283
async def on_reset(self, cancellation_token: CancellationToken) -> None:
9384
pass

0 commit comments

Comments
 (0)