Skip to content

Commit db50380

Browse files
Fix: enforce strict instructions function signature in get_system_prompt
1 parent bad88e7 commit db50380

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

src/agents/agent.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,21 +393,33 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
393393
return run_agent
394394

395395
async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
396-
"""Get the system prompt for the agent."""
397396
if isinstance(self.instructions, str):
398397
return self.instructions
399398
elif callable(self.instructions):
399+
# Inspect the signature of the instructions function
400+
sig = inspect.signature(self.instructions)
401+
params = list(sig.parameters.values())
402+
403+
# Enforce exactly 2 parameters
404+
if len(params) != 2:
405+
raise TypeError(
406+
f"'instructions' callable must accept exactly 2 arguments (context, agent), "
407+
f"but got {len(params)}: {[p.name for p in params]}"
408+
)
409+
410+
# Call the instructions function properly
400411
if inspect.iscoroutinefunction(self.instructions):
401412
return await cast(Awaitable[str], self.instructions(run_context, self))
402413
else:
403414
return cast(str, self.instructions(run_context, self))
415+
404416
elif self.instructions is not None:
405-
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
417+
logger.error(f"Instructions must be a string or a callable function, got {type(self.instructions).__name__}")
406418

407419
return None
408420

409421
async def get_prompt(
410422
self, run_context: RunContextWrapper[TContext]
411423
) -> ResponsePromptParam | None:
412424
"""Get the prompt for the agent."""
413-
return await PromptUtil.to_model_input(self.prompt, run_context, self)
425+
return await PromptUtil.to_model_input(self.prompt, run_context, self)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
from src.agents.agent import Agent # adjust if needed
3+
4+
class DummyContext:
5+
pass
6+
7+
class DummyAgent(Agent):
8+
def __init__(self, instructions):
9+
super().__init__(instructions=instructions)
10+
11+
@pytest.mark.asyncio
12+
async def test_valid_signature():
13+
async def good_instructions(ctx, agent):
14+
return "valid"
15+
a = DummyAgent(good_instructions)
16+
result = await a.get_system_prompt(DummyContext())
17+
assert result == "valid"
18+
19+
@pytest.mark.asyncio
20+
async def test_invalid_signature_raises():
21+
async def bad_instructions(ctx):
22+
return "invalid"
23+
a = DummyAgent(bad_instructions)
24+
import pytest
25+
with pytest.raises(TypeError):
26+
await a.get_system_prompt(DummyContext())

0 commit comments

Comments
 (0)