@@ -176,20 +176,24 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
176176 embedding_content = str (embedding_content )
177177 request_model = kwargs .get ("model" )
178178
179+ embedding_token_count = (
180+ settings .ai_monitoring .llm_token_count_callback (request_model , embedding_content )
181+ if settings .ai_monitoring .llm_token_count_callback
182+ else None
183+ )
184+
179185 full_embedding_response_dict = {
180186 "id" : embedding_id ,
181187 "span_id" : span_id ,
182188 "trace_id" : trace_id ,
183- "token_count" : (
184- settings .ai_monitoring .llm_token_count_callback (request_model , embedding_content )
185- if settings .ai_monitoring .llm_token_count_callback
186- else None
187- ),
188189 "request.model" : request_model ,
189190 "duration" : ft .duration * 1000 ,
190191 "vendor" : "gemini" ,
191192 "ingest_source" : "Python" ,
192193 }
194+ if embedding_token_count :
195+ full_embedding_response_dict ["response.usage.total_tokens" ] = embedding_token_count
196+
193197 if settings .ai_monitoring .record_content .enabled :
194198 full_embedding_response_dict ["input" ] = embedding_content
195199
@@ -303,15 +307,13 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
303307 "Unable to parse input message to Gemini LLM. Message content and role will be omitted from "
304308 "corresponding LlmChatCompletionMessage event. "
305309 )
310+ # Extract the input message content and role from the input message if it exists
311+ input_message_content , input_role = _parse_input_message (input_message ) if input_message else (None , None )
306312
307- generation_config = kwargs .get ("config" )
308- if generation_config :
309- request_temperature = getattr (generation_config , "temperature" , None )
310- request_max_tokens = getattr (generation_config , "max_output_tokens" , None )
311- else :
312- request_temperature = None
313- request_max_tokens = None
313+ # Extract data from generation config object
314+ request_temperature , request_max_tokens = _extract_generation_config (kwargs )
314315
316+ # Prepare error attributes
315317 notice_error_attributes = {
316318 "http.statusCode" : getattr (exc , "code" , None ),
317319 "error.message" : getattr (exc , "message" , None ),
@@ -352,15 +354,17 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
352354
353355 create_chat_completion_message_event (
354356 transaction ,
355- input_message ,
357+ input_message_content ,
358+ input_role ,
356359 completion_id ,
357360 span_id ,
358361 trace_id ,
359362 # Passing the request model as the response model here since we do not have access to a response model
360363 request_model ,
361- request_model ,
362364 llm_metadata ,
363365 output_message_list ,
366+ # We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
367+ True ,
364368 request_timestamp ,
365369 )
366370 except Exception :
@@ -388,6 +392,7 @@ def _handle_generation_success(
388392def _record_generation_success (
389393 transaction , linking_metadata , completion_id , kwargs , ft , response , request_timestamp = None
390394):
395+ settings = transaction .settings or global_settings ()
391396 span_id = linking_metadata .get ("span.id" )
392397 trace_id = linking_metadata .get ("trace.id" )
393398 try :
@@ -396,12 +401,14 @@ def _record_generation_success(
396401 # finish_reason is an enum, so grab just the stringified value from it to report
397402 finish_reason = response .get ("candidates" )[0 ].get ("finish_reason" ).value
398403 output_message_list = [response .get ("candidates" )[0 ].get ("content" )]
404+ token_usage = response .get ("usage_metadata" ) or {}
399405 else :
400406 # Set all values to NoneTypes since we cannot access them through kwargs or another method that doesn't
401407 # require the response object
402408 response_model = None
403409 output_message_list = []
404410 finish_reason = None
411+ token_usage = {}
405412
406413 request_model = kwargs .get ("model" )
407414
@@ -423,13 +430,44 @@ def _record_generation_success(
423430 "corresponding LlmChatCompletionMessage event. "
424431 )
425432
426- generation_config = kwargs .get ("config" )
427- if generation_config :
428- request_temperature = getattr (generation_config , "temperature" , None )
429- request_max_tokens = getattr (generation_config , "max_output_tokens" , None )
433+ input_message_content , input_role = _parse_input_message (input_message ) if input_message else (None , None )
434+
435+ # Parse output message content
436+ # This list should have a length of 1 to represent the output message
437+ # Parse the message text out to pass to any registered token counting callback
438+ output_message_content = output_message_list [0 ].get ("parts" )[0 ].get ("text" ) if output_message_list else None
439+
440+ # Extract token counts from response object
441+ if token_usage :
442+ response_prompt_tokens = token_usage .get ("prompt_token_count" )
443+ response_completion_tokens = token_usage .get ("candidates_token_count" )
444+ response_total_tokens = token_usage .get ("total_token_count" )
445+
430446 else :
431- request_temperature = None
432- request_max_tokens = None
447+ response_prompt_tokens = None
448+ response_completion_tokens = None
449+ response_total_tokens = None
450+
451+ # Calculate token counts by checking if a callback is registered and if we have the necessary content to pass
452+ # to it. If not, then we use the token counts provided in the response object
453+ prompt_tokens = (
454+ settings .ai_monitoring .llm_token_count_callback (request_model , input_message_content )
455+ if settings .ai_monitoring .llm_token_count_callback and input_message_content
456+ else response_prompt_tokens
457+ )
458+ completion_tokens = (
459+ settings .ai_monitoring .llm_token_count_callback (response_model , output_message_content )
460+ if settings .ai_monitoring .llm_token_count_callback and output_message_content
461+ else response_completion_tokens
462+ )
463+ total_tokens = (
464+ prompt_tokens + completion_tokens if all ([prompt_tokens , completion_tokens ]) else response_total_tokens
465+ )
466+
467+ all_token_counts = bool (prompt_tokens and completion_tokens and total_tokens )
468+
469+ # Extract generation config
470+ request_temperature , request_max_tokens = _extract_generation_config (kwargs )
433471
434472 full_chat_completion_summary_dict = {
435473 "id" : completion_id ,
@@ -450,68 +488,80 @@ def _record_generation_success(
450488 "timestamp" : request_timestamp ,
451489 }
452490
491+ if all_token_counts :
492+ full_chat_completion_summary_dict ["response.usage.prompt_tokens" ] = prompt_tokens
493+ full_chat_completion_summary_dict ["response.usage.completion_tokens" ] = completion_tokens
494+ full_chat_completion_summary_dict ["response.usage.total_tokens" ] = total_tokens
495+
453496 llm_metadata = _get_llm_attributes (transaction )
454497 full_chat_completion_summary_dict .update (llm_metadata )
455498 transaction .record_custom_event ("LlmChatCompletionSummary" , full_chat_completion_summary_dict )
456499
457500 create_chat_completion_message_event (
458501 transaction ,
459- input_message ,
502+ input_message_content ,
503+ input_role ,
460504 completion_id ,
461505 span_id ,
462506 trace_id ,
463507 response_model ,
464- request_model ,
465508 llm_metadata ,
466509 output_message_list ,
510+ all_token_counts ,
467511 request_timestamp ,
468512 )
469513 except Exception :
470514 _logger .warning (RECORD_EVENTS_FAILURE_LOG_MESSAGE , exc_info = True )
471515
472516
517+ def _parse_input_message (input_message ):
518+ # The input_message will be a string if generate_content was called directly. In this case, we don't have
519+ # access to the role, so we default to user since this was an input message
520+ if isinstance (input_message , str ):
521+ return input_message , "user"
522+ # The input_message will be a Google Content type if send_message was called, so we parse out the message
523+ # text and role (which should be "user")
524+ elif isinstance (input_message , google .genai .types .Content ):
525+ return input_message .parts [0 ].text , input_message .role
526+ else :
527+ return None , None
528+
529+
530+ def _extract_generation_config (kwargs ):
531+ generation_config = kwargs .get ("config" )
532+ if generation_config :
533+ request_temperature = getattr (generation_config , "temperature" , None )
534+ request_max_tokens = getattr (generation_config , "max_output_tokens" , None )
535+ else :
536+ request_temperature = None
537+ request_max_tokens = None
538+
539+ return request_temperature , request_max_tokens
540+
541+
473542def create_chat_completion_message_event (
474543 transaction ,
475- input_message ,
544+ input_message_content ,
545+ input_role ,
476546 chat_completion_id ,
477547 span_id ,
478548 trace_id ,
479549 response_model ,
480- request_model ,
481550 llm_metadata ,
482551 output_message_list ,
552+ all_token_counts ,
483553 request_timestamp = None ,
484554):
485555 try :
486556 settings = transaction .settings or global_settings ()
487557
488- if input_message :
489- # The input_message will be a string if generate_content was called directly. In this case, we don't have
490- # access to the role, so we default to user since this was an input message
491- if isinstance (input_message , str ):
492- input_message_content = input_message
493- input_role = "user"
494- # The input_message will be a Google Content type if send_message was called, so we parse out the message
495- # text and role (which should be "user")
496- elif isinstance (input_message , google .genai .types .Content ):
497- input_message_content = input_message .parts [0 ].text
498- input_role = input_message .role
499- # Set input data to NoneTypes to ensure token_count callback is not called
500- else :
501- input_message_content = None
502- input_role = None
503-
558+ if input_message_content :
504559 message_id = str (uuid .uuid4 ())
505560
506561 chat_completion_input_message_dict = {
507562 "id" : message_id ,
508563 "span_id" : span_id ,
509564 "trace_id" : trace_id ,
510- "token_count" : (
511- settings .ai_monitoring .llm_token_count_callback (request_model , input_message_content )
512- if settings .ai_monitoring .llm_token_count_callback and input_message_content
513- else None
514- ),
515565 "role" : input_role ,
516566 "completion_id" : chat_completion_id ,
517567 # The input message will always be the first message in our request/ response sequence so this will
@@ -521,6 +571,8 @@ def create_chat_completion_message_event(
521571 "vendor" : "gemini" ,
522572 "ingest_source" : "Python" ,
523573 }
574+ if all_token_counts :
575+ chat_completion_input_message_dict ["token_count" ] = 0
524576
525577 if settings .ai_monitoring .record_content .enabled :
526578 chat_completion_input_message_dict ["content" ] = input_message_content
@@ -539,7 +591,7 @@ def create_chat_completion_message_event(
539591
540592 # Add one to the index to account for the single input message so our sequence value is accurate for
541593 # the output message
542- if input_message :
594+ if input_message_content :
543595 index += 1
544596
545597 message_id = str (uuid .uuid4 ())
@@ -548,11 +600,6 @@ def create_chat_completion_message_event(
548600 "id" : message_id ,
549601 "span_id" : span_id ,
550602 "trace_id" : trace_id ,
551- "token_count" : (
552- settings .ai_monitoring .llm_token_count_callback (response_model , message_content )
553- if settings .ai_monitoring .llm_token_count_callback
554- else None
555- ),
556603 "role" : message .get ("role" ),
557604 "completion_id" : chat_completion_id ,
558605 "sequence" : index ,
@@ -562,6 +609,9 @@ def create_chat_completion_message_event(
562609 "is_response" : True ,
563610 }
564611
612+ if all_token_counts :
613+ chat_completion_output_message_dict ["token_count" ] = 0
614+
565615 if settings .ai_monitoring .record_content .enabled :
566616 chat_completion_output_message_dict ["content" ] = message_content
567617 if request_timestamp :
0 commit comments