Skip to content

Commit 568521a

Browse files
committed
Update mesage types
1 parent 6d6886b commit 568521a

File tree

7 files changed

+157
-175
lines changed

7 files changed

+157
-175
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"source": [
5151
"import dotenv\n",
5252
"import logging\n",
53-
"from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody"
53+
"from autogen_text_2_sql import AutoGenText2Sql, QuestionBody, QuestionPayload"
5454
]
5555
},
5656
{
@@ -100,16 +100,9 @@
100100
"metadata": {},
101101
"outputs": [],
102102
"source": [
103-
"async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n",
103+
"async for message in agentic_text_2_sql.process_question(QuestionPayload(body=QuestionBody(question=\"What total number of orders in June 2008?\"))):\n",
104104
" logging.info(\"Received %s Message from Text2SQL System\", message)"
105105
]
106-
},
107-
{
108-
"cell_type": "code",
109-
"execution_count": null,
110-
"metadata": {},
111-
"outputs": [],
112-
"source": []
113106
}
114107
],
115108
"metadata": {
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4-
from text_2_sql_core.payloads.agent_request_response_pair import AgentRequestBody
4+
from text_2_sql_core.payloads.interaction_payloads import QuestionBody, QuestionPayload
55

6-
__all__ = ["AutoGenText2Sql", "AgentRequestBody"]
6+
__all__ = ["AutoGenText2Sql", "QuestionBody", "QuestionPayload"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717
import os
1818
from datetime import datetime
1919

20-
from text_2_sql_core.payloads.agent_request_response_pair import (
21-
AgentRequestResponsePair,
22-
AgentRequestBody,
23-
AnswerWithSources,
20+
from text_2_sql_core.payloads.interaction_payloads import (
21+
QuestionPayload,
22+
AnswerWithSourcesPayload,
23+
AnswerWithSourcesBody,
2424
Source,
25-
DismabiguationRequests,
26-
)
27-
from text_2_sql_core.payloads.chat_history import ChatHistoryItem
28-
from text_2_sql_core.payloads.processing_update import (
25+
DismabiguationRequestPayload,
2926
ProcessingUpdateBody,
30-
ProcessingUpdate,
27+
ProcessingUpdatePayload,
28+
InteractionPayload,
29+
PayloadType,
3130
)
3231
from autogen_agentchat.base import TaskResult
3332
from typing import AsyncGenerator
@@ -108,17 +107,19 @@ def agentic_flow(self):
108107
)
109108
return flow
110109

111-
def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests:
110+
def extract_disambiguation_request(
111+
self, messages: list
112+
) -> DismabiguationRequestPayload:
112113
"""Extract the disambiguation request from the answer."""
113114

114115
disambiguation_request = messages[-1].content
115116

116117
# TODO: Properly extract the disambiguation request
117-
return DismabiguationRequests(
118+
return DismabiguationRequestPayload(
118119
disambiguation_request=disambiguation_request,
119120
)
120121

121-
def extract_sources(self, messages: list) -> AnswerWithSources:
122+
def extract_sources(self, messages: list) -> AnswerWithSourcesPayload:
122123
"""Extract the sources from the answer."""
123124

124125
answer = messages[-1].content
@@ -152,16 +153,18 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
152153
logging.error("Could not load message: %s", sql_query_results)
153154
raise ValueError("Could not load message")
154155

155-
return AnswerWithSources(
156-
answer=answer,
157-
sources=sources,
156+
return AnswerWithSourcesPayload(
157+
body=AnswerWithSourcesBody(
158+
answer=answer,
159+
sources=sources,
160+
)
158161
)
159162

160163
async def process_question(
161164
self,
162-
request: AgentRequestBody,
163-
chat_history: list[ChatHistoryItem] = None,
164-
) -> AsyncGenerator[AgentRequestResponsePair | ProcessingUpdate, None]:
165+
question_payload: QuestionPayload,
166+
chat_history: list[InteractionPayload] = None,
167+
) -> AsyncGenerator[InteractionPayload, None]:
165168
"""Process the complete question through the unified system.
166169
167170
Args:
@@ -174,23 +177,22 @@ async def process_question(
174177
-------
175178
dict: The response from the system.
176179
"""
177-
logging.info("Processing question: %s", request.question)
180+
logging.info("Processing question: %s", question_payload.body.question)
178181
logging.info("Chat history: %s", chat_history)
179182

180183
agent_input = {
181-
"question": request.question,
184+
"question": question_payload.body.question,
182185
"chat_history": {},
183-
"injected_parameters": request.injected_parameters,
186+
"injected_parameters": question_payload.body.injected_parameters,
184187
}
185188

186189
if chat_history is not None:
187190
# Update input
188191
for idx, chat in enumerate(chat_history):
189-
# For now only consider the user query
190-
chat_history_key = f"chat_{idx}"
191-
agent_input[
192-
chat_history_key
193-
] = chat.request_response_pair.request.question
192+
if chat.root.payload_type == PayloadType.QUESTION:
193+
# For now only consider the user query
194+
chat_history_key = f"chat_{idx}"
195+
agent_input[chat_history_key] = chat.root.body.question
194196

195197
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
196198
logging.debug("Message: %s", message)
@@ -213,27 +215,24 @@ async def process_question(
213215
)
214216

215217
if processing_update is not None:
216-
payload = ProcessingUpdate(
217-
processing_update=processing_update,
218+
payload = ProcessingUpdatePayload(
219+
body=processing_update,
218220
)
219221

220222
elif isinstance(message, TaskResult):
221223
# Now we need to return the final answer or the disambiguation request
222224
logging.info("TaskResult: %s", message)
223225

224-
response = None
225226
if message.messages[-1].source == "answer_agent":
226227
# If the message is from the answer_agent, we need to return the final answer
227-
response = self.extract_sources(message.messages)
228+
payload = self.extract_sources(message.messages)
228229
elif message.messages[-1].source == "parallel_query_solving_agent":
229230
# Load into disambiguation request
230-
response = self.extract_disambiguation_request(message.messages)
231+
payload = self.extract_disambiguation_request(message.messages)
231232
else:
232233
logging.error("Unexpected TaskResult: %s", message)
233234
raise ValueError("Unexpected TaskResult")
234235

235-
payload = AgentRequestResponsePair(request=request, response=response)
236-
237236
if payload is not None:
238237
logging.debug("Payload: %s", payload)
239238
yield payload

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/agent_request_response_pair.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/chat_history.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)