11from functools import wraps
2- from typing import Any , Callable , List , Optional
2+ from typing import Any , Callable , Dict , List , Optional , Tuple
33
44import sentry_sdk
55from sentry_sdk .ai .utils import (
1010from sentry_sdk .consts import OP , SPANDATA
1111from sentry_sdk .integrations import DidNotEnable , Integration
1212from sentry_sdk .scope import should_send_default_pii
13+ from sentry_sdk .tracing_utils import _get_value
1314from sentry_sdk .utils import safe_serialize
1415
1516
@@ -103,6 +104,127 @@ def _parse_langgraph_messages(state):
103104 return normalized_messages if normalized_messages else None
104105
105106
107+ def _extract_model_from_config (config ):
108+ # type: (Any) -> Optional[str]
109+ if not config :
110+ return None
111+
112+ if isinstance (config , dict ):
113+ model = config .get ("model" )
114+ if model :
115+ return str (model )
116+
117+ configurable = config .get ("configurable" , {})
118+ if isinstance (configurable , dict ):
119+ model = configurable .get ("model" )
120+ if model :
121+ return str (model )
122+
123+ if hasattr (config , "model" ):
124+ return str (config .model )
125+
126+ if hasattr (config , "configurable" ):
127+ configurable = config .configurable
128+ if isinstance (configurable , dict ):
129+ model = configurable .get ("model" )
130+ if model :
131+ return str (model )
132+ elif hasattr (configurable , "model" ):
133+ return str (configurable .model )
134+
135+ return None
136+
137+
138+ def _extract_model_from_pregel (pregel_instance ):
139+ # type: (Any) -> Optional[str]
140+ if hasattr (pregel_instance , "config" ):
141+ model = _extract_model_from_config (pregel_instance .config )
142+ if model :
143+ return model
144+
145+ if hasattr (pregel_instance , "nodes" ):
146+ nodes = pregel_instance .nodes
147+ if isinstance (nodes , dict ):
148+ for node_name , node in nodes .items ():
149+ if hasattr (node , "bound" ) and hasattr (node .bound , "model_name" ):
150+ return str (node .bound .model_name )
151+ if hasattr (node , "runnable" ) and hasattr (node .runnable , "model_name" ):
152+ return str (node .runnable .model_name )
153+
154+ return None
155+
156+
157+ def _get_token_usage (obj ):
158+ # type: (Any) -> Optional[Dict[str, Any]]
159+ possible_names = ("usage" , "token_usage" , "usage_metadata" )
160+
161+ for name in possible_names :
162+ usage = _get_value (obj , name )
163+ if usage is not None :
164+ return usage
165+
166+ if isinstance (obj , dict ):
167+ messages = obj .get ("messages" , [])
168+ if messages and isinstance (messages , list ):
169+ for message in reversed (messages ):
170+ for name in possible_names :
171+ usage = _get_value (message , name )
172+ if usage is not None :
173+ return usage
174+
175+ return None
176+
177+
178+ def _extract_tokens (token_usage ):
179+ # type: (Any) -> Tuple[Optional[int], Optional[int], Optional[int]]
180+ input_tokens = _get_value (token_usage , "prompt_tokens" ) or _get_value (
181+ token_usage , "input_tokens"
182+ )
183+ output_tokens = _get_value (token_usage , "completion_tokens" ) or _get_value (
184+ token_usage , "output_tokens"
185+ )
186+ total_tokens = _get_value (token_usage , "total_tokens" )
187+
188+ return input_tokens , output_tokens , total_tokens
189+
190+
191+ def _record_token_usage (span , response ):
192+ # type: (Any, Any) -> None
193+ token_usage = _get_token_usage (response )
194+ if not token_usage :
195+ return
196+
197+ input_tokens , output_tokens , total_tokens = _extract_tokens (token_usage )
198+
199+ if input_tokens is not None :
200+ span .set_data (SPANDATA .GEN_AI_USAGE_INPUT_TOKENS , input_tokens )
201+
202+ if output_tokens is not None :
203+ span .set_data (SPANDATA .GEN_AI_USAGE_OUTPUT_TOKENS , output_tokens )
204+
205+ if total_tokens is not None :
206+ span .set_data (SPANDATA .GEN_AI_USAGE_TOTAL_TOKENS , total_tokens )
207+
208+
209+ def _extract_model_from_response (result ):
210+ # type: (Any) -> Optional[str]
211+ if isinstance (result , dict ):
212+ messages = result .get ("messages" , [])
213+ if messages and isinstance (messages , list ):
214+ for message in reversed (messages ):
215+ if hasattr (message , "response_metadata" ):
216+ metadata = message .response_metadata
217+ if isinstance (metadata , dict ):
218+ model = metadata .get ("model" )
219+ if model :
220+ return str (model )
221+ model_name = metadata .get ("model_name" )
222+ if model_name :
223+ return str (model_name )
224+
225+ return None
226+
227+
106228def _wrap_state_graph_compile (f ):
107229 # type: (Callable[..., Any]) -> Callable[..., Any]
108230 @wraps (f )
@@ -175,7 +297,14 @@ def new_invoke(self, *args, **kwargs):
175297
176298 span .set_data (SPANDATA .GEN_AI_OPERATION_NAME , "invoke_agent" )
177299
178- # Store input messages to later compare with output
300+ request_model = _extract_model_from_pregel (self )
301+ if not request_model and len (kwargs ) > 0 :
302+ config = kwargs .get ("config" )
303+ request_model = _extract_model_from_config (config )
304+
305+ if request_model :
306+ span .set_data (SPANDATA .GEN_AI_REQUEST_MODEL , request_model )
307+
179308 input_messages = None
180309 if (
181310 len (args ) > 0
@@ -199,6 +328,14 @@ def new_invoke(self, *args, **kwargs):
199328
200329 result = f (self , * args , ** kwargs )
201330
331+ response_model = _extract_model_from_response (result )
332+ if response_model :
333+ span .set_data (SPANDATA .GEN_AI_RESPONSE_MODEL , response_model )
334+ elif request_model :
335+ span .set_data (SPANDATA .GEN_AI_RESPONSE_MODEL , request_model )
336+
337+ _record_token_usage (span , result )
338+
202339 _set_response_attributes (span , input_messages , result , integration )
203340
204341 return result
@@ -232,6 +369,14 @@ async def new_ainvoke(self, *args, **kwargs):
232369
233370 span .set_data (SPANDATA .GEN_AI_OPERATION_NAME , "invoke_agent" )
234371
372+ request_model = _extract_model_from_pregel (self )
373+ if not request_model and len (kwargs ) > 0 :
374+ config = kwargs .get ("config" )
375+ request_model = _extract_model_from_config (config )
376+
377+ if request_model :
378+ span .set_data (SPANDATA .GEN_AI_REQUEST_MODEL , request_model )
379+
235380 input_messages = None
236381 if (
237382 len (args ) > 0
@@ -255,6 +400,14 @@ async def new_ainvoke(self, *args, **kwargs):
255400
256401 result = await f (self , * args , ** kwargs )
257402
403+ response_model = _extract_model_from_response (result )
404+ if response_model :
405+ span .set_data (SPANDATA .GEN_AI_RESPONSE_MODEL , response_model )
406+ elif request_model :
407+ span .set_data (SPANDATA .GEN_AI_RESPONSE_MODEL , request_model )
408+
409+ _record_token_usage (span , result )
410+
258411 _set_response_attributes (span , input_messages , result , integration )
259412
260413 return result
0 commit comments