diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py index 1fb291bdac..06c9a23604 100644 --- a/sentry_sdk/ai/utils.py +++ b/sentry_sdk/ai/utils.py @@ -139,8 +139,6 @@ def truncate_and_annotate_messages( truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes) if removed_count > 0: - scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len( - truncated_messages - ) + scope._gen_ai_original_message_count[span.span_id] = len(messages) return truncated_messages diff --git a/sentry_sdk/client.py b/sentry_sdk/client.py index ffd899b545..b4a3e8bb6b 100644 --- a/sentry_sdk/client.py +++ b/sentry_sdk/client.py @@ -598,7 +598,7 @@ def _prepare_event( if event_scrubber: event_scrubber.scrub_event(event) - if scope is not None and scope._gen_ai_messages_truncated: + if scope is not None and scope._gen_ai_original_message_count: spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue if isinstance(spans, list): for span in spans: @@ -606,15 +606,12 @@ def _prepare_event( span_data = span.get("data", {}) if ( span_id - and span_id in scope._gen_ai_messages_truncated + and span_id in scope._gen_ai_original_message_count and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data ): span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue( span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES], - { - "len": scope._gen_ai_messages_truncated[span_id] - + len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]) - }, + {"len": scope._gen_ai_original_message_count[span_id]}, ) if previous_total_spans is not None: event["spans"] = AnnotatedValue( diff --git a/sentry_sdk/scope.py b/sentry_sdk/scope.py index 5815a65440..ecb8f370c5 100644 --- a/sentry_sdk/scope.py +++ b/sentry_sdk/scope.py @@ -188,7 +188,7 @@ class Scope: "_extras", "_breadcrumbs", "_n_breadcrumbs_truncated", - "_gen_ai_messages_truncated", + "_gen_ai_original_message_count", "_event_processors", "_error_processors", "_should_capture", @@ -214,7 +214,7 @@ def __init__(self, ty=None, client=None): self._name = None # type: Optional[str] self._propagation_context = None # type: Optional[PropagationContext] self._n_breadcrumbs_truncated = 0 # type: int - self._gen_ai_messages_truncated = {} # type: Dict[str, int] + self._gen_ai_original_message_count = {} # type: Dict[str, int] self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient @@ -249,7 +249,7 @@ def __copy__(self): rv._breadcrumbs = copy(self._breadcrumbs) rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated - rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy() + rv._gen_ai_original_message_count = self._gen_ai_original_message_count.copy() rv._event_processors = self._event_processors.copy() rv._error_processors = self._error_processors.copy() rv._propagation_context = self._propagation_context @@ -1586,8 +1586,10 @@ def update_from_scope(self, scope): self._n_breadcrumbs_truncated = ( self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated ) - if scope._gen_ai_messages_truncated: - self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated) + if scope._gen_ai_original_message_count: + self._gen_ai_original_message_count.update( + scope._gen_ai_original_message_count + ) if scope._span: self._span = scope._span if scope._attachments: diff --git a/tests/test_ai_monitoring.py b/tests/test_ai_monitoring.py index be66860384..5ff136f810 100644 --- a/tests/test_ai_monitoring.py +++ b/tests/test_ai_monitoring.py @@ -1,4 +1,5 @@ import json +import uuid import pytest @@ -210,32 +211,32 @@ def large_messages(): class TestTruncateMessagesBySize: def test_no_truncation_needed(self, sample_messages): """Test that messages under the limit are not truncated""" - result, removed_count = truncate_messages_by_size( + result, truncation_index = truncate_messages_by_size( sample_messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES ) assert len(result) == len(sample_messages) assert result == sample_messages - assert removed_count == 0 + assert truncation_index == 0 def test_truncation_removes_oldest_first(self, large_messages): """Test that oldest messages are removed first during truncation""" small_limit = 3000 - result, removed_count = truncate_messages_by_size( + result, truncation_index = truncate_messages_by_size( large_messages, max_bytes=small_limit ) assert len(result) < len(large_messages) if result: assert result[-1] == large_messages[-1] - assert removed_count == len(large_messages) - len(result) + assert truncation_index == len(large_messages) - len(result) def test_empty_messages_list(self): """Test handling of empty messages list""" - result, removed_count = truncate_messages_by_size( + result, truncation_index = truncate_messages_by_size( [], max_bytes=MAX_GEN_AI_MESSAGE_BYTES // 500 ) assert result == [] - assert removed_count == 0 + assert truncation_index == 0 def test_find_truncation_index( self, @@ -290,7 +291,7 @@ def set_data(self, key, value): class MockScope: def __init__(self): - self._gen_ai_messages_truncated = {} + self._gen_ai_original_message_count = {} span = MockSpan() scope = MockScope() @@ -300,7 +301,7 @@ def __init__(self): assert not isinstance(result, AnnotatedValue) assert len(result) == len(sample_messages) assert result == sample_messages - assert span.span_id not in scope._gen_ai_messages_truncated + assert span.span_id not in scope._gen_ai_original_message_count def test_truncation_sets_metadata_on_scope(self, large_messages): class MockSpan: @@ -313,9 +314,9 @@ def set_data(self, key, value): class MockScope: def __init__(self): - self._gen_ai_messages_truncated = {} + self._gen_ai_original_message_count = {} - small_limit = 1000 + small_limit = 3000 span = MockSpan() scope = MockScope() original_count = len(large_messages) @@ -326,10 +327,9 @@ def __init__(self): assert isinstance(result, list) assert not isinstance(result, AnnotatedValue) assert len(result) < len(large_messages) - n_removed = original_count - len(result) - assert scope._gen_ai_messages_truncated[span.span_id] == n_removed + assert scope._gen_ai_original_message_count[span.span_id] == original_count - def test_scope_tracks_removed_messages(self, large_messages): + def test_scope_tracks_original_message_count(self, large_messages): class MockSpan: def __init__(self): self.span_id = "test_span_id" @@ -340,9 +340,9 @@ def set_data(self, key, value): class MockScope: def __init__(self): - self._gen_ai_messages_truncated = {} + self._gen_ai_original_message_count = {} - small_limit = 1000 + small_limit = 3000 original_count = len(large_messages) span = MockSpan() scope = MockScope() @@ -351,9 +351,8 @@ def __init__(self): large_messages, span, scope, max_bytes=small_limit ) - n_removed = original_count - len(result) - assert scope._gen_ai_messages_truncated[span.span_id] == n_removed - assert len(result) + n_removed == original_count + assert scope._gen_ai_original_message_count[span.span_id] == original_count + assert len(result) == 1 def test_empty_messages_returns_none(self): class MockSpan: @@ -366,7 +365,7 @@ def set_data(self, key, value): class MockScope: def __init__(self): - self._gen_ai_messages_truncated = {} + self._gen_ai_original_message_count = {} span = MockSpan() scope = MockScope() @@ -387,7 +386,7 @@ def set_data(self, key, value): class MockScope: def __init__(self): - self._gen_ai_messages_truncated = {} + self._gen_ai_original_message_count = {} small_limit = 3000 span = MockSpan() @@ -416,7 +415,7 @@ def set_data(self, key, value): class MockScope: def __init__(self): - self._gen_ai_messages_truncated = {} + self._gen_ai_original_message_count = {} small_limit = 3000 span = MockSpan() @@ -430,29 +429,27 @@ def __init__(self): span.set_data(SPANDATA.GEN_AI_REQUEST_MESSAGES, truncated_messages) # Verify metadata was set on scope - assert span.span_id in scope._gen_ai_messages_truncated - assert scope._gen_ai_messages_truncated[span.span_id] > 0 + assert span.span_id in scope._gen_ai_original_message_count + assert scope._gen_ai_original_message_count[span.span_id] > 0 # Simulate what client.py does event = {"spans": [{"span_id": span.span_id, "data": span.data.copy()}]} - # Mimic client.py logic - using scope to get the removed count + # Mimic client.py logic - using scope to get the original length for event_span in event["spans"]: span_id = event_span.get("span_id") span_data = event_span.get("data", {}) if ( span_id - and span_id in scope._gen_ai_messages_truncated + and span_id in scope._gen_ai_original_message_count and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data ): messages = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] - n_removed = scope._gen_ai_messages_truncated[span_id] - n_remaining = len(messages) if isinstance(messages, list) else 0 - original_count_calculated = n_removed + n_remaining + n_original_count = scope._gen_ai_original_message_count[span_id] span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue( safe_serialize(messages), - {"len": original_count_calculated}, + {"len": n_original_count}, ) # Verify the annotation happened @@ -460,3 +457,61 @@ def __init__(self): assert isinstance(messages_value, AnnotatedValue) assert messages_value.metadata["len"] == original_count assert isinstance(messages_value.value, str) + + def test_annotated_value_shows_correct_original_length(self, large_messages): + """Test that the annotated value correctly shows the original message count before truncation""" + from sentry_sdk.consts import SPANDATA + + class MockSpan: + def __init__(self): + self.span_id = "test_span_456" + self.data = {} + + def set_data(self, key, value): + self.data[key] = value + + class MockScope: + def __init__(self): + self._gen_ai_original_message_count = {} + + small_limit = 3000 + span = MockSpan() + scope = MockScope() + original_message_count = len(large_messages) + + truncated_messages = truncate_and_annotate_messages( + large_messages, span, scope, max_bytes=small_limit + ) + + assert len(truncated_messages) < original_message_count + + assert span.span_id in scope._gen_ai_original_message_count + stored_original_length = scope._gen_ai_original_message_count[span.span_id] + assert stored_original_length == original_message_count + + event = { + "spans": [ + { + "span_id": span.span_id, + "data": {SPANDATA.GEN_AI_REQUEST_MESSAGES: truncated_messages}, + } + ] + } + + for event_span in event["spans"]: + span_id = event_span.get("span_id") + span_data = event_span.get("data", {}) + if ( + span_id + and span_id in scope._gen_ai_original_message_count + and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data + ): + span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue( + span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES], + {"len": scope._gen_ai_original_message_count[span_id]}, + ) + + messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES] + assert isinstance(messages_value, AnnotatedValue) + assert messages_value.metadata["len"] == stored_original_length + assert len(messages_value.value) == len(truncated_messages)