@@ -317,7 +317,7 @@ def _create_event_attributes(
317
317
if isinstance (end_timestamp , datetime ):
318
318
attrs [GEN_AI_RUN_STEP_END_TIMESTAMP ] = end_timestamp .isoformat ()
319
319
elif end_timestamp :
320
- # fallback in case int or string gets passed
320
+ # fallback in case string or int string gets passed
321
321
attrs [GEN_AI_RUN_STEP_END_TIMESTAMP ] = str (end_timestamp )
322
322
323
323
if run_step_last_error :
@@ -1324,37 +1324,68 @@ def trace_list_messages(self, function, *args, **kwargs):
1324
1324
server_address = self .get_server_address_from_arg (args [0 ])
1325
1325
thread_id = kwargs .get ("thread_id" )
1326
1326
1327
- span = self .start_list_messages_span (server_address = server_address , thread_id = thread_id )
1328
-
1329
- return _InstrumentedItemPaged (function (* args , ** kwargs ), self .add_thread_message_event , span )
1327
+ return _InstrumentedItemPaged (
1328
+ function (* args , ** kwargs ),
1329
+ start_span_function = self .start_trace_list_messages ,
1330
+ item_instrumentation_function = self .add_thread_message_event ,
1331
+ server_address = server_address ,
1332
+ thread_id = thread_id ,
1333
+ run_id = None ,
1334
+ )
1330
1335
1331
1336
def trace_list_messages_async (self , function , * args , ** kwargs ):
1332
1337
# Note that this method is not async, but it operates on AsyncIterable.
1333
1338
server_address = self .get_server_address_from_arg (args [0 ])
1334
1339
thread_id = kwargs .get ("thread_id" )
1335
1340
1336
- span = self .start_list_messages_span (server_address = server_address , thread_id = thread_id )
1341
+ return _AsyncInstrumentedItemPaged (
1342
+ function (* args , ** kwargs ),
1343
+ start_span_function = self .start_trace_list_messages ,
1344
+ item_instrumentation_function = self .add_thread_message_event ,
1345
+ server_address = server_address ,
1346
+ thread_id = thread_id ,
1347
+ run_id = None ,
1348
+ )
1337
1349
1338
- return _AsyncInstrumentedItemPaged (function (* args , ** kwargs ), self .add_thread_message_event , span )
1350
+ def start_trace_list_messages (
1351
+ self , server_address : Optional [str ] = None , thread_id : Optional [str ] = None , run_id : Optional [str ] = None
1352
+ ):
1353
+ _ = run_id # Unused parameter, but kept for compatibility.
1354
+ return self .start_list_messages_span (server_address = server_address , thread_id = thread_id )
1339
1355
1340
1356
def trace_list_run_steps (self , function , * args , ** kwargs ):
1341
1357
server_address = self .get_server_address_from_arg (args [0 ])
1342
1358
run_id = kwargs .get ("run_id" )
1343
1359
thread_id = kwargs .get ("thread_id" )
1344
1360
1345
- span = self .start_list_run_steps_span (server_address = server_address , run_id = run_id , thread_id = thread_id )
1346
-
1347
- return _InstrumentedItemPaged (function (* args , ** kwargs ), self .add_run_step_event , span )
1361
+ return _InstrumentedItemPaged (
1362
+ function (* args , ** kwargs ),
1363
+ start_span_function = self .start_list_run_steps_span ,
1364
+ item_instrumentation_function = self .add_run_step_event ,
1365
+ server_address = server_address ,
1366
+ thread_id = thread_id ,
1367
+ run_id = run_id ,
1368
+ )
1348
1369
1349
1370
def trace_list_run_steps_async (self , function , * args , ** kwargs ):
1350
1371
# Note that this method is not async, but it operates on AsyncIterable.
1351
1372
server_address = self .get_server_address_from_arg (args [0 ])
1352
1373
run_id = kwargs .get ("run_id" )
1353
1374
thread_id = kwargs .get ("thread_id" )
1354
1375
1355
- span = self .start_list_run_steps_span (server_address = server_address , run_id = run_id , thread_id = thread_id )
1376
+ return _AsyncInstrumentedItemPaged (
1377
+ function (* args , ** kwargs ),
1378
+ start_span_function = self .start_list_run_steps_span ,
1379
+ item_instrumentation_function = self .add_run_step_event ,
1380
+ server_address = server_address ,
1381
+ thread_id = thread_id ,
1382
+ run_id = run_id ,
1383
+ )
1356
1384
1357
- return _AsyncInstrumentedItemPaged (function (* args , ** kwargs ), self .add_run_step_event , span )
1385
+ def start_trace_list_run_steps (
1386
+ self , server_address : Optional [str ] = None , thread_id : Optional [str ] = None , run_id : Optional [str ] = None
1387
+ ):
1388
+ return self .start_list_run_steps_span (server_address = server_address , thread_id = thread_id , run_id = run_id )
1358
1389
1359
1390
def handle_run_stream_exit (self , _function , * args , ** kwargs ):
1360
1391
agent_run_stream = args [0 ]
@@ -2058,6 +2089,7 @@ def on_unhandled_event(self, event_type: str, event_data: Any) -> None: # type:
2058
2089
if self .inner_handler :
2059
2090
return self .inner_handler .on_unhandled_event (event_type , event_data ) # type: ignore
2060
2091
return super ().on_unhandled_event (event_type , event_data ) # type: ignore
2092
+
2061
2093
# pylint: enable=R1710
2062
2094
2063
2095
def __exit__ (self , exc_type , exc_val , exc_tb ):
0 commit comments