Skip to content

Commit 7211486

Browse files
committed
Update payloads
1 parent 70b73c4 commit 7211486

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from autogen_text_2_sql.autogen_text_2_sql import AutoGenText2Sql
4-
from text_2_sql_core.payloads.interaction_payloads import QuestionPayload
4+
from text_2_sql_core.payloads.interaction_payloads import UserInputPayload
55

6-
__all__ = ["AutoGenText2Sql", "QuestionPayload"]
6+
__all__ = ["AutoGenText2Sql", "UserInputPayload"]

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020

2121
from text_2_sql_core.payloads.interaction_payloads import (
22-
QuestionPayload,
22+
UserInputPayload,
2323
AnswerWithSourcesPayload,
2424
DismabiguationRequestsPayload,
2525
ProcessingUpdatePayload,
@@ -211,7 +211,7 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
211211

212212
async def process_question(
213213
self,
214-
question_payload: QuestionPayload,
214+
question_payload: UserInputPayload,
215215
chat_history: list[InteractionPayload] = None,
216216
) -> AsyncGenerator[InteractionPayload, None]:
217217
"""Process the complete question through the unified system.
@@ -238,7 +238,7 @@ async def process_question(
238238
if chat_history is not None:
239239
# Update input
240240
for idx, chat in enumerate(chat_history):
241-
if chat.root.payload_type == PayloadType.QUESTION:
241+
if chat.root.payload_type == PayloadType.USER_INPUT:
242242
# For now only consider the user query
243243
chat_history_key = f"chat_{idx}"
244244
agent_input[chat_history_key] = chat.root.body.question

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ class PayloadType(StrEnum):
2929
ANSWER_WITH_SOURCES = "answer_with_sources"
3030
DISAMBIGUATION_REQUEST = "disambiguation_request"
3131
PROCESSING_UPDATE = "processing_update"
32-
QUESTION = "question"
32+
USER_INPUT = "user_input"
3333

3434

3535
class ColumnFilterPair(BaseModel):
36+
fqn: str
3637
column: str
3738
filter_value: str | None = Field(default=None)
3839

@@ -57,6 +58,32 @@ def __init__(self, **kwargs):
5758
self.body = self.Body(**kwargs)
5859

5960

61+
request = DismabiguationRequestsPayload(
62+
disambiguation_requests=[
63+
{
64+
"question": "Which of the following do you mean?",
65+
"choices": [
66+
{"fqn": "<fqn>", "column": "product_name", "filter_value": "Road Bike"},
67+
{
68+
"fqn": "<fqn>",
69+
"column": "product_name",
70+
"filter_value": "Mountain Bike",
71+
},
72+
],
73+
},
74+
{
75+
"question": "Do you mean total sales by volume or number of customers?",
76+
"choices": [
77+
{"fqn": "<fqn>", "column": "sales_volume", "filter_value": None},
78+
{"fqn": "<fqn>", "column": "customer_count", "filter_value": None},
79+
],
80+
},
81+
]
82+
)
83+
84+
print(request.model_dump())
85+
86+
6087
class AnswerWithSourcesPayload(PayloadBase):
6188
class Body(BaseModel):
6289
class Source(BaseModel):
@@ -94,7 +121,7 @@ def __init__(self, **kwargs):
94121
self.body = self.Body(**kwargs)
95122

96123

97-
class QuestionPayload(PayloadBase):
124+
class UserInputPayload(PayloadBase):
98125
class Body(BaseModel):
99126
question: str
100127
injected_parameters: dict = Field(default_factory=dict)
@@ -111,7 +138,7 @@ def add_defaults(cls, values):
111138
values["injected_parameters"] = {**defaults, **injected}
112139
return values
113140

114-
payload_type: Literal[PayloadType.QUESTION] = PayloadType.QUESTION
141+
payload_type: Literal[PayloadType.USER_INPUT] = PayloadType.USER_INPUT
115142
payload_source: Literal[PayloadSource.USER] = PayloadSource.USER
116143
body: Body | None = Field(default=None)
117144

@@ -122,6 +149,6 @@ def __init__(self, **kwargs):
122149

123150

124151
class InteractionPayload(RootModel):
125-
root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field(
152+
root: UserInputPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field(
126153
..., discriminator="payload_type"
127154
)

0 commit comments

Comments
 (0)