Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ def truncate_and_annotate_messages(

truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
if removed_count > 0:
scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len(
truncated_messages
)
scope._gen_ai_original_message_count[span.span_id] = len(messages)

return truncated_messages
9 changes: 3 additions & 6 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,23 +598,20 @@ def _prepare_event(
if event_scrubber:
event_scrubber.scrub_event(event)

if scope is not None and scope._gen_ai_messages_truncated:
if scope is not None and scope._gen_ai_original_message_count:
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
if isinstance(spans, list):
for span in spans:
span_id = span.get("span_id", None)
span_data = span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{
"len": scope._gen_ai_messages_truncated[span_id]
+ len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
},
{"len": scope._gen_ai_original_message_count[span_id]},
)
if previous_total_spans is not None:
event["spans"] = AnnotatedValue(
Expand Down
12 changes: 7 additions & 5 deletions sentry_sdk/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class Scope:
"_extras",
"_breadcrumbs",
"_n_breadcrumbs_truncated",
"_gen_ai_messages_truncated",
"_gen_ai_original_message_count",
"_event_processors",
"_error_processors",
"_should_capture",
Expand All @@ -214,7 +214,7 @@ def __init__(self, ty=None, client=None):
self._name = None # type: Optional[str]
self._propagation_context = None # type: Optional[PropagationContext]
self._n_breadcrumbs_truncated = 0 # type: int
self._gen_ai_messages_truncated = {} # type: Dict[str, int]
self._gen_ai_original_message_count = {} # type: Dict[str, int]

self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient

Expand Down Expand Up @@ -249,7 +249,7 @@ def __copy__(self):

rv._breadcrumbs = copy(self._breadcrumbs)
rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated
rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy()
rv._gen_ai_original_message_count = self._gen_ai_original_message_count.copy()
rv._event_processors = self._event_processors.copy()
rv._error_processors = self._error_processors.copy()
rv._propagation_context = self._propagation_context
Expand Down Expand Up @@ -1586,8 +1586,10 @@ def update_from_scope(self, scope):
self._n_breadcrumbs_truncated = (
self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated
)
if scope._gen_ai_messages_truncated:
self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated)
if scope._gen_ai_original_message_count:
self._gen_ai_original_message_count.update(
scope._gen_ai_original_message_count
)
if scope._span:
self._span = scope._span
if scope._attachments:
Expand Down
89 changes: 73 additions & 16 deletions tests/test_ai_monitoring.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import uuid

import pytest

Expand Down Expand Up @@ -290,7 +291,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

span = MockSpan()
scope = MockScope()
Expand All @@ -300,7 +301,7 @@ def __init__(self):
assert not isinstance(result, AnnotatedValue)
assert len(result) == len(sample_messages)
assert result == sample_messages
assert span.span_id not in scope._gen_ai_messages_truncated
assert span.span_id not in scope._gen_ai_original_message_count

def test_truncation_sets_metadata_on_scope(self, large_messages):
class MockSpan:
Expand All @@ -313,7 +314,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 1000
span = MockSpan()
Expand All @@ -327,7 +328,7 @@ def __init__(self):
assert not isinstance(result, AnnotatedValue)
assert len(result) < len(large_messages)
n_removed = original_count - len(result)
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
assert scope._gen_ai_original_message_count[span.span_id] == n_removed
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Test Assertion Mismatch: Original vs Removed Messages

The _gen_ai_original_message_count property now stores the total original message count, but the tests incorrectly assert it against the number of removed messages (n_removed). This assertion should instead compare against the original_count.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please check what's going on here before merging @shellmayr

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, thank you. Coincidentally these were the same value, fixed the test to be clearer about the intent here.


def test_scope_tracks_removed_messages(self, large_messages):
class MockSpan:
Expand All @@ -340,7 +341,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 1000
original_count = len(large_messages)
Expand All @@ -352,7 +353,7 @@ def __init__(self):
)

n_removed = original_count - len(result)
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
assert scope._gen_ai_original_message_count[span.span_id] == n_removed
assert len(result) + n_removed == original_count

def test_empty_messages_returns_none(self):
Expand All @@ -366,7 +367,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

span = MockSpan()
scope = MockScope()
Expand All @@ -387,7 +388,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
Expand Down Expand Up @@ -416,7 +417,7 @@ def set_data(self, key, value):

class MockScope:
def __init__(self):
self._gen_ai_messages_truncated = {}
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
Expand All @@ -430,8 +431,8 @@ def __init__(self):
span.set_data(SPANDATA.GEN_AI_REQUEST_MESSAGES, truncated_messages)

# Verify metadata was set on scope
assert span.span_id in scope._gen_ai_messages_truncated
assert scope._gen_ai_messages_truncated[span.span_id] > 0
assert span.span_id in scope._gen_ai_original_message_count
assert scope._gen_ai_original_message_count[span.span_id] > 0

# Simulate what client.py does
event = {"spans": [{"span_id": span.span_id, "data": span.data.copy()}]}
Expand All @@ -442,21 +443,77 @@ def __init__(self):
span_data = event_span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
messages = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]
n_removed = scope._gen_ai_messages_truncated[span_id]
n_remaining = len(messages) if isinstance(messages, list) else 0
original_count_calculated = n_removed + n_remaining
n_original_count = scope._gen_ai_original_message_count[span_id]

span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
safe_serialize(messages),
{"len": original_count_calculated},
{"len": n_original_count},
)

# Verify the annotation happened
messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_value, AnnotatedValue)
assert messages_value.metadata["len"] == original_count
assert isinstance(messages_value.value, str)

def test_annotated_value_shows_correct_original_length(self, large_messages):
"""Test that the annotated value correctly shows the original message count before truncation"""
from sentry_sdk.consts import SPANDATA

class MockSpan:
def __init__(self):
self.span_id = "test_span_456"
self.data = {}

def set_data(self, key, value):
self.data[key] = value

class MockScope:
def __init__(self):
self._gen_ai_original_message_count = {}

small_limit = 3000
span = MockSpan()
scope = MockScope()
original_message_count = len(large_messages)

truncated_messages = truncate_and_annotate_messages(
large_messages, span, scope, max_bytes=small_limit
)

assert len(truncated_messages) < original_message_count

assert span.span_id in scope._gen_ai_original_message_count
stored_original_length = scope._gen_ai_original_message_count[span.span_id]
assert stored_original_length == original_message_count

event = {
"spans": [
{
"span_id": span.span_id,
"data": {SPANDATA.GEN_AI_REQUEST_MESSAGES: truncated_messages},
}
]
}

for event_span in event["spans"]:
span_id = event_span.get("span_id")
span_data = event_span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_original_message_count
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{"len": scope._gen_ai_original_message_count[span_id]},
)

messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_value, AnnotatedValue)
assert messages_value.metadata["len"] == stored_original_length
assert len(messages_value.value) == len(truncated_messages)