11try :
22 import langchain # noqa: F401
33except ImportError :
4- raise ModuleNotFoundError ("Please install LangChain to use this feature: 'pip install langchain'" )
4+ raise ModuleNotFoundError (
5+ "Please install LangChain to use this feature: 'pip install langchain'"
6+ )
57
68import logging
79import time
2123from langchain .callbacks .base import BaseCallbackHandler
2224from langchain .schema .agent import AgentAction , AgentFinish
2325from langchain_core .documents import Document
24- from langchain_core .messages import AIMessage , BaseMessage , FunctionMessage , HumanMessage , SystemMessage , ToolMessage
26+ from langchain_core .messages import (
27+ AIMessage ,
28+ BaseMessage ,
29+ FunctionMessage ,
30+ HumanMessage ,
31+ SystemMessage ,
32+ ToolMessage ,
33+ )
2534from langchain_core .outputs import ChatGeneration , LLMResult
2635from pydantic import BaseModel
2736
@@ -63,6 +72,7 @@ class GenerationMetadata(SpanMetadata):
6372 tools : Optional [List [Dict [str , Any ]]] = None
6473 """Tools provided to the model."""
6574
75+
6676RunMetadata = Union [SpanMetadata , GenerationMetadata ]
6777RunMetadataStorage = Dict [UUID , RunMetadata ]
6878
@@ -142,7 +152,9 @@ def on_chain_start(
142152 ):
143153 self ._log_debug_event ("on_chain_start" , run_id , parent_run_id , inputs = inputs )
144154 self ._set_parent_of_run (run_id , parent_run_id )
145- self ._set_trace_or_span_metadata (serialized , inputs , run_id , parent_run_id , ** kwargs )
155+ self ._set_trace_or_span_metadata (
156+ serialized , inputs , run_id , parent_run_id , ** kwargs
157+ )
146158
147159 def on_chain_end (
148160 self ,
@@ -175,9 +187,13 @@ def on_chat_model_start(
175187 parent_run_id : Optional [UUID ] = None ,
176188 ** kwargs ,
177189 ):
178- self ._log_debug_event ("on_chat_model_start" , run_id , parent_run_id , messages = messages )
190+ self ._log_debug_event (
191+ "on_chat_model_start" , run_id , parent_run_id , messages = messages
192+ )
179193 self ._set_parent_of_run (run_id , parent_run_id )
180- input = [_convert_message_to_dict (message ) for row in messages for message in row ]
194+ input = [
195+ _convert_message_to_dict (message ) for row in messages for message in row
196+ ]
181197 self ._set_llm_metadata (serialized , run_id , input , ** kwargs )
182198
183199 def on_llm_start (
@@ -215,7 +231,9 @@ def on_llm_end(
215231 """
216232 The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
217233 """
218- self ._log_debug_event ("on_llm_end" , run_id , parent_run_id , response = response , kwargs = kwargs )
234+ self ._log_debug_event (
235+ "on_llm_end" , run_id , parent_run_id , response = response , kwargs = kwargs
236+ )
219237 self ._pop_run_and_capture_generation (run_id , parent_run_id , response )
220238
221239 def on_llm_error (
@@ -239,9 +257,13 @@ def on_tool_start(
239257 metadata : Optional [Dict [str , Any ]] = None ,
240258 ** kwargs : Any ,
241259 ) -> Any :
242- self ._log_debug_event ("on_tool_start" , run_id , parent_run_id , input_str = input_str )
260+ self ._log_debug_event (
261+ "on_tool_start" , run_id , parent_run_id , input_str = input_str
262+ )
243263 self ._set_parent_of_run (run_id , parent_run_id )
244- self ._set_trace_or_span_metadata (serialized , input_str , run_id , parent_run_id , ** kwargs )
264+ self ._set_trace_or_span_metadata (
265+ serialized , input_str , run_id , parent_run_id , ** kwargs
266+ )
245267
246268 def on_tool_end (
247269 self ,
@@ -278,7 +300,9 @@ def on_retriever_start(
278300 ) -> Any :
279301 self ._log_debug_event ("on_retriever_start" , run_id , parent_run_id , query = query )
280302 self ._set_parent_of_run (run_id , parent_run_id )
281- self ._set_trace_or_span_metadata (serialized , query , run_id , parent_run_id , ** kwargs )
303+ self ._set_trace_or_span_metadata (
304+ serialized , query , run_id , parent_run_id , ** kwargs
305+ )
282306
283307 def on_retriever_end (
284308 self ,
@@ -288,7 +312,9 @@ def on_retriever_end(
288312 parent_run_id : Optional [UUID ] = None ,
289313 ** kwargs : Any ,
290314 ):
291- self ._log_debug_event ("on_retriever_end" , run_id , parent_run_id , documents = documents )
315+ self ._log_debug_event (
316+ "on_retriever_end" , run_id , parent_run_id , documents = documents
317+ )
292318 self ._pop_run_and_capture_trace_or_span (run_id , parent_run_id , documents )
293319
294320 def on_retriever_error (
@@ -363,7 +389,9 @@ def _set_trace_or_span_metadata(
363389 ):
364390 default_name = "trace" if parent_run_id is None else "span"
365391 run_name = _get_langchain_run_name (serialized , ** kwargs ) or default_name
366- self ._runs [run_id ] = SpanMetadata (name = run_name , input = input , start_time = time .time (), end_time = None )
392+ self ._runs [run_id ] = SpanMetadata (
393+ name = run_name , input = input , start_time = time .time (), end_time = None
394+ )
367395
368396 def _set_llm_metadata (
369397 self ,
@@ -375,7 +403,9 @@ def _set_llm_metadata(
375403 ** kwargs ,
376404 ):
377405 run_name = _get_langchain_run_name (serialized , ** kwargs ) or "generation"
378- generation = GenerationMetadata (name = run_name , input = messages , start_time = time .time (), end_time = None )
406+ generation = GenerationMetadata (
407+ name = run_name , input = messages , start_time = time .time (), end_time = None
408+ )
379409 if isinstance (invocation_params , dict ):
380410 generation .model_params = get_model_params (invocation_params )
381411 if tools := invocation_params .get ("tools" ):
@@ -409,25 +439,35 @@ def _get_trace_id(self, run_id: UUID):
409439 return run_id
410440 return trace_id
411441
412- def _get_parent_run_id (self , trace_id : Any , run_id : UUID , parent_run_id : Optional [UUID ]):
442+ def _get_parent_run_id (
443+ self , trace_id : Any , run_id : UUID , parent_run_id : Optional [UUID ]
444+ ):
413445 """
414446 Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set.
415447 """
416448 if parent_run_id is not None and parent_run_id not in self ._parent_tree :
417449 return trace_id
418450 return parent_run_id
419451
420- def _pop_run_and_capture_trace_or_span (self , run_id : UUID , parent_run_id : Optional [UUID ], outputs : Any ):
452+ def _pop_run_and_capture_trace_or_span (
453+ self , run_id : UUID , parent_run_id : Optional [UUID ], outputs : Any
454+ ):
421455 trace_id = self ._get_trace_id (run_id )
422456 self ._pop_parent_of_run (run_id )
423457 run = self ._pop_run_metadata (run_id )
424458 if not run :
425459 return
426460 if isinstance (run , GenerationMetadata ):
427- log .warning (f"Run { run_id } is a generation, but attempted to be captured as a trace or span." )
461+ log .warning (
462+ f"Run { run_id } is a generation, but attempted to be captured as a trace or span."
463+ )
428464 return
429465 self ._capture_trace_or_span (
430- trace_id , run_id , run , outputs , self ._get_parent_run_id (trace_id , run_id , parent_run_id )
466+ trace_id ,
467+ run_id ,
468+ run ,
469+ outputs ,
470+ self ._get_parent_run_id (trace_id , run_id , parent_run_id ),
431471 )
432472
433473 def _capture_trace_or_span (
@@ -441,7 +481,9 @@ def _capture_trace_or_span(
441481 event_name = "$ai_trace" if parent_run_id is None else "$ai_span"
442482 event_properties = {
443483 "$ai_trace_id" : trace_id ,
444- "$ai_input_state" : with_privacy_mode (self ._client , self ._privacy_mode , run .input ),
484+ "$ai_input_state" : with_privacy_mode (
485+ self ._client , self ._privacy_mode , run .input
486+ ),
445487 "$ai_latency" : run .latency ,
446488 "$ai_span_name" : run .name ,
447489 "$ai_span_id" : run_id ,
@@ -455,7 +497,9 @@ def _capture_trace_or_span(
455497 event_properties ["$ai_error" ] = _stringify_exception (outputs )
456498 event_properties ["$ai_is_error" ] = True
457499 elif outputs is not None :
458- event_properties ["$ai_output_state" ] = with_privacy_mode (self ._client , self ._privacy_mode , outputs )
500+ event_properties ["$ai_output_state" ] = with_privacy_mode (
501+ self ._client , self ._privacy_mode , outputs
502+ )
459503
460504 if self ._distinct_id is None :
461505 event_properties ["$process_person_profile" ] = False
@@ -468,18 +512,27 @@ def _capture_trace_or_span(
468512 )
469513
470514 def _pop_run_and_capture_generation (
471- self , run_id : UUID , parent_run_id : Optional [UUID ], response : Union [LLMResult , BaseException ]
515+ self ,
516+ run_id : UUID ,
517+ parent_run_id : Optional [UUID ],
518+ response : Union [LLMResult , BaseException ],
472519 ):
473520 trace_id = self ._get_trace_id (run_id )
474521 self ._pop_parent_of_run (run_id )
475522 run = self ._pop_run_metadata (run_id )
476523 if not run :
477524 return
478525 if not isinstance (run , GenerationMetadata ):
479- log .warning (f"Run { run_id } is not a generation, but attempted to be captured as a generation." )
526+ log .warning (
527+ f"Run { run_id } is not a generation, but attempted to be captured as a generation."
528+ )
480529 return
481530 self ._capture_generation (
482- trace_id , run_id , run , response , self ._get_parent_run_id (trace_id , run_id , parent_run_id )
531+ trace_id ,
532+ run_id ,
533+ run ,
534+ response ,
535+ self ._get_parent_run_id (trace_id , run_id , parent_run_id ),
483536 )
484537
485538 def _capture_generation (
@@ -524,8 +577,12 @@ def _capture_generation(
524577 for generation in generation_result
525578 ]
526579 else :
527- completions = [_extract_raw_esponse (generation ) for generation in generation_result ]
528- event_properties ["$ai_output_choices" ] = with_privacy_mode (self ._client , self ._privacy_mode , completions )
580+ completions = [
581+ _extract_raw_esponse (generation ) for generation in generation_result
582+ ]
583+ event_properties ["$ai_output_choices" ] = with_privacy_mode (
584+ self ._client , self ._privacy_mode , completions
585+ )
529586
530587 if self ._properties :
531588 event_properties .update (self ._properties )
@@ -615,7 +672,9 @@ def _parse_usage_model(
615672 if model_key in usage :
616673 captured_count = usage [model_key ]
617674 final_count = (
618- sum (captured_count ) if isinstance (captured_count , list ) else captured_count
675+ sum (captured_count )
676+ if isinstance (captured_count , list )
677+ else captured_count
619678 ) # For Bedrock, the token count is a list when streamed
620679
621680 parsed_usage [type_key ] = final_count
@@ -640,8 +699,12 @@ def _parse_usage(response: LLMResult):
640699 break
641700
642701 for generation_chunk in generation :
643- if generation_chunk .generation_info and ("usage_metadata" in generation_chunk .generation_info ):
644- llm_usage = _parse_usage_model (generation_chunk .generation_info ["usage_metadata" ])
702+ if generation_chunk .generation_info and (
703+ "usage_metadata" in generation_chunk .generation_info
704+ ):
705+ llm_usage = _parse_usage_model (
706+ generation_chunk .generation_info ["usage_metadata" ]
707+ )
645708 break
646709
647710 message_chunk = getattr (generation_chunk , "message" , {})
@@ -653,13 +716,19 @@ def _parse_usage(response: LLMResult):
653716 else None
654717 )
655718 bedrock_titan_usage = (
656- response_metadata .get ("amazon-bedrock-invocationMetrics" , None ) # for Bedrock-Titan
719+ response_metadata .get (
720+ "amazon-bedrock-invocationMetrics" , None
721+ ) # for Bedrock-Titan
657722 if isinstance (response_metadata , dict )
658723 else None
659724 )
660- ollama_usage = getattr (message_chunk , "usage_metadata" , None ) # for Ollama
725+ ollama_usage = getattr (
726+ message_chunk , "usage_metadata" , None
727+ ) # for Ollama
661728
662- chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
729+ chunk_usage = (
730+ bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
731+ )
663732 if chunk_usage :
664733 llm_usage = _parse_usage_model (chunk_usage )
665734 break
@@ -675,7 +744,9 @@ def _get_http_status(error: BaseException) -> int:
675744 return status_code
676745
677746
678- def _get_langchain_run_name (serialized : Optional [Dict [str , Any ]], ** kwargs : Any ) -> Optional [str ]:
747+ def _get_langchain_run_name (
748+ serialized : Optional [Dict [str , Any ]], ** kwargs : Any
749+ ) -> Optional [str ]:
679750 """Retrieve the name of a serialized LangChain runnable.
680751
681752 The prioritization for the determination of the run name is as follows:
0 commit comments