Skip to content

Commit f5fcd30

Browse files
authored
Added runtime validation for Agent constructor arguments (#998)
Add `__post_init__` validation to ensure Agent name is a string The Agent class previously only used type hints for the name field without runtime validation, allowing non-string values like integers to be passed. This caused downstream errors during JSON serialization, tracing, and other operations that expect the name to be a string. Changes: - Add `__post_init__` method to Agent class with `isinstance` check - Raise TypeError with descriptive message for non-string names - Validation occurs at instantiation time to fail fast Fixes issue where `Agent(name=1)` would succeed but cause errors later in the execution pipeline. Fixes #996
1 parent 417c19b commit f5fcd30

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed

src/agents/agent.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,119 @@ class Agent(AgentBase, Generic[TContext]):
223223
"""Whether to reset the tool choice to the default value after a tool has been called. Defaults
224224
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
225225

226+
def __post_init__(self):
227+
from typing import get_origin
228+
229+
if not isinstance(self.name, str):
230+
raise TypeError(f"Agent name must be a string, got {type(self.name).__name__}")
231+
232+
if self.handoff_description is not None and not isinstance(self.handoff_description, str):
233+
raise TypeError(
234+
f"Agent handoff_description must be a string or None, "
235+
f"got {type(self.handoff_description).__name__}"
236+
)
237+
238+
if not isinstance(self.tools, list):
239+
raise TypeError(f"Agent tools must be a list, got {type(self.tools).__name__}")
240+
241+
if not isinstance(self.mcp_servers, list):
242+
raise TypeError(
243+
f"Agent mcp_servers must be a list, got {type(self.mcp_servers).__name__}"
244+
)
245+
246+
if not isinstance(self.mcp_config, dict):
247+
raise TypeError(
248+
f"Agent mcp_config must be a dict, got {type(self.mcp_config).__name__}"
249+
)
250+
251+
if (
252+
self.instructions is not None
253+
and not isinstance(self.instructions, str)
254+
and not callable(self.instructions)
255+
):
256+
raise TypeError(
257+
f"Agent instructions must be a string, callable, or None, "
258+
f"got {type(self.instructions).__name__}"
259+
)
260+
261+
if (
262+
self.prompt is not None
263+
and not callable(self.prompt)
264+
and not hasattr(self.prompt, "get")
265+
):
266+
raise TypeError(
267+
f"Agent prompt must be a Prompt, DynamicPromptFunction, or None, "
268+
f"got {type(self.prompt).__name__}"
269+
)
270+
271+
if not isinstance(self.handoffs, list):
272+
raise TypeError(f"Agent handoffs must be a list, got {type(self.handoffs).__name__}")
273+
274+
if self.model is not None and not isinstance(self.model, str):
275+
from .models.interface import Model
276+
277+
if not isinstance(self.model, Model):
278+
raise TypeError(
279+
f"Agent model must be a string, Model, or None, got {type(self.model).__name__}"
280+
)
281+
282+
if not isinstance(self.model_settings, ModelSettings):
283+
raise TypeError(
284+
f"Agent model_settings must be a ModelSettings instance, "
285+
f"got {type(self.model_settings).__name__}"
286+
)
287+
288+
if not isinstance(self.input_guardrails, list):
289+
raise TypeError(
290+
f"Agent input_guardrails must be a list, got {type(self.input_guardrails).__name__}"
291+
)
292+
293+
if not isinstance(self.output_guardrails, list):
294+
raise TypeError(
295+
f"Agent output_guardrails must be a list, "
296+
f"got {type(self.output_guardrails).__name__}"
297+
)
298+
299+
if self.output_type is not None:
300+
from .agent_output import AgentOutputSchemaBase
301+
302+
if not (
303+
isinstance(self.output_type, (type, AgentOutputSchemaBase))
304+
or get_origin(self.output_type) is not None
305+
):
306+
raise TypeError(
307+
f"Agent output_type must be a type, AgentOutputSchemaBase, or None, "
308+
f"got {type(self.output_type).__name__}"
309+
)
310+
311+
if self.hooks is not None:
312+
from .lifecycle import AgentHooksBase
313+
314+
if not isinstance(self.hooks, AgentHooksBase):
315+
raise TypeError(
316+
f"Agent hooks must be an AgentHooks instance or None, "
317+
f"got {type(self.hooks).__name__}"
318+
)
319+
320+
if (
321+
not (
322+
isinstance(self.tool_use_behavior, str)
323+
and self.tool_use_behavior in ["run_llm_again", "stop_on_first_tool"]
324+
)
325+
and not isinstance(self.tool_use_behavior, dict)
326+
and not callable(self.tool_use_behavior)
327+
):
328+
raise TypeError(
329+
f"Agent tool_use_behavior must be 'run_llm_again', 'stop_on_first_tool', "
330+
f"StopAtTools dict, or callable, got {type(self.tool_use_behavior).__name__}"
331+
)
332+
333+
if not isinstance(self.reset_tool_choice, bool):
334+
raise TypeError(
335+
f"Agent reset_tool_choice must be a boolean, "
336+
f"got {type(self.reset_tool_choice).__name__}"
337+
)
338+
226339
def clone(self, **kwargs: Any) -> Agent[TContext]:
227340
"""Make a copy of the agent, with the given arguments changed.
228341
Notes:

tests/test_agent_config.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from pydantic import BaseModel
33

44
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, handoff
5+
from agents.lifecycle import AgentHooksBase
6+
from agents.model_settings import ModelSettings
57
from agents.run import AgentRunner
68

79

@@ -167,3 +169,58 @@ async def test_agent_final_output():
167169
assert schema.is_strict_json_schema() is True
168170
assert schema.json_schema() is not None
169171
assert not schema.is_plain_text()
172+
173+
174+
class TestAgentValidation:
175+
"""Essential validation tests for Agent __post_init__"""
176+
177+
def test_name_validation_critical_cases(self):
178+
"""Test name validation - the original issue that started this PR"""
179+
# This was the original failing case that caused JSON serialization errors
180+
with pytest.raises(TypeError, match="Agent name must be a string, got int"):
181+
Agent(name=1) # type: ignore
182+
183+
with pytest.raises(TypeError, match="Agent name must be a string, got NoneType"):
184+
Agent(name=None) # type: ignore
185+
186+
def test_tool_use_behavior_dict_validation(self):
187+
"""Test tool_use_behavior accepts StopAtTools dict - fixes existing test failures"""
188+
# This test ensures the existing failing tests now pass
189+
Agent(name="test", tool_use_behavior={"stop_at_tool_names": ["tool1"]})
190+
191+
# Invalid cases that should fail
192+
with pytest.raises(TypeError, match="Agent tool_use_behavior must be"):
193+
Agent(name="test", tool_use_behavior=123) # type: ignore
194+
195+
def test_hooks_validation_python39_compatibility(self):
196+
"""Test hooks validation works with Python 3.9 - fixes generic type issues"""
197+
198+
class MockHooks(AgentHooksBase):
199+
pass
200+
201+
# Valid case
202+
Agent(name="test", hooks=MockHooks()) # type: ignore
203+
204+
# Invalid case
205+
with pytest.raises(TypeError, match="Agent hooks must be an AgentHooks instance"):
206+
Agent(name="test", hooks="invalid") # type: ignore
207+
208+
def test_list_field_validation(self):
209+
"""Test critical list fields that commonly get wrong types"""
210+
# These are the most common mistakes users make
211+
with pytest.raises(TypeError, match="Agent tools must be a list"):
212+
Agent(name="test", tools="not_a_list") # type: ignore
213+
214+
with pytest.raises(TypeError, match="Agent handoffs must be a list"):
215+
Agent(name="test", handoffs="not_a_list") # type: ignore
216+
217+
def test_model_settings_validation(self):
218+
"""Test model_settings validation - prevents runtime errors"""
219+
# Valid case
220+
Agent(name="test", model_settings=ModelSettings())
221+
222+
# Invalid case that could cause runtime issues
223+
with pytest.raises(
224+
TypeError, match="Agent model_settings must be a ModelSettings instance"
225+
):
226+
Agent(name="test", model_settings={}) # type: ignore

0 commit comments

Comments
 (0)