Skip to content

Commit 8b882d8

Browse files
committed
Add model config
1 parent f46d9bd commit 8b882d8

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

text_2_sql/text_2_sql_core/src/text_2_sql_core/payloads/interaction_payloads.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
from pydantic import BaseModel, RootModel, Field, model_validator
3+
from pydantic import BaseModel, RootModel, Field, model_validator, ConfigDict
44
from enum import StrEnum
55

66
from typing import Literal
@@ -20,7 +20,11 @@ class PayloadType(StrEnum):
2020
USER_INPUT = "user_input"
2121

2222

23-
class PayloadBase(BaseModel):
23+
class InteractionPayloadBase(BaseModel):
24+
model_config = ConfigDict(allow_population_by_field_name=True, extra="ignore")
25+
26+
27+
class PayloadBase(InteractionPayloadBase):
2428
prompt_tokens: int | None = Field(
2529
None, description="Number of tokens in the prompt", alias="promptTokens"
2630
)
@@ -38,9 +42,9 @@ class PayloadBase(BaseModel):
3842
payload_source: PayloadSource = Field(..., alias="payloadSource")
3943

4044

41-
class DismabiguationRequestsPayload(PayloadBase):
42-
class Body(BaseModel):
43-
class DismabiguationRequest(BaseModel):
45+
class DismabiguationRequestsPayload(InteractionPayloadBase):
46+
class Body(InteractionPayloadBase):
47+
class DismabiguationRequest(InteractionPayloadBase):
4448
agent_question: str | None = Field(..., alias="agentQuestion")
4549
user_choices: list[str] | None = Field(default=None, alias="userChoices")
4650

@@ -62,9 +66,9 @@ def __init__(self, **kwargs):
6266
self.body = self.Body(**kwargs)
6367

6468

65-
class AnswerWithSourcesPayload(PayloadBase):
66-
class Body(BaseModel):
67-
class Source(BaseModel):
69+
class AnswerWithSourcesPayload(InteractionPayloadBase):
70+
class Body(InteractionPayloadBase):
71+
class Source(InteractionPayloadBase):
6872
sql_query: str = Field(alias="sqlQuery")
6973
sql_rows: list[dict] = Field(default_factory=list, alias="sqlRows")
7074

@@ -88,8 +92,8 @@ def __init__(self, **kwargs):
8892
self.body = self.Body(**kwargs)
8993

9094

91-
class ProcessingUpdatePayload(PayloadBase):
92-
class Body(BaseModel):
95+
class ProcessingUpdatePayload(InteractionPayloadBase):
96+
class Body(InteractionPayloadBase):
9397
title: str | None = "Processing..."
9498
message: str | None = "Processing..."
9599

@@ -107,8 +111,8 @@ def __init__(self, **kwargs):
107111
self.body = self.Body(**kwargs)
108112

109113

110-
class UserInputPayload(PayloadBase):
111-
class Body(BaseModel):
114+
class UserInputPayload(InteractionPayloadBase):
115+
class Body(InteractionPayloadBase):
112116
user_message: str = Field(..., alias="userMessage")
113117
injected_parameters: dict = Field(
114118
default_factory=dict, alias="injectedParameters"

0 commit comments

Comments
 (0)