Skip to content
152 changes: 101 additions & 51 deletions newrelic/hooks/mlmodel_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,24 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
embedding_content = str(embedding_content)
request_model = kwargs.get("model")

embedding_token_count = (
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
if settings.ai_monitoring.llm_token_count_callback
else None
)

full_embedding_response_dict = {
"id": embedding_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
if settings.ai_monitoring.llm_token_count_callback
else None
),
"request.model": request_model,
"duration": ft.duration * 1000,
"vendor": "gemini",
"ingest_source": "Python",
}
if embedding_token_count:
full_embedding_response_dict["response.usage.total_tokens"] = embedding_token_count

if settings.ai_monitoring.record_content.enabled:
full_embedding_response_dict["input"] = embedding_content

Expand Down Expand Up @@ -300,15 +304,13 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
"Unable to parse input message to Gemini LLM. Message content and role will be omitted from "
"corresponding LlmChatCompletionMessage event. "
)
# Extract the input message content and role from the input message if it exists
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)

generation_config = kwargs.get("config")
if generation_config:
request_temperature = getattr(generation_config, "temperature", None)
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
else:
request_temperature = None
request_max_tokens = None
# Extract data from generation config object
request_temperature, request_max_tokens = _extract_generation_config(kwargs)

# Prepare error attributes
notice_error_attributes = {
"http.statusCode": getattr(exc, "code", None),
"error.message": getattr(exc, "message", None),
Expand Down Expand Up @@ -348,15 +350,17 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg

create_chat_completion_message_event(
transaction,
input_message,
input_message_content,
input_role,
completion_id,
span_id,
trace_id,
# Passing the request model as the response model here since we do not have access to a response model
request_model,
request_model,
llm_metadata,
output_message_list,
# We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
all_token_counts=True,
)
except Exception:
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)
Expand All @@ -377,6 +381,7 @@ def _handle_generation_success(transaction, linking_metadata, completion_id, kwa


def _record_generation_success(transaction, linking_metadata, completion_id, kwargs, ft, response):
settings = transaction.settings or global_settings()
span_id = linking_metadata.get("span.id")
trace_id = linking_metadata.get("trace.id")
try:
Expand All @@ -385,12 +390,14 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
# finish_reason is an enum, so grab just the stringified value from it to report
finish_reason = response.get("candidates")[0].get("finish_reason").value
output_message_list = [response.get("candidates")[0].get("content")]
token_usage = response.get("usage_metadata") or {}
else:
# Set all values to NoneTypes since we cannot access them through kwargs or another method that doesn't
# require the response object
response_model = None
output_message_list = []
finish_reason = None
token_usage = {}

request_model = kwargs.get("model")

Expand All @@ -412,13 +419,44 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
"corresponding LlmChatCompletionMessage event. "
)

generation_config = kwargs.get("config")
if generation_config:
request_temperature = getattr(generation_config, "temperature", None)
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)

# Parse output message content
# This list should have a length of 1 to represent the output message
# Parse the message text out to pass to any registered token counting callback
output_message_content = output_message_list[0].get("parts")[0].get("text") if output_message_list else None

# Extract token counts from response object
if token_usage:
response_prompt_tokens = token_usage.get("prompt_token_count")
response_completion_tokens = token_usage.get("candidates_token_count")
response_total_tokens = token_usage.get("total_token_count")

else:
request_temperature = None
request_max_tokens = None
response_prompt_tokens = None
response_completion_tokens = None
response_total_tokens = None

# Calculate token counts by checking if a callback is registered and if we have the necessary content to pass
# to it. If not, then we use the token counts provided in the response object
prompt_tokens = (
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
if settings.ai_monitoring.llm_token_count_callback and input_message_content
else response_prompt_tokens
)
completion_tokens = (
settings.ai_monitoring.llm_token_count_callback(response_model, output_message_content)
if settings.ai_monitoring.llm_token_count_callback and output_message_content
else response_completion_tokens
)
total_tokens = (
prompt_tokens + completion_tokens if all([prompt_tokens, completion_tokens]) else response_total_tokens
)

all_token_counts = bool(prompt_tokens and completion_tokens and total_tokens)

# Extract generation config
request_temperature, request_max_tokens = _extract_generation_config(kwargs)

full_chat_completion_summary_dict = {
"id": completion_id,
Expand All @@ -438,66 +476,78 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
"response.number_of_messages": 1 + len(output_message_list),
}

if all_token_counts:
full_chat_completion_summary_dict["response.usage.prompt_tokens"] = prompt_tokens
full_chat_completion_summary_dict["response.usage.completion_tokens"] = completion_tokens
full_chat_completion_summary_dict["response.usage.total_tokens"] = total_tokens

llm_metadata = _get_llm_attributes(transaction)
full_chat_completion_summary_dict.update(llm_metadata)
transaction.record_custom_event("LlmChatCompletionSummary", full_chat_completion_summary_dict)

create_chat_completion_message_event(
transaction,
input_message,
input_message_content,
input_role,
completion_id,
span_id,
trace_id,
response_model,
request_model,
llm_metadata,
output_message_list,
all_token_counts,
)
except Exception:
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)


def _parse_input_message(input_message):
# The input_message will be a string if generate_content was called directly. In this case, we don't have
# access to the role, so we default to user since this was an input message
if isinstance(input_message, str):
return input_message, "user"
# The input_message will be a Google Content type if send_message was called, so we parse out the message
# text and role (which should be "user")
elif isinstance(input_message, google.genai.types.Content):
return input_message.parts[0].text, input_message.role
else:
return None, None


def _extract_generation_config(kwargs):
generation_config = kwargs.get("config")
if generation_config:
request_temperature = getattr(generation_config, "temperature", None)
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
else:
request_temperature = None
request_max_tokens = None

return request_temperature, request_max_tokens


def create_chat_completion_message_event(
transaction,
input_message,
input_message_content,
input_role,
chat_completion_id,
span_id,
trace_id,
response_model,
request_model,
llm_metadata,
output_message_list,
all_token_counts,
):
try:
settings = transaction.settings or global_settings()

if input_message:
# The input_message will be a string if generate_content was called directly. In this case, we don't have
# access to the role, so we default to user since this was an input message
if isinstance(input_message, str):
input_message_content = input_message
input_role = "user"
# The input_message will be a Google Content type if send_message was called, so we parse out the message
# text and role (which should be "user")
elif isinstance(input_message, google.genai.types.Content):
input_message_content = input_message.parts[0].text
input_role = input_message.role
# Set input data to NoneTypes to ensure token_count callback is not called
else:
input_message_content = None
input_role = None

if input_message_content:
message_id = str(uuid.uuid4())

chat_completion_input_message_dict = {
"id": message_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
if settings.ai_monitoring.llm_token_count_callback and input_message_content
else None
),
"role": input_role,
"completion_id": chat_completion_id,
# The input message will always be the first message in our request/ response sequence so this will
Expand All @@ -507,6 +557,8 @@ def create_chat_completion_message_event(
"vendor": "gemini",
"ingest_source": "Python",
}
if all_token_counts:
chat_completion_input_message_dict["token_count"] = 0

if settings.ai_monitoring.record_content.enabled:
chat_completion_input_message_dict["content"] = input_message_content
Expand All @@ -523,7 +575,7 @@ def create_chat_completion_message_event(

# Add one to the index to account for the single input message so our sequence value is accurate for
# the output message
if input_message:
if input_message_content:
index += 1

message_id = str(uuid.uuid4())
Expand All @@ -532,11 +584,6 @@ def create_chat_completion_message_event(
"id": message_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(response_model, message_content)
if settings.ai_monitoring.llm_token_count_callback
else None
),
"role": message.get("role"),
"completion_id": chat_completion_id,
"sequence": index,
Expand All @@ -546,6 +593,9 @@ def create_chat_completion_message_event(
"is_response": True,
}

if all_token_counts:
chat_completion_output_message_dict["token_count"] = 0

if settings.ai_monitoring.record_content.enabled:
chat_completion_output_message_dict["content"] = message_content

Expand Down
6 changes: 3 additions & 3 deletions tests/mlmodel_gemini/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import google.genai
from testing_support.fixtures import override_llm_token_callback_settings, reset_core_stats_engine, validate_attributes
from testing_support.ml_testing_utils import (
add_token_count_to_events,
add_token_count_to_embedding_events,
disabled_ai_monitoring_record_content_settings,
disabled_ai_monitoring_settings,
events_sans_content,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_gemini_embedding_sync_no_content(gemini_dev_client, set_trace_info):

@reset_core_stats_engine()
@override_llm_token_callback_settings(llm_token_count_callback)
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
@validate_custom_event_count(count=1)
@validate_transaction_metrics(
name="test_embeddings:test_gemini_embedding_sync_with_token_count",
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_gemini_embedding_async_no_content(gemini_dev_client, loop, set_trace_in

@reset_core_stats_engine()
@override_llm_token_callback_settings(llm_token_count_callback)
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
@validate_custom_event_count(count=1)
@validate_transaction_metrics(
name="test_embeddings:test_gemini_embedding_async_with_token_count",
Expand Down
Loading
Loading