Skip to content

Commit 12dcf5d

Browse files
committed
incrementally compute cumulative message sizes
1 parent 405a096 commit 12dcf5d

File tree

2 files changed

+62
-41
lines changed

2 files changed

+62
-41
lines changed

sentry_sdk/ai/utils.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
2+
from collections import deque
23
from typing import TYPE_CHECKING
4+
from sys import getsizeof
35

46
if TYPE_CHECKING:
5-
from typing import Any, Callable, Dict, List, Optional
7+
from typing import Any, Callable, Dict, List, Optional, Tuple
68

79
from sentry_sdk.tracing import Span
810

@@ -99,21 +101,33 @@ def get_start_span_function():
99101
return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction
100102

101103

102-
def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES):
103-
# type: (List[Dict[str, Any]], int) -> List[Dict[str, Any]]
104-
if not messages:
105-
return messages
104+
def _find_truncation_index(messages, max_bytes):
105+
# type: (List[Dict[str, Any]], int) -> int
106+
"""
107+
Find the index of the first message that would exceed the max bytes limit.
108+
Compute the individual message sizes, and return the index of the first message from the back
109+
of the list that would exceed the max bytes limit.
110+
"""
111+
running_sum = 0
112+
for idx in range(len(messages) - 1, -1, -1):
113+
size = len(json.dumps(messages[idx], separators=(",", ":")))
114+
running_sum += size
115+
if running_sum > max_bytes:
116+
return idx + 1
106117

107-
truncated_messages = list(messages)
118+
return 0
108119

109-
while len(truncated_messages) > 1:
110-
serialized_json = json.dumps(truncated_messages, separators=(",", ":"))
111-
current_size = len(serialized_json.encode("utf-8"))
112-
if current_size <= max_bytes:
113-
break
114-
truncated_messages.pop(0)
115120

116-
return truncated_messages
121+
def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES):
122+
# type: (List[Dict[str, Any]], int) -> Tuple[List[Dict[str, Any]], int]
123+
serialized_json = json.dumps(messages, separators=(",", ":"))
124+
current_size = len(serialized_json.encode("utf-8"))
125+
126+
if current_size <= max_bytes:
127+
return messages, 0
128+
129+
truncation_index = _find_truncation_index(messages, max_bytes)
130+
return messages[truncation_index:], truncation_index
117131

118132

119133
def truncate_and_annotate_messages(
@@ -123,16 +137,10 @@ def truncate_and_annotate_messages(
123137
if not messages:
124138
return None
125139

126-
original_count = len(messages)
127-
truncated_messages = truncate_messages_by_size(messages, max_bytes)
128-
129-
if not truncated_messages:
130-
return None
131-
132-
truncated_count = len(truncated_messages)
133-
n_removed = original_count - truncated_count
134-
135-
if n_removed > 0:
136-
scope._gen_ai_messages_truncated[span.span_id] = n_removed
140+
truncated_messages, removed_count = truncate_messages_by_size(messages, max_bytes)
141+
if removed_count > 0:
142+
scope._gen_ai_messages_truncated[span.span_id] = len(messages) - len(
143+
truncated_messages
144+
)
137145

138146
return truncated_messages

tests/test_ai_monitoring.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
set_data_normalized,
1111
truncate_and_annotate_messages,
1212
truncate_messages_by_size,
13+
_find_truncation_index,
1314
)
1415
from sentry_sdk.serializer import serialize
1516
from sentry_sdk.utils import safe_serialize
@@ -209,27 +210,53 @@ def large_messages():
209210
class TestTruncateMessagesBySize:
210211
def test_no_truncation_needed(self, sample_messages):
211212
"""Test that messages under the limit are not truncated"""
212-
result = truncate_messages_by_size(
213+
result, removed_count = truncate_messages_by_size(
213214
sample_messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
214215
)
215216
assert len(result) == len(sample_messages)
216217
assert result == sample_messages
218+
assert removed_count == 0
217219

218220
def test_truncation_removes_oldest_first(self, large_messages):
219221
"""Test that oldest messages are removed first during truncation"""
220222
small_limit = 3000
221-
result = truncate_messages_by_size(large_messages, max_bytes=small_limit)
223+
result, removed_count = truncate_messages_by_size(
224+
large_messages, max_bytes=small_limit
225+
)
222226
assert len(result) < len(large_messages)
223227

224228
if result:
225229
assert result[-1] == large_messages[-1]
230+
assert removed_count == len(large_messages) - len(result)
226231

227232
def test_empty_messages_list(self):
228233
"""Test handling of empty messages list"""
229-
result = truncate_messages_by_size(
234+
result, removed_count = truncate_messages_by_size(
230235
[], max_bytes=MAX_GEN_AI_MESSAGE_BYTES // 500
231236
)
232237
assert result == []
238+
assert removed_count == 0
239+
240+
def test_find_truncation_index(
241+
self,
242+
):
243+
"""Test that the truncation index is found correctly"""
244+
# when represented in JSON, these are each 7 bytes long
245+
messages = ["A" * 5, "B" * 5, "C" * 5, "D" * 5, "E" * 5]
246+
truncation_index = _find_truncation_index(messages, 20)
247+
assert truncation_index == 3
248+
assert messages[truncation_index:] == ["D" * 5, "E" * 5]
249+
250+
messages = ["A" * 5, "B" * 5, "C" * 5, "D" * 5, "E" * 5]
251+
truncation_index = _find_truncation_index(messages, 40)
252+
assert truncation_index == 0
253+
assert messages[truncation_index:] == [
254+
"A" * 5,
255+
"B" * 5,
256+
"C" * 5,
257+
"D" * 5,
258+
"E" * 5,
259+
]
233260

234261
def test_progressive_truncation(self, large_messages):
235262
"""Test that truncation works progressively with different limits"""
@@ -250,20 +277,6 @@ def test_progressive_truncation(self, large_messages):
250277
assert current_count >= 1
251278
prev_count = current_count
252279

253-
def test_exact_size_boundary(self):
254-
"""Test behavior at exact size boundaries"""
255-
messages = [{"role": "user", "content": "test"}]
256-
257-
serialized = serialize(messages, is_vars=False)
258-
json_str = json.dumps(serialized, separators=(",", ":"))
259-
exact_size = len(json_str.encode("utf-8"))
260-
261-
result = truncate_messages_by_size(messages, max_bytes=exact_size)
262-
assert len(result) == 1
263-
264-
result = truncate_messages_by_size(messages, max_bytes=exact_size - 1)
265-
assert len(result) == 1
266-
267280

268281
class TestTruncateAndAnnotateMessages:
269282
def test_no_truncation_returns_list(self, sample_messages):

0 commit comments

Comments
 (0)