24
24
from langchain_core .callbacks import (
25
25
manager ,
26
26
BaseCallbackHandler ,
27
+ BaseCallbackManager ,
27
28
Callbacks ,
28
29
)
29
30
from langchain_core .agents import AgentAction , AgentFinish
@@ -302,15 +303,15 @@ def on_llm_end(
302
303
if token_usage :
303
304
record_token_usage (
304
305
span_data .span ,
305
- token_usage .get ("prompt_tokens" ),
306
- token_usage .get ("completion_tokens" ),
307
- token_usage .get ("total_tokens" ),
306
+ input_tokens = token_usage .get ("prompt_tokens" ),
307
+ output_tokens = token_usage .get ("completion_tokens" ),
308
+ total_tokens = token_usage .get ("total_tokens" ),
308
309
)
309
310
else :
310
311
record_token_usage (
311
312
span_data .span ,
312
- span_data .num_prompt_tokens ,
313
- span_data .num_completion_tokens ,
313
+ input_tokens = span_data .num_prompt_tokens ,
314
+ output_tokens = span_data .num_completion_tokens ,
314
315
)
315
316
316
317
self ._exit_span (span_data , run_id )
@@ -499,12 +500,20 @@ def new_configure(
499
500
** kwargs ,
500
501
)
501
502
502
- callbacks_list = local_callbacks or []
503
-
504
- if isinstance (callbacks_list , BaseCallbackHandler ):
505
- callbacks_list = [callbacks_list ]
506
- elif not isinstance (callbacks_list , list ):
507
- logger .debug ("Unknown callback type: %s" , callbacks_list )
503
+ local_callbacks = local_callbacks or []
504
+
505
+ # Handle each possible type of local_callbacks. For each type, we
506
+ # extract the list of callbacks to check for SentryLangchainCallback,
507
+ # and define a function that would add the SentryLangchainCallback
508
+ # to the existing callbacks list.
509
+ if isinstance (local_callbacks , BaseCallbackManager ):
510
+ callbacks_list = local_callbacks .handlers
511
+ elif isinstance (local_callbacks , BaseCallbackHandler ):
512
+ callbacks_list = [local_callbacks ]
513
+ elif isinstance (local_callbacks , list ):
514
+ callbacks_list = local_callbacks
515
+ else :
516
+ logger .debug ("Unknown callback type: %s" , local_callbacks )
508
517
# Just proceed with original function call
509
518
return f (
510
519
callback_manager_cls ,
@@ -514,28 +523,38 @@ def new_configure(
514
523
** kwargs ,
515
524
)
516
525
517
- inheritable_callbacks_list = (
518
- inheritable_callbacks if isinstance (inheritable_callbacks , list ) else []
519
- )
526
+ # Handle each possible type of inheritable_callbacks.
527
+ if isinstance (inheritable_callbacks , BaseCallbackManager ):
528
+ inheritable_callbacks_list = inheritable_callbacks .handlers
529
+ elif isinstance (inheritable_callbacks , list ):
530
+ inheritable_callbacks_list = inheritable_callbacks
531
+ else :
532
+ inheritable_callbacks_list = []
520
533
521
534
if not any (
522
535
isinstance (cb , SentryLangchainCallback )
523
536
for cb in itertools .chain (callbacks_list , inheritable_callbacks_list )
524
537
):
525
- # Avoid mutating the existing callbacks list
526
- callbacks_list = [
527
- * callbacks_list ,
528
- SentryLangchainCallback (
529
- integration .max_spans ,
530
- integration .include_prompts ,
531
- integration .tiktoken_encoding_name ,
532
- ),
533
- ]
538
+ sentry_handler = SentryLangchainCallback (
539
+ integration .max_spans ,
540
+ integration .include_prompts ,
541
+ integration .tiktoken_encoding_name ,
542
+ )
543
+ if isinstance (local_callbacks , BaseCallbackManager ):
544
+ local_callbacks = local_callbacks .copy ()
545
+ local_callbacks .handlers = [
546
+ * local_callbacks .handlers ,
547
+ sentry_handler ,
548
+ ]
549
+ elif isinstance (local_callbacks , BaseCallbackHandler ):
550
+ local_callbacks = [local_callbacks , sentry_handler ]
551
+ else : # local_callbacks is a list
552
+ local_callbacks = [* local_callbacks , sentry_handler ]
534
553
535
554
return f (
536
555
callback_manager_cls ,
537
556
inheritable_callbacks ,
538
- callbacks_list ,
557
+ local_callbacks ,
539
558
* args ,
540
559
** kwargs ,
541
560
)
0 commit comments