Skip to content

Commit d43fd3a

Browse files
committed
make tool_arguments required field in ToolContext for RunHooks
1 parent 613476f commit d43fd3a

File tree

5 files changed

+61
-20
lines changed

5 files changed

+61
-20
lines changed

src/agents/realtime/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
408408
usage=self._context_wrapper.usage,
409409
tool_name=event.name,
410410
tool_call_id=event.call_id,
411+
tool_arguments=event.arguments,
411412
)
412413
result = await func_tool.on_invoke_tool(tool_context, event.arguments)
413414

@@ -432,6 +433,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
432433
usage=self._context_wrapper.usage,
433434
tool_name=event.name,
434435
tool_call_id=event.call_id,
436+
tool_arguments=event.arguments,
435437
)
436438

437439
# Execute the handoff to get the new agent

src/agents/tool_context.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def _assert_must_pass_tool_name() -> str:
1414
raise ValueError("tool_name must be passed to ToolContext")
1515

1616

17+
def _assert_must_pass_tool_arguments() -> str:
18+
raise ValueError("tool_arguments must be passed to ToolContext")
19+
20+
1721
@dataclass
1822
class ToolContext(RunContextWrapper[TContext]):
1923
"""The context of a tool call."""
@@ -24,7 +28,7 @@ class ToolContext(RunContextWrapper[TContext]):
2428
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
2529
"""The ID of the tool call."""
2630

27-
tool_arguments: Optional[str] = None
31+
tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments)
2832
"""The raw arguments string of the tool call."""
2933

3034
@classmethod
@@ -42,7 +46,9 @@ def from_agent_context(
4246
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
4347
}
4448
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
45-
tool_args = tool_call.arguments if tool_call is not None else None
49+
tool_args = (
50+
tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments()
51+
)
4652

4753
return cls(
4854
tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values

tests/test_agent_as_tool.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,12 @@ async def fake_run(
277277
)
278278

279279
assert isinstance(tool, FunctionTool)
280-
tool_context = ToolContext(context=None, tool_name="story_tool", tool_call_id="call_1")
280+
tool_context = ToolContext(
281+
context=None,
282+
tool_name="story_tool",
283+
tool_call_id="call_1",
284+
tool_arguments='{"input": "hello"}',
285+
)
281286
output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')
282287

283288
assert output == "Hello world"
@@ -374,7 +379,12 @@ async def extractor(result) -> str:
374379
)
375380

376381
assert isinstance(tool, FunctionTool)
377-
tool_context = ToolContext(context=None, tool_name="summary_tool", tool_call_id="call_2")
382+
tool_context = ToolContext(
383+
context=None,
384+
tool_name="summary_tool",
385+
tool_call_id="call_2",
386+
tool_arguments='{"input": "summarize this"}',
387+
)
378388
output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}')
379389

380390
assert output == "custom output"

tests/test_function_tool.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def test_argless_function():
2727
assert tool.name == "argless_function"
2828

2929
result = await tool.on_invoke_tool(
30-
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
30+
ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
3131
)
3232
assert result == "ok"
3333

@@ -41,12 +41,15 @@ async def test_argless_with_context():
4141
tool = function_tool(argless_with_context)
4242
assert tool.name == "argless_with_context"
4343

44-
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
44+
result = await tool.on_invoke_tool(
45+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
46+
)
4547
assert result == "ok"
4648

4749
# Extra JSON should not raise an error
4850
result = await tool.on_invoke_tool(
49-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
51+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
52+
'{"a": 1}',
5053
)
5154
assert result == "ok"
5255

@@ -61,18 +64,22 @@ async def test_simple_function():
6164
assert tool.name == "simple_function"
6265

6366
result = await tool.on_invoke_tool(
64-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
67+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
68+
'{"a": 1}',
6569
)
6670
assert result == 6
6771

6872
result = await tool.on_invoke_tool(
69-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
73+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'),
74+
'{"a": 1, "b": 2}',
7075
)
7176
assert result == 3
7277

7378
# Missing required argument should raise an error
7479
with pytest.raises(ModelBehaviorError):
75-
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
80+
await tool.on_invoke_tool(
81+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
82+
)
7683

7784

7885
class Foo(BaseModel):
@@ -101,7 +108,8 @@ async def test_complex_args_function():
101108
}
102109
)
103110
result = await tool.on_invoke_tool(
104-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
111+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
112+
valid_json,
105113
)
106114
assert result == "6 hello10 hello"
107115

@@ -112,7 +120,8 @@ async def test_complex_args_function():
112120
}
113121
)
114122
result = await tool.on_invoke_tool(
115-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
123+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
124+
valid_json,
116125
)
117126
assert result == "3 hello10 hello"
118127

@@ -124,14 +133,18 @@ async def test_complex_args_function():
124133
}
125134
)
126135
result = await tool.on_invoke_tool(
127-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
136+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
137+
valid_json,
128138
)
129139
assert result == "3 hello10 world"
130140

131141
# Missing required argument should raise an error
132142
with pytest.raises(ModelBehaviorError):
133143
await tool.on_invoke_tool(
134-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
144+
ToolContext(
145+
None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"foo": {"a": 1}}'
146+
),
147+
'{"foo": {"a": 1}}',
135148
)
136149

137150

@@ -193,7 +206,10 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
193206
assert tool.strict_json_schema
194207

195208
result = await tool.on_invoke_tool(
196-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
209+
ToolContext(
210+
None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"data": "hello"}'
211+
),
212+
'{"data": "hello"}',
197213
)
198214
assert result == "hello_done"
199215

@@ -209,7 +225,12 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
209225
assert "additionalProperties" not in tool_not_strict.params_json_schema
210226

211227
result = await tool_not_strict.on_invoke_tool(
212-
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
228+
ToolContext(
229+
None,
230+
tool_name=tool_not_strict.name,
231+
tool_call_id="1",
232+
tool_arguments='{"data": "hello", "bar": "baz"}',
233+
),
213234
'{"data": "hello", "bar": "baz"}',
214235
)
215236
assert result == "hello_done"
@@ -221,7 +242,7 @@ def my_func(a: int, b: int = 5):
221242
raise ValueError("test")
222243

223244
tool = function_tool(my_func)
224-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
245+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")
225246

226247
result = await tool.on_invoke_tool(ctx, "")
227248
assert "Invalid JSON" in str(result)
@@ -245,7 +266,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
245266
return f"error_{error.__class__.__name__}"
246267

247268
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
248-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
269+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")
249270

250271
result = await tool.on_invoke_tool(ctx, "")
251272
assert result == "error_ModelBehaviorError"
@@ -269,7 +290,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
269290
return f"error_{error.__class__.__name__}"
270291

271292
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
272-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
293+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")
273294

274295
result = await tool.on_invoke_tool(ctx, "")
275296
assert result == "error_ModelBehaviorError"

tests/test_function_tool_decorator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def __init__(self):
1616

1717

1818
def ctx_wrapper() -> ToolContext[DummyContext]:
19-
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")
19+
return ToolContext(
20+
context=DummyContext(), tool_name="dummy", tool_call_id="1", tool_arguments=""
21+
)
2022

2123

2224
@function_tool

0 commit comments

Comments
 (0)