Skip to content

Commit e6cb2bb

Browse files
Add response token count logic to Gemini instrumentation. (#1486)
* Add response token count logic to Gemini instrumentation. * Update token counting util functions. * Linting * Add response token count logic to Gemini instrumentation. * Update token counting util functions. * [MegaLinter] Apply linters fixes * Bump tests. --------- Co-authored-by: Tim Pansino <[email protected]>
1 parent 38d7547 commit e6cb2bb

File tree

6 files changed

+139
-193
lines changed

6 files changed

+139
-193
lines changed

newrelic/hooks/mlmodel_gemini.py

Lines changed: 101 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,24 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
176176
embedding_content = str(embedding_content)
177177
request_model = kwargs.get("model")
178178

179+
embedding_token_count = (
180+
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
181+
if settings.ai_monitoring.llm_token_count_callback
182+
else None
183+
)
184+
179185
full_embedding_response_dict = {
180186
"id": embedding_id,
181187
"span_id": span_id,
182188
"trace_id": trace_id,
183-
"token_count": (
184-
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
185-
if settings.ai_monitoring.llm_token_count_callback
186-
else None
187-
),
188189
"request.model": request_model,
189190
"duration": ft.duration * 1000,
190191
"vendor": "gemini",
191192
"ingest_source": "Python",
192193
}
194+
if embedding_token_count:
195+
full_embedding_response_dict["response.usage.total_tokens"] = embedding_token_count
196+
193197
if settings.ai_monitoring.record_content.enabled:
194198
full_embedding_response_dict["input"] = embedding_content
195199

@@ -303,15 +307,13 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
303307
"Unable to parse input message to Gemini LLM. Message content and role will be omitted from "
304308
"corresponding LlmChatCompletionMessage event. "
305309
)
310+
# Extract the input message content and role from the input message if it exists
311+
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)
306312

307-
generation_config = kwargs.get("config")
308-
if generation_config:
309-
request_temperature = getattr(generation_config, "temperature", None)
310-
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
311-
else:
312-
request_temperature = None
313-
request_max_tokens = None
313+
# Extract data from generation config object
314+
request_temperature, request_max_tokens = _extract_generation_config(kwargs)
314315

316+
# Prepare error attributes
315317
notice_error_attributes = {
316318
"http.statusCode": getattr(exc, "code", None),
317319
"error.message": getattr(exc, "message", None),
@@ -352,15 +354,17 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
352354

353355
create_chat_completion_message_event(
354356
transaction,
355-
input_message,
357+
input_message_content,
358+
input_role,
356359
completion_id,
357360
span_id,
358361
trace_id,
359362
# Passing the request model as the response model here since we do not have access to a response model
360363
request_model,
361-
request_model,
362364
llm_metadata,
363365
output_message_list,
366+
# We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
367+
True,
364368
request_timestamp,
365369
)
366370
except Exception:
@@ -388,6 +392,7 @@ def _handle_generation_success(
388392
def _record_generation_success(
389393
transaction, linking_metadata, completion_id, kwargs, ft, response, request_timestamp=None
390394
):
395+
settings = transaction.settings or global_settings()
391396
span_id = linking_metadata.get("span.id")
392397
trace_id = linking_metadata.get("trace.id")
393398
try:
@@ -396,12 +401,14 @@ def _record_generation_success(
396401
# finish_reason is an enum, so grab just the stringified value from it to report
397402
finish_reason = response.get("candidates")[0].get("finish_reason").value
398403
output_message_list = [response.get("candidates")[0].get("content")]
404+
token_usage = response.get("usage_metadata") or {}
399405
else:
400406
# Set all values to NoneTypes since we cannot access them through kwargs or another method that doesn't
401407
# require the response object
402408
response_model = None
403409
output_message_list = []
404410
finish_reason = None
411+
token_usage = {}
405412

406413
request_model = kwargs.get("model")
407414

@@ -423,13 +430,44 @@ def _record_generation_success(
423430
"corresponding LlmChatCompletionMessage event. "
424431
)
425432

426-
generation_config = kwargs.get("config")
427-
if generation_config:
428-
request_temperature = getattr(generation_config, "temperature", None)
429-
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
433+
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)
434+
435+
# Parse output message content
436+
# This list should have a length of 1 to represent the output message
437+
# Parse the message text out to pass to any registered token counting callback
438+
output_message_content = output_message_list[0].get("parts")[0].get("text") if output_message_list else None
439+
440+
# Extract token counts from response object
441+
if token_usage:
442+
response_prompt_tokens = token_usage.get("prompt_token_count")
443+
response_completion_tokens = token_usage.get("candidates_token_count")
444+
response_total_tokens = token_usage.get("total_token_count")
445+
430446
else:
431-
request_temperature = None
432-
request_max_tokens = None
447+
response_prompt_tokens = None
448+
response_completion_tokens = None
449+
response_total_tokens = None
450+
451+
# Calculate token counts by checking if a callback is registered and if we have the necessary content to pass
452+
# to it. If not, then we use the token counts provided in the response object
453+
prompt_tokens = (
454+
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
455+
if settings.ai_monitoring.llm_token_count_callback and input_message_content
456+
else response_prompt_tokens
457+
)
458+
completion_tokens = (
459+
settings.ai_monitoring.llm_token_count_callback(response_model, output_message_content)
460+
if settings.ai_monitoring.llm_token_count_callback and output_message_content
461+
else response_completion_tokens
462+
)
463+
total_tokens = (
464+
prompt_tokens + completion_tokens if all([prompt_tokens, completion_tokens]) else response_total_tokens
465+
)
466+
467+
all_token_counts = bool(prompt_tokens and completion_tokens and total_tokens)
468+
469+
# Extract generation config
470+
request_temperature, request_max_tokens = _extract_generation_config(kwargs)
433471

434472
full_chat_completion_summary_dict = {
435473
"id": completion_id,
@@ -450,68 +488,80 @@ def _record_generation_success(
450488
"timestamp": request_timestamp,
451489
}
452490

491+
if all_token_counts:
492+
full_chat_completion_summary_dict["response.usage.prompt_tokens"] = prompt_tokens
493+
full_chat_completion_summary_dict["response.usage.completion_tokens"] = completion_tokens
494+
full_chat_completion_summary_dict["response.usage.total_tokens"] = total_tokens
495+
453496
llm_metadata = _get_llm_attributes(transaction)
454497
full_chat_completion_summary_dict.update(llm_metadata)
455498
transaction.record_custom_event("LlmChatCompletionSummary", full_chat_completion_summary_dict)
456499

457500
create_chat_completion_message_event(
458501
transaction,
459-
input_message,
502+
input_message_content,
503+
input_role,
460504
completion_id,
461505
span_id,
462506
trace_id,
463507
response_model,
464-
request_model,
465508
llm_metadata,
466509
output_message_list,
510+
all_token_counts,
467511
request_timestamp,
468512
)
469513
except Exception:
470514
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)
471515

472516

517+
def _parse_input_message(input_message):
518+
# The input_message will be a string if generate_content was called directly. In this case, we don't have
519+
# access to the role, so we default to user since this was an input message
520+
if isinstance(input_message, str):
521+
return input_message, "user"
522+
# The input_message will be a Google Content type if send_message was called, so we parse out the message
523+
# text and role (which should be "user")
524+
elif isinstance(input_message, google.genai.types.Content):
525+
return input_message.parts[0].text, input_message.role
526+
else:
527+
return None, None
528+
529+
530+
def _extract_generation_config(kwargs):
531+
generation_config = kwargs.get("config")
532+
if generation_config:
533+
request_temperature = getattr(generation_config, "temperature", None)
534+
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
535+
else:
536+
request_temperature = None
537+
request_max_tokens = None
538+
539+
return request_temperature, request_max_tokens
540+
541+
473542
def create_chat_completion_message_event(
474543
transaction,
475-
input_message,
544+
input_message_content,
545+
input_role,
476546
chat_completion_id,
477547
span_id,
478548
trace_id,
479549
response_model,
480-
request_model,
481550
llm_metadata,
482551
output_message_list,
552+
all_token_counts,
483553
request_timestamp=None,
484554
):
485555
try:
486556
settings = transaction.settings or global_settings()
487557

488-
if input_message:
489-
# The input_message will be a string if generate_content was called directly. In this case, we don't have
490-
# access to the role, so we default to user since this was an input message
491-
if isinstance(input_message, str):
492-
input_message_content = input_message
493-
input_role = "user"
494-
# The input_message will be a Google Content type if send_message was called, so we parse out the message
495-
# text and role (which should be "user")
496-
elif isinstance(input_message, google.genai.types.Content):
497-
input_message_content = input_message.parts[0].text
498-
input_role = input_message.role
499-
# Set input data to NoneTypes to ensure token_count callback is not called
500-
else:
501-
input_message_content = None
502-
input_role = None
503-
558+
if input_message_content:
504559
message_id = str(uuid.uuid4())
505560

506561
chat_completion_input_message_dict = {
507562
"id": message_id,
508563
"span_id": span_id,
509564
"trace_id": trace_id,
510-
"token_count": (
511-
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
512-
if settings.ai_monitoring.llm_token_count_callback and input_message_content
513-
else None
514-
),
515565
"role": input_role,
516566
"completion_id": chat_completion_id,
517567
# The input message will always be the first message in our request/ response sequence so this will
@@ -521,6 +571,8 @@ def create_chat_completion_message_event(
521571
"vendor": "gemini",
522572
"ingest_source": "Python",
523573
}
574+
if all_token_counts:
575+
chat_completion_input_message_dict["token_count"] = 0
524576

525577
if settings.ai_monitoring.record_content.enabled:
526578
chat_completion_input_message_dict["content"] = input_message_content
@@ -539,7 +591,7 @@ def create_chat_completion_message_event(
539591

540592
# Add one to the index to account for the single input message so our sequence value is accurate for
541593
# the output message
542-
if input_message:
594+
if input_message_content:
543595
index += 1
544596

545597
message_id = str(uuid.uuid4())
@@ -548,11 +600,6 @@ def create_chat_completion_message_event(
548600
"id": message_id,
549601
"span_id": span_id,
550602
"trace_id": trace_id,
551-
"token_count": (
552-
settings.ai_monitoring.llm_token_count_callback(response_model, message_content)
553-
if settings.ai_monitoring.llm_token_count_callback
554-
else None
555-
),
556603
"role": message.get("role"),
557604
"completion_id": chat_completion_id,
558605
"sequence": index,
@@ -562,6 +609,9 @@ def create_chat_completion_message_event(
562609
"is_response": True,
563610
}
564611

612+
if all_token_counts:
613+
chat_completion_output_message_dict["token_count"] = 0
614+
565615
if settings.ai_monitoring.record_content.enabled:
566616
chat_completion_output_message_dict["content"] = message_content
567617
if request_timestamp:

tests/mlmodel_gemini/test_embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import google.genai
1616
from testing_support.fixtures import override_llm_token_callback_settings, reset_core_stats_engine, validate_attributes
1717
from testing_support.ml_testing_utils import (
18-
add_token_count_to_events,
18+
add_token_count_to_embedding_events,
1919
disabled_ai_monitoring_record_content_settings,
2020
disabled_ai_monitoring_settings,
2121
events_sans_content,
@@ -93,7 +93,7 @@ def test_gemini_embedding_sync_no_content(gemini_dev_client, set_trace_info):
9393

9494
@reset_core_stats_engine()
9595
@override_llm_token_callback_settings(llm_token_count_callback)
96-
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
96+
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
9797
@validate_custom_event_count(count=1)
9898
@validate_transaction_metrics(
9999
name="test_embeddings:test_gemini_embedding_sync_with_token_count",
@@ -177,7 +177,7 @@ def test_gemini_embedding_async_no_content(gemini_dev_client, loop, set_trace_in
177177

178178
@reset_core_stats_engine()
179179
@override_llm_token_callback_settings(llm_token_count_callback)
180-
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
180+
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
181181
@validate_custom_event_count(count=1)
182182
@validate_transaction_metrics(
183183
name="test_embeddings:test_gemini_embedding_async_with_token_count",

0 commit comments

Comments
 (0)