|
53 | 53 | ToolCallItemTypes,
|
54 | 54 | TResponseInputItem,
|
55 | 55 | )
|
56 |
| -from .lifecycle import RunHooks |
| 56 | +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase |
57 | 57 | from .logger import logger
|
58 | 58 | from .memory import Session, SessionInputCallback
|
59 | 59 | from .model_settings import ModelSettings
|
@@ -417,13 +417,11 @@ async def run(
|
417 | 417 | ) -> RunResult:
|
418 | 418 | context = kwargs.get("context")
|
419 | 419 | max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
420 |
| - hooks = kwargs.get("hooks") |
| 420 | + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) |
421 | 421 | run_config = kwargs.get("run_config")
|
422 | 422 | previous_response_id = kwargs.get("previous_response_id")
|
423 | 423 | conversation_id = kwargs.get("conversation_id")
|
424 | 424 | session = kwargs.get("session")
|
425 |
| - if hooks is None: |
426 |
| - hooks = RunHooks[Any]() |
427 | 425 | if run_config is None:
|
428 | 426 | run_config = RunConfig()
|
429 | 427 |
|
@@ -624,14 +622,12 @@ def run_streamed(
|
624 | 622 | ) -> RunResultStreaming:
|
625 | 623 | context = kwargs.get("context")
|
626 | 624 | max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)
|
627 |
| - hooks = kwargs.get("hooks") |
| 625 | + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) |
628 | 626 | run_config = kwargs.get("run_config")
|
629 | 627 | previous_response_id = kwargs.get("previous_response_id")
|
630 | 628 | conversation_id = kwargs.get("conversation_id")
|
631 | 629 | session = kwargs.get("session")
|
632 | 630 |
|
633 |
| - if hooks is None: |
634 |
| - hooks = RunHooks[Any]() |
635 | 631 | if run_config is None:
|
636 | 632 | run_config = RunConfig()
|
637 | 633 |
|
@@ -688,6 +684,23 @@ def run_streamed(
|
688 | 684 | )
|
689 | 685 | return streamed_result
|
690 | 686 |
|
| 687 | + @staticmethod |
| 688 | + def _validate_run_hooks( |
| 689 | + hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, |
| 690 | + ) -> RunHooks[Any]: |
| 691 | + if hooks is None: |
| 692 | + return RunHooks[Any]() |
| 693 | + input_hook_type = type(hooks).__name__ |
| 694 | + if isinstance(hooks, AgentHooksBase): |
| 695 | + raise TypeError( |
| 696 | + "Run hooks must be instances of RunHooks. " |
| 697 | + f"Received agent-scoped hooks ({input_hook_type}). " |
| 698 | + "Attach AgentHooks to an Agent via Agent(..., hooks=...)." |
| 699 | + ) |
| 700 | + if not isinstance(hooks, RunHooksBase): |
| 701 | + raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") |
| 702 | + return hooks |
| 703 | + |
691 | 704 | @classmethod
|
692 | 705 | async def _maybe_filter_model_input(
|
693 | 706 | cls,
|
|
0 commit comments