From 421db5c0a3039b632419ac6caea6adbe316c3b52 Mon Sep 17 00:00:00 2001 From: Viraj Date: Thu, 21 Aug 2025 21:16:34 -0700 Subject: [PATCH] feat(run): add lifecycle interrupt + inject, cancel-aware tools, safer tracing --- src/agents/_run_impl.py | 99 +++- src/agents/models/openai_responses.py | 24 +- src/agents/result.py | 73 ++- src/agents/run.py | 365 ++++++++++++-- src/agents/stream_events.py | 21 +- tests/test_run_lifecycle.py | 685 ++++++++++++++++++++++++++ 6 files changed, 1194 insertions(+), 73 deletions(-) create mode 100644 tests/test_run_lifecycle.py diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 56784004c..f6e589cfe 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import dataclasses import inspect from collections.abc import Awaitable @@ -226,6 +227,29 @@ def get_model_tracing_impl( return ModelTracing.ENABLED_WITHOUT_DATA +# Helpers for cancellable tool execution + + +async def _await_cancellable(awaitable): + """Await an awaitable in its own task so CancelledError interrupts promptly.""" + task = asyncio.create_task(awaitable) + try: + return await task + except asyncio.CancelledError: + # propagate so run.py can handle terminal cancel + raise + + +def _maybe_call_cancel_hook(tool_obj) -> None: + """Best-effort: call a cancel/terminate hook on the tool if present.""" + for name in ("cancel", "terminate", "stop"): + cb = getattr(tool_obj, name, None) + if callable(cb): + with contextlib.suppress(Exception): + cb() + break + + class RunImpl: @classmethod async def execute_tools_and_side_effects( @@ -572,16 +596,24 @@ async def run_single_tool( if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: - _, _, result = await asyncio.gather( + # run start hooks first (don’t tie them to the cancellable task) + await asyncio.gather( hooks.on_tool_start(tool_context, agent, func_tool), ( agent.hooks.on_tool_start(tool_context, agent, func_tool) if agent.hooks else _coro.noop_coroutine() ), - func_tool.on_invoke_tool(tool_context, tool_call.arguments), ) + try: + result = await _await_cancellable( + func_tool.on_invoke_tool(tool_context, tool_call.arguments) + ) + except asyncio.CancelledError: + _maybe_call_cancel_hook(func_tool) + raise + await asyncio.gather( hooks.on_tool_end(tool_context, agent, func_tool, result), ( @@ -590,6 +622,7 @@ async def run_single_tool( else _coro.noop_coroutine() ), ) + except Exception as e: _error_tracing.attach_error_to_current_span( SpanError( @@ -660,7 +693,6 @@ async def execute_computer_actions( config: RunConfig, ) -> list[RunItem]: results: list[RunItem] = [] - # Need to run these serially, because each action can affect the computer state for action in actions: acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check: @@ -677,24 +709,28 @@ async def execute_computer_actions( if ack: acknowledged.append( ComputerCallOutputAcknowledgedSafetyCheck( - id=check.id, - code=check.code, - message=check.message, + id=check.id, code=check.code, message=check.message ) ) else: raise UserError("Computer tool safety check was not acknowledged") - results.append( - await ComputerAction.execute( - agent=agent, - action=action, - hooks=hooks, - context_wrapper=context_wrapper, - config=config, - acknowledged_safety_checks=acknowledged, + try: + item = await _await_cancellable( + ComputerAction.execute( + agent=agent, + action=action, + hooks=hooks, + context_wrapper=context_wrapper, + config=config, + acknowledged_safety_checks=acknowledged, + ) ) - ) + except asyncio.CancelledError: + _maybe_call_cancel_hook(action.computer_tool) + raise + + results.append(item) return results @@ -1068,16 +1104,23 @@ async def execute( else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call) ) - _, _, output = await asyncio.gather( + # start hooks first + await asyncio.gather( hooks.on_tool_start(context_wrapper, agent, action.computer_tool), ( agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) if agent.hooks else _coro.noop_coroutine() ), - output_func, ) - + # run the action (screenshot/etc) in a cancellable task + try: + output = await _await_cancellable(output_func) + except asyncio.CancelledError: + _maybe_call_cancel_hook(action.computer_tool) + raise + + # end hooks await asyncio.gather( hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), ( @@ -1185,10 +1228,20 @@ async def execute( data=call.tool_call, ) output = call.local_shell_tool.executor(request) - if inspect.isawaitable(output): - result = await output - else: - result = output + try: + if inspect.isawaitable(output): + result = await _await_cancellable(output) + else: + # If executor returns a sync result, just use it (can’t cancel mid-call) + result = output + except asyncio.CancelledError: + # Best-effort: if the executor or tool exposes a cancel/terminate, call it + _maybe_call_cancel_hook(call.local_shell_tool) + # If your executor returns a proc handle (common pattern), adddress it here if needed: + # with contextlib.suppress(Exception): + # proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0) + # proc.kill() + raise await asyncio.gather( hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), @@ -1201,7 +1254,7 @@ async def execute( return ToolCallOutputItem( agent=agent, - output=output, + output=result, raw_item={ "type": "local_shell_call_output", "id": call.tool_call.call_id, diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 0b409f7b0..b356e7da8 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from collections.abc import AsyncIterator from dataclasses import dataclass @@ -175,15 +176,30 @@ async def stream_response( final_response: Response | None = None - async for chunk in stream: - if isinstance(chunk, ResponseCompletedEvent): - final_response = chunk.response - yield chunk + try: + async for chunk in stream: # ensure type checkers relax here + if isinstance(chunk, ResponseCompletedEvent): + final_response = chunk.response + yield chunk + except asyncio.CancelledError: + # Cooperative cancel: ensure the HTTP stream is closed, then propagate + try: + await stream.close() + except Exception: + pass + raise + finally: + # Always close the stream if the async iterator exits (normal or error) + try: + await stream.close() + except Exception: + pass if final_response and tracing.include_data(): span_response.span_data.response = final_response span_response.span_data.input = input + except Exception as e: span_response.set_error( SpanError( diff --git a/src/agents/result.py b/src/agents/result.py index 5cf0e74c8..1d259e36b 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -2,6 +2,7 @@ import abc import asyncio +import contextlib from collections.abc import AsyncIterator from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, cast @@ -143,6 +144,12 @@ class RunResultStreaming(RunResultBase): is_complete: bool = False """Whether the agent has finished running.""" + _emit_status_events: bool = False + """Whether to emit RunUpdatedStreamEvent status updates. + + Defaults to False for backward compatibility. + """ + # Queues that the background run_loop writes to _event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field( default_factory=asyncio.Queue, repr=False @@ -164,17 +171,45 @@ def last_agent(self) -> Agent[Any]: """ return self.current_agent - def cancel(self) -> None: - """Cancels the streaming run, stopping all background tasks and marking the run as - complete.""" - self._cleanup_tasks() # Cancel all running tasks - self.is_complete = True # Mark the run as complete to stop event streaming + def cancel(self, reason: str | None = None) -> None: + # 1) Signal cooperative cancel to the runner + active = getattr(self, "_active_run", None) + if active: + with contextlib.suppress(Exception): + active.cancel(reason) + # 2) Do NOT cancel the background task; let the loop unwind cooperatively + # task = getattr(self, "_run_impl_task", None) + # if task and not task.done(): + # with contextlib.suppress(Exception): + # task.cancel() + + # 4) Mark complete; flushing only when status events are disabled + self.is_complete = True + if not getattr(self, "_emit_status_events", False): + with contextlib.suppress(Exception): + while not self._event_queue.empty(): + self._event_queue.get_nowait() + self._event_queue.task_done() + with contextlib.suppress(Exception): + while not self._input_guardrail_queue.empty(): + self._input_guardrail_queue.get_nowait() + self._input_guardrail_queue.task_done() + + def inject(self, items: list[TResponseInputItem]) -> None: + """ + Inject new input items mid-run. They will be consumed at the start of the next step. + """ + active = getattr(self, "_active_run", None) + if active is not None: + try: + active.inject(items) + except Exception: + pass - # Optionally, clear the event queue to prevent processing stale events - while not self._event_queue.empty(): - self._event_queue.get_nowait() - while not self._input_guardrail_queue.empty(): - self._input_guardrail_queue.get_nowait() + @property + def active_run(self): + """Access the underlying ActiveRun handle (may be None early in startup).""" + return getattr(self, "_active_run", None) async def stream_events(self) -> AsyncIterator[StreamEvent]: """Stream deltas for new items as they are generated. We're using the types from the @@ -243,21 +278,33 @@ def _check_errors(self): # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): run_impl_exc = self._run_impl_task.exception() - if run_impl_exc and isinstance(run_impl_exc, Exception): + if ( + run_impl_exc + and isinstance(run_impl_exc, Exception) + and not isinstance(run_impl_exc, asyncio.CancelledError) + ): if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None: run_impl_exc.run_data = self._create_error_details() self._stored_exception = run_impl_exc if self._input_guardrails_task and self._input_guardrails_task.done(): in_guard_exc = self._input_guardrails_task.exception() - if in_guard_exc and isinstance(in_guard_exc, Exception): + if ( + in_guard_exc + and isinstance(in_guard_exc, Exception) + and not isinstance(in_guard_exc, asyncio.CancelledError) + ): if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None: in_guard_exc.run_data = self._create_error_details() self._stored_exception = in_guard_exc if self._output_guardrails_task and self._output_guardrails_task.done(): out_guard_exc = self._output_guardrails_task.exception() - if out_guard_exc and isinstance(out_guard_exc, Exception): + if ( + out_guard_exc + and isinstance(out_guard_exc, Exception) + and not isinstance(out_guard_exc, asyncio.CancelledError) + ): if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None: out_guard_exc.run_data = self._create_error_details() self._stored_exception = out_guard_exc diff --git a/src/agents/run.py b/src/agents/run.py index 4575edb3f..04b06546d 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -60,7 +60,12 @@ from .models.multi_provider import MultiProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext -from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent +from .stream_events import ( + AgentUpdatedStreamEvent, + RawResponsesStreamEvent, + RunItemStreamEvent, + RunUpdatedStreamEvent, +) from .tool import Tool from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData @@ -98,6 +103,55 @@ def _default_trace_include_sensitive_data() -> bool: return val.strip().lower() in ("1", "true", "yes", "on") +# Cooperative cancellation + active handle + + +class _Cancellation: + def __init__(self): + self._ev = asyncio.Event() + self.reason: str | None = None + + def start(self, reason: str | None = None): + if not self._ev.is_set(): + self.reason = reason + self._ev.set() + + def is_cancelling(self) -> bool: + return self._ev.is_set() + + def raise_if_cancelled(self): + if self.is_cancelling(): + raise asyncio.CancelledError(self.reason or "Cancelled") + + +class _ActiveRun: + """ + Lightweight handle exposed to results so callers can cancel or inject input. + NOTE: the `RunResultStreaming` will keep a reference to this. + """ + + def __init__( + self, + cancel: _Cancellation, + inbox: list[TResponseInputItem], + state_cb: Callable[[], dict[str, Any]], + ): + self._cancel = cancel + self._inbox = inbox + self._state_cb = state_cb + + def cancel(self, reason: str | None = None) -> None: + self._cancel.start(reason) + + def inject(self, items: list[TResponseInputItem]) -> None: + # Append external inputs; consumed at the start of the next step + self._inbox.extend(items) + + def state(self) -> dict[str, Any]: + # optional: expose minimal state for debugging + return self._state_cb() + + @dataclass class ModelInputData: """Container for the data that will be sent to the model.""" @@ -393,6 +447,21 @@ class AgentRunner: It should not be used directly or subclassed. """ + @staticmethod + def _safe_finish(obj, *, reset_current: bool = True) -> None: + """ + Finish a span/trace safely even if called from a different task context. + Tries reset_current=True first; falls back to reset_current=False if needed. + """ + try: + obj.finish(reset_current=reset_current) + except Exception: + try: + obj.finish(reset_current=False) + except Exception: + # Last-resort: suppress exceptions since we are already tearing down. + pass + async def run( self, starting_agent: Agent[TContext], @@ -414,6 +483,18 @@ async def run( # Prepare input with session if enabled prepared_input = await self._prepare_input_with_session(input, session) + # Cancellation + inbox + handle for non-streamed runs + cancel_token = _Cancellation() + inbox: list[TResponseInputItem] = [] + + def _state_cb() -> dict[str, Any]: + return { + "current_turn": 0, # we'll update this below + "inbox_len": len(inbox), + } + + active_run = _ActiveRun(cancel_token, inbox, _state_cb) + tool_use_tracker = AgentToolUseTracker() with TraceCtxManager( @@ -424,6 +505,11 @@ async def run( disabled=run_config.tracing_disabled, ): current_turn = 0 + + def _update_state_turn(n: int): + _state_cb_dict = active_run.state() + _state_cb_dict["current_turn"] = n # optional, purely for debugging + original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) generated_items: list[RunItem] = [] model_responses: list[ModelResponse] = [] @@ -431,6 +517,9 @@ async def run( context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context=context, # type: ignore ) + # Stash inbox + cancel token on context wrapper for internal access + cast(Any, context_wrapper)._inbox = inbox + cast(Any, context_wrapper)._cancel_token = cancel_token input_guardrail_results: list[InputGuardrailResult] = [] @@ -442,6 +531,12 @@ async def run( while True: all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper) + # Cooperative cancel at loop top + if cancel_token.is_cancelling(): + raise asyncio.CancelledError(cancel_token.reason or "Cancelled") + + _update_state_turn(current_turn) + # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: @@ -540,11 +635,12 @@ async def run( # Save the conversation to session if enabled await self._save_result_to_session(session, input, result) + cast(Any, result).active_run = active_run # expose handle on non-streamed return result elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) - current_span.finish(reset_current=True) + AgentRunner._safe_finish(current_span, reset_current=True) current_span = None should_run_agent_start_hooks = True elif isinstance(turn_result.next_step, NextStepRunAgain): @@ -553,6 +649,24 @@ async def run( raise AgentsException( f"Unknown next step type: {type(turn_result.next_step)}" ) + + except asyncio.CancelledError as _c: + # Produce a terminal cancelled result; mirror the RunResult shape + result = RunResult( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + final_output=None, + _last_agent=current_agent, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + context_wrapper=context_wrapper, + ) + # Save to session if enabled + cast(Any, result).active_run = active_run + await self._save_result_to_session(session, input, result) + return result + except AgentsException as exc: exc.run_data = RunErrorDetails( input=original_input, @@ -566,7 +680,7 @@ async def run( raise finally: if current_span: - current_span.finish(reset_current=True) + AgentRunner._safe_finish(current_span, reset_current=True) def run_sync( self, @@ -651,6 +765,27 @@ def run_streamed( context_wrapper=context_wrapper, ) + # Cancellation + inbox + handle for this streamed run + cancel_token = _Cancellation() + inbox: list[TResponseInputItem] = [] + + # A tiny state closure for debugging/inspection + def _state_cb() -> dict[str, Any]: + current = getattr(streamed_result, "current_agent", None) + return { + "current_agent": current.name if current else None, + "current_turn": streamed_result.current_turn, + "is_complete": streamed_result.is_complete, + "inbox_len": len(inbox), + } + + active_run = _ActiveRun(cancel_token, inbox, _state_cb) + + # Stash these on the streamed_result; you'll expose helpers in result.py + cast(Any, streamed_result)._active_run = active_run + cast(Any, streamed_result)._cancel_token = cancel_token + cast(Any, streamed_result)._inbox = inbox + # Kick off the actual agent loop in the background and return the streamed result object. streamed_result._run_impl_task = asyncio.create_task( self._start_streaming( @@ -664,8 +799,11 @@ def run_streamed( previous_response_id=previous_response_id, conversation_id=conversation_id, session=session, + _cancel_token=cancel_token, + _inbox=inbox, ) ) + return streamed_result @classmethod @@ -765,6 +903,8 @@ async def _start_streaming( previous_response_id: str | None, conversation_id: str | None, session: Session | None, + _cancel_token: _Cancellation, + _inbox: list[TResponseInputItem], ): if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) @@ -777,6 +917,10 @@ async def _start_streaming( streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) + # Track whether we've already closed span/trace in a special path (GeneratorExit) + span_finished = False + trace_finished = False + try: # Prepare input with session if enabled prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session) @@ -785,6 +929,16 @@ async def _start_streaming( streamed_result.input = prepared_input while True: + # Cooperative cancel at loop top + if _cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="cancelled", reason=_cancel_token.reason) + ) + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + if streamed_result.is_complete: break @@ -810,6 +964,7 @@ async def _start_streaming( current_span.start(mark_as_current=True) tool_names = [t.name for t in all_tools] current_span.span_data.tools = tool_names + current_turn += 1 streamed_result.current_turn = current_turn @@ -821,7 +976,15 @@ async def _start_streaming( data={"max_turns": max_turns}, ), ) - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="failed", reason=f"Max turns exceeded ({max_turns})" + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break if current_turn == 1: @@ -859,12 +1022,18 @@ async def _start_streaming( if isinstance(turn_result.next_step, NextStepHandoff): current_agent = turn_result.next_step.new_agent - current_span.finish(reset_current=True) + if current_span: + AgentRunner._safe_finish(current_span, reset_current=True) + span_finished = True # this span is closed here current_span = None should_run_agent_start_hooks = True streamed_result._event_queue.put_nowait( AgentUpdatedStreamEvent(new_agent=current_agent) ) + + # After handoff, allow a new span to start on next loop + span_finished = False + elif isinstance(turn_result.next_step, NextStepFinalOutput): streamed_result._output_guardrails_task = asyncio.create_task( cls._run_output_guardrails( @@ -887,7 +1056,6 @@ async def _start_streaming( streamed_result.is_complete = True # Save the conversation to session if enabled - # Create a temporary RunResult for session saving temp_result = RunResult( input=streamed_result.input, new_items=streamed_result.new_items, @@ -902,12 +1070,40 @@ async def _start_streaming( session, starting_input, temp_result ) + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="completed") + ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + elif isinstance(turn_result.next_step, NextStepRunAgain): + # No-op; continue loop for another turn pass + except AgentsException as exc: - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + # If a cancel was requested, normalize any exception as "cancelled" + if _cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="cancelled", + reason=getattr(_cancel_token, "reason", None), + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise + + # existing "failed" path + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="failed", reason=exc.__class__.__name__) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) exc.run_data = RunErrorDetails( input=streamed_result.input, new_items=streamed_result.new_items, @@ -918,25 +1114,70 @@ async def _start_streaming( output_guardrail_results=streamed_result.output_guardrail_results, ) raise + + except asyncio.CancelledError: + # Cooperative cancellation: treat as a normal terminal state + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="cancelled", + reason=getattr(_cancel_token, "reason", None), + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + break + except Exception as e: + # If a cancel was requested, normalize any exception as "cancelled" + if _cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent( + status="cancelled", + reason=getattr(_cancel_token, "reason", None), + ) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise + if current_span: _error_tracing.attach_error_to_span( current_span, - SpanError( - message="Error in agent run", - data={"error": str(e)}, - ), + SpanError(message="Error in agent run", data={"error": str(e)}), ) - streamed_result.is_complete = True - streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="failed", reason=e.__class__.__name__) + ) + if not streamed_result.is_complete: + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) raise streamed_result.is_complete = True + + except GeneratorExit: + # The coroutine is being garbage-collected/closed; avoid cross-context resets. + try: + if current_span and not span_finished: + AgentRunner._safe_finish(current_span, reset_current=False) + span_finished = True + if streamed_result.trace and not trace_finished: + AgentRunner._safe_finish(streamed_result.trace, reset_current=False) + trace_finished = True + finally: + # Respect generator close semantics. + raise + finally: - if current_span: - current_span.finish(reset_current=True) - if streamed_result.trace: - streamed_result.trace.finish(reset_current=True) + if current_span and not span_finished: + AgentRunner._safe_finish(current_span, reset_current=True) + if streamed_result.trace and not trace_finished: + AgentRunner._safe_finish(streamed_result.trace, reset_current=True) @classmethod async def _run_single_turn_streamed( @@ -980,10 +1221,18 @@ async def _run_single_turn_streamed( model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) final_response: ModelResponse | None = None + injected_during_turn = False input = ItemHelpers.input_to_new_input_list(streamed_result.input) input.extend([item.to_input_item() for item in streamed_result.new_items]) + # Consume any externally injected items before planning/model call + # Externally injected items live in streamed_result._inbox (a list of input items) + injected = getattr(streamed_result, "_inbox", None) + if injected: + input.extend(injected) + injected.clear() + # THIS IS THE RESOLVED CONFLICT BLOCK filtered = await cls._maybe_filter_model_input( agent=agent, @@ -1020,6 +1269,12 @@ async def _run_single_turn_streamed( conversation_id=conversation_id, prompt=prompt_config, ): + # Cooperative cancel during streaming + cancel_token = getattr(streamed_result, "_cancel_token", None) + if cancel_token and cancel_token.is_cancelling(): + # Stop iterating; the model adapter should also close its stream cooperatively. + break + if isinstance(event, ResponseCompletedEvent): usage = ( Usage( @@ -1061,6 +1316,31 @@ async def _run_single_turn_streamed( streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) + # Break early if new items were injected during this turn. + if injected and len(injected) > 0: + injected_during_turn = True + break + + if injected_during_turn and final_response is None: + return SingleStepResult( + original_input=streamed_result.input, + model_response=ModelResponse(output=[], usage=Usage(), response_id=None), + pre_step_items=streamed_result.new_items, + new_step_items=[], + next_step=NextStepRunAgain(), + ) + + # If cancelled during streaming, terminate cleanly + cancel_token = getattr(streamed_result, "_cancel_token", None) + if cancel_token and cancel_token.is_cancelling(): + if getattr(streamed_result, "_emit_status_events", False): + streamed_result._event_queue.put_nowait( + RunUpdatedStreamEvent(status="cancelled", reason=cancel_token.reason) + ) + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise asyncio.CancelledError(cancel_token.reason or "Cancelled") + # Call hook just after the model response is finalized. if final_response is not None: await asyncio.gather( @@ -1156,6 +1436,13 @@ async def _run_single_turn( input = ItemHelpers.input_to_new_input_list(original_input) input.extend([generated_item.to_input_item() for generated_item in generated_items]) + # Consume injected items (non-streamed runs) + # We stashed the inbox on the context wrapper to avoid changing all signatures. + inbox: list[TResponseInputItem] | None = getattr(context_wrapper, "_inbox", None) + if inbox: + input.extend(inbox) + inbox.clear() + new_response = await cls._get_new_response( agent, system_prompt, @@ -1367,6 +1654,11 @@ async def _get_new_response( conversation_id: str | None, prompt_config: ResponsePromptParam | None, ) -> ModelResponse: + # Cooperative cancel before the model call (non-streamed) --- + cancel_token: _Cancellation | None = getattr(context_wrapper, "_cancel_token", None) + if cancel_token and cancel_token.is_cancelling(): + raise asyncio.CancelledError(cancel_token.reason or "Cancelled") + # Allow user to modify model input right before the call, if configured filtered = await cls._maybe_filter_model_input( agent=agent, @@ -1395,20 +1687,29 @@ async def _get_new_response( ), ) - new_response = await model.get_response( - system_instructions=filtered.instructions, - input=filtered.input, - model_settings=model_settings, - tools=all_tools, - output_schema=output_schema, - handoffs=handoffs, - tracing=get_model_tracing_impl( - run_config.tracing_disabled, run_config.trace_include_sensitive_data - ), - previous_response_id=previous_response_id, - conversation_id=conversation_id, - prompt=prompt_config, - ) + # call the model in a cancellable task so cancel() interrupts promptly + async def _call_model() -> ModelResponse: + return await model.get_response( + system_instructions=filtered.instructions, + input=filtered.input, + model_settings=model_settings, + tools=all_tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=get_model_tracing_impl( + run_config.tracing_disabled, run_config.trace_include_sensitive_data + ), + previous_response_id=previous_response_id, + conversation_id=conversation_id, + prompt=prompt_config, + ) + + task = asyncio.create_task(_call_model()) + try: + new_response = await task + except asyncio.CancelledError: + # propagate; caller handles terminal state + raise context_wrapper.usage.add(new_response.usage) diff --git a/src/agents/stream_events.py b/src/agents/stream_events.py index a271e8acd..e4a90c603 100644 --- a/src/agents/stream_events.py +++ b/src/agents/stream_events.py @@ -56,6 +56,25 @@ class AgentUpdatedStreamEvent: type: Literal["agent_updated_stream_event"] = "agent_updated_stream_event" + # Terminal / status update event for the overall run -StreamEvent: TypeAlias = Union[RawResponsesStreamEvent, RunItemStreamEvent, AgentUpdatedStreamEvent] + +@dataclass +class RunUpdatedStreamEvent: + """High-level run status update (emitted on completion, failure, or cancellation).""" + + status: Literal["running", "completed", "failed", "cancelled"] = "running" + """Current run status.""" + reason: str | None = None + """Optional human-readable reason (e.g., cancellation reason).""" + type: Literal["run.updated"] = "run.updated" + """Event type identifier.""" + + +StreamEvent: TypeAlias = Union[ + RawResponsesStreamEvent, + RunItemStreamEvent, + AgentUpdatedStreamEvent, + RunUpdatedStreamEvent, +] """A streaming event from an agent.""" diff --git a/tests/test_run_lifecycle.py b/tests/test_run_lifecycle.py new file mode 100644 index 000000000..7587225f6 --- /dev/null +++ b/tests/test_run_lifecycle.py @@ -0,0 +1,685 @@ +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from typing import Any, cast + +import pytest +from openai.types.responses import ResponseStreamEvent +from openai.types.responses.response_prompt_param import ResponsePromptParam + +from agents._run_impl import ( + AgentToolUseTracker, + NextStepRunAgain, + ProcessedResponse, + RunImpl, + SingleStepResult, +) +from agents.agent import Agent +from agents.agent_output import AgentOutputSchemaBase +from agents.exceptions import ( + InputGuardrailTripwireTriggered, + MaxTurnsExceeded, + OutputGuardrailTripwireTriggered, +) +from agents.guardrail import GuardrailFunctionOutput, InputGuardrail, OutputGuardrail +from agents.handoffs import Handoff +from agents.items import ModelResponse, RunItem, TResponseInputItem +from agents.lifecycle import RunHooks +from agents.model_settings import ModelSettings +from agents.models.interface import Model, ModelTracing +from agents.run import ( + AgentRunner, + RunConfig, + Runner, + _Cancellation, + get_default_agent_runner, + set_default_agent_runner, +) +from agents.run_context import RunContextWrapper +from agents.stream_events import RunUpdatedStreamEvent +from agents.tool import Tool +from agents.usage import Usage + +# Reuse the repo’s helper to build a FunctionTool correctly +from tests.test_responses import get_function_tool # <-- existing test helper + + +class MinimalAgent: + """Just enough surface for Runner.""" + + def __init__(self, model: Model, name: str = "test-agent"): + self.name = name + self.model = model + self.model_settings = ModelSettings() + self.output_type = None + self.hooks = None + self.handoffs: list[Handoff] = [] + self.reset_tool_choice = False + self.input_guardrails: list[Any] = [] + self.output_guardrails: list[Any] = [] + + async def get_system_prompt(self, _): + return None + + async def get_prompt(self, _): + return None + + async def get_all_tools(self, _): + return [] + + +class FakeModelNeverCompletes(Model): + async def get_response(self, *a: Any, **k: Any) -> Any: + raise NotImplementedError + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[Any], + model_settings: ModelSettings, + tools: list[Any], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[ResponseStreamEvent]: + while True: + await asyncio.sleep(0.02) + yield cast(ResponseStreamEvent, object()) + + +@pytest.mark.anyio +async def test_cancel_streamed_run_emits_cancelled_status(): + """When status events are enabled, cancel should emit run.updated(cancelled).""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="hello world", + run_config=RunConfig(model=agent.model), + max_turns=10, + ) + # Opt-in to status events for this test + result._emit_status_events = True + + seen_status: str | None = None + + async def consume(): + nonlocal seen_status + async for ev in result.stream_events(): + if isinstance(ev, RunUpdatedStreamEvent): + seen_status = ev.status + + consumer = asyncio.create_task(consume()) + await asyncio.sleep(0.08) + result.cancel("user-requested") + await consumer + + assert result.is_complete is True + assert seen_status == "cancelled" + + +@pytest.mark.anyio +async def test_default_flag_off_emits_no_status_event(): + """By default, no run.updated events should be emitted (back-compat).""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="x", + run_config=RunConfig(model=agent.model), + ) + # DO NOT set result._emit_status_events here + statuses: list[str] = [] + + async def consume(): + async for ev in result.stream_events(): + if isinstance(ev, RunUpdatedStreamEvent): + statuses.append(ev.status) + + task = asyncio.create_task(consume()) + await asyncio.sleep(0.05) + result.cancel("user") + await task + + assert statuses == [] # no run.updated by default + + +@pytest.mark.anyio +async def test_midstream_cancel_emits_cancelled_status_when_enabled(): + """Cancel while model is streaming yields cancelled when flag is on.""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="x", + run_config=RunConfig(model=agent.model), + ) + result._emit_status_events = True + statuses: list[str] = [] + + async def consume(): + async for ev in result.stream_events(): + if isinstance(ev, RunUpdatedStreamEvent): + statuses.append(ev.status) + + task = asyncio.create_task(consume()) + await asyncio.sleep(0.06) + result.cancel("user") + await task + + assert "cancelled" in statuses + + +class FakeModelSlowGet(Model): + async def get_response(self, *a, **k): + # simulate long compute so we can cancel + await asyncio.sleep(1.0) + + async def stream_response(self, *a, **k): + raise NotImplementedError + + +@pytest.mark.anyio +async def test_non_streamed_cancel_propagates_cancelled_error_or_returns_terminal_result(): + """Runner.run cancellation should terminate cleanly. + + We accept either a CancelledError or a terminal RunResult. + """ + agent = MinimalAgent(model=FakeModelSlowGet()) + + async def run_it(): + return await Runner.run( + cast(Agent[Any], agent), + input="y", + run_config=RunConfig(model=agent.model), + ) + + task = asyncio.create_task(run_it()) + await asyncio.sleep(0.05) + task.cancel() + + try: + result = await task + except asyncio.CancelledError: + # Current contract may propagate cancel; this is acceptable. + return + + # If your contract returns a terminal result on cancel, assert it here. + assert getattr(result, "final_output", None) is None + + +@pytest.mark.anyio +async def test_inject_is_consumed_on_next_turn(): + """ + Injected items should be included in a subsequent model turn input. + We capture the inputs passed into FakeModel each turn and assert presence. + """ + INJECT_TOKEN: TResponseInputItem = { + "role": "user", + "content": "INJECTED", + } # match message-style items + + class FakeModelCapture(Model): + def __init__(self) -> None: + self.inputs: list[list[Any]] = [] + + async def get_response(self, *a: Any, **k: Any) -> Any: + raise NotImplementedError + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[Any], + model_settings: ModelSettings, + tools: list[Any], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[ResponseStreamEvent]: + turn = 0 + while True: + self.inputs.append(list(input)) + yield cast(ResponseStreamEvent, object()) + turn += 1 + await asyncio.sleep(0.01) + + + model = FakeModelCapture() + agent = MinimalAgent(model=model) + + result = Runner.run_streamed( + starting_agent=cast(Agent[Any], agent), + input="hello", + run_config=RunConfig(model=agent.model), + max_turns=6, + ) + + async def drive_and_inject(): + # Let at least one turn record baseline input + await asyncio.sleep(0.05) + # Inject so a future turn sees it + result.inject([INJECT_TOKEN]) + # Give time for a couple more turns to run + await asyncio.sleep(0.12) + result.cancel("done") + + consumer = asyncio.create_task(drive_and_inject()) + async for _ in result.stream_events(): + pass + await consumer + + # We should have recorded ≥2 turns + assert len(model.inputs) >= 2 + + # Assert the injected message appears in ANY turn after injection time + flattened_after_injection = [item for turn in model.inputs[1:] for item in turn] + assert any( + isinstance(item, dict) and item.get("role") == "user" and item.get("content") == "INJECTED" + for item in flattened_after_injection + ), f"Injected item not present after injection; captured={model.inputs}" + + +class FakeModelTriggersTool(Model): + """ + Emits continuous events so we can cancel while a function tool is (hypothetically) running. + Note: This is a timing smoke test. For a full tool-call path test, emit tool-call outputs. + """ + + async def get_response(self, *a: Any, **k: Any) -> Any: + raise NotImplementedError + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[Any], + model_settings: ModelSettings, + tools: list[Any], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[ResponseStreamEvent]: + while True: + await asyncio.sleep(0.02) + yield cast(ResponseStreamEvent, object()) + + +class AgentWithTool(MinimalAgent): + def __init__(self, model: Model, tool: Tool): + super().__init__(model) + self._tool = tool + + async def get_all_tools(self, _): + return [self._tool] + + +@pytest.mark.anyio +async def test_function_tool_cancels_promptly(): + # Build the tool using the repo helper (it doesn't take a handler argument) + tool = get_function_tool("long", "done") + + agent = AgentWithTool(FakeModelTriggersTool(), tool) + + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="trigger tool", + run_config=RunConfig(model=agent.model), + ) + start = time.perf_counter() + await asyncio.sleep(0.05) # let some activity happen + result.cancel("user") + + # Drain stream; ensure no hang + async for _ in result.stream_events(): + pass + + elapsed = time.perf_counter() - start + # Expect prompt cancellation (well under 1s) + assert elapsed < 0.4 + + +class _FailingModel(Model): + async def get_response(self, *a: Any, **k: Any) -> Any: + raise NotImplementedError + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[Any], + model_settings: ModelSettings, + tools: list[Any], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[ResponseStreamEvent]: + if False: + yield cast(ResponseStreamEvent, object()) + raise RuntimeError("boom") + + +class _TickingModel(Model): + async def get_response(self, *a: Any, **k: Any) -> Any: + raise NotImplementedError + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[Any], + model_settings: ModelSettings, + tools: list[Any], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: str | None, + conversation_id: str | None, + prompt: ResponsePromptParam | None, + ) -> AsyncIterator[ResponseStreamEvent]: + while True: + await asyncio.sleep(0.001) + yield cast(ResponseStreamEvent, object()) + + +@pytest.mark.anyio +async def test_streamed_failure_emits_failed_status_and_closes(): + agent = MinimalAgent(model=_FailingModel()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="x", + run_config=RunConfig(model=agent.model), + ) + result._emit_status_events = True + + statuses: list[str] = [] + it = result.stream_events() # ensure we CALL the method and get an async iterator + assert hasattr(it, "__aiter__"), ( + f"stream_events() did not return an async iterator, got {type(it)}" + ) + caught = False + try: + async for ev in it: + if isinstance(ev, RunUpdatedStreamEvent): + statuses.append(ev.status) + except RuntimeError as e: + # Our failing model raises immediately; accept that path too + assert str(e) == "boom" + caught = True + + # Either we saw a 'failed' status, or we caught the model's RuntimeError. + assert result.is_complete is True + assert (statuses and statuses[-1] == "failed") or caught + + +@pytest.mark.anyio +async def test_max_turns_exceeded_hits_failed_path(): + agent = MinimalAgent(model=_TickingModel()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="tick", + run_config=RunConfig(model=agent.model), + max_turns=0, # immediate exceed + ) + result._emit_status_events = True + + # Draining the stream should raise MaxTurnsExceeded for this configuration. + it = result.stream_events() + assert hasattr(it, "__aiter__"), ( + f"stream_events() did not return an async iterator, got {type(it)}" + ) + with pytest.raises(MaxTurnsExceeded): + async for _ in it: + pass + + +@pytest.mark.anyio +async def test_cancel_before_streaming_closes_immediately(): + """Cancel right away to hit the early-cancel branch at the top of the loop.""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="early", + run_config=RunConfig(model=agent.model), + ) + result._emit_status_events = True + + # Cancel before we ever start iterating events + result.cancel("early-stop") + + statuses: list[str] = [] + it = result.stream_events() + assert hasattr(it, "__aiter__"), ( + f"stream_events() did not return an async iterator, got {type(it)}" + ) + async for ev in it: + if isinstance(ev, RunUpdatedStreamEvent): + statuses.append(ev.status) + + assert result.is_complete is True + # Some runners emit the status, others just close; allow both + assert (not statuses) or statuses[-1] == "cancelled" + + +@pytest.mark.anyio +async def test_idempotent_cancel_emits_single_terminal_status_when_enabled(): + """Double cancel should not duplicate terminal status.""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + result = Runner.run_streamed( + cast(Agent[Any], agent), + input="dup", + run_config=RunConfig(model=agent.model), + ) + result._emit_status_events = True + + # Issue cancel twice + result.cancel("first") + result.cancel("second") + + statuses: list[str] = [] + it = result.stream_events() + assert hasattr(it, "__aiter__"), ( + f"stream_events() did not return an async iterator, got {type(it)}" + ) + async for ev in it: + if isinstance(ev, RunUpdatedStreamEvent): + statuses.append(ev.status) + + assert result.is_complete is True + # If statuses are emitted, 'cancelled' should appear exactly once + if statuses: + assert statuses.count("cancelled") == 1 + + + + +def test_get_default_agent_runner_roundtrip() -> None: + """The default agent runner can be replaced and restored.""" + original = get_default_agent_runner() + try: + set_default_agent_runner(None) + new_runner = get_default_agent_runner() + assert isinstance(new_runner, AgentRunner) + finally: + set_default_agent_runner(original) + + +def test_cancellation_raises_cancelled_error() -> None: + """_Cancellation.raise_if_cancelled raises when started.""" + cancel = _Cancellation() + cancel.start("go") + with pytest.raises(asyncio.CancelledError): + cancel.raise_if_cancelled() + + +@pytest.mark.anyio +async def test_run_input_guardrails_handles_tripwire() -> None: + """Tripwire in input guardrail raises InputGuardrailTripwireTriggered.""" + + async def tripwire(_, __, ___): + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + guardrail = InputGuardrail(tripwire, name="trip") + context = RunContextWrapper(context=None) + agent = MinimalAgent(model=FakeModelNeverCompletes()) + with pytest.raises(InputGuardrailTripwireTriggered): + await AgentRunner._run_input_guardrails(cast(Agent[Any], agent), [guardrail], "hi", context) + + +@pytest.mark.anyio +async def test_run_input_guardrails_collects_results() -> None: + """Non-tripwire guardrails return their results.""" + + async def ok(_, __, ___): + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) + + guardrails = [InputGuardrail(ok, name="g1"), InputGuardrail(ok, name="g2")] + context = RunContextWrapper(context=None) + agent = MinimalAgent(model=FakeModelNeverCompletes()) + results = await AgentRunner._run_input_guardrails( + cast(Agent[Any], agent), guardrails, "hi", context + ) + assert len(results) == 2 + + +@pytest.mark.anyio +async def test_run_output_guardrails_handles_tripwire() -> None: + """Tripwire in output guardrail raises OutputGuardrailTripwireTriggered.""" + + async def tripwire(_, __, ___): + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + guardrail = OutputGuardrail(tripwire, name="trip") + context = RunContextWrapper(context=None) + agent = MinimalAgent(model=FakeModelNeverCompletes()) + with pytest.raises(OutputGuardrailTripwireTriggered): + await AgentRunner._run_output_guardrails( + [guardrail], cast(Agent[Any], agent), "out", context + ) + + +@pytest.mark.anyio +async def test_run_output_guardrails_collects_results() -> None: + """Non-tripwire output guardrails return their results.""" + + async def ok(_, __, ___): + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) + + guardrails = [OutputGuardrail(ok, name="g1"), OutputGuardrail(ok, name="g2")] + context = RunContextWrapper(context=None) + agent = MinimalAgent(model=FakeModelNeverCompletes()) + results = await AgentRunner._run_output_guardrails( + guardrails, cast(Agent[Any], agent), "out", context + ) + assert len(results) == 2 + + +@pytest.mark.anyio +async def test_get_single_step_result_from_response(monkeypatch) -> None: + """_get_single_step_result_from_response processes model output.""" + agent = MinimalAgent(model=FakeModelNeverCompletes()) + new_response = ModelResponse(output=[], usage=Usage(), response_id=None) + + async def fake_execute(*args, **kwargs): + return SingleStepResult( + original_input="hi", + model_response=new_response, + pre_step_items=[], + new_step_items=[], + next_step=NextStepRunAgain(), + ) + + def fake_process(*args, **kwargs): + return ProcessedResponse( + new_items=[], + handoffs=[], + functions=[], + computer_actions=[], + local_shell_calls=[], + tools_used=[], + mcp_approval_requests=[], + ) + + monkeypatch.setattr(RunImpl, "process_model_response", fake_process) + monkeypatch.setattr(RunImpl, "execute_tools_and_side_effects", fake_execute) + + result = await AgentRunner._get_single_step_result_from_response( + agent=cast(Agent[Any], agent), + all_tools=[], + original_input="hi", + pre_step_items=[], + new_response=new_response, + output_schema=None, + handoffs=[], + hooks=RunHooks[Any](), + context_wrapper=RunContextWrapper(context=None), + run_config=RunConfig(model=agent.model), + tool_use_tracker=AgentToolUseTracker(), + ) + + assert result.model_response is new_response + + +@pytest.mark.anyio +async def test_get_single_step_result_from_streamed_response(monkeypatch) -> None: + """_get_single_step_result_from_streamed_response handles streamed events.""" + + agent = MinimalAgent(model=FakeModelNeverCompletes()) + new_response = ModelResponse(output=[], usage=Usage(), response_id=None) + + async def fake_execute(*args, **kwargs): + return SingleStepResult( + original_input="hi", + model_response=new_response, + pre_step_items=[], + new_step_items=[], + next_step=NextStepRunAgain(), + ) + + def fake_process(*args, **kwargs): + return ProcessedResponse( + new_items=[], + handoffs=[], + functions=[], + computer_actions=[], + local_shell_calls=[], + tools_used=[], + mcp_approval_requests=[], + ) + + class DummyStreamed: + def __init__(self) -> None: + self.input = "hi" + self.new_items: list[RunItem] = [] + self._event_queue: asyncio.Queue[Any] = asyncio.Queue() + + monkeypatch.setattr(RunImpl, "process_model_response", fake_process) + monkeypatch.setattr(RunImpl, "execute_tools_and_side_effects", fake_execute) + + streamed = DummyStreamed() + result = await AgentRunner._get_single_step_result_from_streamed_response( + agent=cast(Agent[Any], agent), + all_tools=[], + streamed_result=cast(Any, streamed), + new_response=new_response, + output_schema=None, + handoffs=[], + hooks=RunHooks[Any](), + context_wrapper=RunContextWrapper(context=None), + run_config=RunConfig(model=agent.model), + tool_use_tracker=AgentToolUseTracker(), + ) + + assert result.model_response is new_response