Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class Agent(AgentBase, Generic[TContext]):
usable with OpenAI models, using the Responses API.
"""

handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list)
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
modularity.
Expand Down
2 changes: 1 addition & 1 deletion src/agents/guardrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def decorator(
return InputGuardrail(
guardrail_function=f,
# If not set, guardrail name uses the function’s name by default.
name=name if name else f.__name__
name=name if name else f.__name__,
)

if func is not None:
Expand Down
46 changes: 32 additions & 14 deletions src/agents/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
from .util._types import MaybeAwaitable

if TYPE_CHECKING:
from .agent import Agent
from .agent import Agent, AgentBase


# The handoff input type is the type of data passed when the agent is called via a handoff.
THandoffInput = TypeVar("THandoffInput", default=Any)

# The agent type that the handoff returns
TAgent = TypeVar("TAgent", bound="AgentBase[Any]", default="Agent[Any]")

OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]

Expand Down Expand Up @@ -52,7 +55,7 @@ class HandoffInputData:


@dataclass
class Handoff(Generic[TContext]):
class Handoff(Generic[TContext, TAgent]):
"""A handoff is when an agent delegates a task to another agent.
For example, in a customer support scenario you might have a "triage agent" that determines
which agent should handle the user's request, and sub-agents that specialize in different
Expand All @@ -69,7 +72,7 @@ class Handoff(Generic[TContext]):
"""The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
"""

on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[Agent[TContext]]]
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[TAgent]]
"""The function that invokes the handoff. The parameters passed are:
1. The handoff run context
2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
Expand Down Expand Up @@ -100,20 +103,22 @@ class Handoff(Generic[TContext]):
True, as it increases the likelihood of correct JSON input.
"""

is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
True
)
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
a handoff based on your context/state."""

def get_transfer_message(self, agent: Agent[Any]) -> str:
def get_transfer_message(self, agent: AgentBase[Any]) -> str:
return json.dumps({"assistant": agent.name})

@classmethod
def default_tool_name(cls, agent: Agent[Any]) -> str:
def default_tool_name(cls, agent: AgentBase[Any]) -> str:
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")

@classmethod
def default_tool_description(cls, agent: Agent[Any]) -> str:
def default_tool_description(cls, agent: AgentBase[Any]) -> str:
return (
f"Handoff to the {agent.name} agent to handle the request. "
f"{agent.handoff_description or ''}"
Expand All @@ -128,7 +133,7 @@ def handoff(
tool_description_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]: ...
) -> Handoff[TContext, Agent[TContext]]: ...


@overload
Expand All @@ -141,7 +146,7 @@ def handoff(
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]: ...
) -> Handoff[TContext, Agent[TContext]]: ...


@overload
Expand All @@ -153,7 +158,7 @@ def handoff(
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]: ...
) -> Handoff[TContext, Agent[TContext]]: ...


def handoff(
Expand All @@ -163,8 +168,9 @@ def handoff(
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
input_type: type[THandoffInput] | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext]:
is_enabled: bool
| Callable[[RunContextWrapper[Any], Agent[TContext]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext, Agent[TContext]]:
"""Create a handoff from an agent.

Args:
Expand Down Expand Up @@ -202,7 +208,7 @@ def handoff(

async def _invoke_handoff(
ctx: RunContextWrapper[Any], input_json: str | None = None
) -> Agent[Any]:
) -> Agent[TContext]:
if input_type is not None and type_adapter is not None:
if input_json is None:
_error_tracing.attach_error_to_current_span(
Expand Down Expand Up @@ -239,12 +245,24 @@ async def _invoke_handoff(
# If there is a need, we can make this configurable in the future
input_json_schema = ensure_strict_json_schema(input_json_schema)

async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making this consistent looks good; with this, we may be able to switch the is_enabled property of the class object from Callable[[RunContextWrapper[Any], Agent[TContext]], MaybeAwaitable[bool]] to Callable[[..., ..., Awaitable[bool]] in future releases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true

from .agent import Agent

assert callable(is_enabled), "is_enabled must be non-null here"
assert isinstance(agent_base, Agent), "Can't handoff to a non-Agent"
result = is_enabled(ctx, agent_base)

if inspect.isawaitable(result):
return await result

return result

return Handoff(
tool_name=tool_name,
tool_description=tool_description,
input_json_schema=input_json_schema,
on_invoke_handoff=_invoke_handoff,
input_filter=input_filter,
agent_name=agent.name,
is_enabled=is_enabled,
is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
)
2 changes: 1 addition & 1 deletion src/agents/models/chatcmpl_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam:
)

@classmethod
def convert_handoff_tool(cls, handoff: Handoff[Any]) -> ChatCompletionToolParam:
def convert_handoff_tool(cls, handoff: Handoff[Any, Any]) -> ChatCompletionToolParam:
return {
"type": "function",
"function": {
Expand Down
2 changes: 1 addition & 1 deletion src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def get_response_format(
def convert_tools(
cls,
tools: list[Tool],
handoffs: list[Handoff[Any]],
handoffs: list[Handoff[Any, Any]],
) -> ConvertedTools:
converted_tools: list[ToolParam] = []
includes: list[ResponseIncludable] = []
Expand Down
3 changes: 3 additions & 0 deletions src/agents/realtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RealtimeToolEnd,
RealtimeToolStart,
)
from .handoffs import realtime_handoff
from .items import (
AssistantMessageItem,
AssistantText,
Expand Down Expand Up @@ -92,6 +93,8 @@
"RealtimeAgentHooks",
"RealtimeRunHooks",
"RealtimeRunner",
# Handoffs
"realtime_handoff",
# Config
"RealtimeAudioFormat",
"RealtimeClientMessage",
Expand Down
11 changes: 10 additions & 1 deletion src/agents/realtime/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import dataclasses
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, cast

from ..agent import AgentBase
from ..handoffs import Handoff
from ..lifecycle import AgentHooksBase, RunHooksBase
from ..logger import logger
from ..run_context import RunContextWrapper, TContext
Expand Down Expand Up @@ -53,6 +54,14 @@ class RealtimeAgent(AgentBase, Generic[TContext]):
return a string.
"""

handoffs: list[RealtimeAgent[Any] | Handoff[TContext, RealtimeAgent[Any]]] = field(
default_factory=list
)
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
modularity.
"""

hooks: RealtimeAgentHooks | None = None
"""A class that receives callbacks on various lifecycle events for this agent.
"""
Expand Down
2 changes: 2 additions & 0 deletions src/agents/realtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import NotRequired, TypeAlias, TypedDict

from ..guardrail import OutputGuardrail
from ..handoffs import Handoff
from ..model_settings import ToolChoice
from ..tool import Tool

Expand Down Expand Up @@ -71,6 +72,7 @@ class RealtimeSessionModelSettings(TypedDict):

tool_choice: NotRequired[ToolChoice]
tools: NotRequired[list[Tool]]
handoffs: NotRequired[list[Handoff]]

tracing: NotRequired[RealtimeModelTracingConfig | None]

Expand Down
165 changes: 165 additions & 0 deletions src/agents/realtime/handoffs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, Callable, cast, overload

from pydantic import TypeAdapter
from typing_extensions import TypeVar

from ..exceptions import ModelBehaviorError, UserError
from ..handoffs import Handoff
from ..run_context import RunContextWrapper, TContext
from ..strict_schema import ensure_strict_json_schema
from ..tracing.spans import SpanError
from ..util import _error_tracing, _json
from ..util._types import MaybeAwaitable

if TYPE_CHECKING:
from ..agent import AgentBase
from . import RealtimeAgent


# The handoff input type is the type of data passed when the agent is called via a handoff.
THandoffInput = TypeVar("THandoffInput", default=Any)

OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]


@overload
def realtime_handoff(
agent: RealtimeAgent[TContext],
*,
tool_name_override: str | None = None,
tool_description_override: str | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext, RealtimeAgent[TContext]]: ...


@overload
def realtime_handoff(
agent: RealtimeAgent[TContext],
*,
on_handoff: OnHandoffWithInput[THandoffInput],
input_type: type[THandoffInput],
tool_description_override: str | None = None,
tool_name_override: str | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext, RealtimeAgent[TContext]]: ...


@overload
def realtime_handoff(
agent: RealtimeAgent[TContext],
*,
on_handoff: OnHandoffWithoutInput,
tool_description_override: str | None = None,
tool_name_override: str | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext, RealtimeAgent[TContext]]: ...


def realtime_handoff(
agent: RealtimeAgent[TContext],
tool_name_override: str | None = None,
tool_description_override: str | None = None,
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
input_type: type[THandoffInput] | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
) -> Handoff[TContext, RealtimeAgent[TContext]]:
"""Create a handoff from a RealtimeAgent.

Args:
agent: The RealtimeAgent to handoff to, or a function that returns a RealtimeAgent.
tool_name_override: Optional override for the name of the tool that represents the handoff.
tool_description_override: Optional override for the description of the tool that
represents the handoff.
on_handoff: A function that runs when the handoff is invoked.
input_type: the type of the input to the handoff. If provided, the input will be validated
against this type. Only relevant if you pass a function that takes an input.
is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the handoff is enabled. Disabled handoffs are
hidden from the LLM at runtime.

Note: input_filter is not supported for RealtimeAgent handoffs.
"""
assert (on_handoff and input_type) or not (on_handoff and input_type), (
"You must provide either both on_handoff and input_type, or neither"
)
type_adapter: TypeAdapter[Any] | None
if input_type is not None:
assert callable(on_handoff), "on_handoff must be callable"
sig = inspect.signature(on_handoff)
if len(sig.parameters) != 2:
raise UserError("on_handoff must take two arguments: context and input")

type_adapter = TypeAdapter(input_type)
input_json_schema = type_adapter.json_schema()
else:
type_adapter = None
input_json_schema = {}
if on_handoff is not None:
sig = inspect.signature(on_handoff)
if len(sig.parameters) != 1:
raise UserError("on_handoff must take one argument: context")

async def _invoke_handoff(
ctx: RunContextWrapper[Any], input_json: str | None = None
) -> RealtimeAgent[TContext]:
if input_type is not None and type_adapter is not None:
if input_json is None:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Handoff function expected non-null input, but got None",
data={"details": "input_json is None"},
)
)
raise ModelBehaviorError("Handoff function expected non-null input, but got None")

validated_input = _json.validate_json(
json_str=input_json,
type_adapter=type_adapter,
partial=False,
)
input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff)
if inspect.iscoroutinefunction(input_func):
await input_func(ctx, validated_input)
else:
input_func(ctx, validated_input)
elif on_handoff is not None:
no_input_func = cast(OnHandoffWithoutInput, on_handoff)
if inspect.iscoroutinefunction(no_input_func):
await no_input_func(ctx)
else:
no_input_func(ctx)

return agent

tool_name = tool_name_override or Handoff.default_tool_name(agent)
tool_description = tool_description_override or Handoff.default_tool_description(agent)

# Always ensure the input JSON schema is in strict mode
# If there is a need, we can make this configurable in the future
input_json_schema = ensure_strict_json_schema(input_json_schema)

async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool:
assert callable(is_enabled), "is_enabled must be non-null here"
assert isinstance(agent_base, RealtimeAgent), "Can't handoff to a non-RealtimeAgent"
result = is_enabled(ctx, agent_base)
if inspect.isawaitable(result):
return await result
return result

return Handoff(
tool_name=tool_name,
tool_description=tool_description,
input_json_schema=input_json_schema,
on_invoke_handoff=_invoke_handoff,
input_filter=None, # Not supported for RealtimeAgent handoffs
agent_name=agent.name,
is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
)
Loading