From ec7388ec38013ab932f53ab6c1111ebd6fb252b7 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Tue, 15 Jul 2025 21:55:40 -0400 Subject: [PATCH] Realtime: handoffs --- src/agents/agent.py | 2 +- src/agents/guardrail.py | 2 +- src/agents/handoffs.py | 46 +++++-- src/agents/models/chatcmpl_converter.py | 2 +- src/agents/models/openai_responses.py | 2 +- src/agents/realtime/__init__.py | 3 + src/agents/realtime/agent.py | 11 +- src/agents/realtime/config.py | 2 + src/agents/realtime/handoffs.py | 165 +++++++++++++++++++++++ src/agents/realtime/openai_realtime.py | 20 ++- src/agents/realtime/session.py | 43 +++++- tests/realtime/test_realtime_handoffs.py | 96 +++++++++++++ tests/realtime/test_session.py | 42 +++--- tests/test_handoff_tool.py | 5 + 14 files changed, 393 insertions(+), 48 deletions(-) create mode 100644 src/agents/realtime/handoffs.py create mode 100644 tests/realtime/test_realtime_handoffs.py diff --git a/src/agents/agent.py b/src/agents/agent.py index 9c107a81b..b67a12c0d 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -158,7 +158,7 @@ class Agent(AgentBase, Generic[TContext]): usable with OpenAI models, using the Responses API. """ - handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list) + handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list) """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs, and the agent can choose to delegate to them if relevant. Allows for separation of concerns and modularity. diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index f8a272b53..2bb0f014e 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -244,7 +244,7 @@ def decorator( return InputGuardrail( guardrail_function=f, # If not set, guardrail name uses the function’s name by default. - name=name if name else f.__name__ + name=name if name else f.__name__, ) if func is not None: diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index cb2752e4f..1ad8831f0 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -18,12 +18,15 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: - from .agent import Agent + from .agent import Agent, AgentBase # The handoff input type is the type of data passed when the agent is called via a handoff. THandoffInput = TypeVar("THandoffInput", default=Any) +# The agent type that the handoff returns +TAgent = TypeVar("TAgent", bound="AgentBase[Any]", default="Agent[Any]") + OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any] OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any] @@ -52,7 +55,7 @@ class HandoffInputData: @dataclass -class Handoff(Generic[TContext]): +class Handoff(Generic[TContext, TAgent]): """A handoff is when an agent delegates a task to another agent. For example, in a customer support scenario you might have a "triage agent" that determines which agent should handle the user's request, and sub-agents that specialize in different @@ -69,7 +72,7 @@ class Handoff(Generic[TContext]): """The JSON schema for the handoff input. Can be empty if the handoff does not take an input. """ - on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[Agent[TContext]]] + on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[TAgent]] """The function that invokes the handoff. The parameters passed are: 1. The handoff run context 2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty. @@ -100,20 +103,22 @@ class Handoff(Generic[TContext]): True, as it increases the likelihood of correct JSON input. """ - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = ( + True + ) """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable a handoff based on your context/state.""" - def get_transfer_message(self, agent: Agent[Any]) -> str: + def get_transfer_message(self, agent: AgentBase[Any]) -> str: return json.dumps({"assistant": agent.name}) @classmethod - def default_tool_name(cls, agent: Agent[Any]) -> str: + def default_tool_name(cls, agent: AgentBase[Any]) -> str: return _transforms.transform_string_function_style(f"transfer_to_{agent.name}") @classmethod - def default_tool_description(cls, agent: Agent[Any]) -> str: + def default_tool_description(cls, agent: AgentBase[Any]) -> str: return ( f"Handoff to the {agent.name} agent to handle the request. " f"{agent.handoff_description or ''}" @@ -128,7 +133,7 @@ def handoff( tool_description_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, -) -> Handoff[TContext]: ... +) -> Handoff[TContext, Agent[TContext]]: ... @overload @@ -141,7 +146,7 @@ def handoff( tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, -) -> Handoff[TContext]: ... +) -> Handoff[TContext, Agent[TContext]]: ... @overload @@ -153,7 +158,7 @@ def handoff( tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, -) -> Handoff[TContext]: ... +) -> Handoff[TContext, Agent[TContext]]: ... def handoff( @@ -163,8 +168,9 @@ def handoff( on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None, input_type: type[THandoffInput] | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, -) -> Handoff[TContext]: + is_enabled: bool + | Callable[[RunContextWrapper[Any], Agent[TContext]], MaybeAwaitable[bool]] = True, +) -> Handoff[TContext, Agent[TContext]]: """Create a handoff from an agent. Args: @@ -202,7 +208,7 @@ def handoff( async def _invoke_handoff( ctx: RunContextWrapper[Any], input_json: str | None = None - ) -> Agent[Any]: + ) -> Agent[TContext]: if input_type is not None and type_adapter is not None: if input_json is None: _error_tracing.attach_error_to_current_span( @@ -239,6 +245,18 @@ async def _invoke_handoff( # If there is a need, we can make this configurable in the future input_json_schema = ensure_strict_json_schema(input_json_schema) + async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool: + from .agent import Agent + + assert callable(is_enabled), "is_enabled must be non-null here" + assert isinstance(agent_base, Agent), "Can't handoff to a non-Agent" + result = is_enabled(ctx, agent_base) + + if inspect.isawaitable(result): + return await result + + return result + return Handoff( tool_name=tool_name, tool_description=tool_description, @@ -246,5 +264,5 @@ async def _invoke_handoff( on_invoke_handoff=_invoke_handoff, input_filter=input_filter, agent_name=agent.name, - is_enabled=is_enabled, + is_enabled=_is_enabled if callable(is_enabled) else is_enabled, ) diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index d3c71c24e..351dc5db7 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -484,7 +484,7 @@ def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam: ) @classmethod - def convert_handoff_tool(cls, handoff: Handoff[Any]) -> ChatCompletionToolParam: + def convert_handoff_tool(cls, handoff: Handoff[Any, Any]) -> ChatCompletionToolParam: return { "type": "function", "function": { diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 76c67903c..f6da60b08 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -370,7 +370,7 @@ def get_response_format( def convert_tools( cls, tools: list[Tool], - handoffs: list[Handoff[Any]], + handoffs: list[Handoff[Any, Any]], ) -> ConvertedTools: converted_tools: list[ToolParam] = [] includes: list[ResponseIncludable] = [] diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index 0e3e12f75..49c131389 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -30,6 +30,7 @@ RealtimeToolEnd, RealtimeToolStart, ) +from .handoffs import realtime_handoff from .items import ( AssistantMessageItem, AssistantText, @@ -92,6 +93,8 @@ "RealtimeAgentHooks", "RealtimeRunHooks", "RealtimeRunner", + # Handoffs + "realtime_handoff", # Config "RealtimeAudioFormat", "RealtimeClientMessage", diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py index 9bbed8cb4..30e80a95b 100644 --- a/src/agents/realtime/agent.py +++ b/src/agents/realtime/agent.py @@ -3,10 +3,11 @@ import dataclasses import inspect from collections.abc import Awaitable -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable, Generic, cast from ..agent import AgentBase +from ..handoffs import Handoff from ..lifecycle import AgentHooksBase, RunHooksBase from ..logger import logger from ..run_context import RunContextWrapper, TContext @@ -53,6 +54,14 @@ class RealtimeAgent(AgentBase, Generic[TContext]): return a string. """ + handoffs: list[RealtimeAgent[Any] | Handoff[TContext, RealtimeAgent[Any]]] = field( + default_factory=list + ) + """Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs, + and the agent can choose to delegate to them if relevant. Allows for separation of concerns and + modularity. + """ + hooks: RealtimeAgentHooks | None = None """A class that receives callbacks on various lifecycle events for this agent. """ diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 7f874cfb0..6e7e3f4be 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -9,6 +9,7 @@ from typing_extensions import NotRequired, TypeAlias, TypedDict from ..guardrail import OutputGuardrail +from ..handoffs import Handoff from ..model_settings import ToolChoice from ..tool import Tool @@ -71,6 +72,7 @@ class RealtimeSessionModelSettings(TypedDict): tool_choice: NotRequired[ToolChoice] tools: NotRequired[list[Tool]] + handoffs: NotRequired[list[Handoff]] tracing: NotRequired[RealtimeModelTracingConfig | None] diff --git a/src/agents/realtime/handoffs.py b/src/agents/realtime/handoffs.py new file mode 100644 index 000000000..a3e5151f6 --- /dev/null +++ b/src/agents/realtime/handoffs.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any, Callable, cast, overload + +from pydantic import TypeAdapter +from typing_extensions import TypeVar + +from ..exceptions import ModelBehaviorError, UserError +from ..handoffs import Handoff +from ..run_context import RunContextWrapper, TContext +from ..strict_schema import ensure_strict_json_schema +from ..tracing.spans import SpanError +from ..util import _error_tracing, _json +from ..util._types import MaybeAwaitable + +if TYPE_CHECKING: + from ..agent import AgentBase + from . import RealtimeAgent + + +# The handoff input type is the type of data passed when the agent is called via a handoff. +THandoffInput = TypeVar("THandoffInput", default=Any) + +OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any] +OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any] + + +@overload +def realtime_handoff( + agent: RealtimeAgent[TContext], + *, + tool_name_override: str | None = None, + tool_description_override: str | None = None, + is_enabled: bool + | Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True, +) -> Handoff[TContext, RealtimeAgent[TContext]]: ... + + +@overload +def realtime_handoff( + agent: RealtimeAgent[TContext], + *, + on_handoff: OnHandoffWithInput[THandoffInput], + input_type: type[THandoffInput], + tool_description_override: str | None = None, + tool_name_override: str | None = None, + is_enabled: bool + | Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True, +) -> Handoff[TContext, RealtimeAgent[TContext]]: ... + + +@overload +def realtime_handoff( + agent: RealtimeAgent[TContext], + *, + on_handoff: OnHandoffWithoutInput, + tool_description_override: str | None = None, + tool_name_override: str | None = None, + is_enabled: bool + | Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True, +) -> Handoff[TContext, RealtimeAgent[TContext]]: ... + + +def realtime_handoff( + agent: RealtimeAgent[TContext], + tool_name_override: str | None = None, + tool_description_override: str | None = None, + on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None, + input_type: type[THandoffInput] | None = None, + is_enabled: bool + | Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True, +) -> Handoff[TContext, RealtimeAgent[TContext]]: + """Create a handoff from a RealtimeAgent. + + Args: + agent: The RealtimeAgent to handoff to, or a function that returns a RealtimeAgent. + tool_name_override: Optional override for the name of the tool that represents the handoff. + tool_description_override: Optional override for the description of the tool that + represents the handoff. + on_handoff: A function that runs when the handoff is invoked. + input_type: the type of the input to the handoff. If provided, the input will be validated + against this type. Only relevant if you pass a function that takes an input. + is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run + context and agent and returns whether the handoff is enabled. Disabled handoffs are + hidden from the LLM at runtime. + + Note: input_filter is not supported for RealtimeAgent handoffs. + """ + assert (on_handoff and input_type) or not (on_handoff and input_type), ( + "You must provide either both on_handoff and input_type, or neither" + ) + type_adapter: TypeAdapter[Any] | None + if input_type is not None: + assert callable(on_handoff), "on_handoff must be callable" + sig = inspect.signature(on_handoff) + if len(sig.parameters) != 2: + raise UserError("on_handoff must take two arguments: context and input") + + type_adapter = TypeAdapter(input_type) + input_json_schema = type_adapter.json_schema() + else: + type_adapter = None + input_json_schema = {} + if on_handoff is not None: + sig = inspect.signature(on_handoff) + if len(sig.parameters) != 1: + raise UserError("on_handoff must take one argument: context") + + async def _invoke_handoff( + ctx: RunContextWrapper[Any], input_json: str | None = None + ) -> RealtimeAgent[TContext]: + if input_type is not None and type_adapter is not None: + if input_json is None: + _error_tracing.attach_error_to_current_span( + SpanError( + message="Handoff function expected non-null input, but got None", + data={"details": "input_json is None"}, + ) + ) + raise ModelBehaviorError("Handoff function expected non-null input, but got None") + + validated_input = _json.validate_json( + json_str=input_json, + type_adapter=type_adapter, + partial=False, + ) + input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff) + if inspect.iscoroutinefunction(input_func): + await input_func(ctx, validated_input) + else: + input_func(ctx, validated_input) + elif on_handoff is not None: + no_input_func = cast(OnHandoffWithoutInput, on_handoff) + if inspect.iscoroutinefunction(no_input_func): + await no_input_func(ctx) + else: + no_input_func(ctx) + + return agent + + tool_name = tool_name_override or Handoff.default_tool_name(agent) + tool_description = tool_description_override or Handoff.default_tool_description(agent) + + # Always ensure the input JSON schema is in strict mode + # If there is a need, we can make this configurable in the future + input_json_schema = ensure_strict_json_schema(input_json_schema) + + async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool: + assert callable(is_enabled), "is_enabled must be non-null here" + assert isinstance(agent_base, RealtimeAgent), "Can't handoff to a non-RealtimeAgent" + result = is_enabled(ctx, agent_base) + if inspect.isawaitable(result): + return await result + return result + + return Handoff( + tool_name=tool_name, + tool_description=tool_description, + input_json_schema=input_json_schema, + on_invoke_handoff=_invoke_handoff, + input_filter=None, # Not supported for RealtimeAgent handoffs + agent_name=agent.name, + is_enabled=_is_enabled if callable(is_enabled) else is_enabled, + ) diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 129d89c49..bbffaeefa 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -56,6 +56,7 @@ from typing_extensions import assert_never from websockets.asyncio.client import ClientConnection +from agents.handoffs import Handoff from agents.tool import FunctionTool, Tool from agents.util._types import MaybeAwaitable @@ -519,10 +520,14 @@ def _get_session_config( "tool_choice", DEFAULT_MODEL_SETTINGS.get("tool_choice"), # type: ignore ), - tools=self._tools_to_session_tools(model_settings.get("tools", [])), + tools=self._tools_to_session_tools( + tools=model_settings.get("tools", []), handoffs=model_settings.get("handoffs", []) + ), ) - def _tools_to_session_tools(self, tools: list[Tool]) -> list[OpenAISessionTool]: + def _tools_to_session_tools( + self, tools: list[Tool], handoffs: list[Handoff] + ) -> list[OpenAISessionTool]: converted_tools: list[OpenAISessionTool] = [] for tool in tools: if not isinstance(tool, FunctionTool): @@ -535,6 +540,17 @@ def _tools_to_session_tools(self, tools: list[Tool]) -> list[OpenAISessionTool]: type="function", ) ) + + for handoff in handoffs: + converted_tools.append( + OpenAISessionTool( + name=handoff.tool_name, + description=handoff.tool_description, + parameters=handoff.input_json_schema, + type="function", + ) + ) + return converted_tools diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 07791c8d8..6df35b438 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import inspect from collections.abc import AsyncIterator from typing import Any, cast @@ -31,6 +32,7 @@ RealtimeToolEnd, RealtimeToolStart, ) +from .handoffs import realtime_handoff from .items import InputAudio, InputText, RealtimeItem from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener from .model_events import ( @@ -255,9 +257,12 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None: async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: """Handle a tool call event.""" - all_tools = await self._current_agent.get_all_tools(self._context_wrapper) - function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} - handoff_map = {tool.name: tool for tool in all_tools if isinstance(tool, Handoff)} + tools, handoffs = await asyncio.gather( + self._current_agent.get_all_tools(self._context_wrapper), + self._get_handoffs(self._current_agent, self._context_wrapper), + ) + function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)} + handoff_map = {handoff.tool_name: handoff for handoff in handoffs} if event.name in function_map: await self._put_event( @@ -303,7 +308,9 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None: # Execute the handoff to get the new agent result = await handoff.on_invoke_handoff(self._context_wrapper, event.arguments) if not isinstance(result, RealtimeAgent): - raise UserError(f"Handoff {handoff.name} returned invalid result: {type(result)}") + raise UserError( + f"Handoff {handoff.tool_name} returned invalid result: {type(result)}" + ) # Store previous agent for event previous_agent = self._current_agent @@ -492,11 +499,37 @@ async def _get__updated_model_settings( self, new_agent: RealtimeAgent ) -> RealtimeSessionModelSettings: updated_settings: RealtimeSessionModelSettings = {} - instructions, tools = await asyncio.gather( + instructions, tools, handoffs = await asyncio.gather( new_agent.get_system_prompt(self._context_wrapper), new_agent.get_all_tools(self._context_wrapper), + self._get_handoffs(new_agent, self._context_wrapper), ) updated_settings["instructions"] = instructions or "" updated_settings["tools"] = tools or [] + updated_settings["handoffs"] = handoffs or [] return updated_settings + + @classmethod + async def _get_handoffs( + cls, agent: RealtimeAgent[Any], context_wrapper: RunContextWrapper[Any] + ) -> list[Handoff[Any, RealtimeAgent[Any]]]: + handoffs: list[Handoff[Any, RealtimeAgent[Any]]] = [] + for handoff_item in agent.handoffs: + if isinstance(handoff_item, Handoff): + handoffs.append(handoff_item) + elif isinstance(handoff_item, RealtimeAgent): + handoffs.append(realtime_handoff(handoff_item)) + + async def _check_handoff_enabled(handoff_obj: Handoff[Any, RealtimeAgent[Any]]) -> bool: + attr = handoff_obj.is_enabled + if isinstance(attr, bool): + return attr + res = attr(context_wrapper, agent) + if inspect.isawaitable(res): + return await res + return res + + results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) + enabled = [h for h, ok in zip(handoffs, results) if ok] + return enabled diff --git a/tests/realtime/test_realtime_handoffs.py b/tests/realtime/test_realtime_handoffs.py new file mode 100644 index 000000000..07385fe20 --- /dev/null +++ b/tests/realtime/test_realtime_handoffs.py @@ -0,0 +1,96 @@ +"""Tests for realtime handoff functionality.""" + +from unittest.mock import Mock + +import pytest + +from agents import Agent +from agents.realtime import RealtimeAgent, realtime_handoff + + +def test_realtime_handoff_creation(): + """Test basic realtime handoff creation.""" + realtime_agent = RealtimeAgent(name="test_agent") + handoff_obj = realtime_handoff(realtime_agent) + + assert handoff_obj.agent_name == "test_agent" + assert handoff_obj.tool_name == "transfer_to_test_agent" + assert handoff_obj.input_filter is None # Should not support input filters + assert handoff_obj.is_enabled is True + + +def test_realtime_handoff_with_custom_params(): + """Test realtime handoff with custom parameters.""" + realtime_agent = RealtimeAgent( + name="helper_agent", + handoff_description="Helps with general tasks", + ) + + handoff_obj = realtime_handoff( + realtime_agent, + tool_name_override="custom_handoff", + tool_description_override="Custom handoff description", + is_enabled=False, + ) + + assert handoff_obj.agent_name == "helper_agent" + assert handoff_obj.tool_name == "custom_handoff" + assert handoff_obj.tool_description == "Custom handoff description" + assert handoff_obj.is_enabled is False + + +@pytest.mark.asyncio +async def test_realtime_handoff_execution(): + """Test that realtime handoff returns the correct agent.""" + realtime_agent = RealtimeAgent(name="target_agent") + handoff_obj = realtime_handoff(realtime_agent) + + # Mock context + mock_context = Mock() + + # Execute handoff + result = await handoff_obj.on_invoke_handoff(mock_context, "") + + assert result is realtime_agent + assert isinstance(result, RealtimeAgent) + + +def test_realtime_handoff_with_on_handoff_callback(): + """Test realtime handoff with custom on_handoff callback.""" + realtime_agent = RealtimeAgent(name="callback_agent") + callback_called = [] + + def on_handoff_callback(ctx): + callback_called.append(True) + + handoff_obj = realtime_handoff( + realtime_agent, + on_handoff=on_handoff_callback, + ) + + assert handoff_obj.agent_name == "callback_agent" + + +def test_regular_agent_handoff_still_works(): + """Test that regular Agent handoffs still work with the new generic types.""" + from agents import handoff + + regular_agent = Agent(name="regular_agent") + handoff_obj = handoff(regular_agent) + + assert handoff_obj.agent_name == "regular_agent" + assert handoff_obj.tool_name == "transfer_to_regular_agent" + # Regular agent handoffs should support input filters + assert hasattr(handoff_obj, "input_filter") + + +def test_type_annotations_work(): + """Test that type annotations work correctly.""" + from agents.handoffs import Handoff + from agents.realtime.handoffs import realtime_handoff + + realtime_agent = RealtimeAgent(name="typed_agent") + handoff_obj = realtime_handoff(realtime_agent) + + # This should be typed as Handoff[Any, RealtimeAgent[Any]] + assert isinstance(handoff_obj, Handoff) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 4cc0dae6b..d95e4e33d 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -1,5 +1,5 @@ from typing import cast -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, Mock, PropertyMock import pytest @@ -101,6 +101,8 @@ async def close(self): def mock_agent(): agent = Mock(spec=RealtimeAgent) agent.get_all_tools = AsyncMock(return_value=[]) + + type(agent).handoffs = PropertyMock(return_value=[]) return agent @@ -794,30 +796,26 @@ async def test_function_tool_with_multiple_tools_available(self, mock_model, moc assert sent_output == "result_two" @pytest.mark.asyncio - async def test_handoff_tool_handling(self, mock_model, mock_agent, mock_handoff): - """Test that handoff tools are properly handled""" - from unittest.mock import AsyncMock - - from agents.realtime.agent import RealtimeAgent - - # Create a mock new agent to be returned by handoff - mock_new_agent = Mock(spec=RealtimeAgent) - mock_new_agent.name = "new_agent" - mock_new_agent.instructions = "New agent instructions" - mock_new_agent.get_all_tools = AsyncMock(return_value=[]) - mock_new_agent.get_system_prompt = AsyncMock(return_value="New agent system prompt") - - # Set up handoff to return the new agent - mock_handoff.on_invoke_handoff = AsyncMock(return_value=mock_new_agent) - mock_handoff.name = "test_handoff" + async def test_handoff_tool_handling(self, mock_model): + first_agent = RealtimeAgent( + name="first_agent", + instructions="first_agent_instructions", + tools=[], + handoffs=[], + ) + second_agent = RealtimeAgent( + name="second_agent", + instructions="second_agent_instructions", + tools=[], + handoffs=[], + ) - # Set up agent to return handoff tool - mock_agent.get_all_tools.return_value = [mock_handoff] + first_agent.handoffs = [second_agent] - session = RealtimeSession(mock_model, mock_agent, None) + session = RealtimeSession(mock_model, first_agent, None) tool_call_event = RealtimeModelToolCallEvent( - name="test_handoff", call_id="call_789", arguments="{}" + name=Handoff.default_tool_name(second_agent), call_id="call_789", arguments="{}" ) await session._handle_tool_call(tool_call_event) @@ -829,7 +827,7 @@ async def test_handoff_tool_handling(self, mock_model, mock_agent, mock_handoff) assert session._event_queue.qsize() >= 1 # Verify agent was updated - assert session._current_agent == mock_new_agent + assert session._current_agent == second_agent @pytest.mark.asyncio async def test_unknown_tool_handling(self, mock_model, mock_agent, mock_function_tool): diff --git a/tests/test_handoff_tool.py b/tests/test_handoff_tool.py index 0f7fc2166..291f0a4f5 100644 --- a/tests/test_handoff_tool.py +++ b/tests/test_handoff_tool.py @@ -1,3 +1,4 @@ +import inspect import json from typing import Any @@ -318,6 +319,8 @@ def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: handoff_callable_enabled = handoff(agent, is_enabled=always_enabled) assert callable(handoff_callable_enabled.is_enabled) result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent) + assert inspect.isawaitable(result) + result = await result assert result is True # Test callable that returns False @@ -327,6 +330,8 @@ def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: handoff_callable_disabled = handoff(agent, is_enabled=always_disabled) assert callable(handoff_callable_disabled.is_enabled) result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent) + assert inspect.isawaitable(result) + result = await result assert result is False # Test async callable