11try :
22 import langchain
33except ImportError :
4- raise ModuleNotFoundError (
5- "Please install LangChain to use this feature: 'pip install langchain'"
6- )
4+ raise ModuleNotFoundError ("Please install LangChain to use this feature: 'pip install langchain'" )
75
86import logging
97import time
3129PosthogProperties = dict [str , Any ]
3230
3331
34- class ModelParams (TypedDict , total = False ):
35- temperature : Optional [float ]
36- max_tokens : Optional [int ]
37- top_p : Optional [float ]
38- frequency_penalty : Optional [float ]
39- presence_penalty : Optional [float ]
40- n : Optional [int ]
41- stop : Optional [list [str ]]
42- stream : Optional [bool ]
43-
44-
4532class RunMetadata (TypedDict , total = False ):
4633 messages : list [dict [str , Any ]] | list [str ]
4734 provider : str
4835 model : str
49- model_params : ModelParams
36+ model_params : dict [ str , Any ]
5037 start_time : float
5138 end_time : float
5239
@@ -114,9 +101,7 @@ def on_chat_model_start(
114101 ** kwargs ,
115102 ):
116103 self ._set_parent_of_run (run_id , parent_run_id )
117- input = [
118- _convert_message_to_dict (message ) for row in messages for message in row
119- ]
104+ input = [_convert_message_to_dict (message ) for row in messages for message in row ]
120105 self ._set_run_metadata (run_id , input , ** kwargs )
121106
122107 def on_llm_start (
@@ -166,13 +151,10 @@ def on_llm_end(
166151 generation_result = response .generations [- 1 ]
167152 if isinstance (generation_result [- 1 ], ChatGeneration ):
168153 output = [
169- _convert_message_to_dict (cast (ChatGeneration , generation ).message )
170- for generation in generation_result
154+ _convert_message_to_dict (cast (ChatGeneration , generation ).message ) for generation in generation_result
171155 ]
172156 else :
173- output = [
174- _extract_raw_esponse (generation ) for generation in generation_result
175- ]
157+ output = [_extract_raw_esponse (generation ) for generation in generation_result ]
176158
177159 event_properties = {
178160 "$ai_provider" : run .get ("provider" ),
@@ -276,7 +258,7 @@ def _set_run_metadata(
276258 "start_time" : time .time (),
277259 }
278260 if isinstance (invocation_params , dict ):
279- run ["model_params" ] = cast ( ModelParams , get_model_params (invocation_params ) )
261+ run ["model_params" ] = get_model_params (invocation_params )
280262 if isinstance (metadata , dict ):
281263 if model := metadata .get ("ls_model_name" ):
282264 run ["model" ] = model
@@ -361,9 +343,7 @@ def _parse_usage_model(usage: Union[BaseModel, dict]) -> tuple[int | None, int |
361343 if model_key in usage :
362344 captured_count = usage [model_key ]
363345 final_count = (
364- sum (captured_count )
365- if isinstance (captured_count , list )
366- else captured_count
346+ sum (captured_count ) if isinstance (captured_count , list ) else captured_count
367347 ) # For Bedrock, the token count is a list when streamed
368348
369349 parsed_usage [type_key ] = final_count
@@ -384,12 +364,8 @@ def _parse_usage(response: LLMResult):
384364 if hasattr (response , "generations" ):
385365 for generation in response .generations :
386366 for generation_chunk in generation :
387- if generation_chunk .generation_info and (
388- "usage_metadata" in generation_chunk .generation_info
389- ):
390- llm_usage = _parse_usage_model (
391- generation_chunk .generation_info ["usage_metadata" ]
392- )
367+ if generation_chunk .generation_info and ("usage_metadata" in generation_chunk .generation_info ):
368+ llm_usage = _parse_usage_model (generation_chunk .generation_info ["usage_metadata" ])
393369 break
394370
395371 message_chunk = getattr (generation_chunk , "message" , {})
@@ -402,9 +378,7 @@ def _parse_usage(response: LLMResult):
402378 else None
403379 )
404380 or (
405- response_metadata .get (
406- "amazon-bedrock-invocationMetrics" , None
407- ) # for Bedrock-Titan
381+ response_metadata .get ("amazon-bedrock-invocationMetrics" , None ) # for Bedrock-Titan
408382 if isinstance (response_metadata , dict )
409383 else None
410384 )
0 commit comments