Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions sentry_sdk/integrations/langgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import wraps
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import sentry_sdk
from sentry_sdk.ai.utils import (
Expand All @@ -10,6 +10,7 @@
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import _get_value
from sentry_sdk.utils import safe_serialize


Expand Down Expand Up @@ -103,6 +104,127 @@ def _parse_langgraph_messages(state):
return normalized_messages if normalized_messages else None


def _extract_model_from_config(config):
# type: (Any) -> Optional[str]
if not config:
return None

if isinstance(config, dict):
model = config.get("model")
if model:
return str(model)

configurable = config.get("configurable", {})
if isinstance(configurable, dict):
model = configurable.get("model")
if model:
return str(model)

if hasattr(config, "model"):
return str(config.model)

if hasattr(config, "configurable"):
configurable = config.configurable
if isinstance(configurable, dict):
model = configurable.get("model")
if model:
return str(model)
elif hasattr(configurable, "model"):
return str(configurable.model)

return None


def _extract_model_from_pregel(pregel_instance):
# type: (Any) -> Optional[str]
if hasattr(pregel_instance, "config"):
model = _extract_model_from_config(pregel_instance.config)
if model:
return model

if hasattr(pregel_instance, "nodes"):
nodes = pregel_instance.nodes
if isinstance(nodes, dict):
for node_name, node in nodes.items():
if hasattr(node, "bound") and hasattr(node.bound, "model_name"):
return str(node.bound.model_name)
if hasattr(node, "runnable") and hasattr(node.runnable, "model_name"):
return str(node.runnable.model_name)

return None


def _get_token_usage(obj):
# type: (Any) -> Optional[Dict[str, Any]]
possible_names = ("usage", "token_usage", "usage_metadata")

for name in possible_names:
usage = _get_value(obj, name)
if usage is not None:
return usage

if isinstance(obj, dict):
messages = obj.get("messages", [])
if messages and isinstance(messages, list):
for message in reversed(messages):
for name in possible_names:
usage = _get_value(message, name)
if usage is not None:
return usage

return None


def _extract_tokens(token_usage):
# type: (Any) -> Tuple[Optional[int], Optional[int], Optional[int]]
input_tokens = _get_value(token_usage, "prompt_tokens") or _get_value(
token_usage, "input_tokens"
)
output_tokens = _get_value(token_usage, "completion_tokens") or _get_value(
token_usage, "output_tokens"
)
total_tokens = _get_value(token_usage, "total_tokens")

return input_tokens, output_tokens, total_tokens


def _record_token_usage(span, response):
# type: (Any, Any) -> None
token_usage = _get_token_usage(response)
if not token_usage:
return

input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)

if input_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)

if output_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)

if total_tokens is not None:
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)


def _extract_model_from_response(result):
# type: (Any) -> Optional[str]
if isinstance(result, dict):
messages = result.get("messages", [])
if messages and isinstance(messages, list):
for message in reversed(messages):
if hasattr(message, "response_metadata"):
metadata = message.response_metadata
if isinstance(metadata, dict):
model = metadata.get("model")
if model:
return str(model)
model_name = metadata.get("model_name")
if model_name:
return str(model_name)

return None


def _wrap_state_graph_compile(f):
# type: (Callable[..., Any]) -> Callable[..., Any]
@wraps(f)
Expand Down Expand Up @@ -175,7 +297,14 @@ def new_invoke(self, *args, **kwargs):

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")

# Store input messages to later compare with output
request_model = _extract_model_from_pregel(self)
if not request_model and len(kwargs) > 0:
config = kwargs.get("config")
request_model = _extract_model_from_config(config)

if request_model:
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, request_model)

input_messages = None
if (
len(args) > 0
Expand All @@ -199,6 +328,14 @@ def new_invoke(self, *args, **kwargs):

result = f(self, *args, **kwargs)

response_model = _extract_model_from_response(result)
if response_model:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
elif request_model:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, request_model)

_record_token_usage(span, result)

_set_response_attributes(span, input_messages, result, integration)

return result
Expand Down Expand Up @@ -232,6 +369,14 @@ async def new_ainvoke(self, *args, **kwargs):

span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")

request_model = _extract_model_from_pregel(self)
if not request_model and len(kwargs) > 0:
config = kwargs.get("config")
request_model = _extract_model_from_config(config)

if request_model:
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, request_model)

input_messages = None
if (
len(args) > 0
Expand All @@ -255,6 +400,14 @@ async def new_ainvoke(self, *args, **kwargs):

result = await f(self, *args, **kwargs)

response_model = _extract_model_from_response(result)
if response_model:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
elif request_model:
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, request_model)

_record_token_usage(span, result)

_set_response_attributes(span, input_messages, result, integration)

return result
Expand Down
180 changes: 180 additions & 0 deletions tests/integrations/langgraph/test_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,3 +755,183 @@ def original_invoke(self, *args, **kwargs):
assert "small message 4" in str(parsed_messages[0])
assert "small message 5" in str(parsed_messages[1])
assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5


def test_pregel_invoke_with_model_and_usage(sentry_init, capture_events):
"""Test that model and usage information are captured during graph execution."""
sentry_init(
integrations=[LanggraphIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

class MockMessageWithMetadata(MockMessage):
def __init__(self, content, response_metadata=None):
super().__init__(content, type="ai")
self.response_metadata = response_metadata or {}

class MockPregelWithModel:
def __init__(self, model_name):
self.name = "test_graph_with_model"
self.config = {"model": model_name}

def invoke(self, state, config=None):
return {
"messages": [
MockMessageWithMetadata(
"Response from model",
response_metadata={"model": "gpt-4"},
)
],
"usage_metadata": {
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
},
}

test_state = {"messages": [MockMessage("Hello, model test")]}
pregel = MockPregelWithModel("gpt-4")

def original_invoke(self, *args, **kwargs):
return self.invoke(*args, **kwargs)

with start_transaction():
wrapped_invoke = _wrap_pregel_invoke(original_invoke)
wrapped_invoke(pregel, test_state)

tx = events[0]
invoke_spans = [
span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
]
assert len(invoke_spans) == 1

invoke_span = invoke_spans[0]

assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-4"

assert SPANDATA.GEN_AI_RESPONSE_MODEL in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "gpt-4"

assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100

assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50

assert SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 150


def test_pregel_ainvoke_with_model_and_usage(sentry_init, capture_events):
"""Test that model and usage information are captured during async graph execution."""
sentry_init(
integrations=[LanggraphIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

class MockMessageWithMetadata(MockMessage):
def __init__(self, content, response_metadata=None):
super().__init__(content, type="ai")
self.response_metadata = response_metadata or {}

class MockPregelWithModel:
def __init__(self, model_name):
self.name = "async_graph_with_model"
self.config = {"model": model_name}

async def ainvoke(self, state, config=None):
return {
"messages": [
MockMessageWithMetadata(
"Async response from model",
response_metadata={"model": "claude-3"},
)
],
"usage_metadata": {
"input_tokens": 200,
"output_tokens": 75,
"total_tokens": 275,
},
}

test_state = {"messages": [MockMessage("Hello, async model test")]}
pregel = MockPregelWithModel("claude-3")

async def original_ainvoke(self, *args, **kwargs):
return await self.ainvoke(*args, **kwargs)

async def run_test():
with start_transaction():
wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
await wrapped_ainvoke(pregel, test_state)

asyncio.run(run_test())

tx = events[0]
invoke_spans = [
span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
]
assert len(invoke_spans) == 1

invoke_span = invoke_spans[0]

assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "claude-3"

assert SPANDATA.GEN_AI_RESPONSE_MODEL in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "claude-3"

assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200

assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 75

assert SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 275


def test_pregel_invoke_with_config_model(sentry_init, capture_events):
"""Test that model information is extracted from config parameter."""
sentry_init(
integrations=[LanggraphIntegration(include_prompts=True)],
traces_sample_rate=1.0,
send_default_pii=True,
)
events = capture_events()

class MockPregelNoModel:
def __init__(self):
self.name = "test_graph_config_model"

def invoke(self, state, config=None):
return {
"messages": [MockMessage("Response")],
}

test_state = {"messages": [MockMessage("Hello")]}
pregel = MockPregelNoModel()
config = {"model": "gpt-3.5-turbo"}

def original_invoke(self, *args, **kwargs):
return self.invoke(*args, **kwargs)

with start_transaction():
wrapped_invoke = _wrap_pregel_invoke(original_invoke)
wrapped_invoke(pregel, test_state, config=config)

tx = events[0]
invoke_spans = [
span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
]
assert len(invoke_spans) == 1

invoke_span = invoke_spans[0]

assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"]
assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-3.5-turbo"