Skip to content

Commit 0c39714

Browse files
Payload and Disambiguation Improvements (#135)
1 parent 298e175 commit 0c39714

16 files changed

+301
-187
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"source": [
5151
"import dotenv\n",
5252
"import logging\n",
53-
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload"
53+
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload"
5454
]
5555
},
5656
{
@@ -100,7 +100,7 @@
100100
"metadata": {},
101101
"outputs": [],
102102
"source": [
103-
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n",
103+
"async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"What is the total number of sales?\")):\n",
104104
" logging.info(\"Received %s Message from Text2SQL System\", message)"
105105
]
106106
},

text_2_sql/autogen/evaluate_autogen_text2sql.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
"# Add the src directory to the path\n",
6969
"sys.path.append(str(notebook_dir / \"src\"))\n",
7070
"\n",
71-
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload\n",
71+
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n",
7272
"from autogen_text_2_sql.evaluation_utils import get_final_sql_query\n",
7373
"\n",
7474
"# Configure logging\n",
@@ -127,7 +127,7 @@
127127
" all_queries = []\n",
128128
" final_query = None\n",
129129
" \n",
130-
" async for message in autogen_text2sql.process_question(QuestionPayload(question=question)):\n",
130+
" async for message in autogen_text2sql.process_user_message(UserMessagePayload(user_message=question)):\n",
131131
" if message.payload_type == \"answer_with_sources\":\n",
132132
" # Extract from results\n",
133133
" if hasattr(message.body, 'results'):\n",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4-
from text_2_sql_core.payloads.interaction_payloads import QuestionPayload
4+
from text_2_sql_core.payloads.interaction_payloads import UserMessagePayload
55

6-
__all__ = ["AutoGenText2Sql", "QuestionPayload"]
6+
__all__ = ["AutoGenText2Sql", "UserMessagePayload"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import re
2020

2121
from text_2_sql_core.payloads.interaction_payloads import (
22-
QuestionPayload,
22+
UserMessagePayload,
2323
AnswerWithSourcesPayload,
24-
DismabiguationRequestPayload,
24+
DismabiguationRequestsPayload,
2525
ProcessingUpdatePayload,
2626
InteractionPayload,
2727
PayloadType,
@@ -40,16 +40,16 @@ def get_all_agents(self):
4040
# Get current datetime for the Query Rewrite Agent
4141
current_datetime = datetime.now()
4242

43-
self.question_rewrite_agent = LLMAgentCreator.create(
44-
"question_rewrite_agent", current_datetime=current_datetime
43+
self.user_message_rewrite_agent = LLMAgentCreator.create(
44+
"user_message_rewrite_agent", current_datetime=current_datetime
4545
)
4646

4747
self.parallel_query_solving_agent = ParallelQuerySolvingAgent(**self.kwargs)
4848

4949
self.answer_agent = LLMAgentCreator.create("answer_agent")
5050

5151
agents = [
52-
self.question_rewrite_agent,
52+
self.user_message_rewrite_agent,
5353
self.parallel_query_solving_agent,
5454
self.answer_agent,
5555
]
@@ -62,7 +62,7 @@ def termination_condition(self):
6262
termination = (
6363
TextMentionTermination("TERMINATE")
6464
| SourceMatchTermination("answer_agent")
65-
| TextMentionTermination("requires_user_information_request")
65+
| TextMentionTermination("contains_disambiguation_requests")
6666
| MaxMessageTermination(5)
6767
)
6868
return termination
@@ -73,11 +73,11 @@ def unified_selector(self, messages):
7373
current_agent = messages[-1].source if messages else "user"
7474
decision = None
7575

76-
# If this is the first message start with question_rewrite_agent
76+
# If this is the first message start with user_message_rewrite_agent
7777
if current_agent == "user":
78-
decision = "question_rewrite_agent"
78+
decision = "user_message_rewrite_agent"
7979
# Handle transition after query rewriting
80-
elif current_agent == "question_rewrite_agent":
80+
elif current_agent == "user_message_rewrite_agent":
8181
decision = "parallel_query_solving_agent"
8282
# Handle transition after parallel query solving
8383
elif current_agent == "parallel_query_solving_agent":
@@ -102,15 +102,6 @@ def agentic_flow(self):
102102
)
103103
return flow
104104

105-
def extract_disambiguation_request(
106-
self, messages: list
107-
) -> DismabiguationRequestPayload:
108-
"""Extract the disambiguation request from the answer."""
109-
disambiguation_request = messages[-1].content
110-
return DismabiguationRequestPayload(
111-
disambiguation_request=disambiguation_request,
112-
)
113-
114105
def parse_message_content(self, content):
115106
"""Parse different message content formats into a dictionary."""
116107
if isinstance(content, (list, dict)):
@@ -134,6 +125,49 @@ def parse_message_content(self, content):
134125
# If all parsing attempts fail, return the content as-is
135126
return content
136127

128+
def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]:
129+
"""Extract the decomposed messages from the answer."""
130+
# Only load sub-message results if we have a database result
131+
sub_message_results = self.parse_message_content(messages[1].content)
132+
logging.info("Decomposed Results: %s", sub_message_results)
133+
134+
decomposed_user_messages = sub_message_results.get(
135+
"decomposed_user_messages", []
136+
)
137+
138+
logging.debug(
139+
"Returning decomposed_user_messages: %s", decomposed_user_messages
140+
)
141+
142+
return decomposed_user_messages
143+
144+
def extract_disambiguation_request(
145+
self, messages: list
146+
) -> DismabiguationRequestsPayload:
147+
"""Extract the disambiguation request from the answer."""
148+
all_disambiguation_requests = self.parse_message_content(messages[-1].content)
149+
150+
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
151+
request_payload = DismabiguationRequestsPayload(
152+
decomposed_user_messages=decomposed_user_messages
153+
)
154+
155+
for per_question_disambiguation_request in all_disambiguation_requests[
156+
"disambiguation_requests"
157+
].values():
158+
for disambiguation_request in per_question_disambiguation_request:
159+
logging.info(
160+
"Disambiguation Request Identified: %s", disambiguation_request
161+
)
162+
163+
request = DismabiguationRequestsPayload.Body.DismabiguationRequest(
164+
agent_question=disambiguation_request["agent_question"],
165+
user_choices=disambiguation_request["user_choices"],
166+
)
167+
request_payload.body.disambiguation_requests.append(request)
168+
169+
return request_payload
170+
137171
def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
138172
"""Extract the sources from the answer."""
139173
answer = messages[-1].content
@@ -145,41 +179,35 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
145179
except json.JSONDecodeError:
146180
logging.warning("Unable to read SQL query results: %s", sql_query_results)
147181
sql_query_results = {}
148-
sub_question_results = {}
149-
else:
150-
# Only load sub-question results if we have a database result
151-
sub_question_results = self.parse_message_content(messages[1].content)
152-
logging.info("Sub-Question Results: %s", sub_question_results)
153182

154183
try:
155-
sub_questions = [
156-
sub_question
157-
for sub_question_group in sub_question_results.get("sub_questions", [])
158-
for sub_question in sub_question_group
159-
]
184+
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
160185

161186
logging.info("SQL Query Results: %s", sql_query_results)
162187
payload = AnswerWithSourcesPayload(
163-
answer=answer, sub_questions=sub_questions
188+
answer=answer, decomposed_user_messages=decomposed_user_messages
164189
)
165190

166191
if not isinstance(sql_query_results, dict):
167192
logging.error(f"Expected dict, got {type(sql_query_results)}")
168193
return payload
169194

170-
if "results" not in sql_query_results:
195+
if "database_results" not in sql_query_results:
171196
logging.error("No 'results' key in sql_query_results")
172197
return payload
173198

174-
for question, sql_query_result_list in sql_query_results["results"].items():
199+
for message, sql_query_result_list in sql_query_results[
200+
"database_results"
201+
].items():
175202
if not sql_query_result_list: # Check if list is empty
176-
logging.warning(f"No results for question: {question}")
203+
logging.warning(f"No results for message: {message}")
177204
continue
178205

179206
for sql_query_result in sql_query_result_list:
180207
if not isinstance(sql_query_result, dict):
181208
logging.error(
182-
f"Expected dict for sql_query_result, got {type(sql_query_result)}"
209+
"Expected dict for sql_query_result, got %s",
210+
type(sql_query_result),
183211
)
184212
continue
185213

@@ -208,47 +236,47 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
208236
answer=f"{answer}\nError processing results: {str(e)}"
209237
)
210238

211-
async def process_question(
239+
async def process_user_message(
212240
self,
213-
question_payload: QuestionPayload,
241+
message_payload: UserMessagePayload,
214242
chat_history: list[InteractionPayload] = None,
215243
) -> AsyncGenerator[InteractionPayload, None]:
216-
"""Process the complete question through the unified system.
244+
"""Process the complete message through the unified system.
217245
218246
Args:
219247
----
220-
task (str): The user question to process.
248+
task (str): The user message to process.
221249
chat_history (list[str], optional): The chat history. Defaults to None.
222250
injected_parameters (dict, optional): Parameters to pass to agents. Defaults to None.
223251
224252
Returns:
225253
-------
226254
dict: The response from the system.
227255
"""
228-
logging.info("Processing question: %s", question_payload.body.question)
256+
logging.info("Processing message: %s", message_payload.body.user_message)
229257
logging.info("Chat history: %s", chat_history)
230258

231259
agent_input = {
232-
"question": question_payload.body.question,
260+
"message": message_payload.body.user_message,
233261
"chat_history": {},
234-
"injected_parameters": question_payload.body.injected_parameters,
262+
"injected_parameters": message_payload.body.injected_parameters,
235263
}
236264

237265
if chat_history is not None:
238266
# Update input
239267
for idx, chat in enumerate(chat_history):
240-
if chat.root.payload_type == PayloadType.QUESTION:
268+
if chat.root.payload_type == PayloadType.USER_MESSAGE:
241269
# For now only consider the user query
242270
chat_history_key = f"chat_{idx}"
243-
agent_input[chat_history_key] = chat.root.body.question
271+
agent_input[chat_history_key] = chat.root.body.user_message
244272

245273
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
246274
logging.debug("Message: %s", message)
247275

248276
payload = None
249277

250278
if isinstance(message, TextMessage):
251-
if message.source == "question_rewrite_agent":
279+
if message.source == "user_message_rewrite_agent":
252280
payload = ProcessingUpdatePayload(
253281
message="Rewriting the query...",
254282
)
@@ -271,10 +299,10 @@ async def process_question(
271299
elif message.messages[-1].source == "parallel_query_solving_agent":
272300
# Load into disambiguation request
273301
payload = self.extract_disambiguation_request(message.messages)
274-
elif message.messages[-1].source == "question_rewrite_agent":
302+
elif message.messages[-1].source == "user_message_rewrite_agent":
275303
# Load into empty response
276304
payload = AnswerWithSourcesPayload(
277-
answer="Apologies, I cannot answer that question as it is not relevant. Please try another question or rephrase your current question."
305+
answer="Apologies, I cannot answer that message as it is not relevant. Please try another message or rephrase your current message."
278306
)
279307
else:
280308
logging.error("Unexpected TaskResult: %s", message)

text_2_sql/autogen/src/autogen_text_2_sql/creators/llm_agent_creator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_tool(cls, sql_helper, tool_name: str):
4343
elif tool_name == "sql_get_entity_schemas_tool":
4444
return FunctionToolAlias(
4545
sql_helper.get_entity_schemas,
46-
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
46+
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user input and use these as the search term. Several entities may be returned. Only use when the provided schemas in the message history are not sufficient to answer the question.",
4747
)
4848
elif tool_name == "sql_get_column_values_tool":
4949
return FunctionToolAlias(

0 commit comments

Comments
 (0)