Skip to content

Fix: Emit tool_called events immediately in streaming runs #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
69 changes: 63 additions & 6 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Generic, cast
from typing import Any, Generic, cast, get_args

from openai.types.responses import ResponseCompletedEvent
from openai.types.responses import (
ResponseCompletedEvent,
ResponseOutputItemAddedEvent,
)
from openai.types.responses.response_prompt_param import (
ResponsePromptParam,
)
Expand Down Expand Up @@ -41,7 +44,14 @@
OutputGuardrailResult,
)
from .handoffs import Handoff, HandoffInputFilter, handoff
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .items import (
ItemHelpers,
ModelResponse,
RunItem,
ToolCallItem,
ToolCallItemTypes,
TResponseInputItem,
)
from .lifecycle import RunHooks
from .logger import logger
from .memory import Session
Expand All @@ -50,7 +60,7 @@
from .models.multi_provider import MultiProvider
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent
from .tool import Tool
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
from .tracing.span_data import AgentSpanData
Expand Down Expand Up @@ -833,6 +843,8 @@ async def _run_single_turn_streamed(
all_tools: list[Tool],
previous_response_id: str | None,
) -> SingleStepResult:
emitted_tool_call_ids: set[str] = set()

if should_run_agent_start_hooks:
await asyncio.gather(
hooks.on_agent_start(context_wrapper, agent),
Expand Down Expand Up @@ -897,9 +909,27 @@ async def _run_single_turn_streamed(
)
context_wrapper.usage.add(usage)

if isinstance(event, ResponseOutputItemAddedEvent):
output_item = event.item

if isinstance(output_item, _TOOL_CALL_TYPES):
call_id: str | None = getattr(
output_item, "call_id", getattr(output_item, "id", None)
)

if call_id and call_id not in emitted_tool_call_ids:
emitted_tool_call_ids.add(call_id)

tool_item = ToolCallItem(
raw_item=cast(ToolCallItemTypes, output_item),
agent=agent,
)
streamed_result._event_queue.put_nowait(
RunItemStreamEvent(item=tool_item, name="tool_called")
)

streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

# 2. At this point, the streaming is complete for this turn of the agent loop.
if not final_response:
raise ModelBehaviorError("Model did not produce a final response!")

Expand All @@ -918,7 +948,32 @@ async def _run_single_turn_streamed(
tool_use_tracker=tool_use_tracker,
)

RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
if emitted_tool_call_ids:
import dataclasses as _dc

filtered_items = [
item
for item in single_step_result.new_step_items
if not (
isinstance(item, ToolCallItem)
and (
call_id := getattr(
item.raw_item, "call_id", getattr(item.raw_item, "id", None)
)
)
and call_id in emitted_tool_call_ids
)
]

single_step_result_filtered = _dc.replace(
single_step_result, new_step_items=filtered_items
)

RunImpl.stream_step_result_to_queue(
single_step_result_filtered, streamed_result._event_queue
)
else:
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
return single_step_result

@classmethod
Expand Down Expand Up @@ -1240,3 +1295,5 @@ async def _save_result_to_session(


DEFAULT_AGENT_RUNNER = AgentRunner()

_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes)