Skip to content

Commit f62a12b

Browse files
committed
Added ToolContext object to hold the tool_call_id
1 parent b4659e8 commit f62a12b

File tree

8 files changed

+73
-41
lines changed

8 files changed

+73
-41
lines changed

src/agents/_run_impl.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
MCPToolApprovalRequest,
7575
Tool,
7676
)
77+
from .tool_context import ToolContext
7778
from .tracing import (
7879
SpanError,
7980
Trace,
@@ -539,26 +540,24 @@ async def run_single_tool(
539540
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
540541
) -> Any:
541542
with function_span(func_tool.name) as span_fn:
542-
tool_context_wrapper = dataclasses.replace(
543-
context_wrapper, tool_call_id=tool_call.call_id
544-
)
543+
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
545544
if config.trace_include_sensitive_data:
546545
span_fn.span_data.input = tool_call.arguments
547546
try:
548547
_, _, result = await asyncio.gather(
549-
hooks.on_tool_start(tool_context_wrapper, agent, func_tool),
548+
hooks.on_tool_start(tool_context, agent, func_tool),
550549
(
551-
agent.hooks.on_tool_start(tool_context_wrapper, agent, func_tool)
550+
agent.hooks.on_tool_start(tool_context, agent, func_tool)
552551
if agent.hooks
553552
else _coro.noop_coroutine()
554553
),
555-
func_tool.on_invoke_tool(tool_context_wrapper, tool_call.arguments),
554+
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
556555
)
557556

558557
await asyncio.gather(
559-
hooks.on_tool_end(tool_context_wrapper, agent, func_tool, result),
558+
hooks.on_tool_end(tool_context, agent, func_tool, result),
560559
(
561-
agent.hooks.on_tool_end(tool_context_wrapper, agent, func_tool, result)
560+
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
562561
if agent.hooks
563562
else _coro.noop_coroutine()
564563
),

src/agents/function_schema.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .exceptions import UserError
1414
from .run_context import RunContextWrapper
1515
from .strict_schema import ensure_strict_json_schema
16+
from .tool_context import ToolContext
1617

1718

1819
@dataclass
@@ -237,21 +238,21 @@ def function_schema(
237238
ann = type_hints.get(first_name, first_param.annotation)
238239
if ann != inspect._empty:
239240
origin = get_origin(ann) or ann
240-
if origin is RunContextWrapper:
241+
if origin is RunContextWrapper or origin is ToolContext:
241242
takes_context = True # Mark that the function takes context
242243
else:
243244
filtered_params.append((first_name, first_param))
244245
else:
245246
filtered_params.append((first_name, first_param))
246247

247-
# For parameters other than the first, raise error if any use RunContextWrapper.
248+
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
248249
for name, param in params[1:]:
249250
ann = type_hints.get(name, param.annotation)
250251
if ann != inspect._empty:
251252
origin = get_origin(ann) or ann
252-
if origin is RunContextWrapper:
253+
if origin is RunContextWrapper or origin is ToolContext:
253254
raise UserError(
254-
f"RunContextWrapper param found at non-first position in function"
255+
f"RunContextWrapper/ToolContext param found at non-first position in function"
255256
f" {func.__name__}"
256257
)
257258
filtered_params.append((name, param))

src/agents/run_context.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,3 @@ class RunContextWrapper(Generic[TContext]):
2424
"""The usage of the agent run so far. For streamed responses, the usage will be stale until the
2525
last chunk of the stream is processed.
2626
"""
27-
28-
tool_call_id: str | None = None
29-
"""The ID of the tool call for the current tool execution."""

src/agents/tool.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .items import RunItem
2121
from .logger import logger
2222
from .run_context import RunContextWrapper
23+
from .tool_context import ToolContext
2324
from .tracing import SpanError
2425
from .util import _error_tracing
2526
from .util._types import MaybeAwaitable
@@ -28,8 +29,13 @@
2829

2930
ToolFunctionWithoutContext = Callable[ToolParams, Any]
3031
ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]
32+
ToolFunctionWithToolContext = Callable[Concatenate[ToolContext, ToolParams], Any]
3133

32-
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
34+
ToolFunction = Union[
35+
ToolFunctionWithoutContext[ToolParams],
36+
ToolFunctionWithContext[ToolParams],
37+
ToolFunctionWithToolContext[ToolParams],
38+
]
3339

3440

3541
@dataclass
@@ -59,7 +65,7 @@ class FunctionTool:
5965
params_json_schema: dict[str, Any]
6066
"""The JSON schema for the tool's parameters."""
6167

62-
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
68+
on_invoke_tool: Callable[[ToolContext[Any], str], Awaitable[Any]]
6369
"""A function that invokes the tool with the given context and parameters. The params passed
6470
are:
6571
1. The tool run context.
@@ -330,7 +336,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
330336
strict_json_schema=strict_mode,
331337
)
332338

333-
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
339+
async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
334340
try:
335341
json_data: dict[str, Any] = json.loads(input) if input else {}
336342
except Exception as e:
@@ -379,7 +385,7 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
379385

380386
return result
381387

382-
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
388+
async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
383389
try:
384390
return await _on_invoke_tool_impl(ctx, input)
385391
except Exception as e:

src/agents/tool_context.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import KW_ONLY, dataclass, fields
2+
from typing import Any
3+
4+
from .run_context import RunContextWrapper, TContext
5+
6+
7+
@dataclass
8+
class ToolContext(RunContextWrapper[TContext]):
9+
"""The context of a tool call."""
10+
11+
_: KW_ONLY
12+
tool_call_id: str
13+
"""The ID of the tool call."""
14+
15+
@classmethod
16+
def from_agent_context(
17+
cls, context: RunContextWrapper[TContext], tool_call_id: str
18+
) -> "ToolContext":
19+
"""
20+
Create a ToolContext from a RunContextWrapper.
21+
"""
22+
# Grab the names of the RunContextWrapper's init=True fields
23+
base_values: dict[str, Any] = {
24+
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
25+
}
26+
return cls(tool_call_id=tool_call_id, **base_values)

tests/test_function_tool.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
99
from agents.tool import default_tool_error_function
10+
from agents.tool_context import ToolContext
1011

1112

1213
def argless_function() -> str:
@@ -18,11 +19,11 @@ async def test_argless_function():
1819
tool = function_tool(argless_function)
1920
assert tool.name == "argless_function"
2021

21-
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
22+
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
2223
assert result == "ok"
2324

2425

25-
def argless_with_context(ctx: RunContextWrapper[str]) -> str:
26+
def argless_with_context(ctx: ToolContext[str]) -> str:
2627
return "ok"
2728

2829

@@ -31,11 +32,11 @@ async def test_argless_with_context():
3132
tool = function_tool(argless_with_context)
3233
assert tool.name == "argless_with_context"
3334

34-
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
35+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
3536
assert result == "ok"
3637

3738
# Extra JSON should not raise an error
38-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
39+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
3940
assert result == "ok"
4041

4142

@@ -48,15 +49,15 @@ async def test_simple_function():
4849
tool = function_tool(simple_function, failure_error_function=None)
4950
assert tool.name == "simple_function"
5051

51-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
52+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
5253
assert result == 6
5354

54-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
55+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
5556
assert result == 3
5657

5758
# Missing required argument should raise an error
5859
with pytest.raises(ModelBehaviorError):
59-
await tool.on_invoke_tool(RunContextWrapper(None), "")
60+
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
6061

6162

6263
class Foo(BaseModel):
@@ -84,7 +85,7 @@ async def test_complex_args_function():
8485
"bar": Bar(x="hello", y=10),
8586
}
8687
)
87-
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
88+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
8889
assert result == "6 hello10 hello"
8990

9091
valid_json = json.dumps(
@@ -93,7 +94,7 @@ async def test_complex_args_function():
9394
"bar": Bar(x="hello", y=10),
9495
}
9596
)
96-
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
97+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
9798
assert result == "3 hello10 hello"
9899

99100
valid_json = json.dumps(
@@ -103,12 +104,12 @@ async def test_complex_args_function():
103104
"baz": "world",
104105
}
105106
)
106-
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
107+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
107108
assert result == "3 hello10 world"
108109

109110
# Missing required argument should raise an error
110111
with pytest.raises(ModelBehaviorError):
111-
await tool.on_invoke_tool(RunContextWrapper(None), '{"foo": {"a": 1}}')
112+
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
112113

113114

114115
def test_function_config_overrides():
@@ -168,7 +169,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
168169
assert tool.params_json_schema[key] == value
169170
assert tool.strict_json_schema
170171

171-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"data": "hello"}')
172+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
172173
assert result == "hello_done"
173174

174175
tool_not_strict = FunctionTool(
@@ -183,7 +184,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
183184
assert "additionalProperties" not in tool_not_strict.params_json_schema
184185

185186
result = await tool_not_strict.on_invoke_tool(
186-
RunContextWrapper(None), '{"data": "hello", "bar": "baz"}'
187+
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
187188
)
188189
assert result == "hello_done"
189190

@@ -194,7 +195,7 @@ def my_func(a: int, b: int = 5):
194195
raise ValueError("test")
195196

196197
tool = function_tool(my_func)
197-
ctx = RunContextWrapper(None)
198+
ctx = ToolContext(None, tool_call_id="1")
198199

199200
result = await tool.on_invoke_tool(ctx, "")
200201
assert "Invalid JSON" in str(result)
@@ -218,7 +219,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
218219
return f"error_{error.__class__.__name__}"
219220

220221
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
221-
ctx = RunContextWrapper(None)
222+
ctx = ToolContext(None, tool_call_id="1")
222223

223224
result = await tool.on_invoke_tool(ctx, "")
224225
assert result == "error_ModelBehaviorError"
@@ -242,7 +243,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
242243
return f"error_{error.__class__.__name__}"
243244

244245
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
245-
ctx = RunContextWrapper(None)
246+
ctx = ToolContext(None, tool_call_id="1")
246247

247248
result = await tool.on_invoke_tool(ctx, "")
248249
assert result == "error_ModelBehaviorError"

tests/test_function_tool_decorator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77

88
from agents import function_tool
99
from agents.run_context import RunContextWrapper
10+
from agents.tool_context import ToolContext
1011

1112

1213
class DummyContext:
1314
def __init__(self):
1415
self.data = "something"
1516

1617

17-
def ctx_wrapper() -> RunContextWrapper[DummyContext]:
18-
return RunContextWrapper(DummyContext())
18+
def ctx_wrapper() -> ToolContext[DummyContext]:
19+
return ToolContext(context=DummyContext(), tool_call_id="1")
1920

2021

2122
@function_tool
@@ -44,7 +45,7 @@ async def test_sync_no_context_with_args_invocation():
4445

4546

4647
@function_tool
47-
def sync_with_context(ctx: RunContextWrapper[DummyContext], name: str) -> str:
48+
def sync_with_context(ctx: ToolContext[DummyContext], name: str) -> str:
4849
return f"{name}_{ctx.context.data}"
4950

5051

@@ -71,7 +72,7 @@ async def test_async_no_context_invocation():
7172

7273

7374
@function_tool
74-
async def async_with_context(ctx: RunContextWrapper[DummyContext], prefix: str, num: int) -> str:
75+
async def async_with_context(ctx: ToolContext[DummyContext], prefix: str, num: int) -> str:
7576
await asyncio.sleep(0)
7677
return f"{prefix}-{num}-{ctx.context.data}"
7778

tests/test_run_step_execution.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SingleStepResult,
2929
)
3030
from agents.tool import function_tool
31+
from agents.tool_context import ToolContext
3132

3233
from .test_responses import (
3334
get_final_output_message,
@@ -162,8 +163,8 @@ async def test_multiple_tool_calls():
162163

163164
@pytest.mark.asyncio
164165
async def test_multiple_tool_calls_with_tool_context():
165-
async def _fake_tool(agent_context: RunContextWrapper[str], value: str) -> str:
166-
return f"{value}-{agent_context.tool_call_id}"
166+
async def _fake_tool(context: ToolContext[str], value: str) -> str:
167+
return f"{value}-{context.tool_call_id}"
167168

168169
tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None)
169170

0 commit comments

Comments
 (0)