Skip to content

Commit 1cf50f6

Browse files
authored
Update model_inputs.py
1 parent 5956ac5 commit 1cf50f6

File tree

1 file changed

+130
-44
lines changed

1 file changed

+130
-44
lines changed

src/agents/realtime/model_inputs.py

Lines changed: 130 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,66 @@
1+
"""
2+
Realtime model input classes with Pydantic validation.
3+
4+
This module defines input classes for sending data to realtime models.
5+
"""
6+
17
from __future__ import annotations
28

3-
from dataclasses import dataclass
49
from typing import Any, Literal, Union
510

6-
from typing_extensions import NotRequired, TypeAlias, TypedDict
11+
from pydantic import BaseModel, ConfigDict, Field
12+
from typing_extensions import TypeAlias
713

814
from .config import RealtimeSessionModelSettings
9-
from .model_events import RealtimeModelToolCallEvent
1015

1116

12-
class RealtimeModelRawClientMessage(TypedDict):
17+
class RealtimeModelRawClientMessage(BaseModel):
1318
"""A raw message to be sent to the model."""
1419

15-
type: str # explicitly required
16-
other_data: NotRequired[dict[str, Any]]
17-
"""Merged into the message body."""
20+
model_config = ConfigDict(
21+
arbitrary_types_allowed=True,
22+
validate_assignment=True,
23+
extra="forbid",
24+
frozen=True,
25+
)
1826

27+
type: str = Field(..., description="Message type identifier.")
28+
other_data: dict[str, Any] = Field(
29+
default_factory=dict, description="Additional data merged into the message body."
30+
)
1931

20-
class RealtimeModelInputTextContent(TypedDict):
32+
33+
class RealtimeModelInputTextContent(BaseModel):
2134
"""A piece of text to be sent to the model."""
2235

23-
type: Literal["input_text"]
24-
text: str
36+
model_config = ConfigDict(
37+
arbitrary_types_allowed=True,
38+
validate_assignment=True,
39+
extra="forbid",
40+
frozen=True,
41+
)
42+
43+
type: Literal["input_text"] = Field(
44+
default="input_text", description="Content type identifier."
45+
)
46+
text: str = Field(..., description="The text content.")
2547

2648

27-
class RealtimeModelUserInputMessage(TypedDict):
49+
class RealtimeModelUserInputMessage(BaseModel):
2850
"""A message to be sent to the model."""
2951

30-
type: Literal["message"]
31-
role: Literal["user"]
32-
content: list[RealtimeModelInputTextContent]
52+
model_config = ConfigDict(
53+
arbitrary_types_allowed=True,
54+
validate_assignment=True,
55+
extra="forbid",
56+
frozen=True,
57+
)
58+
59+
type: Literal["message"] = Field(default="message", description="Message type identifier.")
60+
role: Literal["user"] = Field(default="user", description="Message role identifier.")
61+
content: list[RealtimeModelInputTextContent] = Field(
62+
..., description="List of content items for the message."
63+
)
3364

3465

3566
RealtimeModelUserInput: TypeAlias = Union[str, RealtimeModelUserInputMessage]
@@ -39,62 +70,117 @@ class RealtimeModelUserInputMessage(TypedDict):
3970
# Model messages
4071

4172

42-
@dataclass
43-
class RealtimeModelSendRawMessage:
73+
class RealtimeModelSendRawMessage(BaseModel):
4474
"""Send a raw message to the model."""
4575

46-
message: RealtimeModelRawClientMessage
47-
"""The message to send."""
76+
model_config = ConfigDict(
77+
arbitrary_types_allowed=True,
78+
validate_assignment=True,
79+
extra="forbid",
80+
frozen=True,
81+
)
82+
83+
message: RealtimeModelRawClientMessage = Field(..., description="The message to send.")
84+
type: Literal["raw_message"] = Field(
85+
default="raw_message", description="Event type identifier."
86+
)
4887

4988

50-
@dataclass
51-
class RealtimeModelSendUserInput:
89+
class RealtimeModelSendUserInput(BaseModel):
5290
"""Send a user input to the model."""
5391

54-
user_input: RealtimeModelUserInput
55-
"""The user input to send."""
92+
model_config = ConfigDict(
93+
arbitrary_types_allowed=True,
94+
validate_assignment=True,
95+
extra="forbid",
96+
frozen=True,
97+
)
98+
99+
user_input: RealtimeModelUserInput = Field(..., description="The user input to send.")
100+
type: Literal["user_input"] = Field(default="user_input", description="Event type identifier.")
56101

57102

58-
@dataclass
59-
class RealtimeModelSendAudio:
103+
class RealtimeModelSendAudio(BaseModel):
60104
"""Send audio to the model."""
61105

62-
audio: bytes
63-
commit: bool = False
106+
model_config = ConfigDict(
107+
arbitrary_types_allowed=True,
108+
validate_assignment=True,
109+
extra="forbid",
110+
frozen=True,
111+
)
64112

113+
audio: bytes = Field(..., description="The audio data to send.")
114+
commit: bool = Field(default=False, description="Whether to commit the audio buffer.")
115+
type: Literal["send_audio"] = Field(default="send_audio", description="Event type identifier.")
65116

66-
@dataclass
67-
class RealtimeModelSendToolOutput:
68-
"""Send tool output to the model."""
69117

70-
tool_call: RealtimeModelToolCallEvent
71-
"""The tool call to send."""
118+
class RealtimeModelSendInterrupt(BaseModel):
119+
"""Send interrupt signal to the model."""
120+
121+
model_config = ConfigDict(
122+
arbitrary_types_allowed=True,
123+
validate_assignment=True,
124+
extra="forbid",
125+
frozen=True,
126+
)
72127

73-
output: str
74-
"""The output to send."""
128+
type: Literal["interrupt"] = Field(default="interrupt", description="Event type identifier.")
75129

76-
start_response: bool
77-
"""Whether to start a response."""
78130

131+
class RealtimeModelSendSessionUpdate(BaseModel):
132+
"""Send session update to the model."""
79133

80-
@dataclass
81-
class RealtimeModelSendInterrupt:
82-
"""Send an interrupt to the model."""
134+
model_config = ConfigDict(
135+
arbitrary_types_allowed=True,
136+
validate_assignment=True,
137+
extra="forbid",
138+
frozen=True,
139+
)
83140

141+
session: RealtimeSessionModelSettings = Field(
142+
..., description="The session configuration to update."
143+
)
144+
type: Literal["session_update"] = Field(
145+
default="session_update", description="Event type identifier."
146+
)
84147

85-
@dataclass
86-
class RealtimeModelSendSessionUpdate:
87-
"""Send a session update to the model."""
88148

89-
session_settings: RealtimeSessionModelSettings
90-
"""The updated session settings to send."""
149+
class RealtimeModelSendToolOutput(BaseModel):
150+
"""Send tool output to the model."""
151+
152+
model_config = ConfigDict(
153+
arbitrary_types_allowed=True,
154+
validate_assignment=True,
155+
extra="forbid",
156+
frozen=True,
157+
)
158+
159+
call_id: str = Field(..., description="ID of the tool call this output is for.")
160+
output: str = Field(..., description="The tool output result.")
161+
type: Literal["tool_output"] = Field(
162+
default="tool_output", description="Event type identifier."
163+
)
91164

92165

93166
RealtimeModelSendEvent: TypeAlias = Union[
94167
RealtimeModelSendRawMessage,
95168
RealtimeModelSendUserInput,
96169
RealtimeModelSendAudio,
97-
RealtimeModelSendToolOutput,
98170
RealtimeModelSendInterrupt,
99171
RealtimeModelSendSessionUpdate,
172+
RealtimeModelSendToolOutput,
100173
]
174+
"""An event to be sent to the realtime model."""
175+
176+
177+
# Rebuild models after all definitions
178+
RealtimeModelRawClientMessage.model_rebuild()
179+
RealtimeModelInputTextContent.model_rebuild()
180+
RealtimeModelUserInputMessage.model_rebuild()
181+
RealtimeModelSendRawMessage.model_rebuild()
182+
RealtimeModelSendUserInput.model_rebuild()
183+
RealtimeModelSendAudio.model_rebuild()
184+
RealtimeModelSendInterrupt.model_rebuild()
185+
RealtimeModelSendSessionUpdate.model_rebuild()
186+
RealtimeModelSendToolOutput.model_rebuild()

0 commit comments

Comments
 (0)