Skip to content

Commit a4775eb

Browse files
committed
Update data contract
1 parent 66660ed commit a4775eb

File tree

11 files changed

+149
-88
lines changed

11 files changed

+149
-88
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"source": [
5151
"import dotenv\n",
5252
"import logging\n",
53-
"from autogen_text_2_sql import AutoGenText2Sql"
53+
"from autogen_text_2_sql import AutoGenText2Sql, AgentRequestBody"
5454
]
5555
},
5656
{
@@ -100,7 +100,7 @@
100100
"metadata": {},
101101
"outputs": [],
102102
"source": [
103-
"async for message in agentic_text_2_sql.process_question(question=\"What total number of orders in June 2008?\"):\n",
103+
"async for message in agentic_text_2_sql.process_question(AgentRequestBody(question=\"What total number of orders in June 2008?\")):\n",
104104
" logging.info(\"Received %s Message from Text2SQL System\", message)"
105105
]
106106
},
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
13
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4+
from text_2_sql_core.payloads.agent_response import AgentRequestBody
25

3-
__all__ = ["AutoGenText2Sql"]
6+
__all__ = ["AutoGenText2Sql", "AgentRequestBody"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@
1818
import os
1919
from datetime import datetime
2020

21-
from text_2_sql_core.payloads import (
21+
from text_2_sql_core.payloads.agent_response import (
22+
AgentResponse,
23+
AgentRequestBody,
2224
AnswerWithSources,
23-
UserInformationRequest,
25+
Source,
26+
DismabiguationRequests,
27+
)
28+
from text_2_sql_core.payloads.chat_history import ChatHistoryItem
29+
from text_2_sql_core.payloads.processing_update import (
30+
ProcessingUpdateBody,
2431
ProcessingUpdate,
25-
ChatHistoryItem,
2632
)
2733
from autogen_agentchat.base import Response, TaskResult
2834
from typing import AsyncGenerator
@@ -123,6 +129,16 @@ def agentic_flow(self):
123129
)
124130
return flow
125131

132+
def extract_disambiguation_request(self, messages: list) -> DismabiguationRequests:
133+
"""Extract the disambiguation request from the answer."""
134+
135+
disambiguation_request = messages[-1].content
136+
137+
# TODO: Properly extract the disambiguation request
138+
return DismabiguationRequests(
139+
disambiguation_request=disambiguation_request,
140+
)
141+
126142
def extract_sources(self, messages: list) -> AnswerWithSources:
127143
"""Extract the sources from the answer."""
128144

@@ -147,10 +163,10 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
147163
for sql_query_result in sql_query_result_list:
148164
logging.info("SQL Query Result: %s", sql_query_result)
149165
sources.append(
150-
{
151-
"sql_query": sql_query_result["sql_query"],
152-
"sql_rows": sql_query_result["sql_rows"],
153-
}
166+
Source(
167+
sql_query=sql_query_result["sql_query"],
168+
sql_rows=sql_query_result["sql_rows"],
169+
)
154170
)
155171

156172
except json.JSONDecodeError:
@@ -164,10 +180,9 @@ def extract_sources(self, messages: list) -> AnswerWithSources:
164180

165181
async def process_question(
166182
self,
167-
question: str,
183+
request: AgentRequestBody,
168184
chat_history: list[ChatHistoryItem] = None,
169-
injected_parameters: dict = None,
170-
) -> AsyncGenerator[AnswerWithSources | UserInformationRequest, None]:
185+
) -> AsyncGenerator[AgentResponse | ProcessingUpdate, None]:
171186
"""Process the complete question through the unified system.
172187
173188
Args:
@@ -180,58 +195,63 @@ async def process_question(
180195
-------
181196
dict: The response from the system.
182197
"""
183-
logging.info("Processing question: %s", question)
198+
logging.info("Processing question: %s", request.question)
184199
logging.info("Chat history: %s", chat_history)
185200

186201
agent_input = {
187-
"question": question,
202+
"question": request.question,
188203
"chat_history": {},
189-
"injected_parameters": injected_parameters,
204+
"injected_parameters": request.injected_parameters,
190205
}
191206

192207
if chat_history is not None:
193208
# Update input
194209
for idx, chat in enumerate(chat_history):
195210
# For now only consider the user query
196-
agent_input[f"chat_{idx}"] = chat.user_query
211+
agent_input[f"chat_{idx}"] = chat.request.question
197212

198213
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
199214
logging.debug("Message: %s", message)
200215

201216
payload = None
202217

203218
if isinstance(message, TextMessage):
219+
processing_update = None
204220
if message.source == "query_rewrite_agent":
205-
# If the message is from the query_rewrite_agent, we need to update the chat history
206-
payload = ProcessingUpdate(
221+
processing_update = ProcessingUpdateBody(
207222
message="Rewriting the query...",
208223
)
209224
elif message.source == "parallel_query_solving_agent":
210-
# If the message is from the parallel_query_solving_agent, we need to update the chat history
211-
payload = ProcessingUpdate(
225+
processing_update = ProcessingUpdateBody(
212226
message="Solving the query...",
213227
)
214228
elif message.source == "answer_agent":
215-
# If the message is from the answer_agent, we need to update the chat history
216-
payload = ProcessingUpdate(
229+
processing_update = ProcessingUpdateBody(
217230
message="Generating the answer...",
218231
)
219232

233+
if processing_update is not None:
234+
payload = ProcessingUpdate(
235+
processing_update=processing_update,
236+
)
237+
220238
elif isinstance(message, TaskResult):
221239
# Now we need to return the final answer or the disambiguation request
222240
logging.info("TaskResult: %s", message)
223241

242+
response = None
224243
if message.messages[-1].source == "answer_agent":
225244
# If the message is from the answer_agent, we need to return the final answer
226-
payload = self.extract_sources(message.messages)
245+
response = self.extract_sources(message.messages)
227246
elif message.messages[-1].source == "parallel_query_solving_agent":
228-
payload = UserInformationRequest(
229-
**json.loads(message.messages[-1].content),
230-
)
247+
# Load into disambiguation request
248+
response = self.extract_disambiguation_request(message.messages)
231249
else:
232250
logging.error("Unexpected TaskResult: %s", message)
233251
raise ValueError("Unexpected TaskResult")
234252

253+
payload = AgentResponse(request=request, response=response)
254+
235255
if payload is not None:
236256
logging.debug("Payload: %s", payload)
237257
yield payload

text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/parallel_query_solving_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ async def on_messages_stream(
4444
last_response = messages[-1].content
4545
parameter_input = messages[0].content
4646
try:
47-
user_parameters = json.loads(parameter_input)["parameters"]
47+
injected_parameters = json.loads(parameter_input)["injected_parameters"]
4848
except json.JSONDecodeError:
4949
logging.error("Error decoding the user parameters.")
50-
user_parameters = {}
50+
injected_parameters = {}
5151

5252
# Load the json of the last message to populate the final output object
5353
query_rewrites = json.loads(last_response)
@@ -117,7 +117,7 @@ async def consume_inner_messages_from_agentic_flow(
117117
inner_solving_generators.append(
118118
consume_inner_messages_from_agentic_flow(
119119
inner_autogen_text_2_sql.process_question(
120-
question=query_rewrite, parameters=user_parameters
120+
question=query_rewrite, injected_parameters=injected_parameters
121121
),
122122
query_rewrite,
123123
database_results,
Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +0,0 @@
1-
from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources, Source
2-
from text_2_sql_core.payloads.user_information_request import UserInformationRequest
3-
from text_2_sql_core.payloads.processing_update import ProcessingUpdate
4-
from text_2_sql_core.payloads.chat_history import ChatHistoryItem
5-
6-
__all__ = [
7-
"AnswerWithSources",
8-
"Source",
9-
"UserInformationRequest",
10-
"ProcessingUpdate",
11-
"ChatHistoryItem",
12-
]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from pydantic import BaseModel, RootModel, Field
4+
from enum import StrEnum
5+
6+
from typing import Literal
7+
from datetime import datetime, timezone
8+
9+
10+
class AgentResponseHeader(BaseModel):
11+
prompt_tokens: int
12+
completion_tokens: int
13+
timestamp: datetime = Field(
14+
...,
15+
description="Timestamp in UTC",
16+
default_factory=lambda: datetime.now(timezone.utc),
17+
)
18+
19+
20+
class AgentResponseType(StrEnum):
21+
ANSWER_WITH_SOURCES = "answer_with_sources"
22+
DISAMBIGUATION = "disambiguation"
23+
24+
25+
class DismabiguationRequest(BaseModel):
26+
question: str
27+
matching_columns: list[str]
28+
matching_filter_values: list[str]
29+
other_user_choices: list[str]
30+
31+
32+
class DismabiguationRequests(BaseModel):
33+
response_type: Literal[AgentResponseType.DISAMBIGUATION] = Field(
34+
default=AgentResponseType.DISAMBIGUATION
35+
)
36+
requests: list[DismabiguationRequest]
37+
38+
39+
class Source(BaseModel):
40+
sql_query: str
41+
sql_rows: list[dict]
42+
43+
44+
class AnswerWithSources(BaseModel):
45+
response_type: Literal[AgentResponseType.ANSWER_WITH_SOURCES] = Field(
46+
default=AgentResponseType.ANSWER_WITH_SOURCES
47+
)
48+
answer: str
49+
sources: list[Source] = Field(default_factory=list)
50+
51+
52+
class AgentResponseBody(RootModel):
53+
root: DismabiguationRequests | AnswerWithSources = Field(
54+
..., discriminator="response_type"
55+
)
56+
57+
58+
class AgentRequestBody(BaseModel):
59+
question: str
60+
injected_parameters: dict = Field(default_factory=dict)
61+
62+
63+
class AgentResponse(BaseModel):
64+
header: AgentResponseHeader | None = Field(default=None)
65+
request: AgentRequestBody
66+
response: AgentResponseBody

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/answer_with_sources.py

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from pydantic import BaseModel
2-
from text_2_sql_core.payloads.answer_with_sources import AnswerWithSources
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from pydantic import BaseModel, Field
4+
from text_2_sql_core.payloads.agent_response import AgentResponse
5+
from datetime import datetime, timezone
36

47

58
class ChatHistoryItem(BaseModel):
69
"""Chat history item with user message and agent response."""
710

8-
user_query: str
9-
agent_response: AnswerWithSources
11+
timestamp: datetime = Field(
12+
...,
13+
description="Timestamp in UTC",
14+
default_factory=lambda: datetime.now(timezone.utc),
15+
)
16+
agent_response: AgentResponse
Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
13
from pydantic import BaseModel, Field
4+
from datetime import datetime, timezone
25

36

4-
class ProcessingUpdate(BaseModel):
7+
class ProcessingUpdateHeader(BaseModel):
8+
timestamp: datetime = Field(
9+
...,
10+
description="Timestamp in UTC",
11+
default_factory=lambda: datetime.now(timezone.utc),
12+
)
13+
14+
15+
class ProcessingUpdateBody(BaseModel):
516
title: str | None = Field(default="Processing...")
617
message: str | None = Field(default="Processing...")
18+
19+
20+
class ProcessingUpdate(BaseModel):
21+
header: ProcessingUpdateHeader | None = Field(
22+
default_factory=ProcessingUpdateHeader
23+
)
24+
processing_update: ProcessingUpdateBody

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/user_information_request.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

0 commit comments

Comments
 (0)