1212 List ,
1313 Optional ,
1414 Tuple ,
15- TypedDict ,
1615 Union ,
1716 cast ,
1817)
1918from uuid import UUID
19+ from dataclasses import dataclass
2020
2121from langchain .callbacks .base import BaseCallbackHandler
2222from langchain .schema .agent import AgentAction , AgentFinish
3131log = logging .getLogger ("posthog" )
3232
3333
34- class RunMetadata (TypedDict , total = False ):
35- input : Any
36- """Input of the run: messages, prompt variables, etc."""
34+ @dataclass
35+ class RunMetadata :
3736 name : str
3837 """Name of the run: chain name, model name, etc."""
39- provider : str
38+ start_time : float
39+ """Start time of the run."""
40+ end_time : Optional [float ]
41+ """End time of the run."""
42+ input : Optional [Any ]
43+ """Input of the run: messages, prompt variables, etc."""
44+
45+ @property
46+ def latency (self ) -> float :
47+ if not self .end_time :
48+ return 0
49+ return self .end_time - self .start_time
50+
51+
52+ @dataclass
53+ class TraceMetadata (RunMetadata ):
54+ pass
55+
56+
57+ @dataclass
58+ class GenerationMetadata (RunMetadata ):
59+ provider : Optional [str ] = None
4060 """Provider of the run: OpenAI, Anthropic"""
41- model : str
61+ model : Optional [ str ] = None
4262 """Model used in the run"""
43- model_params : Dict [str , Any ]
63+ model_params : Optional [ Dict [str , Any ]] = None
4464 """Model parameters of the run: temperature, max_tokens, etc."""
45- base_url : str
65+ base_url : Optional [ str ] = None
4666 """Base URL of the provider's API used in the run."""
47- start_time : float
48- """Start time of the run."""
49- end_time : float
50- """End time of the run."""
5167
5268
53- RunStorage = Dict [UUID , RunMetadata ]
69+ @dataclass
70+ class SpanMetadata (RunMetadata ):
71+ pass
72+
73+
74+ RunMetadataUnion = Union [TraceMetadata , GenerationMetadata , SpanMetadata ]
75+ RunMetadataStorage = Dict [UUID , RunMetadataUnion ]
5476
5577
5678class CallbackHandler (BaseCallbackHandler ):
@@ -76,7 +98,7 @@ class CallbackHandler(BaseCallbackHandler):
7698 _properties : Optional [Dict [str , Any ]]
7799 """Global properties to be sent with every event."""
78100
79- _runs : RunStorage
101+ _runs : RunMetadataStorage
80102 """Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
81103
82104 _parent_tree : Dict [UUID , UUID ]
@@ -107,8 +129,6 @@ def __init__(
107129 self ._client = client or default_client
108130 self ._distinct_id = distinct_id
109131 self ._trace_id = trace_id
110- self ._trace_name = None
111- self ._trace_input = None
112132 self ._properties = properties or {}
113133 self ._privacy_mode = privacy_mode
114134 self ._groups = groups or {}
@@ -127,8 +147,7 @@ def on_chain_start(
127147 ):
128148 self ._log_debug_event ("on_chain_start" , run_id , parent_run_id , inputs = inputs )
129149 self ._set_parent_of_run (run_id , parent_run_id )
130- if parent_run_id is None and self ._trace_name is None :
131- self ._set_span_metadata (run_id , self ._get_langchain_run_name (serialized , ** kwargs ), inputs )
150+ self ._set_trace_or_span_metadata (serialized , inputs , run_id , parent_run_id , ** kwargs )
132151
133152 def on_chat_model_start (
134153 self ,
@@ -179,6 +198,7 @@ def on_tool_start(
179198 ** kwargs : Any ,
180199 ) -> Any :
181200 self ._log_debug_event ("on_tool_start" , run_id , parent_run_id , input_str = input_str )
201+ self ._set_trace_or_span_metadata (serialized , input_str , run_id , parent_run_id , ** kwargs )
182202
183203 def on_tool_end (
184204 self ,
@@ -192,10 +212,11 @@ def on_tool_end(
192212
193213 def on_tool_error (
194214 self ,
195- error : Union [ Exception , KeyboardInterrupt ] ,
215+ error : BaseException ,
196216 * ,
197217 run_id : UUID ,
198218 parent_run_id : Optional [UUID ] = None ,
219+ tags : Optional [list [str ]] = None ,
199220 ** kwargs : Any ,
200221 ) -> Any :
201222 self ._log_debug_event ("on_tool_error" , run_id , parent_run_id , error = error )
@@ -243,10 +264,9 @@ def on_llm_end(
243264 trace_id = self ._get_trace_id (run_id )
244265 self ._pop_parent_of_run (run_id )
245266 run = self ._pop_run_metadata (run_id )
246- if not run :
267+ if not run or not isinstance ( run , GenerationMetadata ) :
247268 return
248269
249- latency = run .get ("end_time" , 0 ) - run .get ("start_time" , 0 )
250270 input_tokens , output_tokens = _parse_usage (response )
251271
252272 generation_result = response .generations [- 1 ]
@@ -258,19 +278,20 @@ def on_llm_end(
258278 output = [_extract_raw_esponse (generation ) for generation in generation_result ]
259279
260280 event_properties = {
261- "$ai_provider" : run .get ( " provider" ) ,
262- "$ai_model" : run .get ( " model" ) ,
263- "$ai_model_parameters" : run .get ( " model_params" ) ,
264- "$ai_input" : with_privacy_mode (self ._client , self ._privacy_mode , run .get ( " input" ) ),
281+ "$ai_provider" : run .provider ,
282+ "$ai_model" : run .model ,
283+ "$ai_model_parameters" : run .model_params ,
284+ "$ai_input" : with_privacy_mode (self ._client , self ._privacy_mode , run .input ),
265285 "$ai_output_choices" : with_privacy_mode (self ._client , self ._privacy_mode , output ),
266286 "$ai_http_status" : 200 ,
267287 "$ai_input_tokens" : input_tokens ,
268288 "$ai_output_tokens" : output_tokens ,
269- "$ai_latency" : latency ,
289+ "$ai_latency" : run . latency ,
270290 "$ai_trace_id" : trace_id ,
271- "$ai_base_url" : run .get ("base_url" ),
272- ** self ._properties ,
291+ "$ai_base_url" : run .base_url ,
273292 }
293+ if self ._properties :
294+ event_properties .update (self ._properties )
274295 if self ._distinct_id is None :
275296 event_properties ["$process_person_profile" ] = False
276297 self ._client .capture (
@@ -292,21 +313,21 @@ def on_llm_error(
292313 trace_id = self ._get_trace_id (run_id )
293314 self ._pop_parent_of_run (run_id )
294315 run = self ._pop_run_metadata (run_id )
295- if not run :
316+ if not run or not isinstance ( run , GenerationMetadata ) :
296317 return
297318
298- latency = run .get ("end_time" , 0 ) - run .get ("start_time" , 0 )
299319 event_properties = {
300- "$ai_provider" : run .get ( " provider" ) ,
301- "$ai_model" : run .get ( " model" ) ,
302- "$ai_model_parameters" : run .get ( " model_params" ) ,
303- "$ai_input" : with_privacy_mode (self ._client , self ._privacy_mode , run .get ( " input" ) ),
320+ "$ai_provider" : run .provider ,
321+ "$ai_model" : run .model ,
322+ "$ai_model_parameters" : run .model_params ,
323+ "$ai_input" : with_privacy_mode (self ._client , self ._privacy_mode , run .input ),
304324 "$ai_http_status" : _get_http_status (error ),
305- "$ai_latency" : latency ,
325+ "$ai_latency" : run . latency ,
306326 "$ai_trace_id" : trace_id ,
307- "$ai_base_url" : run .get ("base_url" ),
308- ** self ._properties ,
327+ "$ai_base_url" : run .base_url ,
309328 }
329+ if self ._properties :
330+ event_properties .update (self ._properties )
310331 if self ._distinct_id is None :
311332 event_properties ["$process_person_profile" ] = False
312333 self ._client .capture (
@@ -327,13 +348,15 @@ def on_retriever_start(
327348 ** kwargs : Any ,
328349 ) -> Any :
329350 self ._log_debug_event ("on_retriever_start" , run_id , parent_run_id , query = query )
351+ self ._set_trace_or_span_metadata (serialized , query , run_id , parent_run_id , ** kwargs )
330352
331353 def on_retriever_error (
332354 self ,
333- error : Union [ Exception , KeyboardInterrupt ] ,
355+ error : BaseException ,
334356 * ,
335357 run_id : UUID ,
336358 parent_run_id : Optional [UUID ] = None ,
359+ tags : Optional [list [str ]] = None ,
337360 ** kwargs : Any ,
338361 ) -> Any :
339362 """Run when Retriever errors."""
@@ -385,12 +408,23 @@ def _find_root_run(self, run_id: UUID) -> UUID:
385408 id = self ._parent_tree [id ]
386409 return id
387410
388- def _set_span_metadata (self , run_id : UUID , name : str , input : Any ):
389- self ._runs [run_id ] = {
390- "name" : name ,
391- "input" : input ,
392- "start_time" : time .time (),
393- }
411+ def _set_trace_or_span_metadata (
412+ self ,
413+ serialized : Optional [Dict [str , Any ]],
414+ input : Any ,
415+ run_id : UUID ,
416+ parent_run_id : Optional [UUID ] = None ,
417+ ** kwargs ,
418+ ):
419+ run_name = self ._get_langchain_run_name (serialized , ** kwargs )
420+ if parent_run_id is None :
421+ self ._runs [run_id ] = TraceMetadata (
422+ name = run_name or "trace" , input = input , start_time = time .time (), end_time = None
423+ )
424+ else :
425+ self ._runs [run_id ] = SpanMetadata (
426+ name = run_name or "span" , input = input , start_time = time .time (), end_time = None
427+ )
394428
395429 def _set_llm_metadata (
396430 self ,
@@ -401,33 +435,31 @@ def _set_llm_metadata(
401435 invocation_params : Optional [Dict [str , Any ]] = None ,
402436 ** kwargs ,
403437 ):
404- run : RunMetadata = {
405- "input" : messages ,
406- "start_time" : time .time (),
407- }
438+ run_name = self ._get_langchain_run_name (serialized , ** kwargs ) or "generation"
439+ generation = GenerationMetadata (name = run_name , input = messages , start_time = time .time (), end_time = None )
408440 if isinstance (invocation_params , dict ):
409- run [ " model_params" ] = get_model_params (invocation_params )
441+ generation . model_params = get_model_params (invocation_params )
410442 if isinstance (metadata , dict ):
411443 if model := metadata .get ("ls_model_name" ):
412- run [ " model" ] = model
444+ generation . model = model
413445 if provider := metadata .get ("ls_provider" ):
414- run [ " provider" ] = provider
446+ generation . provider = provider
415447 try :
416448 base_url = serialized ["kwargs" ]["openai_api_base" ]
417449 if base_url is not None :
418- run [ " base_url" ] = base_url
450+ generation . base_url = base_url
419451 except KeyError :
420452 pass
421- self ._runs [run_id ] = run
453+ self ._runs [run_id ] = generation
422454
423- def _pop_run_metadata (self , run_id : UUID ) -> Optional [RunMetadata ]:
455+ def _pop_run_metadata (self , run_id : UUID ) -> Optional [RunMetadataUnion ]:
424456 end_time = time .time ()
425457 try :
426458 run = self ._runs .pop (run_id )
427459 except KeyError :
428460 log .warning (f"No run metadata found for run { run_id } " )
429461 return None
430- run [ " end_time" ] = end_time
462+ run . end_time = end_time
431463 return run
432464
433465 def _get_trace_id (self , run_id : UUID ):
@@ -436,7 +468,7 @@ def _get_trace_id(self, run_id: UUID):
436468 trace_id = uuid .uuid4 ()
437469 return trace_id
438470
439- def _get_langchain_run_name (self , serialized : Optional [Dict [str , Any ]], ** kwargs : Any ) -> str :
471+ def _get_langchain_run_name (self , serialized : Optional [Dict [str , Any ]], ** kwargs : Any ) -> Optional [ str ] :
440472 """Retrieve the name of a serialized LangChain runnable.
441473
442474 The prioritization for the determination of the run name is as follows:
@@ -454,29 +486,31 @@ def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs
454486 """
455487 if "name" in kwargs and kwargs ["name" ] is not None :
456488 return kwargs ["name" ]
457-
489+ if serialized is None :
490+ return None
458491 try :
459492 return serialized ["name" ]
460493 except (KeyError , TypeError ):
461494 pass
462-
463495 try :
464496 return serialized ["id" ][- 1 ]
465497 except (KeyError , TypeError ):
466498 pass
499+ return None
467500
468501 def _pop_trace_and_capture (self , run_id : UUID , * , outputs : Optional [Dict [str , Any ]]):
469502 trace_id = self ._get_trace_id (run_id )
470503 run = self ._pop_run_metadata (run_id )
471504 if not run :
472505 return
473506 event_properties = {
474- "$ai_trace_name" : run .get ( " name" ) ,
507+ "$ai_trace_name" : run .name ,
475508 "$ai_trace_id" : trace_id ,
476- "$ai_input_state" : with_privacy_mode (self ._client , self ._privacy_mode , run .get ("input" )),
477- "$ai_latency" : run .get ("end_time" , 0 ) - run .get ("start_time" , 0 ),
478- ** self ._properties ,
509+ "$ai_input_state" : with_privacy_mode (self ._client , self ._privacy_mode , run .input ),
510+ "$ai_latency" : run .latency ,
479511 }
512+ if self ._properties :
513+ event_properties .update (self ._properties )
480514 if outputs is not None :
481515 event_properties ["$ai_output_state" ] = with_privacy_mode (self ._client , self ._privacy_mode , outputs )
482516 if self ._distinct_id is None :
0 commit comments