Skip to content

Commit e7bf9ac

Browse files
authored
Merge pull request #3 from DanielHashmi/add_validations
Added 16 Validations & Some Critical Tests
2 parents e3639aa + bc1f792 commit e7bf9ac

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ on:
44
push:
55
branches:
66
- main
7+
- add_validations
78
pull_request:
89
# All PRs, including stacked PRs
910

src/agents/agent.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,118 @@ class Agent(AgentBase, Generic[TContext]):
222222
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
223223

224224
def __post_init__(self):
225+
from typing import get_origin
226+
225227
if not isinstance(self.name, str):
226228
raise TypeError(f"Agent name must be a string, got {type(self.name).__name__}")
227229

230+
if self.handoff_description is not None and not isinstance(self.handoff_description, str):
231+
raise TypeError(
232+
f"Agent handoff_description must be a string or None, "
233+
f"got {type(self.handoff_description).__name__}"
234+
)
235+
236+
if not isinstance(self.tools, list):
237+
raise TypeError(f"Agent tools must be a list, got {type(self.tools).__name__}")
238+
239+
if not isinstance(self.mcp_servers, list):
240+
raise TypeError(
241+
f"Agent mcp_servers must be a list, got {type(self.mcp_servers).__name__}"
242+
)
243+
244+
if not isinstance(self.mcp_config, dict):
245+
raise TypeError(
246+
f"Agent mcp_config must be a dict, got {type(self.mcp_config).__name__}"
247+
)
248+
249+
if (
250+
self.instructions is not None
251+
and not isinstance(self.instructions, str)
252+
and not callable(self.instructions)
253+
):
254+
raise TypeError(
255+
f"Agent instructions must be a string, callable, or None, "
256+
f"got {type(self.instructions).__name__}"
257+
)
258+
259+
if (
260+
self.prompt is not None
261+
and not callable(self.prompt)
262+
and not hasattr(self.prompt, "get")
263+
):
264+
raise TypeError(
265+
f"Agent prompt must be a Prompt, DynamicPromptFunction, or None, "
266+
f"got {type(self.prompt).__name__}"
267+
)
268+
269+
if not isinstance(self.handoffs, list):
270+
raise TypeError(f"Agent handoffs must be a list, got {type(self.handoffs).__name__}")
271+
272+
if self.model is not None and not isinstance(self.model, str):
273+
from .models.interface import Model
274+
275+
if not isinstance(self.model, Model):
276+
raise TypeError(
277+
f"Agent model must be a string, Model, or None, got {type(self.model).__name__}"
278+
)
279+
280+
if not isinstance(self.model_settings, ModelSettings):
281+
raise TypeError(
282+
f"Agent model_settings must be a ModelSettings instance, "
283+
f"got {type(self.model_settings).__name__}"
284+
)
285+
286+
if not isinstance(self.input_guardrails, list):
287+
raise TypeError(
288+
f"Agent input_guardrails must be a list, got {type(self.input_guardrails).__name__}"
289+
)
290+
291+
if not isinstance(self.output_guardrails, list):
292+
raise TypeError(
293+
f"Agent output_guardrails must be a list, "
294+
f"got {type(self.output_guardrails).__name__}"
295+
)
296+
297+
if self.output_type is not None:
298+
from .agent_output import AgentOutputSchemaBase
299+
300+
if not (
301+
isinstance(self.output_type, (type, AgentOutputSchemaBase))
302+
or get_origin(self.output_type) is not None
303+
):
304+
raise TypeError(
305+
f"Agent output_type must be a type, AgentOutputSchemaBase, or None, "
306+
f"got {type(self.output_type).__name__}"
307+
)
308+
309+
if self.hooks is not None:
310+
from .lifecycle import AgentHooksBase
311+
312+
if not isinstance(self.hooks, AgentHooksBase):
313+
raise TypeError(
314+
f"Agent hooks must be an AgentHooks instance or None, "
315+
f"got {type(self.hooks).__name__}"
316+
)
317+
318+
if (
319+
not (
320+
isinstance(self.tool_use_behavior, str)
321+
and self.tool_use_behavior in ["run_llm_again", "stop_on_first_tool"]
322+
)
323+
and not isinstance(self.tool_use_behavior, dict)
324+
and not callable(self.tool_use_behavior)
325+
):
326+
raise TypeError(
327+
f"Agent tool_use_behavior must be 'run_llm_again', 'stop_on_first_tool', "
328+
f"StopAtTools dict, or callable, got {type(self.tool_use_behavior).__name__}"
329+
)
330+
331+
if not isinstance(self.reset_tool_choice, bool):
332+
raise TypeError(
333+
f"Agent reset_tool_choice must be a boolean, "
334+
f"got {type(self.reset_tool_choice).__name__}"
335+
)
336+
228337
def clone(self, **kwargs: Any) -> Agent[TContext]:
229338
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
230339
```

tests/test_agent_config.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from pydantic import BaseModel
33

44
from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, handoff
5+
from agents.lifecycle import AgentHooksBase
6+
from agents.model_settings import ModelSettings
57
from agents.run import AgentRunner
68

79

@@ -167,3 +169,58 @@ async def test_agent_final_output():
167169
assert schema.is_strict_json_schema() is True
168170
assert schema.json_schema() is not None
169171
assert not schema.is_plain_text()
172+
173+
174+
class TestAgentValidation:
175+
"""Essential validation tests for Agent __post_init__"""
176+
177+
def test_name_validation_critical_cases(self):
178+
"""Test name validation - the original issue that started this PR"""
179+
# This was the original failing case that caused JSON serialization errors
180+
with pytest.raises(TypeError, match="Agent name must be a string, got int"):
181+
Agent(name=1) # type: ignore
182+
183+
with pytest.raises(TypeError, match="Agent name must be a string, got NoneType"):
184+
Agent(name=None) # type: ignore
185+
186+
def test_tool_use_behavior_dict_validation(self):
187+
"""Test tool_use_behavior accepts StopAtTools dict - fixes existing test failures"""
188+
# This test ensures the existing failing tests now pass
189+
Agent(name="test", tool_use_behavior={"stop_at_tool_names": ["tool1"]})
190+
191+
# Invalid cases that should fail
192+
with pytest.raises(TypeError, match="Agent tool_use_behavior must be"):
193+
Agent(name="test", tool_use_behavior=123) # type: ignore
194+
195+
def test_hooks_validation_python39_compatibility(self):
196+
"""Test hooks validation works with Python 3.9 - fixes generic type issues"""
197+
198+
class MockHooks(AgentHooksBase):
199+
pass
200+
201+
# Valid case
202+
Agent(name="test", hooks=MockHooks()) # type: ignore
203+
204+
# Invalid case
205+
with pytest.raises(TypeError, match="Agent hooks must be an AgentHooks instance"):
206+
Agent(name="test", hooks="invalid") # type: ignore
207+
208+
def test_list_field_validation(self):
209+
"""Test critical list fields that commonly get wrong types"""
210+
# These are the most common mistakes users make
211+
with pytest.raises(TypeError, match="Agent tools must be a list"):
212+
Agent(name="test", tools="not_a_list") # type: ignore
213+
214+
with pytest.raises(TypeError, match="Agent handoffs must be a list"):
215+
Agent(name="test", handoffs="not_a_list") # type: ignore
216+
217+
def test_model_settings_validation(self):
218+
"""Test model_settings validation - prevents runtime errors"""
219+
# Valid case
220+
Agent(name="test", model_settings=ModelSettings())
221+
222+
# Invalid case that could cause runtime issues
223+
with pytest.raises(
224+
TypeError, match="Agent model_settings must be a ModelSettings instance"
225+
):
226+
Agent(name="test", model_settings={}) # type: ignore

0 commit comments

Comments
 (0)