Skip to content

Commit 3e37afb

Browse files
committed
Use pydantic serialization for ActivityModelInput
1 parent a2423d9 commit 3e37afb

File tree

1 file changed

+13
-63
lines changed

1 file changed

+13
-63
lines changed

samples-v2/openai_agents/model_invoker.py

Lines changed: 13 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import timedelta
55
from typing import Any, Optional, Union
66

7+
from pydantic import BaseModel, Field
78
from agents import (
89
AgentOutputSchemaBase,
910
CodeInterpreterTool,
@@ -28,7 +29,6 @@
2829
AsyncOpenAI,
2930
)
3031
from openai.types.responses.tool_param import Mcp
31-
from pydantic_core import to_json
3232
try:
3333
from azure.durable_functions import ApplicationError
3434
except ImportError:
@@ -40,8 +40,7 @@ def __init__(self, message: str, non_retryable: bool = False, next_retry_delay =
4040
self.next_retry_delay = next_retry_delay
4141

4242

43-
@dataclass
44-
class HandoffInput:
43+
class HandoffInput(BaseModel):
4544
"""Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
4645
execution to determine what to handoff to, not the actual handoff invocation, which remains in the workflow context.
4746
"""
@@ -53,8 +52,7 @@ class HandoffInput:
5352
strict_json_schema: bool = True
5453

5554

56-
@dataclass
57-
class FunctionToolInput:
55+
class FunctionToolInput(BaseModel):
5856
"""Data conversion friendly representation of a FunctionTool. Contains only the fields which are needed by the model
5957
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
6058
"""
@@ -65,8 +63,7 @@ class FunctionToolInput:
6563
strict_json_schema: bool = True
6664

6765

68-
@dataclass
69-
class HostedMCPToolInput:
66+
class HostedMCPToolInput(BaseModel):
7067
"""Data conversion friendly representation of a HostedMCPTool. Contains only the fields which are needed by the model
7168
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
7269
"""
@@ -84,13 +81,12 @@ class HostedMCPToolInput:
8481
]
8582

8683

87-
@dataclass
88-
class AgentOutputSchemaInput(AgentOutputSchemaBase):
84+
class AgentOutputSchemaInput(AgentOutputSchemaBase, BaseModel):
8985
"""Data conversion friendly representation of AgentOutputSchema."""
9086

91-
output_type_name: Optional[str]
87+
output_type_name: Optional[str] = None
9288
is_wrapped: bool
93-
output_schema: Optional[dict[str, Any]]
89+
output_schema: Optional[dict[str, Any]] = None
9490
strict_json_schema: bool
9591

9692
def is_plain_text(self) -> bool:
@@ -131,74 +127,28 @@ class ModelTracingInput(enum.IntEnum):
131127
ENABLED_WITHOUT_DATA = 2
132128

133129

134-
@dataclass
135-
class ActivityModelInput:
130+
class ActivityModelInput(BaseModel):
136131
"""Input for the invoke_model_activity activity."""
137132

138133
input: Union[str, list[TResponseInputItem]]
139134
model_settings: ModelSettings
140135
tracing: ModelTracingInput
141136
model_name: Optional[str] = None
142137
system_instructions: Optional[str] = None
143-
tools: list[ToolInput] = None
138+
tools: list[ToolInput] = Field(default_factory=list)
144139
output_schema: Optional[AgentOutputSchemaInput] = None
145-
handoffs: list[HandoffInput] = None
140+
handoffs: list[HandoffInput] = Field(default_factory=list)
146141
previous_response_id: Optional[str] = None
147142
prompt: Optional[Any] = None
148143

149-
def __post_init__(self):
150-
"""Initialize default values for list fields."""
151-
if self.tools is None:
152-
self.tools = []
153-
if self.handoffs is None:
154-
self.handoffs = []
155-
156144
def to_json(self) -> str:
157145
"""Convert the ActivityModelInput to a JSON string."""
158-
return to_json(self).decode('utf-8')
146+
return self.model_dump_json()
159147

160148
@classmethod
161149
def from_json(cls, json_str: str) -> 'ActivityModelInput':
162150
"""Create an ActivityModelInput instance from a JSON string."""
163-
import json
164-
data = json.loads(json_str)
165-
166-
# Convert complex types back from dictionaries
167-
if 'model_settings' in data and isinstance(data['model_settings'], dict):
168-
data['model_settings'] = ModelSettings(**data['model_settings'])
169-
170-
if 'tracing' in data and isinstance(data['tracing'], int):
171-
data['tracing'] = ModelTracingInput(data['tracing'])
172-
173-
# Convert tool inputs back to proper types
174-
if 'tools' in data and data['tools']:
175-
converted_tools = []
176-
for tool_data in data['tools']:
177-
if isinstance(tool_data, dict):
178-
# Check the tool type and convert accordingly
179-
if 'name' in tool_data and 'description' in tool_data and 'params_json_schema' in tool_data:
180-
# FunctionToolInput
181-
converted_tools.append(FunctionToolInput(**tool_data))
182-
elif 'tool_config' in tool_data:
183-
# HostedMCPToolInput
184-
converted_tools.append(HostedMCPToolInput(**tool_data))
185-
else:
186-
# For other tool types like FileSearchTool, WebSearchTool, etc.
187-
# These might be already properly serialized/deserialized by pydantic_core
188-
converted_tools.append(tool_data)
189-
else:
190-
converted_tools.append(tool_data)
191-
data['tools'] = converted_tools
192-
193-
# Convert handoffs back to proper types
194-
if 'handoffs' in data and data['handoffs']:
195-
data['handoffs'] = [HandoffInput(**handoff) for handoff in data['handoffs']]
196-
197-
# Convert output_schema back to proper type
198-
if 'output_schema' in data and data['output_schema'] and isinstance(data['output_schema'], dict):
199-
data['output_schema'] = AgentOutputSchemaInput(**data['output_schema'])
200-
201-
return cls(**data)
151+
return cls.model_validate_json(json_str)
202152

203153

204154
class ModelInvoker:
@@ -220,7 +170,7 @@ async def empty_on_invoke_handoff(
220170

221171
# workaround for https://github.com/pydantic/pydantic/issues/9541
222172
# ValidatorIterator returned
223-
input_json = to_json(input.input)
173+
input_json = json.dumps(input.input, default=str)
224174
input_input = json.loads(input_json)
225175

226176
def make_tool(tool: ToolInput) -> Tool:

0 commit comments

Comments
 (0)