Skip to content

Fix: enforce strict instructions function signature in get_system_prompt #1426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,21 +393,33 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
return run_agent

async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
"""Get the system prompt for the agent."""
if isinstance(self.instructions, str):
return self.instructions
elif callable(self.instructions):
# Inspect the signature of the instructions function
sig = inspect.signature(self.instructions)
params = list(sig.parameters.values())

# Enforce exactly 2 parameters
if len(params) != 2:
raise TypeError(
f"'instructions' callable must accept exactly 2 arguments (context, agent), "
f"but got {len(params)}: {[p.name for p in params]}"
)

# Call the instructions function properly
if inspect.iscoroutinefunction(self.instructions):
return await cast(Awaitable[str], self.instructions(run_context, self))
else:
return cast(str, self.instructions(run_context, self))

elif self.instructions is not None:
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
logger.error(f"Instructions must be a string or a callable function, got {type(self.instructions).__name__}")

return None

async def get_prompt(
self, run_context: RunContextWrapper[TContext]
) -> ResponsePromptParam | None:
"""Get the prompt for the agent."""
return await PromptUtil.to_model_input(self.prompt, run_context, self)
return await PromptUtil.to_model_input(self.prompt, run_context, self)
109 changes: 109 additions & 0 deletions tests/test_agent_instructions_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pytest
from unittest.mock import Mock
# Adjust import based on actual repo structure
from src.agents.agent import Agent, RunContextWrapper

class TestInstructionsSignatureValidation:
"""Test suite for instructions function signature validation"""

@pytest.fixture
def mock_run_context(self):
"""Create a mock RunContextWrapper for testing"""
return Mock(spec=RunContextWrapper)

@pytest.mark.asyncio
async def test_valid_async_signature_passes(self, mock_run_context):
"""Test that async function with correct signature works"""
async def valid_instructions(context, agent):
return "Valid async instructions"

agent = Agent(instructions=valid_instructions)
result = await agent.get_system_prompt(mock_run_context)
assert result == "Valid async instructions"

@pytest.mark.asyncio
async def test_valid_sync_signature_passes(self, mock_run_context):
"""Test that sync function with correct signature works"""
def valid_instructions(context, agent):
return "Valid sync instructions"

agent = Agent(instructions=valid_instructions)
result = await agent.get_system_prompt(mock_run_context)
assert result == "Valid sync instructions"

@pytest.mark.asyncio
async def test_one_parameter_raises_error(self, mock_run_context):
"""Test that function with only one parameter raises TypeError"""
def invalid_instructions(context):
return "Should fail"

agent = Agent(instructions=invalid_instructions)

with pytest.raises(TypeError) as exc_info:
await agent.get_system_prompt(mock_run_context)

assert "must accept exactly 2 arguments" in str(exc_info.value)
assert "but got 1" in str(exc_info.value)

@pytest.mark.asyncio
async def test_three_parameters_raises_error(self, mock_run_context):
"""Test that function with three parameters raises TypeError"""
def invalid_instructions(context, agent, extra):
return "Should fail"

agent = Agent(instructions=invalid_instructions)

with pytest.raises(TypeError) as exc_info:
await agent.get_system_prompt(mock_run_context)

assert "must accept exactly 2 arguments" in str(exc_info.value)
assert "but got 3" in str(exc_info.value)

@pytest.mark.asyncio
async def test_zero_parameters_raises_error(self, mock_run_context):
"""Test that function with no parameters raises TypeError"""
def invalid_instructions():
return "Should fail"

agent = Agent(instructions=invalid_instructions)

with pytest.raises(TypeError) as exc_info:
await agent.get_system_prompt(mock_run_context)

assert "must accept exactly 2 arguments" in str(exc_info.value)
assert "but got 0" in str(exc_info.value)

@pytest.mark.asyncio
async def test_function_with_args_kwargs_passes(self, mock_run_context):
"""Test that function with *args/**kwargs still works (edge case)"""
def flexible_instructions(context, agent, *args, **kwargs):
return "Flexible instructions"

agent = Agent(instructions=flexible_instructions)
# This should potentially pass as it can accept the 2 required args
# Adjust this test based on your desired behavior
result = await agent.get_system_prompt(mock_run_context)
assert result == "Flexible instructions"

@pytest.mark.asyncio
async def test_string_instructions_still_work(self, mock_run_context):
"""Test that string instructions continue to work"""
agent = Agent(instructions="Static string instructions")
result = await agent.get_system_prompt(mock_run_context)
assert result == "Static string instructions"

@pytest.mark.asyncio
async def test_none_instructions_return_none(self, mock_run_context):
"""Test that None instructions return None"""
agent = Agent(instructions=None)
result = await agent.get_system_prompt(mock_run_context)
assert result is None

@pytest.mark.asyncio
async def test_non_callable_instructions_log_error(self, mock_run_context, caplog):
"""Test that non-callable instructions log an error"""
agent = Agent(instructions=123) # Invalid type
result = await agent.get_system_prompt(mock_run_context)
assert result is None
# Check that error was logged (adjust based on actual logging setup)
# assert "Instructions must be a string or a function" in caplog.text
Loading