Skip to content

Commit 70afcf1

Browse files
committed
fix(ai): correct size calculation, rename internal property for message truncation & add test
1 parent 814cd5a commit 70afcf1

File tree

4 files changed

+84
-30
lines changed

4 files changed

+84
-30
lines changed

sentry_sdk/ai/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ def truncate_and_annotate_messages(
139139

140140
truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
141141
if removed_count > 0:
142-
scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len(
143-
truncated_messages
144-
)
142+
scope._gen_ai_original_message_count[span.span_id] = len(messages)
145143

146144
return truncated_messages

sentry_sdk/client.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -598,23 +598,20 @@ def _prepare_event(
598598
if event_scrubber:
599599
event_scrubber.scrub_event(event)
600600

601-
if scope is not None and scope._gen_ai_messages_truncated:
601+
if scope is not None and scope._gen_ai_original_message_count:
602602
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
603603
if isinstance(spans, list):
604604
for span in spans:
605605
span_id = span.get("span_id", None)
606606
span_data = span.get("data", {})
607607
if (
608608
span_id
609-
and span_id in scope._gen_ai_messages_truncated
609+
and span_id in scope._gen_ai_original_message_count
610610
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
611611
):
612612
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
613613
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
614-
{
615-
"len": scope._gen_ai_messages_truncated[span_id]
616-
+ len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
617-
},
614+
{"len": scope._gen_ai_original_message_count[span_id]},
618615
)
619616
if previous_total_spans is not None:
620617
event["spans"] = AnnotatedValue(

sentry_sdk/scope.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ class Scope:
188188
"_extras",
189189
"_breadcrumbs",
190190
"_n_breadcrumbs_truncated",
191-
"_gen_ai_messages_truncated",
191+
"_gen_ai_original_message_count",
192192
"_event_processors",
193193
"_error_processors",
194194
"_should_capture",
@@ -214,7 +214,7 @@ def __init__(self, ty=None, client=None):
214214
self._name = None # type: Optional[str]
215215
self._propagation_context = None # type: Optional[PropagationContext]
216216
self._n_breadcrumbs_truncated = 0 # type: int
217-
self._gen_ai_messages_truncated = {} # type: Dict[str, int]
217+
self._gen_ai_original_message_count = {} # type: Dict[str, int]
218218

219219
self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient
220220

@@ -249,7 +249,7 @@ def __copy__(self):
249249

250250
rv._breadcrumbs = copy(self._breadcrumbs)
251251
rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated
252-
rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy()
252+
rv._gen_ai_original_message_count = self._gen_ai_original_message_count.copy()
253253
rv._event_processors = self._event_processors.copy()
254254
rv._error_processors = self._error_processors.copy()
255255
rv._propagation_context = self._propagation_context
@@ -1586,8 +1586,10 @@ def update_from_scope(self, scope):
15861586
self._n_breadcrumbs_truncated = (
15871587
self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated
15881588
)
1589-
if scope._gen_ai_messages_truncated:
1590-
self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated)
1589+
if scope._gen_ai_original_message_count:
1590+
self._gen_ai_original_message_count.update(
1591+
scope._gen_ai_original_message_count
1592+
)
15911593
if scope._span:
15921594
self._span = scope._span
15931595
if scope._attachments:

tests/test_ai_monitoring.py

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import uuid
23

34
import pytest
45

@@ -290,7 +291,7 @@ def set_data(self, key, value):
290291

291292
class MockScope:
292293
def __init__(self):
293-
self._gen_ai_messages_truncated = {}
294+
self._gen_ai_original_message_count = {}
294295

295296
span = MockSpan()
296297
scope = MockScope()
@@ -300,7 +301,7 @@ def __init__(self):
300301
assert not isinstance(result, AnnotatedValue)
301302
assert len(result) == len(sample_messages)
302303
assert result == sample_messages
303-
assert span.span_id not in scope._gen_ai_messages_truncated
304+
assert span.span_id not in scope._gen_ai_original_message_count
304305

305306
def test_truncation_sets_metadata_on_scope(self, large_messages):
306307
class MockSpan:
@@ -313,7 +314,7 @@ def set_data(self, key, value):
313314

314315
class MockScope:
315316
def __init__(self):
316-
self._gen_ai_messages_truncated = {}
317+
self._gen_ai_original_message_count = {}
317318

318319
small_limit = 1000
319320
span = MockSpan()
@@ -327,7 +328,7 @@ def __init__(self):
327328
assert not isinstance(result, AnnotatedValue)
328329
assert len(result) < len(large_messages)
329330
n_removed = original_count - len(result)
330-
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
331+
assert scope._gen_ai_original_message_count[span.span_id] == n_removed
331332

332333
def test_scope_tracks_removed_messages(self, large_messages):
333334
class MockSpan:
@@ -340,7 +341,7 @@ def set_data(self, key, value):
340341

341342
class MockScope:
342343
def __init__(self):
343-
self._gen_ai_messages_truncated = {}
344+
self._gen_ai_original_message_count = {}
344345

345346
small_limit = 1000
346347
original_count = len(large_messages)
@@ -352,7 +353,7 @@ def __init__(self):
352353
)
353354

354355
n_removed = original_count - len(result)
355-
assert scope._gen_ai_messages_truncated[span.span_id] == n_removed
356+
assert scope._gen_ai_original_message_count[span.span_id] == n_removed
356357
assert len(result) + n_removed == original_count
357358

358359
def test_empty_messages_returns_none(self):
@@ -366,7 +367,7 @@ def set_data(self, key, value):
366367

367368
class MockScope:
368369
def __init__(self):
369-
self._gen_ai_messages_truncated = {}
370+
self._gen_ai_original_message_count = {}
370371

371372
span = MockSpan()
372373
scope = MockScope()
@@ -387,7 +388,7 @@ def set_data(self, key, value):
387388

388389
class MockScope:
389390
def __init__(self):
390-
self._gen_ai_messages_truncated = {}
391+
self._gen_ai_original_message_count = {}
391392

392393
small_limit = 3000
393394
span = MockSpan()
@@ -416,7 +417,7 @@ def set_data(self, key, value):
416417

417418
class MockScope:
418419
def __init__(self):
419-
self._gen_ai_messages_truncated = {}
420+
self._gen_ai_original_message_count = {}
420421

421422
small_limit = 3000
422423
span = MockSpan()
@@ -430,8 +431,8 @@ def __init__(self):
430431
span.set_data(SPANDATA.GEN_AI_REQUEST_MESSAGES, truncated_messages)
431432

432433
# Verify metadata was set on scope
433-
assert span.span_id in scope._gen_ai_messages_truncated
434-
assert scope._gen_ai_messages_truncated[span.span_id] > 0
434+
assert span.span_id in scope._gen_ai_original_message_count
435+
assert scope._gen_ai_original_message_count[span.span_id] > 0
435436

436437
# Simulate what client.py does
437438
event = {"spans": [{"span_id": span.span_id, "data": span.data.copy()}]}
@@ -442,21 +443,77 @@ def __init__(self):
442443
span_data = event_span.get("data", {})
443444
if (
444445
span_id
445-
and span_id in scope._gen_ai_messages_truncated
446+
and span_id in scope._gen_ai_original_message_count
446447
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
447448
):
448449
messages = span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES]
449-
n_removed = scope._gen_ai_messages_truncated[span_id]
450-
n_remaining = len(messages) if isinstance(messages, list) else 0
451-
original_count_calculated = n_removed + n_remaining
450+
n_original_count = scope._gen_ai_original_message_count[span_id]
452451

453452
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
454453
safe_serialize(messages),
455-
{"len": original_count_calculated},
454+
{"len": n_original_count},
456455
)
457456

458457
# Verify the annotation happened
459458
messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
460459
assert isinstance(messages_value, AnnotatedValue)
461460
assert messages_value.metadata["len"] == original_count
462461
assert isinstance(messages_value.value, str)
462+
463+
def test_annotated_value_shows_correct_original_length(self, large_messages):
464+
"""Test that the annotated value correctly shows the original message count before truncation"""
465+
from sentry_sdk.consts import SPANDATA
466+
467+
class MockSpan:
468+
def __init__(self):
469+
self.span_id = "test_span_456"
470+
self.data = {}
471+
472+
def set_data(self, key, value):
473+
self.data[key] = value
474+
475+
class MockScope:
476+
def __init__(self):
477+
self._gen_ai_original_message_count = {}
478+
479+
small_limit = 3000
480+
span = MockSpan()
481+
scope = MockScope()
482+
original_message_count = len(large_messages)
483+
484+
truncated_messages = truncate_and_annotate_messages(
485+
large_messages, span, scope, max_bytes=small_limit
486+
)
487+
488+
assert len(truncated_messages) < original_message_count
489+
490+
assert span.span_id in scope._gen_ai_original_message_count
491+
stored_original_length = scope._gen_ai_original_message_count[span.span_id]
492+
assert stored_original_length == original_message_count
493+
494+
event = {
495+
"spans": [
496+
{
497+
"span_id": span.span_id,
498+
"data": {SPANDATA.GEN_AI_REQUEST_MESSAGES: truncated_messages},
499+
}
500+
]
501+
}
502+
503+
for event_span in event["spans"]:
504+
span_id = event_span.get("span_id")
505+
span_data = event_span.get("data", {})
506+
if (
507+
span_id
508+
and span_id in scope._gen_ai_original_message_count
509+
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
510+
):
511+
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
512+
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
513+
{"len": scope._gen_ai_original_message_count[span_id]},
514+
)
515+
516+
messages_value = event["spans"][0]["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
517+
assert isinstance(messages_value, AnnotatedValue)
518+
assert messages_value.metadata["len"] == stored_original_length
519+
assert len(messages_value.value) == len(truncated_messages)

0 commit comments

Comments
 (0)