11from typing import Any , Generic
22
3- from .agent import Agent
3+ from typing_extensions import TypeVar
4+
5+ from .agent import Agent , AgentBase
46from .run_context import RunContextWrapper , TContext
57from .tool import Tool
68
9+ TAgent = TypeVar ("TAgent" , bound = AgentBase , default = AgentBase )
10+
711
8- class RunHooks (Generic [TContext ]):
12+ class RunHooksBase (Generic [TContext , TAgent ]):
913 """A class that receives callbacks on various lifecycle events in an agent run. Subclass and
1014 override the methods you need.
1115 """
1216
13- async def on_agent_start (
14- self , context : RunContextWrapper [TContext ], agent : Agent [TContext ]
15- ) -> None :
17+ async def on_agent_start (self , context : RunContextWrapper [TContext ], agent : TAgent ) -> None :
1618 """Called before the agent is invoked. Called each time the current agent changes."""
1719 pass
1820
1921 async def on_agent_end (
2022 self ,
2123 context : RunContextWrapper [TContext ],
22- agent : Agent [ TContext ] ,
24+ agent : TAgent ,
2325 output : Any ,
2426 ) -> None :
2527 """Called when the agent produces a final output."""
@@ -28,16 +30,16 @@ async def on_agent_end(
2830 async def on_handoff (
2931 self ,
3032 context : RunContextWrapper [TContext ],
31- from_agent : Agent [ TContext ] ,
32- to_agent : Agent [ TContext ] ,
33+ from_agent : TAgent ,
34+ to_agent : TAgent ,
3335 ) -> None :
3436 """Called when a handoff occurs."""
3537 pass
3638
3739 async def on_tool_start (
3840 self ,
3941 context : RunContextWrapper [TContext ],
40- agent : Agent [ TContext ] ,
42+ agent : TAgent ,
4143 tool : Tool ,
4244 ) -> None :
4345 """Called before a tool is invoked."""
@@ -46,30 +48,30 @@ async def on_tool_start(
4648 async def on_tool_end (
4749 self ,
4850 context : RunContextWrapper [TContext ],
49- agent : Agent [ TContext ] ,
51+ agent : TAgent ,
5052 tool : Tool ,
5153 result : str ,
5254 ) -> None :
5355 """Called after a tool is invoked."""
5456 pass
5557
5658
57- class AgentHooks (Generic [TContext ]):
59+ class AgentHooksBase (Generic [TContext , TAgent ]):
5860 """A class that receives callbacks on various lifecycle events for a specific agent. You can
5961 set this on `agent.hooks` to receive events for that specific agent.
6062
6163 Subclass and override the methods you need.
6264 """
6365
64- async def on_start (self , context : RunContextWrapper [TContext ], agent : Agent [ TContext ] ) -> None :
66+ async def on_start (self , context : RunContextWrapper [TContext ], agent : TAgent ) -> None :
6567 """Called before the agent is invoked. Called each time the running agent is changed to this
6668 agent."""
6769 pass
6870
6971 async def on_end (
7072 self ,
7173 context : RunContextWrapper [TContext ],
72- agent : Agent [ TContext ] ,
74+ agent : TAgent ,
7375 output : Any ,
7476 ) -> None :
7577 """Called when the agent produces a final output."""
@@ -78,8 +80,8 @@ async def on_end(
7880 async def on_handoff (
7981 self ,
8082 context : RunContextWrapper [TContext ],
81- agent : Agent [ TContext ] ,
82- source : Agent [ TContext ] ,
83+ agent : TAgent ,
84+ source : TAgent ,
8385 ) -> None :
8486 """Called when the agent is being handed off to. The `source` is the agent that is handing
8587 off to this agent."""
@@ -88,7 +90,7 @@ async def on_handoff(
8890 async def on_tool_start (
8991 self ,
9092 context : RunContextWrapper [TContext ],
91- agent : Agent [ TContext ] ,
93+ agent : TAgent ,
9294 tool : Tool ,
9395 ) -> None :
9496 """Called before a tool is invoked."""
@@ -97,9 +99,16 @@ async def on_tool_start(
9799 async def on_tool_end (
98100 self ,
99101 context : RunContextWrapper [TContext ],
100- agent : Agent [ TContext ] ,
102+ agent : TAgent ,
101103 tool : Tool ,
102104 result : str ,
103105 ) -> None :
104106 """Called after a tool is invoked."""
105107 pass
108+
109+
110+ RunHooks = RunHooksBase [TContext , Agent ]
111+ """Run hooks when using `Agent`."""
112+
113+ AgentHooks = AgentHooksBase [TContext , Agent ]
114+ """Agent hooks for `Agent`s."""
0 commit comments