Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion sentry_sdk/ai/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import json

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Callable

from sentry_sdk.tracing import Span

from typing import TYPE_CHECKING

import sentry_sdk
from sentry_sdk.utils import logger

if TYPE_CHECKING:
from typing import Any, Dict, List, Optional

MAX_GEN_AI_MESSAGE_BYTES = 20_000 # 20KB


class GEN_AI_ALLOWED_MESSAGE_ROLES:
SYSTEM = "system"
Expand Down Expand Up @@ -95,3 +102,78 @@ def get_start_span_function():
current_span is not None and current_span.containing_transaction is not None
)
return sentry_sdk.start_span if transaction_exists else sentry_sdk.start_transaction


def truncate_messages_by_size(messages, max_bytes=MAX_GEN_AI_MESSAGE_BYTES):
# type: (List[Dict[str, Any]], int) -> List[Dict[str, Any]]
if not messages:
return messages

truncated_messages = list(messages)

while len(truncated_messages) > 1:
serialized_json = json.dumps(truncated_messages, separators=(",", ":"))
current_size = len(serialized_json.encode("utf-8"))

if current_size <= max_bytes:
break

truncated_messages.pop(0)

serialized_json = json.dumps(truncated_messages, separators=(",", ":"))
current_size = len(serialized_json.encode("utf-8"))

if current_size > max_bytes and len(truncated_messages) == 1:
message = truncated_messages[0].copy()
content = message.get("content", "")

if isinstance(content, str):
max_content_length = max_bytes // 2
while True:
message["content"] = content[:max_content_length]
test_json = json.dumps([message], separators=(",", ":"))
if len(test_json.encode("utf-8")) <= max_bytes:
break
max_content_length = int(max_content_length * 0.9)
if max_content_length < 100:
message["content"] = ""
break

truncated_messages = [message]
elif isinstance(content, list):
content_copy = list(content)
while len(content_copy) > 0:
message["content"] = content_copy
test_json = json.dumps([message], separators=(",", ":"))
if len(test_json.encode("utf-8")) <= max_bytes:
break
content_copy = content_copy[:-1]

if len(content_copy) == 0:
message["content"] = []

truncated_messages = [message]

return truncated_messages


def truncate_and_annotate_messages(
messages, span, scope, max_bytes=MAX_GEN_AI_MESSAGE_BYTES
):
# type: (Optional[List[Dict[str, Any]]], Any, Any, int) -> Optional[List[Dict[str, Any]]]
if not messages:
return None

original_count = len(messages)
truncated_messages = truncate_messages_by_size(messages, max_bytes)

if not truncated_messages:
return None

truncated_count = len(truncated_messages)
n_removed = original_count - truncated_count

if n_removed > 0:
scope._gen_ai_messages_truncated[span.span_id] = n_removed

return truncated_messages
18 changes: 18 additions & 0 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,23 @@ def _prepare_event(
if event_scrubber:
event_scrubber.scrub_event(event)

if scope is not None and scope._gen_ai_messages_truncated:
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue[List[Dict[str, Any]]]
for span in spans:
span_id = span.get("span_id", None)
span_data = span.get("data", {})
if (
span_id
and span_id in scope._gen_ai_messages_truncated
and SPANDATA.GEN_AI_REQUEST_MESSAGES in span_data
):
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] = AnnotatedValue(
span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES],
{
"len": scope._gen_ai_messages_truncated[span_id]
+ len(span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES])
},
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect Message Count Calculation

The calculation for the original message count in AnnotatedValue for SPANDATA.GEN_AI_REQUEST_MESSAGES is incorrect. span_data[SPANDATA.GEN_AI_REQUEST_MESSAGES] is a JSON string (serialized by set_data_normalized), so len() returns its character count instead of the number of messages. This results in an inaccurate "len" metadata value.

Fix in Cursor Fix in Web

if previous_total_spans is not None:
event["spans"] = AnnotatedValue(
event.get("spans", []), {"len": previous_total_spans}
Expand All @@ -606,6 +623,7 @@ def _prepare_event(
event["breadcrumbs"] = AnnotatedValue(
event.get("breadcrumbs", []), {"len": previous_total_breadcrumbs}
)

# Postprocess the event here so that annotated types do
# generally not surface in before_send
if event is not None:
Expand Down
18 changes: 12 additions & 6 deletions sentry_sdk/integrations/openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from functools import wraps
from collections.abc import Iterable

import sentry_sdk
from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import set_data_normalized, normalize_message_roles
from sentry_sdk.ai.utils import (
set_data_normalized,
normalize_message_roles,
truncate_and_annotate_messages,
)
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
Expand All @@ -18,7 +21,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, List, Optional, Callable, AsyncIterator, Iterator
from typing import Any, Iterable, List, Optional, Callable, AsyncIterator, Iterator
from sentry_sdk.tracing import Span

try:
Expand Down Expand Up @@ -189,9 +192,12 @@ def _set_input_data(span, kwargs, operation, integration):
and integration.include_prompts
):
normalized_messages = normalize_message_roles(messages)
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, normalized_messages, unpack=False
)
scope = sentry_sdk.get_current_scope()
messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
if messages_data is not None:
set_data_normalized(
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
)

# Input attributes: Common
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, "openai")
Expand Down
5 changes: 5 additions & 0 deletions sentry_sdk/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class Scope:
"_extras",
"_breadcrumbs",
"_n_breadcrumbs_truncated",
"_gen_ai_messages_truncated",
"_event_processors",
"_error_processors",
"_should_capture",
Expand All @@ -213,6 +214,7 @@ def __init__(self, ty=None, client=None):
self._name = None # type: Optional[str]
self._propagation_context = None # type: Optional[PropagationContext]
self._n_breadcrumbs_truncated = 0 # type: int
self._gen_ai_messages_truncated = {} # type: Dict[str, int]

self.client = NonRecordingClient() # type: sentry_sdk.client.BaseClient

Expand Down Expand Up @@ -247,6 +249,7 @@ def __copy__(self):

rv._breadcrumbs = copy(self._breadcrumbs)
rv._n_breadcrumbs_truncated = self._n_breadcrumbs_truncated
rv._gen_ai_messages_truncated = self._gen_ai_messages_truncated.copy()
rv._event_processors = self._event_processors.copy()
rv._error_processors = self._error_processors.copy()
rv._propagation_context = self._propagation_context
Expand Down Expand Up @@ -1583,6 +1586,8 @@ def update_from_scope(self, scope):
self._n_breadcrumbs_truncated = (
self._n_breadcrumbs_truncated + scope._n_breadcrumbs_truncated
)
if scope._gen_ai_messages_truncated:
self._gen_ai_messages_truncated.update(scope._gen_ai_messages_truncated)
if scope._span:
self._span = scope._span
if scope._attachments:
Expand Down
63 changes: 58 additions & 5 deletions tests/integrations/openai/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import pytest

from sentry_sdk.utils import package_version
Expand All @@ -6,7 +7,6 @@
from openai import NOT_GIVEN
except ImportError:
NOT_GIVEN = None

try:
from openai import omit
except ImportError:
Expand Down Expand Up @@ -44,6 +44,9 @@
OpenAIIntegration,
_calculate_token_usage,
)
from sentry_sdk.ai.utils import MAX_GEN_AI_MESSAGE_BYTES
from sentry_sdk._types import AnnotatedValue
from sentry_sdk.serializer import serialize

from unittest import mock # python 3.3 and above

Expand Down Expand Up @@ -1456,6 +1459,7 @@ def test_empty_tools_in_chat_completion(sentry_init, capture_events, tools):

def test_openai_message_role_mapping(sentry_init, capture_events):
"""Test that OpenAI integration properly maps message roles like 'ai' to 'assistant'"""

sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
Expand All @@ -1465,7 +1469,6 @@ def test_openai_message_role_mapping(sentry_init, capture_events):

client = OpenAI(api_key="z")
client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)

# Test messages with mixed roles including "ai" that should be mapped to "assistant"
test_messages = [
{"role": "system", "content": "You are helpful."},
Expand All @@ -1476,11 +1479,9 @@ def test_openai_message_role_mapping(sentry_init, capture_events):

with start_transaction(name="openai tx"):
client.chat.completions.create(model="test-model", messages=test_messages)

# Verify that the span was created correctly
(event,) = events
span = event["spans"][0]

# Verify that the span was created correctly
assert span["op"] == "gen_ai.chat"
assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]

Expand All @@ -1505,3 +1506,55 @@ def test_openai_message_role_mapping(sentry_init, capture_events):
# Verify no "ai" roles remain
roles = [msg["role"] for msg in stored_messages]
assert "ai" not in roles


def test_openai_message_truncation(sentry_init, capture_events):
"""Test that large messages are truncated properly in OpenAI integration."""
sentry_init(
integrations=[OpenAIIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

client = OpenAI(api_key="z")
client.chat.completions._post = mock.Mock(return_value=EXAMPLE_CHAT_COMPLETION)

large_content = (
"This is a very long message that will exceed our size limits. " * 1000
)
large_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": large_content},
{"role": "assistant", "content": large_content},
{"role": "user", "content": large_content},
]

with start_transaction(name="openai tx"):
client.chat.completions.create(
model="some-model",
messages=large_messages,
)

(event,) = events
span = event["spans"][0]
assert SPANDATA.GEN_AI_REQUEST_MESSAGES in span["data"]

messages_data = span["data"][SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert isinstance(messages_data, str)

parsed_messages = json.loads(messages_data)
assert isinstance(parsed_messages, list)
assert len(parsed_messages) <= len(large_messages)

if "_meta" in event and len(parsed_messages) < len(large_messages):
meta_path = event["_meta"]
if (
"spans" in meta_path
and "0" in meta_path["spans"]
and "data" in meta_path["spans"]["0"]
):
span_meta = meta_path["spans"]["0"]["data"]
if SPANDATA.GEN_AI_REQUEST_MESSAGES in span_meta:
messages_meta = span_meta[SPANDATA.GEN_AI_REQUEST_MESSAGES]
assert "len" in messages_meta.get("", {})
Loading
Loading