diff --git a/examples/basic/lifecycle_example.py b/examples/basic/lifecycle_example.py index 941b67768..1429872b8 100644 --- a/examples/basic/lifecycle_example.py +++ b/examples/basic/lifecycle_example.py @@ -46,7 +46,7 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None: self.event_counter += 1 print( - f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}" + f"### {self.event_counter}: Tool {tool.name} started. name={context.tool_name}, call_id={context.tool_call_id}, args={context.tool_arguments}. Usage: {self._usage_to_str(context.usage)}" # type: ignore[attr-defined] ) async def on_tool_end( @@ -54,7 +54,7 @@ async def on_tool_end( ) -> None: self.event_counter += 1 print( - f"### {self.event_counter}: Tool {tool.name} ended with result {result}. Usage: {self._usage_to_str(context.usage)}" + f"### {self.event_counter}: Tool {tool.name} finished. result={result}, name={context.tool_name}, call_id={context.tool_call_id}, args={context.tool_arguments}. Usage: {self._usage_to_str(context.usage)}" # type: ignore[attr-defined] ) async def on_handoff( @@ -128,19 +128,19 @@ async def main() -> None: ### 1: Agent Start Agent started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens ### 2: LLM started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens ### 3: LLM ended. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens -### 4: Tool random_number started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens -### 5: Tool random_number ended with result 69. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens +### 4: Tool random_number started. name=random_number, call_id=call_IujmDZYiM800H0hy7v17VTS0, args={"max":250}. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens +### 5: Tool random_number finished. result=107, name=random_number, call_id=call_IujmDZYiM800H0hy7v17VTS0, args={"max":250}. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens ### 6: LLM started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens ### 7: LLM ended. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens ### 8: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens ### 9: Agent Multiply Agent started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens ### 10: LLM started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens ### 11: LLM ended. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens -### 12: Tool multiply_by_two started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens -### 13: Tool multiply_by_two ended with result 138. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens +### 12: Tool multiply_by_two started. name=multiply_by_two, call_id=call_KhHvTfsgaosZsfi741QvzgYw, args={"x":107}. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens +### 13: Tool multiply_by_two finished. result=214, name=multiply_by_two, call_id=call_KhHvTfsgaosZsfi741QvzgYw, args={"x":107}. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens ### 14: LLM started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens ### 15: LLM ended. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens -### 16: Agent Multiply Agent ended with output number=138. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens +### 16: Agent Multiply Agent ended with output number=214. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens Done! """ diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 62adc529c..42dcf531a 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -408,6 +408,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: usage=self._context_wrapper.usage, tool_name=event.name, tool_call_id=event.call_id, + tool_arguments=event.arguments, ) result = await func_tool.on_invoke_tool(tool_context, event.arguments) @@ -432,6 +433,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: usage=self._context_wrapper.usage, tool_name=event.name, tool_call_id=event.call_id, + tool_arguments=event.arguments, ) # Execute the handoff to get the new agent diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 16845badd..5b81239f6 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -14,6 +14,10 @@ def _assert_must_pass_tool_name() -> str: raise ValueError("tool_name must be passed to ToolContext") +def _assert_must_pass_tool_arguments() -> str: + raise ValueError("tool_arguments must be passed to ToolContext") + + @dataclass class ToolContext(RunContextWrapper[TContext]): """The context of a tool call.""" @@ -24,6 +28,9 @@ class ToolContext(RunContextWrapper[TContext]): tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id) """The ID of the tool call.""" + tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments) + """The raw arguments string of the tool call.""" + @classmethod def from_agent_context( cls, @@ -39,4 +46,10 @@ def from_agent_context( f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init } tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name() - return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values) + tool_args = ( + tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments() + ) + + return cls( + tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values + ) diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 813f72c28..1b8b99682 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -277,7 +277,12 @@ async def fake_run( ) assert isinstance(tool, FunctionTool) - tool_context = ToolContext(context=None, tool_name="story_tool", tool_call_id="call_1") + tool_context = ToolContext( + context=None, + tool_name="story_tool", + tool_call_id="call_1", + tool_arguments='{"input": "hello"}', + ) output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}') assert output == "Hello world" @@ -374,7 +379,12 @@ async def extractor(result) -> str: ) assert isinstance(tool, FunctionTool) - tool_context = ToolContext(context=None, tool_name="summary_tool", tool_call_id="call_2") + tool_context = ToolContext( + context=None, + tool_name="summary_tool", + tool_call_id="call_2", + tool_arguments='{"input": "summarize this"}', + ) output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') assert output == "custom output" diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 15602bbac..9f227aadb 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -27,7 +27,7 @@ async def test_argless_function(): assert tool.name == "argless_function" result = await tool.on_invoke_tool( - ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), "" + ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" ) assert result == "ok" @@ -41,12 +41,15 @@ async def test_argless_with_context(): tool = function_tool(argless_with_context) assert tool.name == "argless_with_context" - result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ) assert result == "ok" # Extra JSON should not raise an error result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), + '{"a": 1}', ) assert result == "ok" @@ -61,18 +64,22 @@ async def test_simple_function(): assert tool.name == "simple_function" result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}' + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'), + '{"a": 1}', ) assert result == 6 result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}' + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'), + '{"a": 1, "b": 2}', ) assert result == 3 # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "") + await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), "" + ) class Foo(BaseModel): @@ -101,7 +108,8 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + valid_json, ) assert result == "6 hello10 hello" @@ -112,7 +120,8 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + valid_json, ) assert result == "3 hello10 hello" @@ -124,14 +133,18 @@ async def test_complex_args_function(): } ) result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json), + valid_json, ) assert result == "3 hello10 world" # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}' + ToolContext( + None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"foo": {"a": 1}}' + ), + '{"foo": {"a": 1}}', ) @@ -193,7 +206,10 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert tool.strict_json_schema result = await tool.on_invoke_tool( - ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}' + ToolContext( + None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"data": "hello"}' + ), + '{"data": "hello"}', ) assert result == "hello_done" @@ -209,7 +225,12 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert "additionalProperties" not in tool_not_strict.params_json_schema result = await tool_not_strict.on_invoke_tool( - ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"), + ToolContext( + None, + tool_name=tool_not_strict.name, + tool_call_id="1", + tool_arguments='{"data": "hello", "bar": "baz"}', + ), '{"data": "hello", "bar": "baz"}', ) assert result == "hello_done" @@ -221,7 +242,7 @@ def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) @@ -245,7 +266,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -269,7 +290,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1") + ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index b81d5dbe2..2f5a38223 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -16,7 +16,9 @@ def __init__(self): def ctx_wrapper() -> ToolContext[DummyContext]: - return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1") + return ToolContext( + context=DummyContext(), tool_name="dummy", tool_call_id="1", tool_arguments="" + ) @function_tool