Skip to content

Commit 08b868d

Browse files
committed
fix #1750 better error message when passing AgentHooks to Runner
1 parent 456d284 commit 08b868d

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

src/agents/run.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
ToolCallItemTypes,
5454
TResponseInputItem,
5555
)
56-
from .lifecycle import RunHooks
56+
from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase
5757
from .logger import logger
5858
from .memory import Session, SessionInputCallback
5959
from .model_settings import ModelSettings
@@ -417,13 +417,11 @@ async def run(
417417
) -> RunResult:
418418
context = kwargs.get("context")
419419
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
420-
hooks = kwargs.get("hooks")
420+
hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks")))
421421
run_config = kwargs.get("run_config")
422422
previous_response_id = kwargs.get("previous_response_id")
423423
conversation_id = kwargs.get("conversation_id")
424424
session = kwargs.get("session")
425-
if hooks is None:
426-
hooks = RunHooks[Any]()
427425
if run_config is None:
428426
run_config = RunConfig()
429427

@@ -624,14 +622,12 @@ def run_streamed(
624622
) -> RunResultStreaming:
625623
context = kwargs.get("context")
626624
max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
627-
hooks = kwargs.get("hooks")
625+
hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks")))
628626
run_config = kwargs.get("run_config")
629627
previous_response_id = kwargs.get("previous_response_id")
630628
conversation_id = kwargs.get("conversation_id")
631629
session = kwargs.get("session")
632630

633-
if hooks is None:
634-
hooks = RunHooks[Any]()
635631
if run_config is None:
636632
run_config = RunConfig()
637633

@@ -688,6 +684,23 @@ def run_streamed(
688684
)
689685
return streamed_result
690686

687+
@staticmethod
688+
def _validate_run_hooks(
689+
hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None,
690+
) -> RunHooks[Any]:
691+
if hooks is None:
692+
return RunHooks[Any]()
693+
input_hook_type = type(hooks).__name__
694+
if isinstance(hooks, AgentHooksBase):
695+
raise TypeError(
696+
"Run hooks must be instances of RunHooks. "
697+
f"Received agent-scoped hooks ({input_hook_type}). "
698+
"Attach AgentHooks to an Agent via Agent(..., hooks=...)."
699+
)
700+
if not isinstance(hooks, RunHooksBase):
701+
raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.")
702+
return hooks
703+
691704
@classmethod
692705
async def _maybe_filter_model_input(
693706
cls,

tests/test_agent_as_tool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import pytest
24
from pydantic import BaseModel
35

tests/test_run_hooks.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from collections import defaultdict
2-
from typing import Any, Optional
2+
from typing import Any, Optional, cast
33

44
import pytest
55

66
from agents.agent import Agent
77
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
8-
from agents.lifecycle import RunHooks
8+
from agents.lifecycle import AgentHooks, RunHooks
99
from agents.models.interface import Model
1010
from agents.run import Runner
1111
from agents.run_context import RunContextWrapper, TContext
@@ -191,6 +191,29 @@ async def boom(*args, **kwargs):
191191
assert hooks.events["on_agent_end"] == 0
192192

193193

194+
class DummyAgentHooks(AgentHooks):
195+
"""Agent-scoped hooks used to verify runtime validation."""
196+
197+
198+
@pytest.mark.asyncio
199+
async def test_runner_run_rejects_agent_hooks():
200+
model = FakeModel()
201+
agent = Agent(name="A", model=model)
202+
hooks = cast(RunHooks, DummyAgentHooks())
203+
204+
with pytest.raises(TypeError, match="Run hooks must be instances of RunHooks"):
205+
await Runner.run(agent, input="hello", hooks=hooks)
206+
207+
208+
def test_runner_run_streamed_rejects_agent_hooks():
209+
model = FakeModel()
210+
agent = Agent(name="A", model=model)
211+
hooks = cast(RunHooks, DummyAgentHooks())
212+
213+
with pytest.raises(TypeError, match="Run hooks must be instances of RunHooks"):
214+
Runner.run_streamed(agent, input="hello", hooks=hooks)
215+
216+
194217
class BoomModel(Model):
195218
async def get_response(self, *a, **k):
196219
raise AssertionError("get_response should not be called in streaming test")

0 commit comments

Comments
 (0)