Skip to content

Fix: Emit tool_called events immediately in streaming runs #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
11 changes: 4 additions & 7 deletions examples/basic/stream_function_call_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def main():

result = Runner.run_streamed(
agent,
input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn"
input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn",
)

# Track function calls for detailed output
Expand All @@ -50,23 +50,20 @@ async def main():
function_name = getattr(event.data.item, "name", "unknown")
call_id = getattr(event.data.item, "call_id", "unknown")

function_calls[call_id] = {
'name': function_name,
'arguments': ""
}
function_calls[call_id] = {"name": function_name, "arguments": ""}
current_active_call_id = call_id
print(f"\n📞 Function call streaming started: {function_name}()")
print("📝 Arguments building...")

# Real-time argument streaming
elif isinstance(event.data, ResponseFunctionCallArgumentsDeltaEvent):
if current_active_call_id and current_active_call_id in function_calls:
function_calls[current_active_call_id]['arguments'] += event.data.delta
function_calls[current_active_call_id]["arguments"] += event.data.delta
print(event.data.delta, end="", flush=True)

# Function call completed
elif event.data.type == "response.output_item.done":
if hasattr(event.data.item, 'call_id'):
if hasattr(event.data.item, "call_id"):
call_id = getattr(event.data.item, "call_id", "unknown")
if call_id in function_calls:
function_info = function_calls[call_id]
Expand Down
10 changes: 8 additions & 2 deletions examples/customer_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ class AirlineAgentContext(BaseModel):
)
async def faq_lookup_tool(question: str) -> str:
question_lower = question.lower()
if any(keyword in question_lower for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"]):
if any(
keyword in question_lower
for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"]
):
return (
"You are allowed to bring one bag on the plane. "
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
Expand All @@ -52,7 +55,10 @@ async def faq_lookup_tool(question: str) -> str:
"Exit rows are rows 4 and 16. "
"Rows 5-8 are Economy Plus, with extra legroom. "
)
elif any(keyword in question_lower for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"]):
elif any(
keyword in question_lower
for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"]
):
return "We have free wifi on the plane, join Airline-Wifi"
return "I'm sorry, I don't know the answer to that question."

Expand Down
3 changes: 2 additions & 1 deletion examples/realtime/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import logging
import struct
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, assert_never
from typing import TYPE_CHECKING, Any

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from typing_extensions import assert_never

from agents.realtime import RealtimeRunner, RealtimeSession, RealtimeSessionEvent

Expand Down
6 changes: 3 additions & 3 deletions src/agents/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ class Handoff(Generic[TContext, TAgent]):
True, as it increases the likelihood of correct JSON input.
"""

is_enabled: bool | Callable[
[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]
] = True
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
True
)
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
a handoff based on your context/state."""
Expand Down
1 change: 0 additions & 1 deletion src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class MCPToolChoice:
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]



@dataclass
class ModelSettings:
"""Settings to use when calling an LLM.
Expand Down
126 changes: 68 additions & 58 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Generic, cast
from typing import Any, Generic, cast, get_args

from openai.types.responses import ResponseCompletedEvent
from openai.types.responses import (
ResponseCompletedEvent,
ResponseOutputItemAddedEvent,
)
from openai.types.responses.response_prompt_param import (
ResponsePromptParam,
)
Expand Down Expand Up @@ -41,7 +44,14 @@
OutputGuardrailResult,
)
from .handoffs import Handoff, HandoffInputFilter, handoff
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .items import (
ItemHelpers,
ModelResponse,
RunItem,
ToolCallItem,
ToolCallItemTypes,
TResponseInputItem,
)
from .lifecycle import RunHooks
from .logger import logger
from .memory import Session
Expand All @@ -50,7 +60,7 @@
from .models.multi_provider import MultiProvider
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent
from .tool import Tool
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
from .tracing.span_data import AgentSpanData
Expand Down Expand Up @@ -833,6 +843,8 @@ async def _run_single_turn_streamed(
all_tools: list[Tool],
previous_response_id: str | None,
) -> SingleStepResult:
emitted_tool_call_ids: set[str] = set()

if should_run_agent_start_hooks:
await asyncio.gather(
hooks.on_agent_start(context_wrapper, agent),
Expand Down Expand Up @@ -897,16 +909,35 @@ async def _run_single_turn_streamed(
)
context_wrapper.usage.add(usage)

if isinstance(event, ResponseOutputItemAddedEvent):
output_item = event.item

if isinstance(output_item, _TOOL_CALL_TYPES):
call_id: str | None = getattr(
output_item, "call_id", getattr(output_item, "id", None)
)

if call_id and call_id not in emitted_tool_call_ids:
emitted_tool_call_ids.add(call_id)

tool_item = ToolCallItem(
raw_item=cast(ToolCallItemTypes, output_item),
agent=agent,
)
streamed_result._event_queue.put_nowait(
RunItemStreamEvent(item=tool_item, name="tool_called")
)

streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))

# 2. At this point, the streaming is complete for this turn of the agent loop.
if not final_response:
raise ModelBehaviorError("Model did not produce a final response!")

# 3. Now, we can process the turn as we do in the non-streaming case
return await cls._get_single_step_result_from_streamed_response(
single_step_result = await cls._get_single_step_result_from_response(
agent=agent,
streamed_result=streamed_result,
original_input=streamed_result.input,
pre_step_items=streamed_result.new_items,
new_response=final_response,
output_schema=output_schema,
all_tools=all_tools,
Expand All @@ -917,6 +948,34 @@ async def _run_single_turn_streamed(
tool_use_tracker=tool_use_tracker,
)

if emitted_tool_call_ids:
import dataclasses as _dc

filtered_items = [
item
for item in single_step_result.new_step_items
if not (
isinstance(item, ToolCallItem)
and (
call_id := getattr(
item.raw_item, "call_id", getattr(item.raw_item, "id", None)
)
)
and call_id in emitted_tool_call_ids
)
]

single_step_result_filtered = _dc.replace(
single_step_result, new_step_items=filtered_items
)

RunImpl.stream_step_result_to_queue(
single_step_result_filtered, streamed_result._event_queue
)
else:
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
return single_step_result

@classmethod
async def _run_single_turn(
cls,
Expand Down Expand Up @@ -1019,57 +1078,6 @@ async def _get_single_step_result_from_response(
run_config=run_config,
)

@classmethod
async def _get_single_step_result_from_streamed_response(
cls,
*,
agent: Agent[TContext],
all_tools: list[Tool],
streamed_result: RunResultStreaming,
new_response: ModelResponse,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:

original_input = streamed_result.input
pre_step_items = streamed_result.new_items
event_queue = streamed_result._event_queue

processed_response = RunImpl.process_model_response(
agent=agent,
all_tools=all_tools,
response=new_response,
output_schema=output_schema,
handoffs=handoffs,
)
new_items_processed_response = processed_response.new_items
tool_use_tracker.add_tool_use(agent, processed_response.tools_used)
RunImpl.stream_step_items_to_queue(new_items_processed_response, event_queue)

single_step_result = await RunImpl.execute_tools_and_side_effects(
agent=agent,
original_input=original_input,
pre_step_items=pre_step_items,
new_response=new_response,
processed_response=processed_response,
output_schema=output_schema,
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
)
new_step_items = [
item
for item in single_step_result.new_step_items
if item not in new_items_processed_response
]
RunImpl.stream_step_items_to_queue(new_step_items, event_queue)

return single_step_result

@classmethod
async def _run_input_guardrails(
cls,
Expand Down Expand Up @@ -1287,3 +1295,5 @@ async def _save_result_to_session(


DEFAULT_AGENT_RUNNER = AgentRunner()

_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes)
4 changes: 2 additions & 2 deletions src/agents/tracing/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def set_api_key(self, api_key: str):
client.
"""
# Clear the cached property if it exists
if 'api_key' in self.__dict__:
del self.__dict__['api_key']
if "api_key" in self.__dict__:
del self.__dict__["api_key"]

# Update the private attribute
self._api_key = api_key
Expand Down
5 changes: 2 additions & 3 deletions tests/test_agent_clone_shallow_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
def greet(name: str) -> str:
return f"Hello, {name}!"


def test_agent_clone_shallow_copy():
"""Test that clone creates shallow copy with tools.copy() workaround"""
target_agent = Agent(name="Target")
Expand All @@ -16,9 +17,7 @@ def test_agent_clone_shallow_copy():
)

cloned = original.clone(
name="Cloned",
tools=original.tools.copy(),
handoffs=original.handoffs.copy()
name="Cloned", tools=original.tools.copy(), handoffs=original.handoffs.copy()
)

# Basic assertions
Expand Down
1 change: 1 addition & 0 deletions tests/test_stream_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ async def foo() -> str:
await asyncio.sleep(3)
return "success!"


@pytest.mark.asyncio
async def test_stream_events_main():
model = FakeModel()
Expand Down