diff --git a/sentry_sdk/integrations/langgraph.py b/sentry_sdk/integrations/langgraph.py index 5bb0e0fd08..4b0e5f3ad4 100644 --- a/sentry_sdk/integrations/langgraph.py +++ b/sentry_sdk/integrations/langgraph.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import sentry_sdk from sentry_sdk.ai.utils import ( @@ -10,6 +10,7 @@ from sentry_sdk.consts import OP, SPANDATA from sentry_sdk.integrations import DidNotEnable, Integration from sentry_sdk.scope import should_send_default_pii +from sentry_sdk.tracing_utils import _get_value from sentry_sdk.utils import safe_serialize @@ -103,6 +104,127 @@ def _parse_langgraph_messages(state): return normalized_messages if normalized_messages else None +def _extract_model_from_config(config): + # type: (Any) -> Optional[str] + if not config: + return None + + if isinstance(config, dict): + model = config.get("model") + if model: + return str(model) + + configurable = config.get("configurable", {}) + if isinstance(configurable, dict): + model = configurable.get("model") + if model: + return str(model) + + if hasattr(config, "model"): + return str(config.model) + + if hasattr(config, "configurable"): + configurable = config.configurable + if isinstance(configurable, dict): + model = configurable.get("model") + if model: + return str(model) + elif hasattr(configurable, "model"): + return str(configurable.model) + + return None + + +def _extract_model_from_pregel(pregel_instance): + # type: (Any) -> Optional[str] + if hasattr(pregel_instance, "config"): + model = _extract_model_from_config(pregel_instance.config) + if model: + return model + + if hasattr(pregel_instance, "nodes"): + nodes = pregel_instance.nodes + if isinstance(nodes, dict): + for node_name, node in nodes.items(): + if hasattr(node, "bound") and hasattr(node.bound, "model_name"): + return str(node.bound.model_name) + if hasattr(node, "runnable") and hasattr(node.runnable, "model_name"): + return str(node.runnable.model_name) + + return None + + +def _get_token_usage(obj): + # type: (Any) -> Optional[Dict[str, Any]] + possible_names = ("usage", "token_usage", "usage_metadata") + + for name in possible_names: + usage = _get_value(obj, name) + if usage is not None: + return usage + + if isinstance(obj, dict): + messages = obj.get("messages", []) + if messages and isinstance(messages, list): + for message in reversed(messages): + for name in possible_names: + usage = _get_value(message, name) + if usage is not None: + return usage + + return None + + +def _extract_tokens(token_usage): + # type: (Any) -> Tuple[Optional[int], Optional[int], Optional[int]] + input_tokens = _get_value(token_usage, "prompt_tokens") or _get_value( + token_usage, "input_tokens" + ) + output_tokens = _get_value(token_usage, "completion_tokens") or _get_value( + token_usage, "output_tokens" + ) + total_tokens = _get_value(token_usage, "total_tokens") + + return input_tokens, output_tokens, total_tokens + + +def _record_token_usage(span, response): + # type: (Any, Any) -> None + token_usage = _get_token_usage(response) + if not token_usage: + return + + input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage) + + if input_tokens is not None: + span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens) + + if output_tokens is not None: + span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens) + + if total_tokens is not None: + span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens) + + +def _extract_model_from_response(result): + # type: (Any) -> Optional[str] + if isinstance(result, dict): + messages = result.get("messages", []) + if messages and isinstance(messages, list): + for message in reversed(messages): + if hasattr(message, "response_metadata"): + metadata = message.response_metadata + if isinstance(metadata, dict): + model = metadata.get("model") + if model: + return str(model) + model_name = metadata.get("model_name") + if model_name: + return str(model_name) + + return None + + def _wrap_state_graph_compile(f): # type: (Callable[..., Any]) -> Callable[..., Any] @wraps(f) @@ -175,7 +297,14 @@ def new_invoke(self, *args, **kwargs): span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent") - # Store input messages to later compare with output + request_model = _extract_model_from_pregel(self) + if not request_model and len(kwargs) > 0: + config = kwargs.get("config") + request_model = _extract_model_from_config(config) + + if request_model: + span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, request_model) + input_messages = None if ( len(args) > 0 @@ -199,6 +328,14 @@ def new_invoke(self, *args, **kwargs): result = f(self, *args, **kwargs) + response_model = _extract_model_from_response(result) + if response_model: + span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model) + elif request_model: + span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, request_model) + + _record_token_usage(span, result) + _set_response_attributes(span, input_messages, result, integration) return result @@ -232,6 +369,14 @@ async def new_ainvoke(self, *args, **kwargs): span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent") + request_model = _extract_model_from_pregel(self) + if not request_model and len(kwargs) > 0: + config = kwargs.get("config") + request_model = _extract_model_from_config(config) + + if request_model: + span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, request_model) + input_messages = None if ( len(args) > 0 @@ -255,6 +400,14 @@ async def new_ainvoke(self, *args, **kwargs): result = await f(self, *args, **kwargs) + response_model = _extract_model_from_response(result) + if response_model: + span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model) + elif request_model: + span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, request_model) + + _record_token_usage(span, result) + _set_response_attributes(span, input_messages, result, integration) return result diff --git a/tests/integrations/langgraph/test_langgraph.py b/tests/integrations/langgraph/test_langgraph.py index df574dd2c3..1060583663 100644 --- a/tests/integrations/langgraph/test_langgraph.py +++ b/tests/integrations/langgraph/test_langgraph.py @@ -755,3 +755,183 @@ def original_invoke(self, *args, **kwargs): assert "small message 4" in str(parsed_messages[0]) assert "small message 5" in str(parsed_messages[1]) assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5 + + +def test_pregel_invoke_with_model_and_usage(sentry_init, capture_events): + """Test that model and usage information are captured during graph execution.""" + sentry_init( + integrations=[LanggraphIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + class MockMessageWithMetadata(MockMessage): + def __init__(self, content, response_metadata=None): + super().__init__(content, type="ai") + self.response_metadata = response_metadata or {} + + class MockPregelWithModel: + def __init__(self, model_name): + self.name = "test_graph_with_model" + self.config = {"model": model_name} + + def invoke(self, state, config=None): + return { + "messages": [ + MockMessageWithMetadata( + "Response from model", + response_metadata={"model": "gpt-4"}, + ) + ], + "usage_metadata": { + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + }, + } + + test_state = {"messages": [MockMessage("Hello, model test")]} + pregel = MockPregelWithModel("gpt-4") + + def original_invoke(self, *args, **kwargs): + return self.invoke(*args, **kwargs) + + with start_transaction(): + wrapped_invoke = _wrap_pregel_invoke(original_invoke) + wrapped_invoke(pregel, test_state) + + tx = events[0] + invoke_spans = [ + span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT + ] + assert len(invoke_spans) == 1 + + invoke_span = invoke_spans[0] + + assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-4" + + assert SPANDATA.GEN_AI_RESPONSE_MODEL in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "gpt-4" + + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100 + + assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50 + + assert SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 150 + + +def test_pregel_ainvoke_with_model_and_usage(sentry_init, capture_events): + """Test that model and usage information are captured during async graph execution.""" + sentry_init( + integrations=[LanggraphIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + class MockMessageWithMetadata(MockMessage): + def __init__(self, content, response_metadata=None): + super().__init__(content, type="ai") + self.response_metadata = response_metadata or {} + + class MockPregelWithModel: + def __init__(self, model_name): + self.name = "async_graph_with_model" + self.config = {"model": model_name} + + async def ainvoke(self, state, config=None): + return { + "messages": [ + MockMessageWithMetadata( + "Async response from model", + response_metadata={"model": "claude-3"}, + ) + ], + "usage_metadata": { + "input_tokens": 200, + "output_tokens": 75, + "total_tokens": 275, + }, + } + + test_state = {"messages": [MockMessage("Hello, async model test")]} + pregel = MockPregelWithModel("claude-3") + + async def original_ainvoke(self, *args, **kwargs): + return await self.ainvoke(*args, **kwargs) + + async def run_test(): + with start_transaction(): + wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke) + await wrapped_ainvoke(pregel, test_state) + + asyncio.run(run_test()) + + tx = events[0] + invoke_spans = [ + span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT + ] + assert len(invoke_spans) == 1 + + invoke_span = invoke_spans[0] + + assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "claude-3" + + assert SPANDATA.GEN_AI_RESPONSE_MODEL in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "claude-3" + + assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200 + + assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 75 + + assert SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 275 + + +def test_pregel_invoke_with_config_model(sentry_init, capture_events): + """Test that model information is extracted from config parameter.""" + sentry_init( + integrations=[LanggraphIntegration(include_prompts=True)], + traces_sample_rate=1.0, + send_default_pii=True, + ) + events = capture_events() + + class MockPregelNoModel: + def __init__(self): + self.name = "test_graph_config_model" + + def invoke(self, state, config=None): + return { + "messages": [MockMessage("Response")], + } + + test_state = {"messages": [MockMessage("Hello")]} + pregel = MockPregelNoModel() + config = {"model": "gpt-3.5-turbo"} + + def original_invoke(self, *args, **kwargs): + return self.invoke(*args, **kwargs) + + with start_transaction(): + wrapped_invoke = _wrap_pregel_invoke(original_invoke) + wrapped_invoke(pregel, test_state, config=config) + + tx = events[0] + invoke_spans = [ + span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT + ] + assert len(invoke_spans) == 1 + + invoke_span = invoke_spans[0] + + assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"] + assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-3.5-turbo"