Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ def truncate_and_annotate_messages(

truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
if removed_count > 0:
scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len(
truncated_messages
)
scope._gen_ai_original_message_count[span.span_id] = len(messages)

return truncated_messages
9 changes: 3 additions & 6 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,23 +598,20 @@ def _prepare_event(
if event_scrubber:
event_scrubber.scrub_event(event)

if scope is not None and scope._gen_ai_messages_truncated:
if scope is not None and scope._gen_ai_original_message_count:
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
if isinstance(spans, list):
for span in spans:
span_id = span.get("span_id", None)
span_data = span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{
"len": scope._gen_ai_messages_truncated[span_id]
+ len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
},
{"len": scope._gen_ai_original_message_count[span_id]},
)
if previous_total_spans is not None:
event["spans"] = AnnotatedValue(
Expand Down
12 changes: 7 additions & 5 deletions sentry_sdk/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class Scope:
"_extras",
"_breadcrumbs",
"_n_breadcrumbs_truncated",
"_gen_ai_messages_truncated",
"_gen_ai_original_message_count",
"_event_processors",
"_error_processors",
"_should_capture",
Expand All @@ -214,7 +214,7 @@ def __init__(self, ty=None, client=None):
self._name = None # type: Optional[str]
self._propagation_context = None # type: Optional[PropagationContext]
self._n_breadcrumbs_truncated = 0 # type: int
self._gen_ai_messages_truncated = {} # type: Dict[str, int]
self._gen_ai_original_message_count = {} # type: Dict[str, int]

self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient

Expand Down Expand Up @@ -249,7 +249,7 @@ def __copy__(self):

rv._breadcrumbs = copy(self._breadcrumbs)
rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated
rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy()
rv._gen_ai_original_message_count = self._gen_ai_original_message_count.copy()
rv._event_processors = self._event_processors.copy()
rv._error_processors = self._error_processors.copy()
rv._propagation_context = self._propagation_context
Expand Down Expand Up @@ -1586,8 +1586,10 @@ def update_from_scope(self, scope):
self._n_breadcrumbs_truncated = (
self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated
)
if scope._gen_ai_messages_truncated:
self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated)
if scope._gen_ai_original_message_count:
self._gen_ai_original_message_count.update(
scope._gen_ai_original_message_count
)
if scope._span:
self._span = scope._span
if scope._attachments:
Expand Down
92 changes: 74 additions & 18 deletions tests/test_ai_monitoring.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import uuid

import pytest

Expand Down Expand Up @@ -290,7 +291,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

span = MockSpan()
scope = MockScope()
Expand All @@ -300,7 +301,7 @@ def __init__(self):
assert not isinstance(result, AnnotatedValue)
assert len(result) == len(sample_messages)
assert result == sample_messages
assert span.span_id not in scope._gen_ai_messages_truncated
assert span.span_id not in scope._gen_ai_original_message_count

def test_truncation_sets_metadata_on_scope(self, large_messages):
class MockSpan:
Expand All @@ -313,9 +314,9 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 1000
small_limit = 3000
span = MockSpan()
scope = MockScope()
original_count = len(large_messages)
Expand All @@ -326,8 +327,7 @@ def __init__(self):
assert isinstance(result, list)
assert not isinstance(result, AnnotatedValue)
assert len(result) < len(large_messages)
n_removed = original_count - len(result)
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
assert scope._gen_ai_original_message_count[span.span_id] == original_count

def test_scope_tracks_removed_messages(self, large_messages):
class MockSpan:
Expand All @@ -340,7 +340,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 1000
original_count = len(large_messages)
Expand All @@ -352,7 +352,7 @@ def __init__(self):
)

n_removed = original_count - len(result)
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
assert scope._gen_ai_original_message_count[span.span_id] == n_removed
assert len(result) + n_removed == original_count

def test_empty_messages_returns_none(self):
Expand All @@ -366,7 +366,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

span = MockSpan()
scope = MockScope()
Expand All @@ -387,7 +387,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
Expand Down Expand Up @@ -416,7 +416,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
Expand All @@ -430,8 +430,8 @@ def __init__(self):
span.set_data(SPANDATA.GEN_AI_REQUEST_MESSAGES, truncated_messages)

# Verify metadata was set on scope
assert span.span_id in scope._gen_ai_messages_truncated
assert scope._gen_ai_messages_truncated[span.span_id] > 0
assert span.span_id in scope._gen_ai_original_message_count
assert scope._gen_ai_original_message_count[span.span_id] > 0

# Simulate what client.py does
event = {"spans": [{"span_id": span.span_id, "data": span.data.copy()}]}
Expand All @@ -442,21 +442,77 @@ def __init__(self):
span_data = event_span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
messages = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]
n_removed = scope._gen_ai_messages_truncated[span_id]
n_remaining = len(messages) if isinstance(messages, list) else 0
original_count_calculated = n_removed + n_remaining
n_original_count = scope._gen_ai_original_message_count[span_id]

span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
safe_serialize(messages),
{"len": original_count_calculated},
{"len": n_original_count},
)

# Verify the annotation happened
messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_value, AnnotatedValue)
assert messages_value.metadata["len"] == original_count
assert isinstance(messages_value.value, str)

def test_annotated_value_shows_correct_original_length(self, large_messages):
"""Test that the annotated value correctly shows the original message count before truncation"""
from sentry_sdk.consts import SPANDATA

class MockSpan:
def __init__(self):
self.span_id = "test_span_456"
self.data = {}

def set_data(self, key, value):
self.data[key] = value

class MockScope:
def __init__(self):
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
scope = MockScope()
original_message_count = len(large_messages)

truncated_messages = truncate_and_annotate_messages(
large_messages, span, scope, max_bytes=small_limit
)

assert len(truncated_messages) < original_message_count

assert span.span_id in scope._gen_ai_original_message_count
stored_original_length = scope._gen_ai_original_message_count[span.span_id]
assert stored_original_length == original_message_count

event = {
"spans": [
{
"span_id": span.span_id,
"data": {SPANDATA.GEN_AI_REQUEST_MESSAGES: truncated_messages},
}
]
}

for event_span in event["spans"]:
span_id = event_span.get("span_id")
span_data = event_span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{"len": scope._gen_ai_original_message_count[span_id]},
)

messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_value, AnnotatedValue)
assert messages_value.metadata["len"] == stored_original_length
assert len(messages_value.value) == len(truncated_messages)