diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 6c417b308..230756dde 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -544,7 +544,7 @@ async def execute_function_tool_calls( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> list[FunctionToolResult]: - async def run_single_tool( + async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: with function_span(func_tool.name) as span_fn: diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py index 16845badd..aa42b5cdd 100644 --- a/src/agents/tool_context.py +++ b/src/agents/tool_context.py @@ -1,17 +1,15 @@ from dataclasses import dataclass, field, fields -from typing import Any, Optional - from openai.types.responses import ResponseFunctionToolCall - +from typing import Any, Optional from .run_context import RunContextWrapper, TContext -def _assert_must_pass_tool_call_id() -> str: - raise ValueError("tool_call_id must be passed to ToolContext") +def _assert_must_pass_tool_name() -> str: + raise ValueError("Tool name must be passed") -def _assert_must_pass_tool_name() -> str: - raise ValueError("tool_name must be passed to ToolContext") +def _assert_must_pass_tool_call_id() -> str: + raise ValueError("Tool call ID must be passed") @dataclass @@ -24,6 +22,9 @@ class ToolContext(RunContextWrapper[TContext]): tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id) """The ID of the tool call.""" + arguments: Optional[str] = None + """The raw JSON arguments string sent by the model for this tool call, if available.""" + @classmethod def from_agent_context( cls, @@ -34,9 +35,14 @@ def from_agent_context( """ Create a ToolContext from a RunContextWrapper. """ - # Grab the names of the RunContextWrapper's init=True fields base_values: dict[str, Any] = { f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init } - tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name() - return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values) + tool_name = tool_call.function.name if tool_call is not None else _assert_must_pass_tool_name() + args = tool_call.function.arguments if tool_call is not None else None + return cls( + tool_name=tool_name, + tool_call_id=tool_call_id, + arguments=args, + **base_values, + ) \ No newline at end of file diff --git a/tests/test_tool_context_arg.py b/tests/test_tool_context_arg.py new file mode 100644 index 000000000..ac76363e3 --- /dev/null +++ b/tests/test_tool_context_arg.py @@ -0,0 +1,77 @@ +import json +from dataclasses import fields +from types import SimpleNamespace +from typing import Optional + +import pytest + +from agents import function_tool +from agents.run_context import RunContextWrapper +from agents.tool_context import ToolContext + + +class FakeToolCall: + def __init__(self, name: str, arguments: Optional[str] = None): + self.name = name + self.arguments = arguments + + +def make_minimal_context_like_runcontext(): + ctx = SimpleNamespace() + for f in fields(RunContextWrapper): + setattr(ctx, f.name, None) + return ctx + + +def test_from_agent_context_populates_arguments_and_names(): + context_like = make_minimal_context_like_runcontext() + fake_call = FakeToolCall(name="my_tool", arguments='{"x": 1, "y": 2}') + + tc: ToolContext = ToolContext.from_agent_context( + context_like, tool_call_id="c-1", tool_call=fake_call + ) + + assert tc.tool_name == "my_tool" + assert tc.tool_call_id == "c-1" + assert tc.arguments == '{"x": 1, "y": 2}' + + +def test_from_agent_context_raises_if_tool_name_missing(): + context_like = make_minimal_context_like_runcontext() + + with pytest.raises(ValueError, match="Tool name must"): + ToolContext.from_agent_context(context_like, tool_call_id="c-2", tool_call=None) + + +@pytest.mark.asyncio +async def test_function_tool_accepts_toolcontext_generic_argless(): + def argless_with_context(ctx: ToolContext[str]) -> str: + return "ok" + + tool = function_tool(argless_with_context) + assert tool.name == "argless_with_context" + + ctx = ToolContext(context=None, tool_name="argless_with_context", tool_call_id="1") + + result = await tool.on_invoke_tool(ctx, "") + assert result == "ok" + + result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}') + assert result == "ok" + + +@pytest.mark.asyncio +async def test_function_tool_with_context_and_args_parsed(): + class DummyCtx: + def __init__(self): + self.data = "xyz" + + def with_ctx_and_name(ctx: ToolContext[DummyCtx], name: str) -> str: + return f"{name}_{ctx.context.data}" + + tool = function_tool(with_ctx_and_name) + ctx = ToolContext(context=DummyCtx(), tool_name="with_ctx_and_name", tool_call_id="1") + payload = json.dumps({"name": "uzair"}) + result = await tool.on_invoke_tool(ctx, payload) + + assert result == "uzair_xyz"