Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,23 @@ tool = FunctionTool(
)
```

### Tool context

When `on_invoke_tool` is called, it receives a `ToolContext` instance. The object contains:

- `context` – the context object you passed to `Runner.run()`.
- `usage` – usage information for the run so far.
- `tool_name` – the name of the tool being invoked.
- `tool_call_id` – the ID of the tool call.

You can access these fields inside your tool function:

```python
async def run_function(ctx: ToolContext[Any], args: str) -> str:
print("Tool invoked:", ctx.tool_name)
...
```

### Automatic argument and docstring parsing

As mentioned before, we automatically parse the function signature to extract the schema for the tool, and we parse the docstring to extract descriptions for the tool and for individual arguments. Some notes on that:
Expand Down
4 changes: 3 additions & 1 deletion src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
tool_context = ToolContext.from_agent_context(
context_wrapper, func_tool.name, tool_call.call_id
)
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
Expand Down
11 changes: 9 additions & 2 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,23 @@ def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext")


def _assert_must_pass_tool_name() -> str:
raise ValueError("tool_name must be passed to ToolContext")


@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""

tool_name: str = field(default_factory=_assert_must_pass_tool_name)
"""The name of the tool being invoked."""

tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""

@classmethod
def from_agent_context(
cls, context: RunContextWrapper[TContext], tool_call_id: str
cls, context: RunContextWrapper[TContext], tool_name: str, tool_call_id: str
) -> "ToolContext":
"""
Create a ToolContext from a RunContextWrapper.
Expand All @@ -26,4 +33,4 @@ def from_agent_context(
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
return cls(tool_call_id=tool_call_id, **base_values)
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
49 changes: 34 additions & 15 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ async def test_argless_function():
tool = function_tool(argless_function)
assert tool.name == "argless_function"

result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
result = await tool.on_invoke_tool(
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
)
assert result == "ok"


Expand All @@ -32,11 +34,13 @@ async def test_argless_with_context():
tool = function_tool(argless_with_context)
assert tool.name == "argless_with_context"

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
assert result == "ok"

# Extra JSON should not raise an error
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
)
assert result == "ok"


Expand All @@ -49,15 +53,19 @@ async def test_simple_function():
tool = function_tool(simple_function, failure_error_function=None)
assert tool.name == "simple_function"

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
)
assert result == 6

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
)
assert result == 3

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")


class Foo(BaseModel):
Expand Down Expand Up @@ -85,7 +93,9 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "6 hello10 hello"

valid_json = json.dumps(
Expand All @@ -94,7 +104,9 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "3 hello10 hello"

valid_json = json.dumps(
Expand All @@ -104,12 +116,16 @@ async def test_complex_args_function():
"baz": "world",
}
)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
)
assert result == "3 hello10 world"

# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
)


def test_function_config_overrides():
Expand Down Expand Up @@ -169,7 +185,9 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert tool.params_json_schema[key] == value
assert tool.strict_json_schema

result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
result = await tool.on_invoke_tool(
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
)
assert result == "hello_done"

tool_not_strict = FunctionTool(
Expand All @@ -184,7 +202,8 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
assert "additionalProperties" not in tool_not_strict.params_json_schema

result = await tool_not_strict.on_invoke_tool(
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
'{"data": "hello", "bar": "baz"}',
)
assert result == "hello_done"

Expand All @@ -195,7 +214,7 @@ def my_func(a: int, b: int = 5):
raise ValueError("test")

tool = function_tool(my_func)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert "Invalid JSON" in str(result)
Expand All @@ -219,7 +238,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand All @@ -243,7 +262,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
return f"error_{error.__class__.__name__}"

tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = ToolContext(None, tool_call_id="1")
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")

result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_function_tool_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):


def ctx_wrapper() -> ToolContext[DummyContext]:
return ToolContext(context=DummyContext(), tool_call_id="1")
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")


@function_tool
Expand Down