Skip to content
66 changes: 64 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import dataclasses
import inspect
from collections.abc import Awaitable
from collections.abc import Awaitable, Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, get_args

from . import _utils
from ._utils import MaybeAwaitable
Expand Down Expand Up @@ -157,3 +157,65 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
logger.error(f"Instructions must be a string or a function, got {self.instructions}")

return None

def _add_item(
self, item: Any, attr_name: str, expected_type: type | tuple[type, ...]
) -> None:
"""
Generic method to add an item to a specified list attribute.
Checks if the item is of the expected type and prevents duplicates.
"""
if not isinstance(item, expected_type):
raise TypeError(f"Expected {expected_type}, got {type(item)}")
attr_list = getattr(self, attr_name)
if item.name not in [existing.name for existing in attr_list]:
attr_list.append(item)
else:
logger.warning(
f"Item {item.name} already exists in {attr_name}. "
"Skipping addition to avoid duplicates."
)

@staticmethod
def _ensure_iterable(item: Any) -> Iterable[Any]:
"""
Ensures the provided item is iterable.
Strings and bytes are treated as non-iterable for our purposes.
"""
if isinstance(item, (str, bytes)):
return [item]
try:
iter(item)
return cast(Iterable[Any], item)
except TypeError:
return [item]

def add_tools(self, tool: Tool | Iterable[Tool]) -> None:
"""Add one or multiple tools to the agent's tool list."""
expected_types = get_args(Tool)
for t in self._ensure_iterable(tool):
self._add_item(t, "tools", expected_types)

def add_handoffs(
self,
handoff: Agent[Any] | Handoff[TContext] | Iterable[Agent[Any] | Handoff[TContext]],
) -> None:
"""Add one or multiple handoffs to the agent's handoff list."""
for h in self._ensure_iterable(handoff):
self._add_item(h, "handoffs", (Agent, Handoff))

def add_input_guardrails(
self,
guardrail: InputGuardrail[TContext] | Iterable[InputGuardrail[TContext]],
) -> None:
"""Add one or multiple input guardrails to the agent's input guardrail list."""
for g in self._ensure_iterable(guardrail):
self._add_item(g, "input_guardrails", InputGuardrail)

def add_output_guardrails(
self,
guardrail: OutputGuardrail[TContext] | Iterable[OutputGuardrail[TContext]],
) -> None:
"""Add one or multiple output guardrails to the agent's output guardrail list."""
for g in self._ensure_iterable(guardrail):
self._add_item(g, "output_guardrails", OutputGuardrail)