25
25
from ..telemetry .metrics import Trace
26
26
from ..telemetry .tracer import get_tracer
27
27
from ..tools ._validator import validate_and_prepare_tools
28
+ from ..types ._events import (
29
+ EventLoopStopEvent ,
30
+ EventLoopThrottleEvent ,
31
+ ForceStopEvent ,
32
+ ModelMessageEvent ,
33
+ StartEvent ,
34
+ StartEventLoopEvent ,
35
+ ToolResultMessageEvent ,
36
+ )
28
37
from ..types .content import Message
29
38
from ..types .exceptions import (
30
39
ContextWindowOverflowException ,
@@ -91,8 +100,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
91
100
cycle_start_time , cycle_trace = agent .event_loop_metrics .start_cycle (attributes = attributes )
92
101
invocation_state ["event_loop_cycle_trace" ] = cycle_trace
93
102
94
- yield { "callback" : { "start" : True }}
95
- yield { "callback" : { "start_event_loop" : True }}
103
+ yield StartEvent ()
104
+ yield StartEventLoopEvent ()
96
105
97
106
# Create tracer span for this event loop cycle
98
107
tracer = get_tracer ()
@@ -175,7 +184,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
175
184
176
185
if isinstance (e , ModelThrottledException ):
177
186
if attempt + 1 == MAX_ATTEMPTS :
178
- yield { "callback" : { "force_stop" : True , "force_stop_reason" : str ( e )}}
187
+ yield ForceStopEvent ( reason = e )
179
188
raise e
180
189
181
190
logger .debug (
@@ -189,7 +198,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
189
198
time .sleep (current_delay )
190
199
current_delay = min (current_delay * 2 , MAX_DELAY )
191
200
192
- yield { "callback" : { "event_loop_throttled_delay" : current_delay , ** invocation_state }}
201
+ yield EventLoopThrottleEvent ( delay = current_delay , invocation_state = invocation_state )
193
202
else :
194
203
raise e
195
204
@@ -201,7 +210,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
201
210
# Add the response message to the conversation
202
211
agent .messages .append (message )
203
212
agent .hooks .invoke_callbacks (MessageAddedEvent (agent = agent , message = message ))
204
- yield { "callback" : { " message" : message }}
213
+ yield ModelMessageEvent ( message = message )
205
214
206
215
# Update metrics
207
216
agent .event_loop_metrics .update_usage (usage )
@@ -235,8 +244,8 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
235
244
cycle_start_time = cycle_start_time ,
236
245
invocation_state = invocation_state ,
237
246
)
238
- async for event in events :
239
- yield event
247
+ async for typed_event in events :
248
+ yield typed_event
240
249
241
250
return
242
251
@@ -264,11 +273,11 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
264
273
tracer .end_span_with_error (cycle_span , str (e ), e )
265
274
266
275
# Handle any other exceptions
267
- yield { "callback" : { "force_stop" : True , "force_stop_reason" : str ( e )}}
276
+ yield ForceStopEvent ( reason = e )
268
277
logger .exception ("cycle failed" )
269
278
raise EventLoopException (e , invocation_state ["request_state" ]) from e
270
279
271
- yield { "stop" : (stop_reason , message , agent .event_loop_metrics , invocation_state ["request_state" ])}
280
+ yield EventLoopStopEvent (stop_reason , message , agent .event_loop_metrics , invocation_state ["request_state" ])
272
281
273
282
274
283
async def recurse_event_loop (agent : "Agent" , invocation_state : dict [str , Any ]) -> AsyncGenerator [dict [str , Any ], None ]:
@@ -295,7 +304,7 @@ async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -
295
304
recursive_trace = Trace ("Recursive call" , parent_id = cycle_trace .id )
296
305
cycle_trace .add_child (recursive_trace )
297
306
298
- yield { "callback" : { "start" : True }}
307
+ yield StartEvent ()
299
308
300
309
events = event_loop_cycle (agent = agent , invocation_state = invocation_state )
301
310
async for event in events :
@@ -339,7 +348,7 @@ async def _handle_tool_execution(
339
348
validate_and_prepare_tools (message , tool_uses , tool_results , invalid_tool_use_ids )
340
349
tool_uses = [tool_use for tool_use in tool_uses if tool_use .get ("toolUseId" ) not in invalid_tool_use_ids ]
341
350
if not tool_uses :
342
- yield { "stop" : (stop_reason , message , agent .event_loop_metrics , invocation_state ["request_state" ])}
351
+ yield EventLoopStopEvent (stop_reason , message , agent .event_loop_metrics , invocation_state ["request_state" ])
343
352
return
344
353
345
354
tool_events = agent .tool_executor ._execute (
@@ -358,15 +367,15 @@ async def _handle_tool_execution(
358
367
359
368
agent .messages .append (tool_result_message )
360
369
agent .hooks .invoke_callbacks (MessageAddedEvent (agent = agent , message = tool_result_message ))
361
- yield { "callback" : { " message" : tool_result_message }}
370
+ yield ToolResultMessageEvent ( message = message )
362
371
363
372
if cycle_span :
364
373
tracer = get_tracer ()
365
374
tracer .end_event_loop_cycle_span (span = cycle_span , message = message , tool_result_message = tool_result_message )
366
375
367
376
if invocation_state ["request_state" ].get ("stop_event_loop" , False ):
368
377
agent .event_loop_metrics .end_cycle (cycle_start_time , cycle_trace )
369
- yield { "stop" : (stop_reason , message , agent .event_loop_metrics , invocation_state ["request_state" ])}
378
+ yield EventLoopStopEvent (stop_reason , message , agent .event_loop_metrics , invocation_state ["request_state" ])
370
379
return
371
380
372
381
events = recurse_event_loop (agent = agent , invocation_state = invocation_state )
0 commit comments