2424 from langchain_core .callbacks import (
2525 manager ,
2626 BaseCallbackHandler ,
27+ BaseCallbackManager ,
2728 Callbacks ,
2829 )
2930 from langchain_core .agents import AgentAction , AgentFinish
@@ -302,15 +303,15 @@ def on_llm_end(
302303 if token_usage :
303304 record_token_usage (
304305 span_data .span ,
305- token_usage .get ("prompt_tokens" ),
306- token_usage .get ("completion_tokens" ),
307- token_usage .get ("total_tokens" ),
306+ input_tokens = token_usage .get ("prompt_tokens" ),
307+ output_tokens = token_usage .get ("completion_tokens" ),
308+ total_tokens = token_usage .get ("total_tokens" ),
308309 )
309310 else :
310311 record_token_usage (
311312 span_data .span ,
312- span_data .num_prompt_tokens ,
313- span_data .num_completion_tokens ,
313+ input_tokens = span_data .num_prompt_tokens ,
314+ output_tokens = span_data .num_completion_tokens ,
314315 )
315316
316317 self ._exit_span (span_data , run_id )
@@ -499,12 +500,20 @@ def new_configure(
499500 ** kwargs ,
500501 )
501502
502- callbacks_list = local_callbacks or []
503-
504- if isinstance (callbacks_list , BaseCallbackHandler ):
505- callbacks_list = [callbacks_list ]
506- elif not isinstance (callbacks_list , list ):
507- logger .debug ("Unknown callback type: %s" , callbacks_list )
503+ local_callbacks = local_callbacks or []
504+
505+ # Handle each possible type of local_callbacks. For each type, we
506+ # extract the list of callbacks to check for SentryLangchainCallback,
507+ # and define a function that would add the SentryLangchainCallback
508+ # to the existing callbacks list.
509+ if isinstance (local_callbacks , BaseCallbackManager ):
510+ callbacks_list = local_callbacks .handlers
511+ elif isinstance (local_callbacks , BaseCallbackHandler ):
512+ callbacks_list = [local_callbacks ]
513+ elif isinstance (local_callbacks , list ):
514+ callbacks_list = local_callbacks
515+ else :
516+ logger .debug ("Unknown callback type: %s" , local_callbacks )
508517 # Just proceed with original function call
509518 return f (
510519 callback_manager_cls ,
@@ -514,28 +523,38 @@ def new_configure(
514523 ** kwargs ,
515524 )
516525
517- inheritable_callbacks_list = (
518- inheritable_callbacks if isinstance (inheritable_callbacks , list ) else []
519- )
526+ # Handle each possible type of inheritable_callbacks.
527+ if isinstance (inheritable_callbacks , BaseCallbackManager ):
528+ inheritable_callbacks_list = inheritable_callbacks .handlers
529+ elif isinstance (inheritable_callbacks , list ):
530+ inheritable_callbacks_list = inheritable_callbacks
531+ else :
532+ inheritable_callbacks_list = []
520533
521534 if not any (
522535 isinstance (cb , SentryLangchainCallback )
523536 for cb in itertools .chain (callbacks_list , inheritable_callbacks_list )
524537 ):
525- # Avoid mutating the existing callbacks list
526- callbacks_list = [
527- * callbacks_list ,
528- SentryLangchainCallback (
529- integration .max_spans ,
530- integration .include_prompts ,
531- integration .tiktoken_encoding_name ,
532- ),
533- ]
538+ sentry_handler = SentryLangchainCallback (
539+ integration .max_spans ,
540+ integration .include_prompts ,
541+ integration .tiktoken_encoding_name ,
542+ )
543+ if isinstance (local_callbacks , BaseCallbackManager ):
544+ local_callbacks = local_callbacks .copy ()
545+ local_callbacks .handlers = [
546+ * local_callbacks .handlers ,
547+ sentry_handler ,
548+ ]
549+ elif isinstance (local_callbacks , BaseCallbackHandler ):
550+ local_callbacks = [local_callbacks , sentry_handler ]
551+ else : # local_callbacks is a list
552+ local_callbacks = [* local_callbacks , sentry_handler ]
534553
535554 return f (
536555 callback_manager_cls ,
537556 inheritable_callbacks ,
538- callbacks_list ,
557+ local_callbacks ,
539558 * args ,
540559 ** kwargs ,
541560 )
0 commit comments