Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from .handoffs import Handoff, HandoffInputFilter, handoff
from .items import (
HandoffCallItem,
ItemHelpers,
ModelResponse,
RunItem,
Expand All @@ -60,7 +61,12 @@
from .models.multi_provider import MultiProvider
from .result import RunResult, RunResultStreaming
from .run_context import RunContextWrapper, TContext
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent
from .stream_events import (
AgentUpdatedStreamEvent,
RawResponsesStreamEvent,
RunItemStreamEvent,
StreamEvent,
)
from .tool import Tool
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
from .tracing.span_data import AgentSpanData
Expand Down Expand Up @@ -1095,14 +1101,19 @@ async def _run_single_turn_streamed(
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
event_queue=streamed_result._event_queue,
)

if emitted_tool_call_ids:
import dataclasses as _dc
import dataclasses as _dc

# Filter out items that have already been sent to avoid duplicates
items_to_filter = single_step_result.new_step_items

filtered_items = [
if emitted_tool_call_ids:
# Filter out tool call items that were already emitted during streaming
items_to_filter = [
item
for item in single_step_result.new_step_items
for item in items_to_filter
if not (
isinstance(item, ToolCallItem)
and (
Expand All @@ -1114,15 +1125,17 @@ async def _run_single_turn_streamed(
)
]

single_step_result_filtered = _dc.replace(
single_step_result, new_step_items=filtered_items
)
# Filter out HandoffCallItem to avoid duplicates (already sent earlier)
items_to_filter = [
item for item in items_to_filter
if not isinstance(item, HandoffCallItem)
]

RunImpl.stream_step_result_to_queue(
single_step_result_filtered, streamed_result._event_queue
)
else:
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
# Create filtered result and send to queue
filtered_result = _dc.replace(
single_step_result, new_step_items=items_to_filter
)
RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue)
return single_step_result

@classmethod
Expand Down Expand Up @@ -1207,6 +1220,7 @@ async def _get_single_step_result_from_response(
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None,
) -> SingleStepResult:
processed_response = RunImpl.process_model_response(
agent=agent,
Expand All @@ -1218,6 +1232,15 @@ async def _get_single_step_result_from_response(

tool_use_tracker.add_tool_use(agent, processed_response.tools_used)

# Send handoff items immediately for streaming, but avoid duplicates
if event_queue is not None and processed_response.new_items:
handoff_items = [
item for item in processed_response.new_items
if isinstance(item, HandoffCallItem)
]
if handoff_items:
RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue)

return await RunImpl.execute_tools_and_side_effects(
agent=agent,
original_input=original_input,
Expand Down
60 changes: 58 additions & 2 deletions tests/test_stream_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import pytest

from agents import Agent, Runner, function_tool
from agents import Agent, HandoffCallItem, Runner, function_tool
from agents.extensions.handoff_filters import remove_all_tools
from agents.handoffs import handoff

from .fake_model import FakeModel
from .test_responses import get_function_tool_call, get_text_message
from .test_responses import get_function_tool_call, get_handoff_tool_call, get_text_message


@function_tool
Expand Down Expand Up @@ -52,3 +54,57 @@ async def test_stream_events_main():
assert tool_call_start_time > 0, "tool_call_item was not observed"
assert tool_call_end_time > 0, "tool_call_output_item was not observed"
assert tool_call_start_time < tool_call_end_time, "Tool call ended before or equals it started?"


@pytest.mark.asyncio
async def test_stream_events_main_with_handoff():
@function_tool
async def foo(args: str) -> str:
return f"foo_result_{args}"

english_agent = Agent(
name="EnglishAgent",
instructions="You only speak English.",
model=FakeModel(),
)

model = FakeModel()
model.add_multiple_turn_outputs(
[
[
get_text_message("Hello"),
get_function_tool_call("foo", '{"args": "arg1"}'),
get_handoff_tool_call(english_agent),
],
[get_text_message("Done")],
]
)

triage_agent = Agent(
name="TriageAgent",
instructions="Handoff to the appropriate agent based on the language of the request.",
handoffs=[
handoff(english_agent, input_filter=remove_all_tools),
],
tools=[foo],
model=model,
)

result = Runner.run_streamed(
triage_agent,
input="Start",
)

handoff_requested_seen = False
agent_switched_to_english = False

async for event in result.stream_events():
if event.type == "run_item_stream_event":
if isinstance(event.item, HandoffCallItem):
handoff_requested_seen = True
elif event.type == "agent_updated_stream_event":
if hasattr(event, "new_agent") and event.new_agent.name == "EnglishAgent":
agent_switched_to_english = True

assert handoff_requested_seen, "handoff_requested event not observed"
assert agent_switched_to_english, "Agent did not switch to EnglishAgent"