Skip to content

Commit 534e2d5

Browse files
Fix: enforce strict instructions function signature in get_system_prompt (#1426)
1 parent bad88e7 commit 534e2d5

File tree

2 files changed

+130
-2
lines changed

2 files changed

+130
-2
lines changed

src/agents/agent.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,16 +393,31 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
393393
return run_agent
394394

395395
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
396-
"""Get the system prompt for the agent."""
397396
if isinstance(self.instructions, str):
398397
return self.instructions
399398
elif callable(self.instructions):
399+
# Inspect the signature of the instructions function
400+
sig = inspect.signature(self.instructions)
401+
params = list(sig.parameters.values())
402+
403+
# Enforce exactly 2 parameters
404+
if len(params) != 2:
405+
raise TypeError(
406+
f"'instructions' callable must accept exactly 2 arguments (context, agent), "
407+
f"but got {len(params)}: {[p.name for p in params]}"
408+
)
409+
410+
# Call the instructions function properly
400411
if inspect.iscoroutinefunction(self.instructions):
401412
return await cast(Awaitable[str], self.instructions(run_context, self))
402413
else:
403414
return cast(str, self.instructions(run_context, self))
415+
404416
elif self.instructions is not None:
405-
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
417+
logger.error(
418+
f"Instructions must be a string or a callable function, "
419+
f"got {type(self.instructions).__name__}"
420+
)
406421

407422
return None
408423

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
5+
from agents import Agent, RunContextWrapper
6+
7+
8+
class TestInstructionsSignatureValidation:
9+
"""Test suite for instructions function signature validation"""
10+
11+
@pytest.fixture
12+
def mock_run_context(self):
13+
"""Create a mock RunContextWrapper for testing"""
14+
return Mock(spec=RunContextWrapper)
15+
16+
@pytest.mark.asyncio
17+
async def test_valid_async_signature_passes(self, mock_run_context):
18+
"""Test that async function with correct signature works"""
19+
async def valid_instructions(context, agent):
20+
return "Valid async instructions"
21+
22+
agent = Agent(name="test_agent", instructions=valid_instructions)
23+
result = await agent.get_system_prompt(mock_run_context)
24+
assert result == "Valid async instructions"
25+
26+
@pytest.mark.asyncio
27+
async def test_valid_sync_signature_passes(self, mock_run_context):
28+
"""Test that sync function with correct signature works"""
29+
def valid_instructions(context, agent):
30+
return "Valid sync instructions"
31+
32+
agent = Agent(name="test_agent", instructions=valid_instructions)
33+
result = await agent.get_system_prompt(mock_run_context)
34+
assert result == "Valid sync instructions"
35+
36+
@pytest.mark.asyncio
37+
async def test_one_parameter_raises_error(self, mock_run_context):
38+
"""Test that function with only one parameter raises TypeError"""
39+
def invalid_instructions(context):
40+
return "Should fail"
41+
42+
agent = Agent(name="test_agent", instructions=invalid_instructions) # type: ignore[arg-type]
43+
44+
with pytest.raises(TypeError) as exc_info:
45+
await agent.get_system_prompt(mock_run_context)
46+
47+
assert "must accept exactly 2 arguments" in str(exc_info.value)
48+
assert "but got 1" in str(exc_info.value)
49+
50+
@pytest.mark.asyncio
51+
async def test_three_parameters_raises_error(self, mock_run_context):
52+
"""Test that function with three parameters raises TypeError"""
53+
def invalid_instructions(context, agent, extra):
54+
return "Should fail"
55+
56+
agent = Agent(name="test_agent", instructions=invalid_instructions) # type: ignore[arg-type]
57+
58+
with pytest.raises(TypeError) as exc_info:
59+
await agent.get_system_prompt(mock_run_context)
60+
61+
assert "must accept exactly 2 arguments" in str(exc_info.value)
62+
assert "but got 3" in str(exc_info.value)
63+
64+
@pytest.mark.asyncio
65+
async def test_zero_parameters_raises_error(self, mock_run_context):
66+
"""Test that function with no parameters raises TypeError"""
67+
def invalid_instructions():
68+
return "Should fail"
69+
70+
agent = Agent(name="test_agent", instructions=invalid_instructions) # type: ignore[arg-type]
71+
72+
with pytest.raises(TypeError) as exc_info:
73+
await agent.get_system_prompt(mock_run_context)
74+
75+
assert "must accept exactly 2 arguments" in str(exc_info.value)
76+
assert "but got 0" in str(exc_info.value)
77+
78+
@pytest.mark.asyncio
79+
async def test_function_with_args_kwargs_fails(self, mock_run_context):
80+
"""Test that function with *args/**kwargs fails validation"""
81+
def flexible_instructions(context, agent, *args, **kwargs):
82+
return "Flexible instructions"
83+
84+
agent = Agent(name="test_agent", instructions=flexible_instructions)
85+
86+
with pytest.raises(TypeError) as exc_info:
87+
await agent.get_system_prompt(mock_run_context)
88+
89+
assert "must accept exactly 2 arguments" in str(exc_info.value)
90+
assert "but got" in str(exc_info.value)
91+
92+
@pytest.mark.asyncio
93+
async def test_string_instructions_still_work(self, mock_run_context):
94+
"""Test that string instructions continue to work"""
95+
agent = Agent(name="test_agent", instructions="Static string instructions")
96+
result = await agent.get_system_prompt(mock_run_context)
97+
assert result == "Static string instructions"
98+
99+
@pytest.mark.asyncio
100+
async def test_none_instructions_return_none(self, mock_run_context):
101+
"""Test that None instructions return None"""
102+
agent = Agent(name="test_agent", instructions=None)
103+
result = await agent.get_system_prompt(mock_run_context)
104+
assert result is None
105+
106+
@pytest.mark.asyncio
107+
async def test_non_callable_instructions_raises_error(self, mock_run_context):
108+
"""Test that non-callable instructions raise a TypeError during initialization"""
109+
with pytest.raises(TypeError) as exc_info:
110+
Agent(name="test_agent", instructions=123) # type: ignore[arg-type]
111+
112+
assert "Agent instructions must be a string, callable, or None" in str(exc_info.value)
113+
assert "got int" in str(exc_info.value)

0 commit comments

Comments
 (0)