Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from copy import deepcopy
from collections import deque
from typing import TYPE_CHECKING
from sys import getsizeof
Expand All @@ -12,6 +13,8 @@
from sentry_sdk.utils import logger

MAX_GEN_AI_MESSAGE_BYTES = 20_000 # 20KB
# Maximum characters when only a single message is left after bytes truncation
MAX_SINGLE_MESSAGE_CONTENT_CHARS = 10_000


class GEN_AI_ALLOWED_MESSAGE_ROLES:
Expand Down Expand Up @@ -101,6 +104,23 @@ def get_start_span_function():
return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction


def _truncate_single_message_content_if_present(message, max_chars):
# type: (Dict[str, Any], int) -> Dict[str, Any]
"""
Truncate a single message to fit within max_chars.
If the message is too large, truncate the content field.
"""
if not isinstance(message, dict) or "content" not in message:
return message
content = message["content"]

if not isinstance(content, str) or len(content) <= max_chars:
return message

message["content"] = content[:max_chars] + "..."
return message


def _find_truncation_index(messages, max_bytes):
# type: (List[Dict[str, Any]], int) -> int
"""
Expand All @@ -118,16 +138,37 @@ def _find_truncation_index(messages, max_bytes):
return 0


def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES):
# type: (List[Dict[str, Any]], int) -> Tuple[List[Dict[str, Any]], int]
def truncate_messages_by_size(
messages,
max_bytes=MAX_GEN_AI_MESSAGE_BYTES,
max_single_message_chars=MAX_SINGLE_MESSAGE_CONTENT_CHARS,
):
# type: (List[Dict[str, Any]], int, int) -> Tuple[List[Dict[str, Any]], int]
"""
Returns a truncated messages array, consisting of
- the last message, with the messages's content truncated to `max_single_message_chars` characters,
if the last message's size exceeds `max_bytes`; otherwise,
- the maximum number of messages, starting from the end of the `messages` array, whose total
serialized size does not exceed `max_bytes` bytes.
"""
serialized_json = json.dumps(messages, separators=(",", ":"))
current_size = len(serialized_json.encode("utf-8"))

if current_size <= max_bytes:
return messages, 0

truncation_index = _find_truncation_index(messages, max_bytes)
return messages[truncation_index:], truncation_index
truncated_messages = (
messages[truncation_index:]
if truncation_index < len(messages)
else messages[-1:]
)
if len(truncated_messages) == 1:
truncated_messages[0] = _truncate_single_message_content_if_present(
deepcopy(truncated_messages[0]), max_chars=max_single_message_chars
)

return truncated_messages, truncation_index


def truncate_and_annotate_messages(
Expand Down
78 changes: 78 additions & 0 deletions tests/test_ai_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,84 @@ def test_progressive_truncation(self, large_messages):
assert current_count >= 1
prev_count = current_count

def test_individual_message_truncation(self):
large_content = "This is a very long message. " * 1000

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": large_content},
]

result, truncation_index = truncate_messages_by_size(
messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
)

assert len(result) > 0

total_size = len(json.dumps(result, separators=(",", ":")).encode("utf-8"))
assert total_size <= MAX_GEN_AI_MESSAGE_BYTES

for msg in result:
msg_size = len(json.dumps(msg, separators=(",", ":")).encode("utf-8"))
assert msg_size <= MAX_GEN_AI_MESSAGE_BYTES

# If the last message is too large, the system message is not present
system_msgs = [m for m in result if m.get("role") == "system"]
assert len(system_msgs) == 0

# Confirm the user message is truncated with '...'
user_msgs = [m for m in result if m.get("role") == "user"]
assert len(user_msgs) == 1
assert user_msgs[0]["content"].endswith("...")
assert len(user_msgs[0]["content"]) < len(large_content)

def test_combined_individual_and_array_truncation(self):
huge_content = "X" * 25000
medium_content = "Y" * 5000

messages = [
{"role": "system", "content": medium_content},
{"role": "user", "content": huge_content},
{"role": "assistant", "content": medium_content},
{"role": "user", "content": "small"},
]

result, truncation_index = truncate_messages_by_size(
messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
)

assert len(result) > 0

total_size = len(json.dumps(result, separators=(",", ":")).encode("utf-8"))
assert total_size <= MAX_GEN_AI_MESSAGE_BYTES

for msg in result:
msg_size = len(json.dumps(msg, separators=(",", ":")).encode("utf-8"))
assert msg_size <= MAX_GEN_AI_MESSAGE_BYTES

# The last user "small" message should always be present and untruncated
last_user_msgs = [
m for m in result if m.get("role") == "user" and m["content"] == "small"
]
assert len(last_user_msgs) == 1

# If the huge message is present, it must be truncated
for user_msg in [
m for m in result if m.get("role") == "user" and "X" in m["content"]
]:
assert user_msg["content"].endswith("...")
assert len(user_msg["content"]) < len(huge_content)

# The medium messages, if present, should not be truncated
for expected_role in ["system", "assistant"]:
role_msgs = [m for m in result if m.get("role") == expected_role]
if role_msgs:
assert role_msgs[0]["content"].startswith("Y")
assert len(role_msgs[0]["content"]) <= len(medium_content)
assert not role_msgs[0]["content"].endswith("...") or len(
role_msgs[0]["content"]
) == len(medium_content)


class TestTruncateAndAnnotateMessages:
def test_no_truncation_returns_list(self, sample_messages):
Expand Down
Loading