Skip to content

Commit 972c2d7

Browse files
feat(integrations): add support for embed_content methods in GoogleGenAI integration (#5128)
#### Issues Closes https://linear.app/getsentry/issue/TET-1464/add-embedding-support-for-google-genai
1 parent 027aa6e commit 972c2d7

File tree

3 files changed

+607
-0
lines changed

3 files changed

+607
-0
lines changed

sentry_sdk/integrations/google_genai/__init__.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
set_span_data_for_response,
2727
_capture_exception,
2828
prepare_generate_content_args,
29+
prepare_embed_content_args,
30+
set_span_data_for_embed_request,
31+
set_span_data_for_embed_response,
2932
)
3033
from .streaming import (
3134
set_span_data_for_streaming_response,
@@ -49,6 +52,7 @@ def setup_once():
4952
Models.generate_content_stream = _wrap_generate_content_stream(
5053
Models.generate_content_stream
5154
)
55+
Models.embed_content = _wrap_embed_content(Models.embed_content)
5256

5357
# Patch async methods
5458
AsyncModels.generate_content = _wrap_async_generate_content(
@@ -57,6 +61,7 @@ def setup_once():
5761
AsyncModels.generate_content_stream = _wrap_async_generate_content_stream(
5862
AsyncModels.generate_content_stream
5963
)
64+
AsyncModels.embed_content = _wrap_async_embed_content(AsyncModels.embed_content)
6065

6166

6267
def _wrap_generate_content_stream(f):
@@ -299,3 +304,73 @@ async def new_async_generate_content(self, *args, **kwargs):
299304
return response
300305

301306
return new_async_generate_content
307+
308+
309+
def _wrap_embed_content(f):
310+
# type: (Callable[..., Any]) -> Callable[..., Any]
311+
@wraps(f)
312+
def new_embed_content(self, *args, **kwargs):
313+
# type: (Any, Any, Any) -> Any
314+
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
315+
if integration is None:
316+
return f(self, *args, **kwargs)
317+
318+
model_name, contents = prepare_embed_content_args(args, kwargs)
319+
320+
with sentry_sdk.start_span(
321+
op=OP.GEN_AI_EMBEDDINGS,
322+
name=f"embeddings {model_name}",
323+
origin=ORIGIN,
324+
) as span:
325+
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
326+
span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
327+
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
328+
set_span_data_for_embed_request(span, integration, contents, kwargs)
329+
330+
try:
331+
response = f(self, *args, **kwargs)
332+
except Exception as exc:
333+
_capture_exception(exc)
334+
span.set_status(SPANSTATUS.INTERNAL_ERROR)
335+
raise
336+
337+
set_span_data_for_embed_response(span, integration, response)
338+
339+
return response
340+
341+
return new_embed_content
342+
343+
344+
def _wrap_async_embed_content(f):
345+
# type: (Callable[..., Any]) -> Callable[..., Any]
346+
@wraps(f)
347+
async def new_async_embed_content(self, *args, **kwargs):
348+
# type: (Any, Any, Any) -> Any
349+
integration = sentry_sdk.get_client().get_integration(GoogleGenAIIntegration)
350+
if integration is None:
351+
return await f(self, *args, **kwargs)
352+
353+
model_name, contents = prepare_embed_content_args(args, kwargs)
354+
355+
with sentry_sdk.start_span(
356+
op=OP.GEN_AI_EMBEDDINGS,
357+
name=f"embeddings {model_name}",
358+
origin=ORIGIN,
359+
) as span:
360+
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
361+
span.set_data(SPANDATA.GEN_AI_SYSTEM, GEN_AI_SYSTEM)
362+
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model_name)
363+
set_span_data_for_embed_request(span, integration, contents, kwargs)
364+
365+
try:
366+
response = await f(self, *args, **kwargs)
367+
except Exception as exc:
368+
_capture_exception(exc)
369+
span.set_status(SPANSTATUS.INTERNAL_ERROR)
370+
raise
371+
372+
set_span_data_for_embed_response(span, integration, response)
373+
374+
return response
375+
376+
return new_async_embed_content

sentry_sdk/integrations/google_genai/utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ContentListUnion,
3737
Tool,
3838
Model,
39+
EmbedContentResponse,
3940
)
4041

4142

@@ -574,3 +575,70 @@ def prepare_generate_content_args(args, kwargs):
574575
kwargs["config"] = wrapped_config
575576

576577
return model, contents, model_name
578+
579+
580+
def prepare_embed_content_args(args, kwargs):
581+
# type: (tuple[Any, ...], dict[str, Any]) -> tuple[str, Any]
582+
"""Extract and prepare common arguments for embed_content methods.
583+
584+
Returns:
585+
tuple: (model_name, contents)
586+
"""
587+
model = kwargs.get("model", "unknown")
588+
contents = kwargs.get("contents")
589+
model_name = get_model_name(model)
590+
591+
return model_name, contents
592+
593+
594+
def set_span_data_for_embed_request(span, integration, contents, kwargs):
595+
# type: (Span, Any, Any, dict[str, Any]) -> None
596+
"""Set span data for embedding request."""
597+
# Include input contents if PII is allowed
598+
if should_send_default_pii() and integration.include_prompts:
599+
if contents:
600+
# For embeddings, contents is typically a list of strings/texts
601+
input_texts = []
602+
603+
# Handle various content formats
604+
if isinstance(contents, str):
605+
input_texts = [contents]
606+
elif isinstance(contents, list):
607+
for item in contents:
608+
text = extract_contents_text(item)
609+
if text:
610+
input_texts.append(text)
611+
else:
612+
text = extract_contents_text(contents)
613+
if text:
614+
input_texts = [text]
615+
616+
if input_texts:
617+
set_data_normalized(
618+
span,
619+
SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
620+
input_texts,
621+
unpack=False,
622+
)
623+
624+
625+
def set_span_data_for_embed_response(span, integration, response):
626+
# type: (Span, Any, EmbedContentResponse) -> None
627+
"""Set span data for embedding response."""
628+
if not response:
629+
return
630+
631+
# Extract token counts from embeddings statistics (Vertex AI only)
632+
# Each embedding has its own statistics with token_count
633+
if hasattr(response, "embeddings") and response.embeddings:
634+
total_tokens = 0
635+
636+
for embedding in response.embeddings:
637+
if hasattr(embedding, "statistics") and embedding.statistics:
638+
token_count = getattr(embedding.statistics, "token_count", None)
639+
if token_count is not None:
640+
total_tokens += int(token_count)
641+
642+
# Set token count if we found any
643+
if total_tokens > 0:
644+
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, total_tokens)

0 commit comments

Comments
 (0)