Skip to content

Commit d507b15

Browse files
Update message types to have common schema (#121)
1 parent 6d6886b commit d507b15

File tree

7 files changed

+163
-195
lines changed

7 files changed

+163
-195
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"source": [
5151
"import dotenv\n",
5252
"import logging\n",
53-
"from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody"
53+
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload"
5454
]
5555
},
5656
{
@@ -100,16 +100,9 @@
100100
"metadata": {},
101101
"outputs": [],
102102
"source": [
103-
"async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n",
103+
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What total number of orders in June 2008?\")):\n",
104104
" logging.info(\"Received %s Message from Text2SQL System\", message)"
105105
]
106-
},
107-
{
108-
"cell_type": "code",
109-
"execution_count": null,
110-
"metadata": {},
111-
"outputs": [],
112-
"source": []
113106
}
114107
],
115108
"metadata": {
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4-
from text_2_sql_core.payloads.agent_request_response_pair import AgentRequestBody
4+
from text_2_sql_core.payloads.interaction_payloads import QuestionPayload
55

6-
__all__ = ["AutoGenText2Sql", "AgentRequestBody"]
6+
__all__ = ["AutoGenText2Sql", "QuestionPayload"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,13 @@
1717
import os
1818
from datetime import datetime
1919

20-
from text_2_sql_core.payloads.agent_request_response_pair import (
21-
AgentRequestResponsePair,
22-
AgentRequestBody,
23-
AnswerWithSources,
24-
Source,
25-
DismabiguationRequests,
26-
)
27-
from text_2_sql_core.payloads.chat_history import ChatHistoryItem
28-
from text_2_sql_core.payloads.processing_update import (
29-
ProcessingUpdateBody,
30-
ProcessingUpdate,
20+
from text_2_sql_core.payloads.interaction_payloads import (
21+
QuestionPayload,
22+
AnswerWithSourcesPayload,
23+
DismabiguationRequestPayload,
24+
ProcessingUpdatePayload,
25+
InteractionPayload,
26+
PayloadType,
3127
)
3228
from autogen_agentchat.base import TaskResult
3329
from typing import AsyncGenerator
@@ -108,17 +104,19 @@ def agentic_flow(self):
108104
)
109105
return flow
110106

111-
def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests:
107+
def extract_disambiguation_request(
108+
self, messages: list
109+
) -> DismabiguationRequestPayload:
112110
"""Extract the disambiguation request from the answer."""
113111

114112
disambiguation_request = messages[-1].content
115113

116114
# TODO: Properly extract the disambiguation request
117-
return DismabiguationRequests(
115+
return DismabiguationRequestPayload(
118116
disambiguation_request=disambiguation_request,
119117
)
120118

121-
def extract_sources(self, messages: list) -> AnswerWithSources:
119+
def extract_sources(self, messages: list) -> AnswerWithSourcesPayload:
122120
"""Extract the sources from the answer."""
123121

124122
answer = messages[-1].content
@@ -130,7 +128,7 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
130128

131129
logging.info("SQL Query Results: %s", sql_query_results)
132130

133-
sources = []
131+
payload = AnswerWithSourcesPayload(answer=answer)
134132

135133
for question, sql_query_result_list in sql_query_results["results"].items():
136134
logging.info(
@@ -141,27 +139,24 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
141139

142140
for sql_query_result in sql_query_result_list:
143141
logging.info("SQL Query Result: %s", sql_query_result)
144-
sources.append(
145-
Source(
146-
sql_query=sql_query_result["sql_query"],
147-
sql_rows=sql_query_result["sql_rows"],
148-
)
142+
# Instantiate Source and append to the payload's sources list
143+
source = AnswerWithSourcesPayload.Body.Source(
144+
sql_query=sql_query_result["sql_query"],
145+
sql_rows=sql_query_result["sql_rows"],
149146
)
147+
payload.body.sources.append(source)
148+
149+
return payload
150150

151151
except json.JSONDecodeError:
152152
logging.error("Could not load message: %s", sql_query_results)
153153
raise ValueError("Could not load message")
154154

155-
return AnswerWithSources(
156-
answer=answer,
157-
sources=sources,
158-
)
159-
160155
async def process_question(
161156
self,
162-
request: AgentRequestBody,
163-
chat_history: list[ChatHistoryItem] = None,
164-
) -> AsyncGenerator[AgentRequestResponsePair | ProcessingUpdate, None]:
157+
question_payload: QuestionPayload,
158+
chat_history: list[InteractionPayload] = None,
159+
) -> AsyncGenerator[InteractionPayload, None]:
165160
"""Process the complete question through the unified system.
166161
167162
Args:
@@ -174,65 +169,55 @@ async def process_question(
174169
-------
175170
dict: The response from the system.
176171
"""
177-
logging.info("Processing question: %s", request.question)
172+
logging.info("Processing question: %s", question_payload.body.question)
178173
logging.info("Chat history: %s", chat_history)
179174

180175
agent_input = {
181-
"question": request.question,
176+
"question": question_payload.body.question,
182177
"chat_history": {},
183-
"injected_parameters": request.injected_parameters,
178+
"injected_parameters": question_payload.body.injected_parameters,
184179
}
185180

186181
if chat_history is not None:
187182
# Update input
188183
for idx, chat in enumerate(chat_history):
189-
# For now only consider the user query
190-
chat_history_key = f"chat_{idx}"
191-
agent_input[
192-
chat_history_key
193-
] = chat.request_response_pair.request.question
184+
if chat.root.payload_type == PayloadType.QUESTION:
185+
# For now only consider the user query
186+
chat_history_key = f"chat_{idx}"
187+
agent_input[chat_history_key] = chat.root.body.question
194188

195189
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
196190
logging.debug("Message: %s", message)
197191

198192
payload = None
199193

200194
if isinstance(message, TextMessage):
201-
processing_update = None
202195
if message.source == "query_rewrite_agent":
203-
processing_update = ProcessingUpdateBody(
196+
payload = ProcessingUpdatePayload(
204197
message="Rewriting the query...",
205198
)
206199
elif message.source == "parallel_query_solving_agent":
207-
processing_update = ProcessingUpdateBody(
200+
payload = ProcessingUpdatePayload(
208201
message="Solving the query...",
209202
)
210203
elif message.source == "answer_agent":
211-
processing_update = ProcessingUpdateBody(
204+
payload = ProcessingUpdatePayload(
212205
message="Generating the answer...",
213206
)
214207

215-
if processing_update is not None:
216-
payload = ProcessingUpdate(
217-
processing_update=processing_update,
218-
)
219-
220208
elif isinstance(message, TaskResult):
221209
# Now we need to return the final answer or the disambiguation request
222210
logging.info("TaskResult: %s", message)
223211

224-
response = None
225212
if message.messages[-1].source == "answer_agent":
226213
# If the message is from the answer_agent, we need to return the final answer
227-
response = self.extract_sources(message.messages)
214+
payload = self.extract_sources(message.messages)
228215
elif message.messages[-1].source == "parallel_query_solving_agent":
229216
# Load into disambiguation request
230-
response = self.extract_disambiguation_request(message.messages)
231-
else:
232-
logging.error("Unexpected TaskResult: %s", message)
233-
raise ValueError("Unexpected TaskResult")
234-
235-
payload = AgentRequestResponsePair(request=request, response=response)
217+
payload = self.extract_disambiguation_request(message.messages)
218+
else:
219+
logging.error("Unexpected TaskResult: %s", message)
220+
raise ValueError("Unexpected TaskResult")
236221

237222
if payload is not None:
238223
logging.debug("Payload: %s", payload)

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_request_response_pair.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)