@@ -82,7 +82,7 @@ def __init__(self, span):
8282
8383
8484class SentryLangchainCallback (BaseCallbackHandler ): # type: ignore[misc]
85- """Base callback handler that can be used to handle callbacks from langchain ."""
85+ """Callback handler that creates Sentry spans ."""
8686
8787 def __init__ (self , max_span_map_size , include_prompts ):
8888 # type: (int, bool) -> None
@@ -99,15 +99,18 @@ def gc_span_map(self):
9999
100100 def _handle_error (self , run_id , error ):
101101 # type: (UUID, Any) -> None
102- if not run_id or run_id not in self .span_map :
103- return
102+ with capture_internal_exceptions ():
103+ if not run_id or run_id not in self .span_map :
104+ return
104105
105- span_data = self .span_map .get (run_id )
106- if not span_data :
107- return
108- sentry_sdk .capture_exception (error , span_data .span .scope )
109- span_data .span .__exit__ (None , None , None )
110- del self .span_map [run_id ]
106+ span_data = self .span_map [run_id ]
107+ span = span_data .span
108+ span .set_status ("unknown" )
109+
110+ sentry_sdk .capture_exception (error , span .scope )
111+
112+ span .__exit__ (None , None , None )
113+ del self .span_map [run_id ]
111114
112115 def _normalize_langchain_message (self , message ):
113116 # type: (BaseMessage) -> Any
@@ -213,13 +216,13 @@ def _extract_token_usage_from_response(self, response):
213216
214217 def _create_span (self , run_id , parent_id , ** kwargs ):
215218 # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
216-
217219 watched_span = None # type: Optional[WatchedSpan]
218220 if parent_id :
219221 parent_span = self .span_map .get (parent_id ) # type: Optional[WatchedSpan]
220222 if parent_span :
221223 watched_span = WatchedSpan (parent_span .span .start_child (** kwargs ))
222224 parent_span .children .append (watched_span )
225+
223226 if watched_span is None :
224227 watched_span = WatchedSpan (sentry_sdk .start_span (** kwargs ))
225228
@@ -235,7 +238,6 @@ def _create_span(self, run_id, parent_id, **kwargs):
235238
236239 def _exit_span (self , span_data , run_id ):
237240 # type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
238-
239241 if span_data .is_pipeline :
240242 set_ai_pipeline_name (None )
241243
@@ -258,6 +260,7 @@ def on_llm_start(
258260 with capture_internal_exceptions ():
259261 if not run_id :
260262 return
263+
261264 all_params = kwargs .get ("invocation_params" , {})
262265 all_params .update (serialized .get ("kwargs" , {}))
263266
@@ -302,6 +305,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
302305 with capture_internal_exceptions ():
303306 if not run_id :
304307 return
308+
305309 all_params = kwargs .get ("invocation_params" , {})
306310 all_params .update (serialized .get ("kwargs" , {}))
307311
@@ -349,8 +353,12 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
349353 # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
350354 """Run when Chat Model ends running."""
351355 with capture_internal_exceptions ():
352- if not run_id :
356+ if not run_id or run_id not in self . span_map :
353357 return
358+
359+ span_data = self .span_map [run_id ]
360+ span = span_data .span
361+
354362 token_usage = None
355363
356364 # Try multiple paths to extract token usage, prioritizing streaming-aware approaches
@@ -370,13 +378,9 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
370378 elif hasattr (response , "usage_metadata" ):
371379 token_usage = response .usage_metadata
372380
373- span_data = self .span_map .get (run_id )
374- if not span_data :
375- return
376-
377381 if should_send_default_pii () and self .include_prompts :
378382 set_data_normalized (
379- span_data . span ,
383+ span ,
380384 SPANDATA .GEN_AI_RESPONSE_TEXT ,
381385 [[x .text for x in list_ ] for list_ in response .generations ],
382386 )
@@ -396,7 +400,7 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
396400 or total_tokens is not None
397401 ):
398402 record_token_usage (
399- span_data . span ,
403+ span ,
400404 input_tokens = input_tokens ,
401405 output_tokens = output_tokens ,
402406 total_tokens = total_tokens ,
@@ -407,40 +411,33 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
407411 def on_llm_new_token (self , token , * , run_id , ** kwargs ):
408412 # type: (SentryLangchainCallback, str, UUID, Any) -> Any
409413 """Run on new LLM token. Only available when streaming is enabled."""
410- # no manual token counting
411- with capture_internal_exceptions ():
412- return
414+ pass
413415
414416 def on_llm_end (self , response , * , run_id , ** kwargs ):
415417 # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
416418 """Run when LLM ends running."""
417419 with capture_internal_exceptions ():
418- if not run_id :
419- return
420-
421- span_data = self .span_map .get (run_id )
422- if not span_data :
420+ if not run_id or run_id not in self .span_map :
423421 return
424422
423+ span_data = self .span_map [run_id ]
425424 span = span_data .span
426425
427426 try :
428- generation_result = response .generations [0 ][0 ]
427+ generation = response .generations [0 ][0 ]
429428 except IndexError :
430- generation_result = None
429+ generation = None
431430
432- if generation_result is not None :
431+ if generation is not None :
433432 try :
434- response_model = generation_result .generation_info .get ("model_name" )
433+ response_model = generation .generation_info .get ("model_name" )
435434 if response_model is not None :
436435 span .set_data (SPANDATA .GEN_AI_RESPONSE_MODEL , response_model )
437436 except AttributeError :
438437 pass
439438
440439 try :
441- finish_reason = generation_result .generation_info .get (
442- "finish_reason"
443- )
440+ finish_reason = generation .generation_info .get ("finish_reason" )
444441 if finish_reason is not None :
445442 span .set_data (
446443 SPANDATA .GEN_AI_RESPONSE_FINISH_REASONS , finish_reason
@@ -449,7 +446,7 @@ def on_llm_end(self, response, *, run_id, **kwargs):
449446 pass
450447
451448 try :
452- tool_calls = getattr (generation_result .message , "tool_calls" , None )
449+ tool_calls = getattr (generation .message , "tool_calls" , None )
453450 if tool_calls is not None :
454451 set_data_normalized (
455452 span ,
@@ -462,7 +459,7 @@ def on_llm_end(self, response, *, run_id, **kwargs):
462459
463460 if should_send_default_pii () and self .include_prompts :
464461 set_data_normalized (
465- span_data . span ,
462+ span ,
466463 SPANDATA .GEN_AI_RESPONSE_TEXT ,
467464 [[x .text for x in list_ ] for list_ in response .generations ],
468465 )
@@ -506,14 +503,12 @@ def on_llm_end(self, response, *, run_id, **kwargs):
506503 def on_llm_error (self , error , * , run_id , ** kwargs ):
507504 # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
508505 """Run when LLM errors."""
509- with capture_internal_exceptions ():
510- self ._handle_error (run_id , error )
506+ self ._handle_error (run_id , error )
511507
512508 def on_chat_model_error (self , error , * , run_id , ** kwargs ):
513509 # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
514510 """Run when Chat Model errors."""
515- with capture_internal_exceptions ():
516- self ._handle_error (run_id , error )
511+ self ._handle_error (run_id , error )
517512
518513 def on_chain_start (self , serialized , inputs , * , run_id , ** kwargs ):
519514 # type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any
@@ -527,9 +522,7 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
527522 if not run_id or run_id not in self .span_map :
528523 return
529524
530- span_data = self .span_map .get (run_id )
531- if not span_data :
532- return
525+ span_data = self .span_map [run_id ]
533526 self ._exit_span (span_data , run_id )
534527
535528 def on_chain_error (self , error , * , run_id , ** kwargs ):
@@ -543,26 +536,25 @@ def on_agent_action(self, action, *, run_id, **kwargs):
543536 if not run_id or run_id not in self .span_map :
544537 return
545538
546- span_data = self .span_map .get (run_id )
547- if not span_data :
548- return
539+ span_data = self .span_map [run_id ]
549540 self ._exit_span (span_data , run_id )
550541
551542 def on_agent_finish (self , finish , * , run_id , ** kwargs ):
552543 # type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
553544 with capture_internal_exceptions ():
554- if not run_id :
545+ if not run_id or run_id not in self . span_map :
555546 return
556547
557- span_data = self .span_map . get ( run_id )
558- if not span_data :
559- return
548+ span_data = self .span_map [ run_id ]
549+ span = span_data . span
550+
560551 if should_send_default_pii () and self .include_prompts :
561552 set_data_normalized (
562- span_data . span ,
553+ span ,
563554 SPANDATA .GEN_AI_RESPONSE_TEXT ,
564555 finish .return_values .items (),
565556 )
557+
566558 self ._exit_span (span_data , run_id )
567559
568560 def on_tool_start (self , serialized , input_str , * , run_id , ** kwargs ):
@@ -604,22 +596,17 @@ def on_tool_end(self, output, *, run_id, **kwargs):
604596 if not run_id or run_id not in self .span_map :
605597 return
606598
607- span_data = self .span_map . get ( run_id )
608- if not span_data :
609- return
599+ span_data = self .span_map [ run_id ]
600+ span = span_data . span
601+
610602 if should_send_default_pii () and self .include_prompts :
611- set_data_normalized (span_data .span , SPANDATA .GEN_AI_TOOL_OUTPUT , output )
603+ set_data_normalized (span , SPANDATA .GEN_AI_TOOL_OUTPUT , output )
604+
612605 self ._exit_span (span_data , run_id )
613606
614607 def on_tool_error (self , error , * args , run_id , ** kwargs ):
615608 # type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
616609 """Run when tool errors."""
617- # TODO(shellmayr): how to correctly set the status when the tool fails?
618- if run_id and run_id in self .span_map :
619- span_data = self .span_map .get (run_id )
620- if span_data :
621- span_data .span .set_status ("unknown" )
622-
623610 self ._handle_error (run_id , error )
624611
625612
0 commit comments