Skip to content

Commit 84f55ba

Browse files
committed
emit tool call output items immediately
1 parent 114b320 commit 84f55ba

File tree

1 file changed

+74
-4
lines changed

1 file changed

+74
-4
lines changed

src/agents/run.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,21 @@
66
from dataclasses import dataclass, field
77
from typing import Any, Generic, cast
88

9-
from openai.types.responses import ResponseCompletedEvent
9+
from openai.types.responses import (
10+
ResponseCompletedEvent,
11+
ResponseComputerToolCall,
12+
ResponseFileSearchToolCall,
13+
ResponseFunctionToolCall,
14+
ResponseOutputItemAddedEvent,
15+
)
16+
from openai.types.responses.response_code_interpreter_tool_call import (
17+
ResponseCodeInterpreterToolCall,
18+
)
19+
from openai.types.responses.response_output_item import (
20+
ImageGenerationCall,
21+
LocalShellCall,
22+
McpCall,
23+
)
1024
from openai.types.responses.response_prompt_param import (
1125
ResponsePromptParam,
1226
)
@@ -41,7 +55,7 @@
4155
OutputGuardrailResult,
4256
)
4357
from .handoffs import Handoff, HandoffInputFilter, handoff
44-
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
58+
from .items import ItemHelpers, ModelResponse, RunItem, ToolCallItem, TResponseInputItem
4559
from .lifecycle import RunHooks
4660
from .logger import logger
4761
from .memory import Session
@@ -50,7 +64,7 @@
5064
from .models.multi_provider import MultiProvider
5165
from .result import RunResult, RunResultStreaming
5266
from .run_context import RunContextWrapper, TContext
53-
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
67+
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent
5468
from .tool import Tool
5569
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
5670
from .tracing.span_data import AgentSpanData
@@ -833,6 +847,10 @@ async def _run_single_turn_streamed(
833847
all_tools: list[Tool],
834848
previous_response_id: str | None,
835849
) -> SingleStepResult:
850+
# Track tool call IDs we've already emitted to avoid duplicates when we later
851+
# enqueue all items at the end of the turn.
852+
emitted_tool_call_ids: set[str] = set()
853+
836854
if should_run_agent_start_hooks:
837855
await asyncio.gather(
838856
hooks.on_agent_start(context_wrapper, agent),
@@ -877,6 +895,8 @@ async def _run_single_turn_streamed(
877895
previous_response_id=previous_response_id,
878896
prompt=prompt_config,
879897
):
898+
# 1. If the event signals the end of the assistant response, remember it so we can
899+
# process the full response after the streaming loop.
880900
if isinstance(event, ResponseCompletedEvent):
881901
usage = (
882902
Usage(
@@ -897,6 +917,34 @@ async def _run_single_turn_streamed(
897917
)
898918
context_wrapper.usage.add(usage)
899919

920+
# 2. Detect tool call output-item additions **while** the model is still streaming.
921+
# Emit a high-level RunItemStreamEvent so UIs can react immediately.
922+
if isinstance(event, ResponseOutputItemAddedEvent):
923+
output_item = event.item
924+
925+
if isinstance(
926+
output_item,
927+
(
928+
ResponseFunctionToolCall,
929+
ResponseFileSearchToolCall,
930+
ResponseComputerToolCall,
931+
ResponseCodeInterpreterToolCall,
932+
ImageGenerationCall,
933+
LocalShellCall,
934+
McpCall,
935+
),
936+
):
937+
call_id = getattr(output_item, "call_id", getattr(output_item, "id", None))
938+
939+
if call_id not in emitted_tool_call_ids:
940+
emitted_tool_call_ids.add(call_id)
941+
942+
tool_item = ToolCallItem(raw_item=output_item, agent=agent)
943+
streamed_result._event_queue.put_nowait(
944+
RunItemStreamEvent(item=tool_item, name="tool_called")
945+
)
946+
947+
# Always forward the raw event.
900948
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
901949

902950
# 2. At this point, the streaming is complete for this turn of the agent loop.
@@ -918,7 +966,29 @@ async def _run_single_turn_streamed(
918966
tool_use_tracker=tool_use_tracker,
919967
)
920968

921-
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
969+
# Remove tool_called items we've already emitted during streaming to avoid duplicates.
970+
if emitted_tool_call_ids:
971+
import dataclasses as _dc # local import to avoid polluting module namespace
972+
973+
filtered_items = [
974+
item
975+
for item in single_step_result.new_step_items
976+
if not (
977+
isinstance(item, ToolCallItem)
978+
and getattr(item.raw_item, "call_id", getattr(item.raw_item, "id", None))
979+
in emitted_tool_call_ids
980+
)
981+
]
982+
983+
single_step_result_filtered = _dc.replace(
984+
single_step_result, new_step_items=filtered_items
985+
)
986+
987+
RunImpl.stream_step_result_to_queue(
988+
single_step_result_filtered, streamed_result._event_queue
989+
)
990+
else:
991+
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
922992
return single_step_result
923993

924994
@classmethod

0 commit comments

Comments
 (0)