11try :
22 import langchain # noqa: F401
33except ImportError :
4- raise ModuleNotFoundError (
5- "Please install LangChain to use this feature: 'pip install langchain'"
6- )
4+ raise ModuleNotFoundError ("Please install LangChain to use this feature: 'pip install langchain'" )
75
86import logging
97import time
@@ -152,9 +150,7 @@ def on_chain_start(
152150 ):
153151 self ._log_debug_event ("on_chain_start" , run_id , parent_run_id , inputs = inputs )
154152 self ._set_parent_of_run (run_id , parent_run_id )
155- self ._set_trace_or_span_metadata (
156- serialized , inputs , run_id , parent_run_id , ** kwargs
157- )
153+ self ._set_trace_or_span_metadata (serialized , inputs , run_id , parent_run_id , ** kwargs )
158154
159155 def on_chain_end (
160156 self ,
@@ -187,13 +183,9 @@ def on_chat_model_start(
187183 parent_run_id : Optional [UUID ] = None ,
188184 ** kwargs ,
189185 ):
190- self ._log_debug_event (
191- "on_chat_model_start" , run_id , parent_run_id , messages = messages
192- )
186+ self ._log_debug_event ("on_chat_model_start" , run_id , parent_run_id , messages = messages )
193187 self ._set_parent_of_run (run_id , parent_run_id )
194- input = [
195- _convert_message_to_dict (message ) for row in messages for message in row
196- ]
188+ input = [_convert_message_to_dict (message ) for row in messages for message in row ]
197189 self ._set_llm_metadata (serialized , run_id , input , ** kwargs )
198190
199191 def on_llm_start (
@@ -231,9 +223,7 @@ def on_llm_end(
231223 """
232224 The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
233225 """
234- self ._log_debug_event (
235- "on_llm_end" , run_id , parent_run_id , response = response , kwargs = kwargs
236- )
226+ self ._log_debug_event ("on_llm_end" , run_id , parent_run_id , response = response , kwargs = kwargs )
237227 self ._pop_run_and_capture_generation (run_id , parent_run_id , response )
238228
239229 def on_llm_error (
@@ -257,13 +247,9 @@ def on_tool_start(
257247 metadata : Optional [Dict [str , Any ]] = None ,
258248 ** kwargs : Any ,
259249 ) -> Any :
260- self ._log_debug_event (
261- "on_tool_start" , run_id , parent_run_id , input_str = input_str
262- )
250+ self ._log_debug_event ("on_tool_start" , run_id , parent_run_id , input_str = input_str )
263251 self ._set_parent_of_run (run_id , parent_run_id )
264- self ._set_trace_or_span_metadata (
265- serialized , input_str , run_id , parent_run_id , ** kwargs
266- )
252+ self ._set_trace_or_span_metadata (serialized , input_str , run_id , parent_run_id , ** kwargs )
267253
268254 def on_tool_end (
269255 self ,
@@ -300,9 +286,7 @@ def on_retriever_start(
300286 ) -> Any :
301287 self ._log_debug_event ("on_retriever_start" , run_id , parent_run_id , query = query )
302288 self ._set_parent_of_run (run_id , parent_run_id )
303- self ._set_trace_or_span_metadata (
304- serialized , query , run_id , parent_run_id , ** kwargs
305- )
289+ self ._set_trace_or_span_metadata (serialized , query , run_id , parent_run_id , ** kwargs )
306290
307291 def on_retriever_end (
308292 self ,
@@ -312,9 +296,7 @@ def on_retriever_end(
312296 parent_run_id : Optional [UUID ] = None ,
313297 ** kwargs : Any ,
314298 ):
315- self ._log_debug_event (
316- "on_retriever_end" , run_id , parent_run_id , documents = documents
317- )
299+ self ._log_debug_event ("on_retriever_end" , run_id , parent_run_id , documents = documents )
318300 self ._pop_run_and_capture_trace_or_span (run_id , parent_run_id , documents )
319301
320302 def on_retriever_error (
@@ -389,9 +371,7 @@ def _set_trace_or_span_metadata(
389371 ):
390372 default_name = "trace" if parent_run_id is None else "span"
391373 run_name = _get_langchain_run_name (serialized , ** kwargs ) or default_name
392- self ._runs [run_id ] = SpanMetadata (
393- name = run_name , input = input , start_time = time .time (), end_time = None
394- )
374+ self ._runs [run_id ] = SpanMetadata (name = run_name , input = input , start_time = time .time (), end_time = None )
395375
396376 def _set_llm_metadata (
397377 self ,
@@ -403,9 +383,7 @@ def _set_llm_metadata(
403383 ** kwargs ,
404384 ):
405385 run_name = _get_langchain_run_name (serialized , ** kwargs ) or "generation"
406- generation = GenerationMetadata (
407- name = run_name , input = messages , start_time = time .time (), end_time = None
408- )
386+ generation = GenerationMetadata (name = run_name , input = messages , start_time = time .time (), end_time = None )
409387 if isinstance (invocation_params , dict ):
410388 generation .model_params = get_model_params (invocation_params )
411389 if tools := invocation_params .get ("tools" ):
@@ -439,28 +417,22 @@ def _get_trace_id(self, run_id: UUID):
439417 return run_id
440418 return trace_id
441419
442- def _get_parent_run_id (
443- self , trace_id : Any , run_id : UUID , parent_run_id : Optional [UUID ]
444- ):
420+ def _get_parent_run_id (self , trace_id : Any , run_id : UUID , parent_run_id : Optional [UUID ]):
445421 """
446422 Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set.
447423 """
448424 if parent_run_id is not None and parent_run_id not in self ._parent_tree :
449425 return trace_id
450426 return parent_run_id
451427
452- def _pop_run_and_capture_trace_or_span (
453- self , run_id : UUID , parent_run_id : Optional [UUID ], outputs : Any
454- ):
428+ def _pop_run_and_capture_trace_or_span (self , run_id : UUID , parent_run_id : Optional [UUID ], outputs : Any ):
455429 trace_id = self ._get_trace_id (run_id )
456430 self ._pop_parent_of_run (run_id )
457431 run = self ._pop_run_metadata (run_id )
458432 if not run :
459433 return
460434 if isinstance (run , GenerationMetadata ):
461- log .warning (
462- f"Run { run_id } is a generation, but attempted to be captured as a trace or span."
463- )
435+ log .warning (f"Run { run_id } is a generation, but attempted to be captured as a trace or span." )
464436 return
465437 self ._capture_trace_or_span (
466438 trace_id ,
@@ -481,9 +453,7 @@ def _capture_trace_or_span(
481453 event_name = "$ai_trace" if parent_run_id is None else "$ai_span"
482454 event_properties = {
483455 "$ai_trace_id" : trace_id ,
484- "$ai_input_state" : with_privacy_mode (
485- self ._client , self ._privacy_mode , run .input
486- ),
456+ "$ai_input_state" : with_privacy_mode (self ._client , self ._privacy_mode , run .input ),
487457 "$ai_latency" : run .latency ,
488458 "$ai_span_name" : run .name ,
489459 "$ai_span_id" : run_id ,
@@ -497,9 +467,7 @@ def _capture_trace_or_span(
497467 event_properties ["$ai_error" ] = _stringify_exception (outputs )
498468 event_properties ["$ai_is_error" ] = True
499469 elif outputs is not None :
500- event_properties ["$ai_output_state" ] = with_privacy_mode (
501- self ._client , self ._privacy_mode , outputs
502- )
470+ event_properties ["$ai_output_state" ] = with_privacy_mode (self ._client , self ._privacy_mode , outputs )
503471
504472 if self ._distinct_id is None :
505473 event_properties ["$process_person_profile" ] = False
@@ -523,9 +491,7 @@ def _pop_run_and_capture_generation(
523491 if not run :
524492 return
525493 if not isinstance (run , GenerationMetadata ):
526- log .warning (
527- f"Run { run_id } is not a generation, but attempted to be captured as a generation."
528- )
494+ log .warning (f"Run { run_id } is not a generation, but attempted to be captured as a generation." )
529495 return
530496 self ._capture_generation (
531497 trace_id ,
@@ -577,12 +543,8 @@ def _capture_generation(
577543 for generation in generation_result
578544 ]
579545 else :
580- completions = [
581- _extract_raw_esponse (generation ) for generation in generation_result
582- ]
583- event_properties ["$ai_output_choices" ] = with_privacy_mode (
584- self ._client , self ._privacy_mode , completions
585- )
546+ completions = [_extract_raw_esponse (generation ) for generation in generation_result ]
547+ event_properties ["$ai_output_choices" ] = with_privacy_mode (self ._client , self ._privacy_mode , completions )
586548
587549 if self ._properties :
588550 event_properties .update (self ._properties )
@@ -672,9 +634,7 @@ def _parse_usage_model(
672634 if model_key in usage :
673635 captured_count = usage [model_key ]
674636 final_count = (
675- sum (captured_count )
676- if isinstance (captured_count , list )
677- else captured_count
637+ sum (captured_count ) if isinstance (captured_count , list ) else captured_count
678638 ) # For Bedrock, the token count is a list when streamed
679639
680640 parsed_usage [type_key ] = final_count
@@ -699,12 +659,8 @@ def _parse_usage(response: LLMResult):
699659 break
700660
701661 for generation_chunk in generation :
702- if generation_chunk .generation_info and (
703- "usage_metadata" in generation_chunk .generation_info
704- ):
705- llm_usage = _parse_usage_model (
706- generation_chunk .generation_info ["usage_metadata" ]
707- )
662+ if generation_chunk .generation_info and ("usage_metadata" in generation_chunk .generation_info ):
663+ llm_usage = _parse_usage_model (generation_chunk .generation_info ["usage_metadata" ])
708664 break
709665
710666 message_chunk = getattr (generation_chunk , "message" , {})
@@ -716,19 +672,13 @@ def _parse_usage(response: LLMResult):
716672 else None
717673 )
718674 bedrock_titan_usage = (
719- response_metadata .get (
720- "amazon-bedrock-invocationMetrics" , None
721- ) # for Bedrock-Titan
675+ response_metadata .get ("amazon-bedrock-invocationMetrics" , None ) # for Bedrock-Titan
722676 if isinstance (response_metadata , dict )
723677 else None
724678 )
725- ollama_usage = getattr (
726- message_chunk , "usage_metadata" , None
727- ) # for Ollama
679+ ollama_usage = getattr (message_chunk , "usage_metadata" , None ) # for Ollama
728680
729- chunk_usage = (
730- bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
731- )
681+ chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
732682 if chunk_usage :
733683 llm_usage = _parse_usage_model (chunk_usage )
734684 break
@@ -744,9 +694,7 @@ def _get_http_status(error: BaseException) -> int:
744694 return status_code
745695
746696
747- def _get_langchain_run_name (
748- serialized : Optional [Dict [str , Any ]], ** kwargs : Any
749- ) -> Optional [str ]:
697+ def _get_langchain_run_name (serialized : Optional [Dict [str , Any ]], ** kwargs : Any ) -> Optional [str ]:
750698 """Retrieve the name of a serialized LangChain runnable.
751699
752700 The prioritization for the determination of the run name is as follows:
0 commit comments