Skip to content

Commit d7fa0e4

Browse files
committed
fix(langgraph): add missing attributes to invoke_agent span
1 parent 828e513 commit d7fa0e4

File tree

2 files changed

+335
-2
lines changed

2 files changed

+335
-2
lines changed

sentry_sdk/integrations/langgraph.py

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import wraps
2-
from typing import Any, Callable, List, Optional
2+
from typing import Any, Callable, Dict, List, Optional, Tuple
33

44
import sentry_sdk
55
from sentry_sdk.ai.utils import (
@@ -10,6 +10,7 @@
1010
from sentry_sdk.consts import OP, SPANDATA
1111
from sentry_sdk.integrations import DidNotEnable, Integration
1212
from sentry_sdk.scope import should_send_default_pii
13+
from sentry_sdk.tracing_utils import _get_value
1314
from sentry_sdk.utils import safe_serialize
1415

1516

@@ -103,6 +104,127 @@ def _parse_langgraph_messages(state):
103104
return normalized_messages if normalized_messages else None
104105

105106

107+
def _extract_model_from_config(config):
108+
# type: (Any) -> Optional[str]
109+
if not config:
110+
return None
111+
112+
if isinstance(config, dict):
113+
model = config.get("model")
114+
if model:
115+
return str(model)
116+
117+
configurable = config.get("configurable", {})
118+
if isinstance(configurable, dict):
119+
model = configurable.get("model")
120+
if model:
121+
return str(model)
122+
123+
if hasattr(config, "model"):
124+
return str(config.model)
125+
126+
if hasattr(config, "configurable"):
127+
configurable = config.configurable
128+
if isinstance(configurable, dict):
129+
model = configurable.get("model")
130+
if model:
131+
return str(model)
132+
elif hasattr(configurable, "model"):
133+
return str(configurable.model)
134+
135+
return None
136+
137+
138+
def _extract_model_from_pregel(pregel_instance):
139+
# type: (Any) -> Optional[str]
140+
if hasattr(pregel_instance, "config"):
141+
model = _extract_model_from_config(pregel_instance.config)
142+
if model:
143+
return model
144+
145+
if hasattr(pregel_instance, "nodes"):
146+
nodes = pregel_instance.nodes
147+
if isinstance(nodes, dict):
148+
for node_name, node in nodes.items():
149+
if hasattr(node, "bound") and hasattr(node.bound, "model_name"):
150+
return str(node.bound.model_name)
151+
if hasattr(node, "runnable") and hasattr(node.runnable, "model_name"):
152+
return str(node.runnable.model_name)
153+
154+
return None
155+
156+
157+
def _get_token_usage(obj):
158+
# type: (Any) -> Optional[Dict[str, Any]]
159+
possible_names = ("usage", "token_usage", "usage_metadata")
160+
161+
for name in possible_names:
162+
usage = _get_value(obj, name)
163+
if usage is not None:
164+
return usage
165+
166+
if isinstance(obj, dict):
167+
messages = obj.get("messages", [])
168+
if messages and isinstance(messages, list):
169+
for message in reversed(messages):
170+
for name in possible_names:
171+
usage = _get_value(message, name)
172+
if usage is not None:
173+
return usage
174+
175+
return None
176+
177+
178+
def _extract_tokens(token_usage):
179+
# type: (Any) -> Tuple[Optional[int], Optional[int], Optional[int]]
180+
input_tokens = _get_value(token_usage, "prompt_tokens") or _get_value(
181+
token_usage, "input_tokens"
182+
)
183+
output_tokens = _get_value(token_usage, "completion_tokens") or _get_value(
184+
token_usage, "output_tokens"
185+
)
186+
total_tokens = _get_value(token_usage, "total_tokens")
187+
188+
return input_tokens, output_tokens, total_tokens
189+
190+
191+
def _record_token_usage(span, response):
192+
# type: (Any, Any) -> None
193+
token_usage = _get_token_usage(response)
194+
if not token_usage:
195+
return
196+
197+
input_tokens, output_tokens, total_tokens = _extract_tokens(token_usage)
198+
199+
if input_tokens is not None:
200+
span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens)
201+
202+
if output_tokens is not None:
203+
span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens)
204+
205+
if total_tokens is not None:
206+
span.set_data(SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, total_tokens)
207+
208+
209+
def _extract_model_from_response(result):
210+
# type: (Any) -> Optional[str]
211+
if isinstance(result, dict):
212+
messages = result.get("messages", [])
213+
if messages and isinstance(messages, list):
214+
for message in reversed(messages):
215+
if hasattr(message, "response_metadata"):
216+
metadata = message.response_metadata
217+
if isinstance(metadata, dict):
218+
model = metadata.get("model")
219+
if model:
220+
return str(model)
221+
model_name = metadata.get("model_name")
222+
if model_name:
223+
return str(model_name)
224+
225+
return None
226+
227+
106228
def _wrap_state_graph_compile(f):
107229
# type: (Callable[..., Any]) -> Callable[..., Any]
108230
@wraps(f)
@@ -175,7 +297,14 @@ def new_invoke(self, *args, **kwargs):
175297

176298
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
177299

178-
# Store input messages to later compare with output
300+
request_model = _extract_model_from_pregel(self)
301+
if not request_model and len(kwargs) > 0:
302+
config = kwargs.get("config")
303+
request_model = _extract_model_from_config(config)
304+
305+
if request_model:
306+
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, request_model)
307+
179308
input_messages = None
180309
if (
181310
len(args) > 0
@@ -199,6 +328,14 @@ def new_invoke(self, *args, **kwargs):
199328

200329
result = f(self, *args, **kwargs)
201330

331+
response_model = _extract_model_from_response(result)
332+
if response_model:
333+
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
334+
elif request_model:
335+
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, request_model)
336+
337+
_record_token_usage(span, result)
338+
202339
_set_response_attributes(span, input_messages, result, integration)
203340

204341
return result
@@ -232,6 +369,14 @@ async def new_ainvoke(self, *args, **kwargs):
232369

233370
span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "invoke_agent")
234371

372+
request_model = _extract_model_from_pregel(self)
373+
if not request_model and len(kwargs) > 0:
374+
config = kwargs.get("config")
375+
request_model = _extract_model_from_config(config)
376+
377+
if request_model:
378+
span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, request_model)
379+
235380
input_messages = None
236381
if (
237382
len(args) > 0
@@ -255,6 +400,14 @@ async def new_ainvoke(self, *args, **kwargs):
255400

256401
result = await f(self, *args, **kwargs)
257402

403+
response_model = _extract_model_from_response(result)
404+
if response_model:
405+
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
406+
elif request_model:
407+
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, request_model)
408+
409+
_record_token_usage(span, result)
410+
258411
_set_response_attributes(span, input_messages, result, integration)
259412

260413
return result

tests/integrations/langgraph/test_langgraph.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,3 +755,183 @@ def original_invoke(self, *args, **kwargs):
755755
assert "small message 4" in str(parsed_messages[0])
756756
assert "small message 5" in str(parsed_messages[1])
757757
assert tx["_meta"]["spans"]["0"]["data"]["gen_ai.request.messages"][""]["len"] == 5
758+
759+
760+
def test_pregel_invoke_with_model_and_usage(sentry_init, capture_events):
761+
"""Test that model and usage information are captured during graph execution."""
762+
sentry_init(
763+
integrations=[LanggraphIntegration(include_prompts=True)],
764+
traces_sample_rate=1.0,
765+
send_default_pii=True,
766+
)
767+
events = capture_events()
768+
769+
class MockMessageWithMetadata(MockMessage):
770+
def __init__(self, content, response_metadata=None):
771+
super().__init__(content, type="ai")
772+
self.response_metadata = response_metadata or {}
773+
774+
class MockPregelWithModel:
775+
def __init__(self, model_name):
776+
self.name = "test_graph_with_model"
777+
self.config = {"model": model_name}
778+
779+
def invoke(self, state, config=None):
780+
return {
781+
"messages": [
782+
MockMessageWithMetadata(
783+
"Response from model",
784+
response_metadata={"model": "gpt-4"},
785+
)
786+
],
787+
"usage_metadata": {
788+
"input_tokens": 100,
789+
"output_tokens": 50,
790+
"total_tokens": 150,
791+
},
792+
}
793+
794+
test_state = {"messages": [MockMessage("Hello, model test")]}
795+
pregel = MockPregelWithModel("gpt-4")
796+
797+
def original_invoke(self, *args, **kwargs):
798+
return self.invoke(*args, **kwargs)
799+
800+
with start_transaction():
801+
wrapped_invoke = _wrap_pregel_invoke(original_invoke)
802+
wrapped_invoke(pregel, test_state)
803+
804+
tx = events[0]
805+
invoke_spans = [
806+
span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
807+
]
808+
assert len(invoke_spans) == 1
809+
810+
invoke_span = invoke_spans[0]
811+
812+
assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"]
813+
assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-4"
814+
815+
assert SPANDATA.GEN_AI_RESPONSE_MODEL in invoke_span["data"]
816+
assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "gpt-4"
817+
818+
assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in invoke_span["data"]
819+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 100
820+
821+
assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS in invoke_span["data"]
822+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 50
823+
824+
assert SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS in invoke_span["data"]
825+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 150
826+
827+
828+
def test_pregel_ainvoke_with_model_and_usage(sentry_init, capture_events):
829+
"""Test that model and usage information are captured during async graph execution."""
830+
sentry_init(
831+
integrations=[LanggraphIntegration(include_prompts=True)],
832+
traces_sample_rate=1.0,
833+
send_default_pii=True,
834+
)
835+
events = capture_events()
836+
837+
class MockMessageWithMetadata(MockMessage):
838+
def __init__(self, content, response_metadata=None):
839+
super().__init__(content, type="ai")
840+
self.response_metadata = response_metadata or {}
841+
842+
class MockPregelWithModel:
843+
def __init__(self, model_name):
844+
self.name = "async_graph_with_model"
845+
self.config = {"model": model_name}
846+
847+
async def ainvoke(self, state, config=None):
848+
return {
849+
"messages": [
850+
MockMessageWithMetadata(
851+
"Async response from model",
852+
response_metadata={"model": "claude-3"},
853+
)
854+
],
855+
"usage_metadata": {
856+
"input_tokens": 200,
857+
"output_tokens": 75,
858+
"total_tokens": 275,
859+
},
860+
}
861+
862+
test_state = {"messages": [MockMessage("Hello, async model test")]}
863+
pregel = MockPregelWithModel("claude-3")
864+
865+
async def original_ainvoke(self, *args, **kwargs):
866+
return await self.ainvoke(*args, **kwargs)
867+
868+
async def run_test():
869+
with start_transaction():
870+
wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke)
871+
await wrapped_ainvoke(pregel, test_state)
872+
873+
asyncio.run(run_test())
874+
875+
tx = events[0]
876+
invoke_spans = [
877+
span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
878+
]
879+
assert len(invoke_spans) == 1
880+
881+
invoke_span = invoke_spans[0]
882+
883+
assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"]
884+
assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "claude-3"
885+
886+
assert SPANDATA.GEN_AI_RESPONSE_MODEL in invoke_span["data"]
887+
assert invoke_span["data"][SPANDATA.GEN_AI_RESPONSE_MODEL] == "claude-3"
888+
889+
assert SPANDATA.GEN_AI_USAGE_INPUT_TOKENS in invoke_span["data"]
890+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 200
891+
892+
assert SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS in invoke_span["data"]
893+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS] == 75
894+
895+
assert SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS in invoke_span["data"]
896+
assert invoke_span["data"][SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS] == 275
897+
898+
899+
def test_pregel_invoke_with_config_model(sentry_init, capture_events):
900+
"""Test that model information is extracted from config parameter."""
901+
sentry_init(
902+
integrations=[LanggraphIntegration(include_prompts=True)],
903+
traces_sample_rate=1.0,
904+
send_default_pii=True,
905+
)
906+
events = capture_events()
907+
908+
class MockPregelNoModel:
909+
def __init__(self):
910+
self.name = "test_graph_config_model"
911+
912+
def invoke(self, state, config=None):
913+
return {
914+
"messages": [MockMessage("Response")],
915+
}
916+
917+
test_state = {"messages": [MockMessage("Hello")]}
918+
pregel = MockPregelNoModel()
919+
config = {"model": "gpt-3.5-turbo"}
920+
921+
def original_invoke(self, *args, **kwargs):
922+
return self.invoke(*args, **kwargs)
923+
924+
with start_transaction():
925+
wrapped_invoke = _wrap_pregel_invoke(original_invoke)
926+
wrapped_invoke(pregel, test_state, config=config)
927+
928+
tx = events[0]
929+
invoke_spans = [
930+
span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT
931+
]
932+
assert len(invoke_spans) == 1
933+
934+
invoke_span = invoke_spans[0]
935+
936+
assert SPANDATA.GEN_AI_REQUEST_MODEL in invoke_span["data"]
937+
assert invoke_span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "gpt-3.5-turbo"

0 commit comments

Comments
 (0)