diff --git a/src/agents/run.py b/src/agents/run.py index 42339eb50..722aa1558 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -53,7 +53,7 @@ ToolCallItemTypes, TResponseInputItem, ) -from .lifecycle import RunHooks +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase from .logger import logger from .memory import Session, SessionInputCallback from .model_settings import ModelSettings @@ -461,13 +461,11 @@ async def run( ) -> RunResult: context = kwargs.get("context") max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") - if hooks is None: - hooks = RunHooks[Any]() if run_config is None: run_config = RunConfig() @@ -668,14 +666,12 @@ def run_streamed( ) -> RunResultStreaming: context = kwargs.get("context") max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) - hooks = kwargs.get("hooks") + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) run_config = kwargs.get("run_config") previous_response_id = kwargs.get("previous_response_id") conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") - if hooks is None: - hooks = RunHooks[Any]() if run_config is None: run_config = RunConfig() @@ -732,6 +728,23 @@ def run_streamed( ) return streamed_result + @staticmethod + def _validate_run_hooks( + hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, + ) -> RunHooks[Any]: + if hooks is None: + return RunHooks[Any]() + input_hook_type = type(hooks).__name__ + if isinstance(hooks, AgentHooksBase): + raise TypeError( + "Run hooks must be instances of RunHooks. " + f"Received agent-scoped hooks ({input_hook_type}). " + "Attach AgentHooks to an Agent via Agent(..., hooks=...)." + ) + if not isinstance(hooks, RunHooksBase): + raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") + return hooks + @classmethod async def _maybe_filter_model_input( cls, diff --git a/tests/test_run_hooks.py b/tests/test_run_hooks.py index 988cd6dc2..f5a2ed478 100644 --- a/tests/test_run_hooks.py +++ b/tests/test_run_hooks.py @@ -1,11 +1,11 @@ from collections import defaultdict -from typing import Any, Optional +from typing import Any, Optional, cast import pytest from agents.agent import Agent from agents.items import ItemHelpers, ModelResponse, TResponseInputItem -from agents.lifecycle import RunHooks +from agents.lifecycle import AgentHooks, RunHooks from agents.models.interface import Model from agents.run import Runner from agents.run_context import RunContextWrapper, TContext @@ -191,6 +191,29 @@ async def boom(*args, **kwargs): assert hooks.events["on_agent_end"] == 0 +class DummyAgentHooks(AgentHooks): + """Agent-scoped hooks used to verify runtime validation.""" + + +@pytest.mark.asyncio +async def test_runner_run_rejects_agent_hooks(): + model = FakeModel() + agent = Agent(name="A", model=model) + hooks = cast(RunHooks, DummyAgentHooks()) + + with pytest.raises(TypeError, match="Run hooks must be instances of RunHooks"): + await Runner.run(agent, input="hello", hooks=hooks) + + +def test_runner_run_streamed_rejects_agent_hooks(): + model = FakeModel() + agent = Agent(name="A", model=model) + hooks = cast(RunHooks, DummyAgentHooks()) + + with pytest.raises(TypeError, match="Run hooks must be instances of RunHooks"): + Runner.run_streamed(agent, input="hello", hooks=hooks) + + class BoomModel(Model): async def get_response(self, *a, **k): raise AssertionError("get_response should not be called in streaming test")