Skip to content

Commit 58d79f7

Browse files
committed
Cleanup
1 parent 2deb597 commit 58d79f7

File tree

1 file changed

+47
-60
lines changed

1 file changed

+47
-60
lines changed

sentry_sdk/integrations/langchain.py

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, span):
8282

8383

8484
class SentryLangchainCallback(BaseCallbackHandler): # type: ignore[misc]
85-
"""Base callback handler that can be used to handle callbacks from langchain."""
85+
"""Callback handler that creates Sentry spans."""
8686

8787
def __init__(self, max_span_map_size, include_prompts):
8888
# type: (int, bool) -> None
@@ -99,15 +99,18 @@ def gc_span_map(self):
9999

100100
def _handle_error(self, run_id, error):
101101
# type: (UUID, Any) -> None
102-
if not run_id or run_id not in self.span_map:
103-
return
102+
with capture_internal_exceptions():
103+
if not run_id or run_id not in self.span_map:
104+
return
104105

105-
span_data = self.span_map.get(run_id)
106-
if not span_data:
107-
return
108-
sentry_sdk.capture_exception(error, span_data.span.scope)
109-
span_data.span.__exit__(None, None, None)
110-
del self.span_map[run_id]
106+
span_data = self.span_map[run_id]
107+
span = span_data.span
108+
span.set_status("unknown")
109+
110+
sentry_sdk.capture_exception(error, span.scope)
111+
112+
span.__exit__(None, None, None)
113+
del self.span_map[run_id]
111114

112115
def _normalize_langchain_message(self, message):
113116
# type: (BaseMessage) -> Any
@@ -213,13 +216,13 @@ def _extract_token_usage_from_response(self, response):
213216

214217
def _create_span(self, run_id, parent_id, **kwargs):
215218
# type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan
216-
217219
watched_span = None # type: Optional[WatchedSpan]
218220
if parent_id:
219221
parent_span = self.span_map.get(parent_id) # type: Optional[WatchedSpan]
220222
if parent_span:
221223
watched_span = WatchedSpan(parent_span.span.start_child(**kwargs))
222224
parent_span.children.append(watched_span)
225+
223226
if watched_span is None:
224227
watched_span = WatchedSpan(sentry_sdk.start_span(**kwargs))
225228

@@ -235,7 +238,6 @@ def _create_span(self, run_id, parent_id, **kwargs):
235238

236239
def _exit_span(self, span_data, run_id):
237240
# type: (SentryLangchainCallback, WatchedSpan, UUID) -> None
238-
239241
if span_data.is_pipeline:
240242
set_ai_pipeline_name(None)
241243

@@ -258,6 +260,7 @@ def on_llm_start(
258260
with capture_internal_exceptions():
259261
if not run_id:
260262
return
263+
261264
all_params = kwargs.get("invocation_params", {})
262265
all_params.update(serialized.get("kwargs", {}))
263266

@@ -302,6 +305,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
302305
with capture_internal_exceptions():
303306
if not run_id:
304307
return
308+
305309
all_params = kwargs.get("invocation_params", {})
306310
all_params.update(serialized.get("kwargs", {}))
307311

@@ -349,8 +353,12 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
349353
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
350354
"""Run when Chat Model ends running."""
351355
with capture_internal_exceptions():
352-
if not run_id:
356+
if not run_id or run_id not in self.span_map:
353357
return
358+
359+
span_data = self.span_map[run_id]
360+
span = span_data.span
361+
354362
token_usage = None
355363

356364
# Try multiple paths to extract token usage, prioritizing streaming-aware approaches
@@ -370,13 +378,9 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
370378
elif hasattr(response, "usage_metadata"):
371379
token_usage = response.usage_metadata
372380

373-
span_data = self.span_map.get(run_id)
374-
if not span_data:
375-
return
376-
377381
if should_send_default_pii() and self.include_prompts:
378382
set_data_normalized(
379-
span_data.span,
383+
span,
380384
SPANDATA.GEN_AI_RESPONSE_TEXT,
381385
[[x.text for x in list_] for list_ in response.generations],
382386
)
@@ -396,7 +400,7 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
396400
or total_tokens is not None
397401
):
398402
record_token_usage(
399-
span_data.span,
403+
span,
400404
input_tokens=input_tokens,
401405
output_tokens=output_tokens,
402406
total_tokens=total_tokens,
@@ -407,40 +411,33 @@ def on_chat_model_end(self, response, *, run_id, **kwargs):
407411
def on_llm_new_token(self, token, *, run_id, **kwargs):
408412
# type: (SentryLangchainCallback, str, UUID, Any) -> Any
409413
"""Run on new LLM token. Only available when streaming is enabled."""
410-
# no manual token counting
411-
with capture_internal_exceptions():
412-
return
414+
pass
413415

414416
def on_llm_end(self, response, *, run_id, **kwargs):
415417
# type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any
416418
"""Run when LLM ends running."""
417419
with capture_internal_exceptions():
418-
if not run_id:
419-
return
420-
421-
span_data = self.span_map.get(run_id)
422-
if not span_data:
420+
if not run_id or run_id not in self.span_map:
423421
return
424422

423+
span_data = self.span_map[run_id]
425424
span = span_data.span
426425

427426
try:
428-
generation_result = response.generations[0][0]
427+
generation = response.generations[0][0]
429428
except IndexError:
430-
generation_result = None
429+
generation = None
431430

432-
if generation_result is not None:
431+
if generation is not None:
433432
try:
434-
response_model = generation_result.generation_info.get("model_name")
433+
response_model = generation.generation_info.get("model_name")
435434
if response_model is not None:
436435
span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)
437436
except AttributeError:
438437
pass
439438

440439
try:
441-
finish_reason = generation_result.generation_info.get(
442-
"finish_reason"
443-
)
440+
finish_reason = generation.generation_info.get("finish_reason")
444441
if finish_reason is not None:
445442
span.set_data(
446443
SPANDATA.GEN_AI_RESPONSE_FINISH_REASONS, finish_reason
@@ -449,7 +446,7 @@ def on_llm_end(self, response, *, run_id, **kwargs):
449446
pass
450447

451448
try:
452-
tool_calls = getattr(generation_result.message, "tool_calls", None)
449+
tool_calls = getattr(generation.message, "tool_calls", None)
453450
if tool_calls is not None:
454451
set_data_normalized(
455452
span,
@@ -462,7 +459,7 @@ def on_llm_end(self, response, *, run_id, **kwargs):
462459

463460
if should_send_default_pii() and self.include_prompts:
464461
set_data_normalized(
465-
span_data.span,
462+
span,
466463
SPANDATA.GEN_AI_RESPONSE_TEXT,
467464
[[x.text for x in list_] for list_ in response.generations],
468465
)
@@ -506,14 +503,12 @@ def on_llm_end(self, response, *, run_id, **kwargs):
506503
def on_llm_error(self, error, *, run_id, **kwargs):
507504
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
508505
"""Run when LLM errors."""
509-
with capture_internal_exceptions():
510-
self._handle_error(run_id, error)
506+
self._handle_error(run_id, error)
511507

512508
def on_chat_model_error(self, error, *, run_id, **kwargs):
513509
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
514510
"""Run when Chat Model errors."""
515-
with capture_internal_exceptions():
516-
self._handle_error(run_id, error)
511+
self._handle_error(run_id, error)
517512

518513
def on_chain_start(self, serialized, inputs, *, run_id, **kwargs):
519514
# type: (SentryLangchainCallback, Dict[str, Any], Dict[str, Any], UUID, Any) -> Any
@@ -527,9 +522,7 @@ def on_chain_end(self, outputs, *, run_id, **kwargs):
527522
if not run_id or run_id not in self.span_map:
528523
return
529524

530-
span_data = self.span_map.get(run_id)
531-
if not span_data:
532-
return
525+
span_data = self.span_map[run_id]
533526
self._exit_span(span_data, run_id)
534527

535528
def on_chain_error(self, error, *, run_id, **kwargs):
@@ -543,26 +536,25 @@ def on_agent_action(self, action, *, run_id, **kwargs):
543536
if not run_id or run_id not in self.span_map:
544537
return
545538

546-
span_data = self.span_map.get(run_id)
547-
if not span_data:
548-
return
539+
span_data = self.span_map[run_id]
549540
self._exit_span(span_data, run_id)
550541

551542
def on_agent_finish(self, finish, *, run_id, **kwargs):
552543
# type: (SentryLangchainCallback, AgentFinish, UUID, Any) -> Any
553544
with capture_internal_exceptions():
554-
if not run_id:
545+
if not run_id or run_id not in self.span_map:
555546
return
556547

557-
span_data = self.span_map.get(run_id)
558-
if not span_data:
559-
return
548+
span_data = self.span_map[run_id]
549+
span = span_data.span
550+
560551
if should_send_default_pii() and self.include_prompts:
561552
set_data_normalized(
562-
span_data.span,
553+
span,
563554
SPANDATA.GEN_AI_RESPONSE_TEXT,
564555
finish.return_values.items(),
565556
)
557+
566558
self._exit_span(span_data, run_id)
567559

568560
def on_tool_start(self, serialized, input_str, *, run_id, **kwargs):
@@ -604,22 +596,17 @@ def on_tool_end(self, output, *, run_id, **kwargs):
604596
if not run_id or run_id not in self.span_map:
605597
return
606598

607-
span_data = self.span_map.get(run_id)
608-
if not span_data:
609-
return
599+
span_data = self.span_map[run_id]
600+
span = span_data.span
601+
610602
if should_send_default_pii() and self.include_prompts:
611-
set_data_normalized(span_data.span, SPANDATA.GEN_AI_TOOL_OUTPUT, output)
603+
set_data_normalized(span, SPANDATA.GEN_AI_TOOL_OUTPUT, output)
604+
612605
self._exit_span(span_data, run_id)
613606

614607
def on_tool_error(self, error, *args, run_id, **kwargs):
615608
# type: (SentryLangchainCallback, Union[Exception, KeyboardInterrupt], UUID, Any) -> Any
616609
"""Run when tool errors."""
617-
# TODO(shellmayr): how to correctly set the status when the tool fails?
618-
if run_id and run_id in self.span_map:
619-
span_data = self.span_map.get(run_id)
620-
if span_data:
621-
span_data.span.set_status("unknown")
622-
623610
self._handle_error(run_id, error)
624611

625612

0 commit comments

Comments
 (0)