1919from uuid import UUID
2020
2121from langchain .callbacks .base import BaseCallbackHandler
22+ from langchain .schema .agent import AgentAction , AgentFinish
2223from langchain_core .messages import AIMessage , BaseMessage , FunctionMessage , HumanMessage , SystemMessage , ToolMessage
2324from langchain_core .outputs import ChatGeneration , LLMResult
2425from pydantic import BaseModel
2526
27+ from posthog import default_client
2628from posthog .ai .utils import get_model_params , with_privacy_mode
2729from posthog .client import Client
2830
@@ -44,19 +46,30 @@ class RunMetadata(TypedDict, total=False):
4446
4547class CallbackHandler (BaseCallbackHandler ):
4648 """
47- A callback handler for LangChain that sends events to PostHog LLM Observability .
49+ The PostHog LLM observability callback handler for LangChain.
4850 """
4951
5052 _client : Client
5153 """PostHog client instance."""
54+
5255 _distinct_id : Optional [Union [str , int , float , UUID ]]
5356 """Distinct ID of the user to associate the trace with."""
57+
5458 _trace_id : Optional [Union [str , int , float , UUID ]]
5559 """Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""
60+
61+ _trace_input : Optional [Any ]
62+ """The input at the start of the trace. Any JSON object."""
63+
64+ _trace_name : Optional [str ]
65+ """Name of the trace, exposed in the UI."""
66+
5667 _properties : Optional [Dict [str , Any ]]
5768 """Global properties to be sent with every event."""
69+
5870 _runs : RunStorage
5971 """Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
72+
6073 _parent_tree : Dict [UUID , UUID ]
6174 """
6275 A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
@@ -65,7 +78,8 @@ class CallbackHandler(BaseCallbackHandler):
6578
6679 def __init__ (
6780 self ,
68- client : Client ,
81+ client : Optional [Client ] = None ,
82+ * ,
6983 distinct_id : Optional [Union [str , int , float , UUID ]] = None ,
7084 trace_id : Optional [Union [str , int , float , UUID ]] = None ,
7185 properties : Optional [Dict [str , Any ]] = None ,
@@ -81,9 +95,11 @@ def __init__(
8195 privacy_mode: Whether to redact the input and output of the trace.
8296 groups: Optional additional PostHog groups to use for the trace.
8397 """
84- self ._client = client
98+ self ._client = client or default_client
8599 self ._distinct_id = distinct_id
86100 self ._trace_id = trace_id
101+ self ._trace_name = None
102+ self ._trace_input = None
87103 self ._properties = properties or {}
88104 self ._privacy_mode = privacy_mode
89105 self ._groups = groups or {}
@@ -97,9 +113,14 @@ def on_chain_start(
97113 * ,
98114 run_id : UUID ,
99115 parent_run_id : Optional [UUID ] = None ,
116+ metadata : Optional [Dict [str , Any ]] = None ,
100117 ** kwargs ,
101118 ):
119+ self ._log_debug_event ("on_chain_start" , run_id , parent_run_id , inputs = inputs )
102120 self ._set_parent_of_run (run_id , parent_run_id )
121+ if parent_run_id is None and self ._trace_name is None :
122+ self ._trace_name = self ._get_langchain_run_name (serialized , ** kwargs )
123+ self ._trace_input = inputs
103124
104125 def on_chat_model_start (
105126 self ,
@@ -110,6 +131,7 @@ def on_chat_model_start(
110131 parent_run_id : Optional [UUID ] = None ,
111132 ** kwargs ,
112133 ):
134+ self ._log_debug_event ("on_chat_model_start" , run_id , parent_run_id , messages = messages )
113135 self ._set_parent_of_run (run_id , parent_run_id )
114136 input = [_convert_message_to_dict (message ) for row in messages for message in row ]
115137 self ._set_run_metadata (serialized , run_id , input , ** kwargs )
@@ -123,32 +145,93 @@ def on_llm_start(
123145 parent_run_id : Optional [UUID ] = None ,
124146 ** kwargs : Any ,
125147 ):
148+ self ._log_debug_event ("on_llm_start" , run_id , parent_run_id , prompts = prompts )
126149 self ._set_parent_of_run (run_id , parent_run_id )
127150 self ._set_run_metadata (serialized , run_id , prompts , ** kwargs )
128151
152+ def on_llm_new_token (
153+ self ,
154+ token : str ,
155+ * ,
156+ run_id : UUID ,
157+ parent_run_id : Optional [UUID ] = None ,
158+ ** kwargs : Any ,
159+ ) -> Any :
160+ """Run on new LLM token. Only available when streaming is enabled."""
161+ self ._log_debug_event ("on_llm_new_token" , run_id , parent_run_id , token = token )
162+
163+ def on_tool_start (
164+ self ,
165+ serialized : Optional [Dict [str , Any ]],
166+ input_str : str ,
167+ * ,
168+ run_id : UUID ,
169+ parent_run_id : Optional [UUID ] = None ,
170+ metadata : Optional [Dict [str , Any ]] = None ,
171+ ** kwargs : Any ,
172+ ) -> Any :
173+ self ._log_debug_event ("on_tool_start" , run_id , parent_run_id , input_str = input_str )
174+
175+ def on_tool_end (
176+ self ,
177+ output : str ,
178+ * ,
179+ run_id : UUID ,
180+ parent_run_id : Optional [UUID ] = None ,
181+ ** kwargs : Any ,
182+ ) -> Any :
183+ self ._log_debug_event ("on_tool_end" , run_id , parent_run_id , output = output )
184+
185+ def on_tool_error (
186+ self ,
187+ error : Union [Exception , KeyboardInterrupt ],
188+ * ,
189+ run_id : UUID ,
190+ parent_run_id : Optional [UUID ] = None ,
191+ ** kwargs : Any ,
192+ ) -> Any :
193+ self ._log_debug_event ("on_tool_error" , run_id , parent_run_id , error = error )
194+
129195 def on_chain_end (
130196 self ,
131197 outputs : Dict [str , Any ],
132198 * ,
133199 run_id : UUID ,
134200 parent_run_id : Optional [UUID ] = None ,
135- tags : Optional [List [str ]] = None ,
136201 ** kwargs : Any ,
137202 ):
203+ self ._log_debug_event ("on_chain_end" , run_id , parent_run_id , outputs = outputs )
204+ self ._pop_parent_of_run (run_id )
205+
206+ if parent_run_id is None :
207+ self ._capture_trace (run_id , outputs = outputs )
208+
209+ def on_chain_error (
210+ self ,
211+ error : BaseException ,
212+ * ,
213+ run_id : UUID ,
214+ parent_run_id : Optional [UUID ] = None ,
215+ ** kwargs : Any ,
216+ ):
217+ self ._log_debug_event ("on_chain_error" , run_id , parent_run_id , error = error )
138218 self ._pop_parent_of_run (run_id )
139219
220+ if parent_run_id is None :
221+ self ._capture_trace (run_id , outputs = None )
222+
140223 def on_llm_end (
141224 self ,
142225 response : LLMResult ,
143226 * ,
144227 run_id : UUID ,
145228 parent_run_id : Optional [UUID ] = None ,
146- tags : Optional [List [str ]] = None ,
147229 ** kwargs : Any ,
148230 ):
149231 """
150232 The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
151233 """
234+ self ._log_debug_event ("on_llm_end" , run_id , parent_run_id , response = response , kwargs = kwargs )
152235 trace_id = self ._get_trace_id (run_id )
153236 self ._pop_parent_of_run (run_id )
154237 run = self ._pop_run_metadata (run_id )
@@ -189,25 +272,15 @@ def on_llm_end(
189272 groups = self ._groups ,
190273 )
191274
192- def on_chain_error (
193- self ,
194- error : BaseException ,
195- * ,
196- run_id : UUID ,
197- parent_run_id : Optional [UUID ] = None ,
198- ** kwargs : Any ,
199- ):
200- self ._pop_parent_of_run (run_id )
201-
202275 def on_llm_error (
203276 self ,
204277 error : BaseException ,
205278 * ,
206279 run_id : UUID ,
207280 parent_run_id : Optional [UUID ] = None ,
208- tags : Optional [List [str ]] = None ,
209281 ** kwargs : Any ,
210282 ):
283+ self ._log_debug_event ("on_llm_error" , run_id , parent_run_id , error = error )
211284 trace_id = self ._get_trace_id (run_id )
212285 self ._pop_parent_of_run (run_id )
213286 run = self ._pop_run_metadata (run_id )
@@ -235,6 +308,50 @@ def on_llm_error(
235308 groups = self ._groups ,
236309 )
237310
311+ def on_retriever_start (
312+ self ,
313+ serialized : Optional [Dict [str , Any ]],
314+ query : str ,
315+ * ,
316+ run_id : UUID ,
317+ parent_run_id : Optional [UUID ] = None ,
318+ metadata : Optional [Dict [str , Any ]] = None ,
319+ ** kwargs : Any ,
320+ ) -> Any :
321+ self ._log_debug_event ("on_retriever_start" , run_id , parent_run_id , query = query )
322+
323+ def on_retriever_error (
324+ self ,
325+ error : Union [Exception , KeyboardInterrupt ],
326+ * ,
327+ run_id : UUID ,
328+ parent_run_id : Optional [UUID ] = None ,
329+ ** kwargs : Any ,
330+ ) -> Any :
331+ """Run when Retriever errors."""
332+ self ._log_debug_event ("on_retriever_error" , run_id , parent_run_id , error = error )
333+
334+ def on_agent_action (
335+ self ,
336+ action : AgentAction ,
337+ * ,
338+ run_id : UUID ,
339+ parent_run_id : Optional [UUID ] = None ,
340+ ** kwargs : Any ,
341+ ) -> Any :
342+ """Run on agent action."""
343+ self ._log_debug_event ("on_agent_action" , run_id , parent_run_id , action = action )
344+
345+ def on_agent_finish (
346+ self ,
347+ finish : AgentFinish ,
348+ * ,
349+ run_id : UUID ,
350+ parent_run_id : Optional [UUID ] = None ,
351+ ** kwargs : Any ,
352+ ) -> Any :
353+ self ._log_debug_event ("on_agent_finish" , run_id , parent_run_id , finish = finish )
354+
238355 def _set_parent_of_run (self , run_id : UUID , parent_run_id : Optional [UUID ] = None ):
239356 """
240357 Set the parent run ID for a chain run. If there is no parent, the run is the root.
@@ -304,6 +421,65 @@ def _get_trace_id(self, run_id: UUID):
304421 trace_id = uuid .uuid4 ()
305422 return trace_id
306423
424+ def _get_langchain_run_name (self , serialized : Optional [Dict [str , Any ]], ** kwargs : Any ) -> str :
425+ """Retrieve the name of a serialized LangChain runnable.
426+
427+ The prioritization for the determination of the run name is as follows:
428+ - The value assigned to the "name" key in `kwargs`.
429+ - The value assigned to the "name" key in `serialized`.
430+ - The last entry of the value assigned to the "id" key in `serialized`.
431+ - "<unknown>".
432+
433+ Args:
434+ serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
435+ **kwargs (Any): Additional keyword arguments, potentially including the 'name' override.
436+
437+ Returns:
438+ str: The determined name of the Langchain runnable.
439+ """
440+ if "name" in kwargs and kwargs ["name" ] is not None :
441+ return kwargs ["name" ]
442+
443+ try :
444+ return serialized ["name" ]
445+ except (KeyError , TypeError ):
446+ pass
447+
448+ try :
449+ return serialized ["id" ][- 1 ]
450+ except (KeyError , TypeError ):
451+ pass
452+
453+ def _capture_trace (self , run_id : UUID , * , outputs : Optional [Dict [str , Any ]]):
454+ trace_id = self ._get_trace_id (run_id )
455+ event_properties = {
456+ "$ai_trace_name" : self ._trace_name ,
457+ "$ai_trace_id" : trace_id ,
458+ "$ai_input_state" : with_privacy_mode (self ._client , self ._privacy_mode , self ._trace_input ),
459+ ** self ._properties ,
460+ }
461+ if outputs is not None :
462+ event_properties ["$ai_output_state" ] = with_privacy_mode (self ._client , self ._privacy_mode , outputs )
463+ if self ._distinct_id is None :
464+ event_properties ["$process_person_profile" ] = False
465+ self ._client .capture (
466+ distinct_id = self ._distinct_id or trace_id ,
467+ event = "$ai_trace" ,
468+ properties = event_properties ,
469+ groups = self ._groups ,
470+ )
471+
472+ def _log_debug_event (
473+ self ,
474+ event_name : str ,
475+ run_id : UUID ,
476+ parent_run_id : Optional [UUID ] = None ,
477+ ** kwargs ,
478+ ):
479+ log .debug (
480+ f"Event: { event_name } , run_id: { str (run_id )[:5 ]} , parent_run_id: { str (parent_run_id )[:5 ]} , kwargs: { kwargs } "
481+ )
482+
307483
308484def _extract_raw_esponse (last_response ):
309485 """Extract the response from the last response of the LLM call."""
@@ -339,7 +515,9 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
339515 return message_dict
340516
341517
342- def _parse_usage_model (usage : Union [BaseModel , Dict ]) -> Tuple [Union [int , None ], Union [int , None ]]:
518+ def _parse_usage_model (
519+ usage : Union [BaseModel , Dict ],
520+ ) -> Tuple [Union [int , None ], Union [int , None ]]:
343521 if isinstance (usage , BaseModel ):
344522 usage = usage .__dict__
345523
0 commit comments