diff --git a/src/agents/agent.py b/src/agents/agent.py index 61c0a8966..2416beb44 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -2,9 +2,9 @@ import dataclasses import inspect -from collections.abc import Awaitable +from collections.abc import Awaitable, Iterable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, cast, get_args from . import _utils from ._utils import MaybeAwaitable @@ -157,3 +157,65 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s logger.error(f"Instructions must be a string or a function, got {self.instructions}") return None + + def _add_item( + self, item: Any, attr_name: str, expected_type: type | tuple[type, ...] + ) -> None: + """ + Generic method to add an item to a specified list attribute. + Checks if the item is of the expected type and prevents duplicates. + """ + if not isinstance(item, expected_type): + raise TypeError(f"Expected {expected_type}, got {type(item)}") + attr_list = getattr(self, attr_name) + if item.name not in [existing.name for existing in attr_list]: + attr_list.append(item) + else: + logger.warning( + f"Item {item.name} already exists in {attr_name}. " + "Skipping addition to avoid duplicates." + ) + + @staticmethod + def _ensure_iterable(item: Any) -> Iterable[Any]: + """ + Ensures the provided item is iterable. + Strings and bytes are treated as non-iterable for our purposes. + """ + if isinstance(item, (str, bytes)): + return [item] + try: + iter(item) + return cast(Iterable[Any], item) + except TypeError: + return [item] + + def add_tools(self, tool: Tool | Iterable[Tool]) -> None: + """Add one or multiple tools to the agent's tool list.""" + expected_types = get_args(Tool) + for t in self._ensure_iterable(tool): + self._add_item(t, "tools", expected_types) + + def add_handoffs( + self, + handoff: Agent[Any] | Handoff[TContext] | Iterable[Agent[Any] | Handoff[TContext]], + ) -> None: + """Add one or multiple handoffs to the agent's handoff list.""" + for h in self._ensure_iterable(handoff): + self._add_item(h, "handoffs", (Agent, Handoff)) + + def add_input_guardrails( + self, + guardrail: InputGuardrail[TContext] | Iterable[InputGuardrail[TContext]], + ) -> None: + """Add one or multiple input guardrails to the agent's input guardrail list.""" + for g in self._ensure_iterable(guardrail): + self._add_item(g, "input_guardrails", InputGuardrail) + + def add_output_guardrails( + self, + guardrail: OutputGuardrail[TContext] | Iterable[OutputGuardrail[TContext]], + ) -> None: + """Add one or multiple output guardrails to the agent's output guardrail list.""" + for g in self._ensure_iterable(guardrail): + self._add_item(g, "output_guardrails", OutputGuardrail)