Skip to content

Commit a3e444c

Browse files
committed
Update entry points
1 parent 8b882d8 commit a3e444c

File tree

7 files changed

+49
-45
lines changed

7 files changed

+49
-45
lines changed

text_2_sql/autogen/Iteration 5 - Agentic Vector Based Text2SQL.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"source": [
5151
"import dotenv\n",
5252
"import logging\n",
53-
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload"
53+
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload"
5454
]
5555
},
5656
{
@@ -100,7 +100,7 @@
100100
"metadata": {},
101101
"outputs": [],
102102
"source": [
103-
"async for message in agentic_text_2_sql.process_question(QuestionPayload(question=\"What is the total number of sales?\")):\n",
103+
"async for message in agentic_text_2_sql.process_user_message(UserMessagePayload(user_message=\"What is the total number of sales?\")):\n",
104104
" logging.info(\"Received %s Message from Text2SQL System\", message)"
105105
]
106106
},

text_2_sql/autogen/evaluate_autogen_text2sql.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
"# Add the src directory to the path\n",
6969
"sys.path.append(str(notebook_dir / \"src\"))\n",
7070
"\n",
71-
"from autogen_text_2_sql import AutoGenText2Sql, QuestionPayload\n",
71+
"from autogen_text_2_sql import AutoGenText2Sql, UserMessagePayload\n",
7272
"from autogen_text_2_sql.evaluation_utils import get_final_sql_query\n",
7373
"\n",
7474
"# Configure logging\n",
@@ -127,7 +127,7 @@
127127
" all_queries = []\n",
128128
" final_query = None\n",
129129
" \n",
130-
" async for message in autogen_text2sql.process_question(QuestionPayload(question=question)):\n",
130+
" async for message in autogen_text2sql.process_user_message(UserMessagePayload(user_message=question)):\n",
131131
" if message.payload_type == \"answer_with_sources\":\n",
132132
" # Extract from results\n",
133133
" if hasattr(message.body, 'results'):\n",
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4-
from text_2_sql_core.payloads.interaction_payloads import UserInputPayload
4+
from text_2_sql_core.payloads.interaction_payloads import UserMessagePayload
55

6-
__all__ = ["AutoGenText2Sql", "UserInputPayload"]
6+
__all__ = ["AutoGenText2Sql", "UserMessagePayload"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020

2121
from text_2_sql_core.payloads.interaction_payloads import (
22-
UserInputPayload,
22+
UserMessagePayload,
2323
AnswerWithSourcesPayload,
2424
DismabiguationRequestsPayload,
2525
ProcessingUpdatePayload,
@@ -102,15 +102,6 @@ def agentic_flow(self):
102102
)
103103
return flow
104104

105-
def extract_disambiguation_request(
106-
self, messages: list
107-
) -> DismabiguationRequestsPayload:
108-
"""Extract the disambiguation request from the answer."""
109-
disambiguation_request = messages[-1].content
110-
return DismabiguationRequestsPayload(
111-
disambiguation_request=disambiguation_request,
112-
)
113-
114105
def parse_message_content(self, content):
115106
"""Parse different message content formats into a dictionary."""
116107
if isinstance(content, (list, dict)):
@@ -134,6 +125,26 @@ def parse_message_content(self, content):
134125
# If all parsing attempts fail, return the content as-is
135126
return content
136127

128+
def extract_decomposed_user_messages(self, messages: list) -> list[list[str]]:
129+
"""Extract the decomposed messages from the answer."""
130+
# Only load sub-message results if we have a database result
131+
sub_message_results = self.parse_message_content(messages[1].content)
132+
logging.info("Decomposed Results: %s", sub_message_results)
133+
134+
return sub_message_results.get("decomposed_messages", [])
135+
136+
def extract_disambiguation_request(
137+
self, messages: list
138+
) -> DismabiguationRequestsPayload:
139+
"""Extract the disambiguation request from the answer."""
140+
disambiguation_request = messages[-1].content
141+
142+
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
143+
return DismabiguationRequestsPayload(
144+
disambiguation_request=disambiguation_request,
145+
decomposed_user_messages=decomposed_user_messages,
146+
)
147+
137148
def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
138149
"""Extract the sources from the answer."""
139150
answer = messages[-1].content
@@ -145,24 +156,13 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
145156
except json.JSONDecodeError:
146157
logging.warning("Unable to read SQL query results: %s", sql_query_results)
147158
sql_query_results = {}
148-
sub_message_results = {}
149-
else:
150-
# Only load sub-message results if we have a database result
151-
sub_message_results = self.parse_message_content(messages[1].content)
152-
logging.info("Sub-message Results: %s", sub_message_results)
153159

154160
try:
155-
decomposed_messages = [
156-
sub_message
157-
for sub_message_group in sub_message_results.get(
158-
"decomposed_messages", []
159-
)
160-
for sub_message in sub_message_group
161-
]
161+
decomposed_user_messages = self.extract_decomposed_user_messages(messages)
162162

163163
logging.info("SQL Query Results: %s", sql_query_results)
164164
payload = AnswerWithSourcesPayload(
165-
answer=answer, decomposed_messages=decomposed_messages
165+
answer=answer, decomposed_user_messages=decomposed_user_messages
166166
)
167167

168168
if not isinstance(sql_query_results, dict):
@@ -213,7 +213,7 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
213213

214214
async def process_message(
215215
self,
216-
message_payload: UserInputPayload,
216+
message_payload: UserMessagePayload,
217217
chat_history: list[InteractionPayload] = None,
218218
) -> AsyncGenerator[InteractionPayload, None]:
219219
"""Process the complete message through the unified system.
@@ -228,22 +228,22 @@ async def process_message(
228228
-------
229229
dict: The response from the system.
230230
"""
231-
logging.info("Processing message: %s", message_payload.body.message)
231+
logging.info("Processing message: %s", message_payload.body.user_message)
232232
logging.info("Chat history: %s", chat_history)
233233

234234
agent_input = {
235-
"message": message_payload.body.message,
235+
"message": message_payload.body.user_message,
236236
"chat_history": {},
237237
"injected_parameters": message_payload.body.injected_parameters,
238238
}
239239

240240
if chat_history is not None:
241241
# Update input
242242
for idx, chat in enumerate(chat_history):
243-
if chat.root.payload_type == PayloadType.message:
243+
if chat.root.payload_type == PayloadType.USER_MESSAGE:
244244
# For now only consider the user query
245245
chat_history_key = f"chat_{idx}"
246-
agent_input[chat_history_key] = chat.root.body.message
246+
agent_input[chat_history_key] = chat.root.body.user_message
247247

248248
async for message in self.agentic_flow.run_stream(task=json.dumps(agent_input)):
249249
logging.debug("Message: %s", message)

text_2_sql/autogen/src/autogen_text_2_sql/evaluation_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
13
import re
24
from typing import Optional, List, Dict, Any
35

text_2_sql/autogen/src/autogen_text_2_sql/inner_autogen_text_2_sql.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def agentic_flow(self):
169169
)
170170
return flow
171171

172-
def process_question(
172+
def process_user_message(
173173
self,
174-
question: str,
174+
user_message: str,
175175
injected_parameters: dict = None,
176176
):
177177
"""Process the complete question through the unified system.
@@ -185,15 +185,14 @@ def process_question(
185185
-------
186186
dict: The response from the system.
187187
"""
188-
logging.info("Processing question: %s", question)
188+
logging.info("Processing question: %s", user_message)
189189

190190
# Update environment with injected parameters
191191
self._update_environment(injected_parameters)
192192

193193
try:
194194
agent_input = {
195-
"question": question,
196-
"chat_history": {},
195+
"user_message": user_message,
197196
"injected_parameters": injected_parameters,
198197
}
199198

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class PayloadType(StrEnum):
1717
ANSWER_WITH_SOURCES = "answer_with_sources"
1818
DISAMBIGUATION_REQUESTS = "disambiguation_requests"
1919
PROCESSING_UPDATE = "processing_update"
20-
USER_INPUT = "user_input"
20+
USER_MESSAGE = "user_message"
2121

2222

2323
class InteractionPayloadBase(BaseModel):
@@ -51,6 +51,9 @@ class DismabiguationRequest(InteractionPayloadBase):
5151
disambiguation_requests: list[DismabiguationRequest] = Field(
5252
alias="disambiguationRequests"
5353
)
54+
decomposed_user_messages: list[list[str]] = Field(
55+
default_factory=list, alias="decomposedUserMessages"
56+
)
5457

5558
payload_type: Literal[PayloadType.DISAMBIGUATION_REQUESTS] = Field(
5659
PayloadType.DISAMBIGUATION_REQUESTS, alias="payloadType"
@@ -73,7 +76,7 @@ class Source(InteractionPayloadBase):
7376
sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows")
7477

7578
answer: str
76-
decomposed_user_messages: list[str] = Field(
79+
decomposed_user_messages: list[list[str]] = Field(
7780
default_factory=list, alias="decomposedUserMessages"
7881
)
7982
sources: list[Source] = Field(default_factory=list)
@@ -111,7 +114,7 @@ def __init__(self, **kwargs):
111114
self.body = self.Body(**kwargs)
112115

113116

114-
class UserInputPayload(InteractionPayloadBase):
117+
class UserMessagePayload(InteractionPayloadBase):
115118
class Body(InteractionPayloadBase):
116119
user_message: str = Field(..., alias="userMessage")
117120
injected_parameters: dict = Field(
@@ -130,8 +133,8 @@ def add_defaults(cls, values):
130133
values["injected_parameters"] = {**defaults, **injected}
131134
return values
132135

133-
payload_type: Literal[PayloadType.USER_INPUT] = Field(
134-
PayloadType.USER_INPUT, alias="payloadType"
136+
payload_type: Literal[PayloadType.USER_MESSAGE] = Field(
137+
PayloadType.USER_MESSAGE, alias="payloadType"
135138
)
136139
payload_source: Literal[PayloadSource.USER] = Field(
137140
PayloadSource.USER, alias="payloadSource"
@@ -145,6 +148,6 @@ def __init__(self, **kwargs):
145148

146149

147150
class InteractionPayload(RootModel):
148-
root: UserInputPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field(
151+
root: UserMessagePayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field(
149152
..., discriminator="payload_type"
150153
)

0 commit comments

Comments
 (0)