diff --git a/src/agents/run.py b/src/agents/run.py index ee08ad134..5056758fb 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -45,6 +45,7 @@ ) from .handoffs import Handoff, HandoffInputFilter, handoff from .items import ( + HandoffCallItem, ItemHelpers, ModelResponse, RunItem, @@ -60,7 +61,12 @@ from .models.multi_provider import MultiProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext -from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent +from .stream_events import ( + AgentUpdatedStreamEvent, + RawResponsesStreamEvent, + RunItemStreamEvent, + StreamEvent, +) from .tool import Tool from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData @@ -1095,14 +1101,19 @@ async def _run_single_turn_streamed( context_wrapper=context_wrapper, run_config=run_config, tool_use_tracker=tool_use_tracker, + event_queue=streamed_result._event_queue, ) - if emitted_tool_call_ids: - import dataclasses as _dc + import dataclasses as _dc + + # Filter out items that have already been sent to avoid duplicates + items_to_filter = single_step_result.new_step_items - filtered_items = [ + if emitted_tool_call_ids: + # Filter out tool call items that were already emitted during streaming + items_to_filter = [ item - for item in single_step_result.new_step_items + for item in items_to_filter if not ( isinstance(item, ToolCallItem) and ( @@ -1114,15 +1125,17 @@ async def _run_single_turn_streamed( ) ] - single_step_result_filtered = _dc.replace( - single_step_result, new_step_items=filtered_items - ) + # Filter out HandoffCallItem to avoid duplicates (already sent earlier) + items_to_filter = [ + item for item in items_to_filter + if not isinstance(item, HandoffCallItem) + ] - RunImpl.stream_step_result_to_queue( - single_step_result_filtered, streamed_result._event_queue - ) - else: - RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue) + # Create filtered result and send to queue + filtered_result = _dc.replace( + single_step_result, new_step_items=items_to_filter + ) + RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue) return single_step_result @classmethod @@ -1207,6 +1220,7 @@ async def _get_single_step_result_from_response( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, tool_use_tracker: AgentToolUseTracker, + event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None, ) -> SingleStepResult: processed_response = RunImpl.process_model_response( agent=agent, @@ -1218,6 +1232,15 @@ async def _get_single_step_result_from_response( tool_use_tracker.add_tool_use(agent, processed_response.tools_used) + # Send handoff items immediately for streaming, but avoid duplicates + if event_queue is not None and processed_response.new_items: + handoff_items = [ + item for item in processed_response.new_items + if isinstance(item, HandoffCallItem) + ] + if handoff_items: + RunImpl.stream_step_items_to_queue(cast(list[RunItem], handoff_items), event_queue) + return await RunImpl.execute_tools_and_side_effects( agent=agent, original_input=original_input, diff --git a/tests/test_stream_events.py b/tests/test_stream_events.py index 0f85b63f8..a2f0338d6 100644 --- a/tests/test_stream_events.py +++ b/tests/test_stream_events.py @@ -3,10 +3,12 @@ import pytest -from agents import Agent, Runner, function_tool +from agents import Agent, HandoffCallItem, Runner, function_tool +from agents.extensions.handoff_filters import remove_all_tools +from agents.handoffs import handoff from .fake_model import FakeModel -from .test_responses import get_function_tool_call, get_text_message +from .test_responses import get_function_tool_call, get_handoff_tool_call, get_text_message @function_tool @@ -52,3 +54,57 @@ async def test_stream_events_main(): assert tool_call_start_time > 0, "tool_call_item was not observed" assert tool_call_end_time > 0, "tool_call_output_item was not observed" assert tool_call_start_time < tool_call_end_time, "Tool call ended before or equals it started?" + + +@pytest.mark.asyncio +async def test_stream_events_main_with_handoff(): + @function_tool + async def foo(args: str) -> str: + return f"foo_result_{args}" + + english_agent = Agent( + name="EnglishAgent", + instructions="You only speak English.", + model=FakeModel(), + ) + + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [ + get_text_message("Hello"), + get_function_tool_call("foo", '{"args": "arg1"}'), + get_handoff_tool_call(english_agent), + ], + [get_text_message("Done")], + ] + ) + + triage_agent = Agent( + name="TriageAgent", + instructions="Handoff to the appropriate agent based on the language of the request.", + handoffs=[ + handoff(english_agent, input_filter=remove_all_tools), + ], + tools=[foo], + model=model, + ) + + result = Runner.run_streamed( + triage_agent, + input="Start", + ) + + handoff_requested_seen = False + agent_switched_to_english = False + + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + if isinstance(event.item, HandoffCallItem): + handoff_requested_seen = True + elif event.type == "agent_updated_stream_event": + if hasattr(event, "new_agent") and event.new_agent.name == "EnglishAgent": + agent_switched_to_english = True + + assert handoff_requested_seen, "handoff_requested event not observed" + assert agent_switched_to_english, "Agent did not switch to EnglishAgent"