Skip to content

Commit 70b73c4

Browse files
committed
Update interaction payload
1 parent 298e175 commit 70b73c4

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from text_2_sql_core.payloads.interaction_payloads import (
2222
QuestionPayload,
2323
AnswerWithSourcesPayload,
24-
DismabiguationRequestPayload,
24+
DismabiguationRequestsPayload,
2525
ProcessingUpdatePayload,
2626
InteractionPayload,
2727
PayloadType,
@@ -104,10 +104,10 @@ def agentic_flow(self):
104104

105105
def extract_disambiguation_request(
106106
self, messages: list
107-
) -> DismabiguationRequestPayload:
107+
) -> DismabiguationRequestsPayload:
108108
"""Extract the disambiguation request from the answer."""
109109
disambiguation_request = messages[-1].content
110-
return DismabiguationRequestPayload(
110+
return DismabiguationRequestsPayload(
111111
disambiguation_request=disambiguation_request,
112112
)
113113

@@ -179,7 +179,8 @@ def extract_answer_payload(self, messages: list) -> AnswerWithSourcesPayload:
179179
for sql_query_result in sql_query_result_list:
180180
if not isinstance(sql_query_result, dict):
181181
logging.error(
182-
f"Expected dict for sql_query_result, got {type(sql_query_result)}"
182+
"Expected dict for sql_query_result, got %s",
183+
type(sql_query_result),
183184
)
184185
continue
185186

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ class PayloadType(StrEnum):
3232
QUESTION = "question"
3333

3434

35-
class DismabiguationRequestPayload(PayloadBase):
35+
class ColumnFilterPair(BaseModel):
36+
column: str
37+
filter_value: str | None = Field(default=None)
38+
39+
40+
class DismabiguationRequestsPayload(PayloadBase):
3641
class Body(BaseModel):
3742
class DismabiguationRequest(BaseModel):
3843
question: str
39-
matching_columns: list[str]
40-
matching_filter_values: list[str]
41-
other_user_choices: list[str]
44+
choices: list[ColumnFilterPair] | None = Field(default=None)
4245

4346
disambiguation_requests: list[DismabiguationRequest]
4447

@@ -119,6 +122,6 @@ def __init__(self, **kwargs):
119122

120123

121124
class InteractionPayload(RootModel):
122-
root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestPayload | AnswerWithSourcesPayload = Field(
125+
root: QuestionPayload | ProcessingUpdatePayload | DismabiguationRequestsPayload | AnswerWithSourcesPayload = Field(
123126
..., discriminator="payload_type"
124127
)

0 commit comments

Comments
 (0)