44from datetime import timedelta
55from typing import Any , Optional , Union
66
7+ from pydantic import BaseModel , Field
78from agents import (
89 AgentOutputSchemaBase ,
910 CodeInterpreterTool ,
2829 AsyncOpenAI ,
2930)
3031from openai .types .responses .tool_param import Mcp
31- from pydantic_core import to_json
3232try :
3333 from azure .durable_functions import ApplicationError
3434except ImportError :
@@ -40,8 +40,7 @@ def __init__(self, message: str, non_retryable: bool = False, next_retry_delay =
4040 self .next_retry_delay = next_retry_delay
4141
4242
43- @dataclass
44- class HandoffInput :
43+ class HandoffInput (BaseModel ):
4544 """Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
4645 execution to determine what to handoff to, not the actual handoff invocation, which remains in the workflow context.
4746 """
@@ -53,8 +52,7 @@ class HandoffInput:
5352 strict_json_schema : bool = True
5453
5554
56- @dataclass
57- class FunctionToolInput :
55+ class FunctionToolInput (BaseModel ):
5856 """Data conversion friendly representation of a FunctionTool. Contains only the fields which are needed by the model
5957 execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
6058 """
@@ -65,8 +63,7 @@ class FunctionToolInput:
6563 strict_json_schema : bool = True
6664
6765
68- @dataclass
69- class HostedMCPToolInput :
66+ class HostedMCPToolInput (BaseModel ):
7067 """Data conversion friendly representation of a HostedMCPTool. Contains only the fields which are needed by the model
7168 execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
7269 """
@@ -84,13 +81,12 @@ class HostedMCPToolInput:
8481]
8582
8683
87- @dataclass
88- class AgentOutputSchemaInput (AgentOutputSchemaBase ):
84+ class AgentOutputSchemaInput (AgentOutputSchemaBase , BaseModel ):
8985 """Data conversion friendly representation of AgentOutputSchema."""
9086
91- output_type_name : Optional [str ]
87+ output_type_name : Optional [str ] = None
9288 is_wrapped : bool
93- output_schema : Optional [dict [str , Any ]]
89+ output_schema : Optional [dict [str , Any ]] = None
9490 strict_json_schema : bool
9591
9692 def is_plain_text (self ) -> bool :
@@ -131,74 +127,28 @@ class ModelTracingInput(enum.IntEnum):
131127 ENABLED_WITHOUT_DATA = 2
132128
133129
134- @dataclass
135- class ActivityModelInput :
130+ class ActivityModelInput (BaseModel ):
136131 """Input for the invoke_model_activity activity."""
137132
138133 input : Union [str , list [TResponseInputItem ]]
139134 model_settings : ModelSettings
140135 tracing : ModelTracingInput
141136 model_name : Optional [str ] = None
142137 system_instructions : Optional [str ] = None
143- tools : list [ToolInput ] = None
138+ tools : list [ToolInput ] = Field ( default_factory = list )
144139 output_schema : Optional [AgentOutputSchemaInput ] = None
145- handoffs : list [HandoffInput ] = None
140+ handoffs : list [HandoffInput ] = Field ( default_factory = list )
146141 previous_response_id : Optional [str ] = None
147142 prompt : Optional [Any ] = None
148143
149- def __post_init__ (self ):
150- """Initialize default values for list fields."""
151- if self .tools is None :
152- self .tools = []
153- if self .handoffs is None :
154- self .handoffs = []
155-
156144 def to_json (self ) -> str :
157145 """Convert the ActivityModelInput to a JSON string."""
158- return to_json ( self ). decode ( 'utf-8' )
146+ return self . model_dump_json ( )
159147
160148 @classmethod
161149 def from_json (cls , json_str : str ) -> 'ActivityModelInput' :
162150 """Create an ActivityModelInput instance from a JSON string."""
163- import json
164- data = json .loads (json_str )
165-
166- # Convert complex types back from dictionaries
167- if 'model_settings' in data and isinstance (data ['model_settings' ], dict ):
168- data ['model_settings' ] = ModelSettings (** data ['model_settings' ])
169-
170- if 'tracing' in data and isinstance (data ['tracing' ], int ):
171- data ['tracing' ] = ModelTracingInput (data ['tracing' ])
172-
173- # Convert tool inputs back to proper types
174- if 'tools' in data and data ['tools' ]:
175- converted_tools = []
176- for tool_data in data ['tools' ]:
177- if isinstance (tool_data , dict ):
178- # Check the tool type and convert accordingly
179- if 'name' in tool_data and 'description' in tool_data and 'params_json_schema' in tool_data :
180- # FunctionToolInput
181- converted_tools .append (FunctionToolInput (** tool_data ))
182- elif 'tool_config' in tool_data :
183- # HostedMCPToolInput
184- converted_tools .append (HostedMCPToolInput (** tool_data ))
185- else :
186- # For other tool types like FileSearchTool, WebSearchTool, etc.
187- # These might be already properly serialized/deserialized by pydantic_core
188- converted_tools .append (tool_data )
189- else :
190- converted_tools .append (tool_data )
191- data ['tools' ] = converted_tools
192-
193- # Convert handoffs back to proper types
194- if 'handoffs' in data and data ['handoffs' ]:
195- data ['handoffs' ] = [HandoffInput (** handoff ) for handoff in data ['handoffs' ]]
196-
197- # Convert output_schema back to proper type
198- if 'output_schema' in data and data ['output_schema' ] and isinstance (data ['output_schema' ], dict ):
199- data ['output_schema' ] = AgentOutputSchemaInput (** data ['output_schema' ])
200-
201- return cls (** data )
151+ return cls .model_validate_json (json_str )
202152
203153
204154class ModelInvoker :
@@ -220,7 +170,7 @@ async def empty_on_invoke_handoff(
220170
221171 # workaround for https://github.com/pydantic/pydantic/issues/9541
222172 # ValidatorIterator returned
223- input_json = to_json (input .input )
173+ input_json = json . dumps (input .input , default = str )
224174 input_input = json .loads (input_json )
225175
226176 def make_tool (tool : ToolInput ) -> Tool :
0 commit comments