Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ def truncate_and_annotate_messages(

truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
if removed_count > 0:
scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len(
truncated_messages
)
scope._gen_ai_original_message_count[span.span_id] = len(messages)

return truncated_messages
9 changes: 3 additions & 6 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,23 +598,20 @@ def _prepare_event(
if event_scrubber:
event_scrubber.scrub_event(event)

if scope is not None and scope._gen_ai_messages_truncated:
if scope is not None and scope._gen_ai_original_message_count:
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
if isinstance(spans, list):
for span in spans:
span_id = span.get("span_id", None)
span_data = span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{
"len": scope._gen_ai_messages_truncated[span_id]
+ len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
},
{"len": scope._gen_ai_original_message_count[span_id]},
)
if previous_total_spans is not None:
event["spans"] = AnnotatedValue(
Expand Down
12 changes: 7 additions & 5 deletions sentry_sdk/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class Scope:
"_extras",
"_breadcrumbs",
"_n_breadcrumbs_truncated",
"_gen_ai_messages_truncated",
"_gen_ai_original_message_count",
"_event_processors",
"_error_processors",
"_should_capture",
Expand All @@ -214,7 +214,7 @@ def __init__(self, ty=None, client=None):
self._name = None # type: Optional[str]
self._propagation_context = None # type: Optional[PropagationContext]
self._n_breadcrumbs_truncated = 0 # type: int
self._gen_ai_messages_truncated = {} # type: Dict[str, int]
self._gen_ai_original_message_count = {} # type: Dict[str, int]

self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient

Expand Down Expand Up @@ -249,7 +249,7 @@ def __copy__(self):

rv._breadcrumbs = copy(self._breadcrumbs)
rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated
rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy()
rv._gen_ai_original_message_count = self._gen_ai_original_message_count.copy()
rv._event_processors = self._event_processors.copy()
rv._error_processors = self._error_processors.copy()
rv._propagation_context = self._propagation_context
Expand Down Expand Up @@ -1586,8 +1586,10 @@ def update_from_scope(self, scope):
self._n_breadcrumbs_truncated = (
self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated
)
if scope._gen_ai_messages_truncated:
self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated)
if scope._gen_ai_original_message_count:
self._gen_ai_original_message_count.update(
scope._gen_ai_original_message_count
)
if scope._span:
self._span = scope._span
if scope._attachments:
Expand Down
113 changes: 84 additions & 29 deletions tests/test_ai_monitoring.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import uuid

import pytest

Expand Down Expand Up @@ -210,32 +211,32 @@ def large_messages():
class TestTruncateMessagesBySize:
def test_no_truncation_needed(self, sample_messages):
"""Test that messages under the limit are not truncated"""
result, removed_count = truncate_messages_by_size(
result, truncation_index = truncate_messages_by_size(
sample_messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
)
assert len(result) == len(sample_messages)
assert result == sample_messages
assert removed_count == 0
assert truncation_index == 0

def test_truncation_removes_oldest_first(self, large_messages):
"""Test that oldest messages are removed first during truncation"""
small_limit = 3000
result, removed_count = truncate_messages_by_size(
result, truncation_index = truncate_messages_by_size(
large_messages, max_bytes=small_limit
)
assert len(result) < len(large_messages)

if result:
assert result[-1] == large_messages[-1]
assert removed_count == len(large_messages) - len(result)
assert truncation_index == len(large_messages) - len(result)

def test_empty_messages_list(self):
"""Test handling of empty messages list"""
result, removed_count = truncate_messages_by_size(
result, truncation_index = truncate_messages_by_size(
[], max_bytes=MAX_GEN_AI_MESSAGE_BYTES // 500
)
assert result == []
assert removed_count == 0
assert truncation_index == 0

def test_find_truncation_index(
self,
Expand Down Expand Up @@ -290,7 +291,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

span = MockSpan()
scope = MockScope()
Expand All @@ -300,7 +301,7 @@ def __init__(self):
assert not isinstance(result, AnnotatedValue)
assert len(result) == len(sample_messages)
assert result == sample_messages
assert span.span_id not in scope._gen_ai_messages_truncated
assert span.span_id not in scope._gen_ai_original_message_count

def test_truncation_sets_metadata_on_scope(self, large_messages):
class MockSpan:
Expand All @@ -313,9 +314,9 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 1000
small_limit = 3000
span = MockSpan()
scope = MockScope()
original_count = len(large_messages)
Expand All @@ -326,10 +327,9 @@ def __init__(self):
assert isinstance(result, list)
assert not isinstance(result, AnnotatedValue)
assert len(result) < len(large_messages)
n_removed = original_count - len(result)
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
assert scope._gen_ai_original_message_count[span.span_id] == original_count

def test_scope_tracks_removed_messages(self, large_messages):
def test_scope_tracks_original_message_count(self, large_messages):
class MockSpan:
def __init__(self):
self.span_id = "test_span_id"
Expand All @@ -340,9 +340,9 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 1000
small_limit = 3000
original_count = len(large_messages)
span = MockSpan()
scope = MockScope()
Expand All @@ -351,9 +351,8 @@ def __init__(self):
large_messages, span, scope, max_bytes=small_limit
)

n_removed = original_count - len(result)
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
assert len(result) + n_removed == original_count
assert scope._gen_ai_original_message_count[span.span_id] == original_count
assert len(result) == 1

def test_empty_messages_returns_none(self):
class MockSpan:
Expand All @@ -366,7 +365,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

span = MockSpan()
scope = MockScope()
Expand All @@ -387,7 +386,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
Expand Down Expand Up @@ -416,7 +415,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
Expand All @@ -430,33 +429,89 @@ def __init__(self):
span.set_data(SPANDATA.GEN_AI_REQUEST_MESSAGES, truncated_messages)

# Verify metadata was set on scope
assert span.span_id in scope._gen_ai_messages_truncated
assert scope._gen_ai_messages_truncated[span.span_id] > 0
assert span.span_id in scope._gen_ai_original_message_count
assert scope._gen_ai_original_message_count[span.span_id] > 0

# Simulate what client.py does
event = {"spans": [{"span_id": span.span_id, "data": span.data.copy()}]}

# Mimic client.py logic - using scope to get the removed count
# Mimic client.py logic - using scope to get the original length
for event_span in event["spans"]:
span_id = event_span.get("span_id")
span_data = event_span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
messages = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]
n_removed = scope._gen_ai_messages_truncated[span_id]
n_remaining = len(messages) if isinstance(messages, list) else 0
original_count_calculated = n_removed + n_remaining
n_original_count = scope._gen_ai_original_message_count[span_id]

span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
safe_serialize(messages),
{"len": original_count_calculated},
{"len": n_original_count},
)

# Verify the annotation happened
messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_value, AnnotatedValue)
assert messages_value.metadata["len"] == original_count
assert isinstance(messages_value.value, str)

def test_annotated_value_shows_correct_original_length(self, large_messages):
"""Test that the annotated value correctly shows the original message count before truncation"""
from sentry_sdk.consts import SPANDATA

class MockSpan:
def __init__(self):
self.span_id = "test_span_456"
self.data = {}

def set_data(self, key, value):
self.data[key] = value

class MockScope:
def __init__(self):
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
scope = MockScope()
original_message_count = len(large_messages)

truncated_messages = truncate_and_annotate_messages(
large_messages, span, scope, max_bytes=small_limit
)

assert len(truncated_messages) < original_message_count

assert span.span_id in scope._gen_ai_original_message_count
stored_original_length = scope._gen_ai_original_message_count[span.span_id]
assert stored_original_length == original_message_count

event = {
"spans": [
{
"span_id": span.span_id,
"data": {SPANDATA.GEN_AI_REQUEST_MESSAGES: truncated_messages},
}
]
}

for event_span in event["spans"]:
span_id = event_span.get("span_id")
span_data = event_span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{"len": scope._gen_ai_original_message_count[span_id]},
)

messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_value, AnnotatedValue)
assert messages_value.metadata["len"] == stored_original_length
assert len(messages_value.value) == len(truncated_messages)