33except ImportError :
44 raise ModuleNotFoundError ("Please install LangChain to use this feature: 'pip install langchain'" )
55
6- import json
76import logging
87import time
98import uuid
@@ -59,14 +58,25 @@ class CallbackHandler(BaseCallbackHandler):
5958
6059 _client : Client
6160 """PostHog client instance."""
61+
6262 _distinct_id : Optional [Union [str , int , float , UUID ]]
6363 """Distinct ID of the user to associate the trace with."""
64+
6465 _trace_id : Optional [Union [str , int , float , UUID ]]
6566 """Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""
67+
68+ _trace_input : Optional [Any ]
69+ """The input at the start of the trace. Any JSON object."""
70+
71+ _trace_name : Optional [str ]
72+ """Name of the trace, exposed in the UI."""
73+
6674 _properties : Optional [Dict [str , Any ]]
6775 """Global properties to be sent with every event."""
76+
6877 _runs : RunStorage
6978 """Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
79+
7080 _parent_tree : Dict [UUID , UUID ]
7181 """
7282 A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
@@ -95,6 +105,8 @@ def __init__(
95105 self ._client = client or default_client
96106 self ._distinct_id = distinct_id
97107 self ._trace_id = trace_id
108+ self ._trace_name = None
109+ self ._trace_input = None
98110 self ._properties = properties or {}
99111 self ._privacy_mode = privacy_mode
100112 self ._groups = groups or {}
@@ -113,7 +125,9 @@ def on_chain_start(
113125 ):
114126 self ._log_debug_event ("on_chain_start" , run_id , parent_run_id , inputs = inputs )
115127 self ._set_parent_of_run (run_id , parent_run_id )
116- self ._set_run_metadata (serialized , run_id , inputs , metadata , ** kwargs )
128+ if parent_run_id is None and self ._trace_name is None :
129+ self ._trace_name = self ._get_langchain_run_name (serialized , ** kwargs )
130+ self ._trace_input = inputs
117131
118132 def on_chat_model_start (
119133 self ,
@@ -151,7 +165,7 @@ def on_llm_new_token(
151165 ** kwargs : Any ,
152166 ) -> Any :
153167 """Run on new LLM token. Only available when streaming is enabled."""
154- self .log . debug ( f"on llm new token: run_id: { run_id } parent_run_id: { parent_run_id } " )
168+ self ._log_debug_event ( "on_llm_new_token" , run_id , parent_run_id , token = token )
155169
156170 def on_tool_start (
157171 self ,
@@ -160,7 +174,6 @@ def on_tool_start(
160174 * ,
161175 run_id : UUID ,
162176 parent_run_id : Optional [UUID ] = None ,
163- tags : Optional [List [str ]] = None ,
164177 metadata : Optional [Dict [str , Any ]] = None ,
165178 ** kwargs : Any ,
166179 ) -> Any :
@@ -192,19 +205,13 @@ def on_chain_end(
192205 * ,
193206 run_id : UUID ,
194207 parent_run_id : Optional [UUID ] = None ,
195- tags : Optional [List [str ]] = None ,
196208 ** kwargs : Any ,
197209 ):
198210 self ._log_debug_event ("on_chain_end" , run_id , parent_run_id , outputs = outputs )
199211 self ._pop_parent_of_run (run_id )
200- run_metadata = self ._pop_run_metadata (run_id )
201212
202213 if parent_run_id is None :
203- self ._end_trace (
204- self ._get_trace_id (run_id ),
205- inputs = run_metadata .get ("messages" ) if run_metadata else None ,
206- outputs = outputs ,
207- )
214+ self ._capture_trace (run_id , outputs = outputs )
208215
209216 def on_chain_error (
210217 self ,
@@ -216,22 +223,16 @@ def on_chain_error(
216223 ):
217224 self ._log_debug_event ("on_chain_error" , run_id , parent_run_id , error = error )
218225 self ._pop_parent_of_run (run_id )
219- run_metadata = self ._pop_run_metadata (run_id )
220226
221227 if parent_run_id is None :
222- self ._end_trace (
223- self ._get_trace_id (run_id ),
224- inputs = run_metadata .get ("messages" ) if run_metadata else None ,
225- outputs = None ,
226- )
228+ self ._capture_trace (run_id , outputs = None )
227229
228230 def on_llm_end (
229231 self ,
230232 response : LLMResult ,
231233 * ,
232234 run_id : UUID ,
233235 parent_run_id : Optional [UUID ] = None ,
234- tags : Optional [List [str ]] = None ,
235236 ** kwargs : Any ,
236237 ):
237238 """
@@ -284,7 +285,6 @@ def on_llm_error(
284285 * ,
285286 run_id : UUID ,
286287 parent_run_id : Optional [UUID ] = None ,
287- tags : Optional [List [str ]] = None ,
288288 ** kwargs : Any ,
289289 ):
290290 self ._log_debug_event ("on_llm_error" , run_id , parent_run_id , error = error )
@@ -322,7 +322,6 @@ def on_retriever_start(
322322 * ,
323323 run_id : UUID ,
324324 parent_run_id : Optional [UUID ] = None ,
325- tags : Optional [List [str ]] = None ,
326325 metadata : Optional [Dict [str , Any ]] = None ,
327326 ** kwargs : Any ,
328327 ) -> Any :
@@ -429,10 +428,41 @@ def _get_trace_id(self, run_id: UUID):
429428 trace_id = uuid .uuid4 ()
430429 return trace_id
431430
432- def _end_trace (self , trace_id : UUID , inputs : Dict [str , Any ], outputs : Optional [Dict [str , Any ]]):
431+ def _get_langchain_run_name (self , serialized : Optional [Dict [str , Any ]], ** kwargs : Any ) -> str :
432+ """Retrieve the name of a serialized LangChain runnable.
433+
434+ The prioritization for the determination of the run name is as follows:
435+ - The value assigned to the "name" key in `kwargs`.
436+ - The value assigned to the "name" key in `serialized`.
437+ - The last entry of the value assigned to the "id" key in `serialized`.
438+ - "<unknown>".
439+
440+ Args:
441+ serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
442+ **kwargs (Any): Additional keyword arguments, potentially including the 'name' override.
443+
444+ Returns:
445+ str: The determined name of the Langchain runnable.
446+ """
447+ if "name" in kwargs and kwargs ["name" ] is not None :
448+ return kwargs ["name" ]
449+
450+ try :
451+ return serialized ["name" ]
452+ except (KeyError , TypeError ):
453+ pass
454+
455+ try :
456+ return serialized ["id" ][- 1 ]
457+ except (KeyError , TypeError ):
458+ pass
459+
460+ def _capture_trace (self , run_id : UUID , * , outputs : Optional [Dict [str , Any ]]):
461+ trace_id = self ._get_trace_id (run_id )
433462 event_properties = {
463+ "$ai_trace_name" : self ._trace_name ,
434464 "$ai_trace_id" : trace_id ,
435- "$ai_input_state" : with_privacy_mode (self ._client , self ._privacy_mode , inputs ),
465+ "$ai_input_state" : with_privacy_mode (self ._client , self ._privacy_mode , self . _trace_input ),
436466 ** self ._properties ,
437467 }
438468 if outputs is not None :
0 commit comments