Skip to content

Fix: enforce strict instructions function signature in get_system_prompt #1426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,16 +393,31 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
return run_agent

async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
"""Get the system prompt for the agent."""
if isinstance(self.instructions, str):
return self.instructions
elif callable(self.instructions):
# Inspect the signature of the instructions function
sig = inspect.signature(self.instructions)
params = list(sig.parameters.values())

# Enforce exactly 2 parameters
if len(params) != 2:
raise TypeError(
f"'instructions' callable must accept exactly 2 arguments (context, agent), "
f"but got {len(params)}: {[p.name for p in params]}"
)

# Call the instructions function properly
if inspect.iscoroutinefunction(self.instructions):
return await cast(Awaitable[str], self.instructions(run_context, self))
else:
return cast(str, self.instructions(run_context, self))

elif self.instructions is not None:
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
logger.error(
f"Instructions must be a string or a callable function, "
f"got {type(self.instructions).__name__}"
)

return None

Expand Down
113 changes: 113 additions & 0 deletions tests/test_agent_instructions_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from unittest.mock import Mock

import pytest

from agents import Agent, RunContextWrapper


class TestInstructionsSignatureValidation:
"""Test suite for instructions function signature validation"""

@pytest.fixture
def mock_run_context(self):
"""Create a mock RunContextWrapper for testing"""
return Mock(spec=RunContextWrapper)

@pytest.mark.asyncio
async def test_valid_async_signature_passes(self, mock_run_context):
"""Test that async function with correct signature works"""
async def valid_instructions(context, agent):
return "Valid async instructions"

agent = Agent(name="test_agent", instructions=valid_instructions)
result = await agent.get_system_prompt(mock_run_context)
assert result == "Valid async instructions"

@pytest.mark.asyncio
async def test_valid_sync_signature_passes(self, mock_run_context):
"""Test that sync function with correct signature works"""
def valid_instructions(context, agent):
return "Valid sync instructions"

agent = Agent(name="test_agent", instructions=valid_instructions)
result = await agent.get_system_prompt(mock_run_context)
assert result == "Valid sync instructions"

@pytest.mark.asyncio
async def test_one_parameter_raises_error(self, mock_run_context):
"""Test that function with only one parameter raises TypeError"""
def invalid_instructions(context):
return "Should fail"

agent = Agent(name="test_agent", instructions=invalid_instructions) # type: ignore[arg-type]

with pytest.raises(TypeError) as exc_info:
await agent.get_system_prompt(mock_run_context)

assert "must accept exactly 2 arguments" in str(exc_info.value)
assert "but got 1" in str(exc_info.value)

@pytest.mark.asyncio
async def test_three_parameters_raises_error(self, mock_run_context):
"""Test that function with three parameters raises TypeError"""
def invalid_instructions(context, agent, extra):
return "Should fail"

agent = Agent(name="test_agent", instructions=invalid_instructions) # type: ignore[arg-type]

with pytest.raises(TypeError) as exc_info:
await agent.get_system_prompt(mock_run_context)

assert "must accept exactly 2 arguments" in str(exc_info.value)
assert "but got 3" in str(exc_info.value)

@pytest.mark.asyncio
async def test_zero_parameters_raises_error(self, mock_run_context):
"""Test that function with no parameters raises TypeError"""
def invalid_instructions():
return "Should fail"

agent = Agent(name="test_agent", instructions=invalid_instructions) # type: ignore[arg-type]

with pytest.raises(TypeError) as exc_info:
await agent.get_system_prompt(mock_run_context)

assert "must accept exactly 2 arguments" in str(exc_info.value)
assert "but got 0" in str(exc_info.value)

@pytest.mark.asyncio
async def test_function_with_args_kwargs_fails(self, mock_run_context):
"""Test that function with *args/**kwargs fails validation"""
def flexible_instructions(context, agent, *args, **kwargs):
return "Flexible instructions"

agent = Agent(name="test_agent", instructions=flexible_instructions)

with pytest.raises(TypeError) as exc_info:
await agent.get_system_prompt(mock_run_context)

assert "must accept exactly 2 arguments" in str(exc_info.value)
assert "but got" in str(exc_info.value)

@pytest.mark.asyncio
async def test_string_instructions_still_work(self, mock_run_context):
"""Test that string instructions continue to work"""
agent = Agent(name="test_agent", instructions="Static string instructions")
result = await agent.get_system_prompt(mock_run_context)
assert result == "Static string instructions"

@pytest.mark.asyncio
async def test_none_instructions_return_none(self, mock_run_context):
"""Test that None instructions return None"""
agent = Agent(name="test_agent", instructions=None)
result = await agent.get_system_prompt(mock_run_context)
assert result is None

@pytest.mark.asyncio
async def test_non_callable_instructions_raises_error(self, mock_run_context):
"""Test that non-callable instructions raise a TypeError during initialization"""
with pytest.raises(TypeError) as exc_info:
Agent(name="test_agent", instructions=123) # type: ignore[arg-type]

assert "Agent instructions must be a string, callable, or None" in str(exc_info.value)
assert "got int" in str(exc_info.value)