Skip to content

Commit b4659e8

Browse files
committed
Added support for passing tool_call_id via the RunContextWrapper
1 parent 6e078bf commit b4659e8

File tree

4 files changed

+53
-7
lines changed

4 files changed

+53
-7
lines changed

src/agents/_run_impl.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,23 +539,26 @@ async def run_single_tool(
539539
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
540540
) -> Any:
541541
with function_span(func_tool.name) as span_fn:
542+
tool_context_wrapper = dataclasses.replace(
543+
context_wrapper, tool_call_id=tool_call.call_id
544+
)
542545
if config.trace_include_sensitive_data:
543546
span_fn.span_data.input = tool_call.arguments
544547
try:
545548
_, _, result = await asyncio.gather(
546-
hooks.on_tool_start(context_wrapper, agent, func_tool),
549+
hooks.on_tool_start(tool_context_wrapper, agent, func_tool),
547550
(
548-
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
551+
agent.hooks.on_tool_start(tool_context_wrapper, agent, func_tool)
549552
if agent.hooks
550553
else _coro.noop_coroutine()
551554
),
552-
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
555+
func_tool.on_invoke_tool(tool_context_wrapper, tool_call.arguments),
553556
)
554557

555558
await asyncio.gather(
556-
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
559+
hooks.on_tool_end(tool_context_wrapper, agent, func_tool, result),
557560
(
558-
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
561+
agent.hooks.on_tool_end(tool_context_wrapper, agent, func_tool, result)
559562
if agent.hooks
560563
else _coro.noop_coroutine()
561564
),

src/agents/run_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,6 @@ class RunContextWrapper(Generic[TContext]):
2424
"""The usage of the agent run so far. For streamed responses, the usage will be stale until the
2525
last chunk of the stream is processed.
2626
"""
27+
28+
tool_call_id: str | None = None
29+
"""The ID of the tool call for the current tool execution."""

tests/test_responses.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def _foo() -> str:
4949
)
5050

5151

52-
def get_function_tool_call(name: str, arguments: str | None = None) -> ResponseOutputItem:
52+
def get_function_tool_call(
53+
name: str, arguments: str | None = None, call_id: str | None = None
54+
) -> ResponseOutputItem:
5355
return ResponseFunctionToolCall(
5456
id="1",
55-
call_id="2",
57+
call_id=call_id or "2",
5658
type="function_call",
5759
name=name,
5860
arguments=arguments or "",

tests/test_run_step_execution.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from typing import Any
45

56
import pytest
@@ -26,6 +27,7 @@
2627
RunImpl,
2728
SingleStepResult,
2829
)
30+
from agents.tool import function_tool
2931

3032
from .test_responses import (
3133
get_final_output_message,
@@ -158,6 +160,42 @@ async def test_multiple_tool_calls():
158160
assert isinstance(result.next_step, NextStepRunAgain)
159161

160162

163+
@pytest.mark.asyncio
164+
async def test_multiple_tool_calls_with_tool_context():
165+
async def _fake_tool(agent_context: RunContextWrapper[str], value: str) -> str:
166+
return f"{value}-{agent_context.tool_call_id}"
167+
168+
tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None)
169+
170+
agent = Agent(
171+
name="test",
172+
tools=[tool],
173+
)
174+
response = ModelResponse(
175+
output=[
176+
get_function_tool_call("fake_tool", json.dumps({"value": "123"}), call_id="1"),
177+
get_function_tool_call("fake_tool", json.dumps({"value": "456"}), call_id="2"),
178+
],
179+
usage=Usage(),
180+
response_id=None,
181+
)
182+
183+
result = await get_execute_result(agent, response)
184+
assert result.original_input == "hello"
185+
186+
# 4 items: new message, 2 tool calls, 2 tool call outputs
187+
assert len(result.generated_items) == 4
188+
assert isinstance(result.next_step, NextStepRunAgain)
189+
190+
items = result.generated_items
191+
assert_item_is_function_tool_call(items[0], "fake_tool", json.dumps({"value": "123"}))
192+
assert_item_is_function_tool_call(items[1], "fake_tool", json.dumps({"value": "456"}))
193+
assert_item_is_function_tool_call_output(items[2], "123-1")
194+
assert_item_is_function_tool_call_output(items[3], "456-2")
195+
196+
assert isinstance(result.next_step, NextStepRunAgain)
197+
198+
161199
@pytest.mark.asyncio
162200
async def test_handoff_output_leads_to_handoff_next_step():
163201
agent_1 = Agent(name="test_1")

0 commit comments

Comments
 (0)