55 "Please install LangChain to use this feature: 'pip install langchain'"
66 )
77
8+ import json
89import logging
910import time
1011import uuid
3031 ToolMessage ,
3132)
3233from langchain_core .outputs import ChatGeneration , LLMResult
34+ from langchain .schema .agent import AgentAction , AgentFinish
3335from pydantic import BaseModel
3436
3537from posthog .ai .utils import get_model_params , with_privacy_mode
3638from posthog .client import Client
39+ from posthog import default_client
3740
3841log = logging .getLogger ("posthog" )
3942
@@ -53,7 +56,7 @@ class RunMetadata(TypedDict, total=False):
5356
5457class CallbackHandler (BaseCallbackHandler ):
5558 """
56- A callback handler for LangChain that sends events to PostHog LLM Observability .
59+ The PostHog LLM observability callback handler for LangChain.
5760 """
5861
5962 _client : Client
@@ -74,7 +77,8 @@ class CallbackHandler(BaseCallbackHandler):
7477
7578 def __init__ (
7679 self ,
77- client : Client ,
80+ client : Optional [Client ] = None ,
81+ * ,
7882 distinct_id : Optional [Union [str , int , float , UUID ]] = None ,
7983 trace_id : Optional [Union [str , int , float , UUID ]] = None ,
8084 properties : Optional [Dict [str , Any ]] = None ,
@@ -90,7 +94,7 @@ def __init__(
9094 privacy_mode: Whether to redact the input and output of the trace.
9195 groups: Optional additional PostHog groups to use for the trace.
9296 """
93- self ._client = client
97+ self ._client = client or default_client
9498 self ._distinct_id = distinct_id
9599 self ._trace_id = trace_id
96100 self ._properties = properties or {}
@@ -106,9 +110,12 @@ def on_chain_start(
106110 * ,
107111 run_id : UUID ,
108112 parent_run_id : Optional [UUID ] = None ,
113+ metadata : Optional [Dict [str , Any ]] = None ,
109114 ** kwargs ,
110115 ):
116+ self ._log_debug_event ("on_chain_start" , run_id , parent_run_id , inputs = inputs )
111117 self ._set_parent_of_run (run_id , parent_run_id )
118+ self ._set_run_metadata (serialized , run_id , inputs , metadata , ** kwargs )
112119
113120 def on_chat_model_start (
114121 self ,
@@ -119,6 +126,9 @@ def on_chat_model_start(
119126 parent_run_id : Optional [UUID ] = None ,
120127 ** kwargs ,
121128 ):
129+ self ._log_debug_event (
130+ "on_chat_model_start" , run_id , parent_run_id , messages = messages
131+ )
122132 self ._set_parent_of_run (run_id , parent_run_id )
123133 input = [
124134 _convert_message_to_dict (message ) for row in messages for message in row
@@ -134,9 +144,58 @@ def on_llm_start(
134144 parent_run_id : Optional [UUID ] = None ,
135145 ** kwargs : Any ,
136146 ):
147+ self ._log_debug_event ("on_llm_start" , run_id , parent_run_id , prompts = prompts )
137148 self ._set_parent_of_run (run_id , parent_run_id )
138149 self ._set_run_metadata (serialized , run_id , prompts , ** kwargs )
139150
151+ def on_llm_new_token (
152+ self ,
153+ token : str ,
154+ * ,
155+ run_id : UUID ,
156+ parent_run_id : Optional [UUID ] = None ,
157+ ** kwargs : Any ,
158+ ) -> Any :
159+ """Run on new LLM token. Only available when streaming is enabled."""
160+ self .log .debug (
161+ f"on llm new token: run_id: { run_id } parent_run_id: { parent_run_id } "
162+ )
163+
164+ def on_tool_start (
165+ self ,
166+ serialized : Optional [Dict [str , Any ]],
167+ input_str : str ,
168+ * ,
169+ run_id : UUID ,
170+ parent_run_id : Optional [UUID ] = None ,
171+ tags : Optional [List [str ]] = None ,
172+ metadata : Optional [Dict [str , Any ]] = None ,
173+ ** kwargs : Any ,
174+ ) -> Any :
175+ self ._log_debug_event (
176+ "on_tool_start" , run_id , parent_run_id , input_str = input_str
177+ )
178+
179+ def on_tool_end (
180+ self ,
181+ output : str ,
182+ * ,
183+ run_id : UUID ,
184+ parent_run_id : Optional [UUID ] = None ,
185+ ** kwargs : Any ,
186+ ) -> Any :
187+ self ._log_debug_event ("on_tool_end" , run_id , parent_run_id , output = output )
188+
189+ def on_tool_error (
190+ self ,
191+ error : Union [Exception , KeyboardInterrupt ],
192+ * ,
193+ run_id : UUID ,
194+ parent_run_id : Optional [UUID ] = None ,
195+ ** kwargs : Any ,
196+ ) -> Any :
197+ self ._log_debug_event ("on_tool_error" , run_id , parent_run_id , error = error )
198+
140199 def on_chain_end (
141200 self ,
142201 outputs : Dict [str , Any ],
@@ -146,7 +205,35 @@ def on_chain_end(
146205 tags : Optional [List [str ]] = None ,
147206 ** kwargs : Any ,
148207 ):
208+ self ._log_debug_event ("on_chain_end" , run_id , parent_run_id , outputs = outputs )
149209 self ._pop_parent_of_run (run_id )
210+ run_metadata = self ._pop_run_metadata (run_id )
211+
212+ if parent_run_id is None :
213+ self ._end_trace (
214+ self ._get_trace_id (run_id ),
215+ inputs = run_metadata .get ("messages" ) if run_metadata else None ,
216+ outputs = outputs ,
217+ )
218+
219+ def on_chain_error (
220+ self ,
221+ error : BaseException ,
222+ * ,
223+ run_id : UUID ,
224+ parent_run_id : Optional [UUID ] = None ,
225+ ** kwargs : Any ,
226+ ):
227+ self ._log_debug_event ("on_chain_error" , run_id , parent_run_id , error = error )
228+ self ._pop_parent_of_run (run_id )
229+ run_metadata = self ._pop_run_metadata (run_id )
230+
231+ if parent_run_id is None :
232+ self ._end_trace (
233+ self ._get_trace_id (run_id ),
234+ inputs = run_metadata .get ("messages" ) if run_metadata else None ,
235+ outputs = None ,
236+ )
150237
151238 def on_llm_end (
152239 self ,
@@ -160,6 +247,9 @@ def on_llm_end(
160247 """
161248 The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
162249 """
250+ self ._log_debug_event (
251+ "on_llm_end" , run_id , parent_run_id , response = response , kwargs = kwargs
252+ )
163253 trace_id = self ._get_trace_id (run_id )
164254 self ._pop_parent_of_run (run_id )
165255 run = self ._pop_run_metadata (run_id )
@@ -207,16 +297,6 @@ def on_llm_end(
207297 groups = self ._groups ,
208298 )
209299
210- def on_chain_error (
211- self ,
212- error : BaseException ,
213- * ,
214- run_id : UUID ,
215- parent_run_id : Optional [UUID ] = None ,
216- ** kwargs : Any ,
217- ):
218- self ._pop_parent_of_run (run_id )
219-
220300 def on_llm_error (
221301 self ,
222302 error : BaseException ,
@@ -226,6 +306,7 @@ def on_llm_error(
226306 tags : Optional [List [str ]] = None ,
227307 ** kwargs : Any ,
228308 ):
309+ self ._log_debug_event ("on_llm_error" , run_id , parent_run_id , error = error )
229310 trace_id = self ._get_trace_id (run_id )
230311 self ._pop_parent_of_run (run_id )
231312 run = self ._pop_run_metadata (run_id )
@@ -255,6 +336,51 @@ def on_llm_error(
255336 groups = self ._groups ,
256337 )
257338
339+ def on_retriever_start (
340+ self ,
341+ serialized : Optional [Dict [str , Any ]],
342+ query : str ,
343+ * ,
344+ run_id : UUID ,
345+ parent_run_id : Optional [UUID ] = None ,
346+ tags : Optional [List [str ]] = None ,
347+ metadata : Optional [Dict [str , Any ]] = None ,
348+ ** kwargs : Any ,
349+ ) -> Any :
350+ self ._log_debug_event ("on_retriever_start" , run_id , parent_run_id , query = query )
351+
352+ def on_retriever_error (
353+ self ,
354+ error : Union [Exception , KeyboardInterrupt ],
355+ * ,
356+ run_id : UUID ,
357+ parent_run_id : Optional [UUID ] = None ,
358+ ** kwargs : Any ,
359+ ) -> Any :
360+ """Run when Retriever errors."""
361+ self ._log_debug_event ("on_retriever_error" , run_id , parent_run_id , error = error )
362+
363+ def on_agent_action (
364+ self ,
365+ action : AgentAction ,
366+ * ,
367+ run_id : UUID ,
368+ parent_run_id : Optional [UUID ] = None ,
369+ ** kwargs : Any ,
370+ ) -> Any :
371+ """Run on agent action."""
372+ self ._log_debug_event ("on_agent_action" , run_id , parent_run_id , action = action )
373+
374+ def on_agent_finish (
375+ self ,
376+ finish : AgentFinish ,
377+ * ,
378+ run_id : UUID ,
379+ parent_run_id : Optional [UUID ] = None ,
380+ ** kwargs : Any ,
381+ ) -> Any :
382+ self ._log_debug_event ("on_agent_finish" , run_id , parent_run_id , finish = finish )
383+
258384 def _set_parent_of_run (self , run_id : UUID , parent_run_id : Optional [UUID ] = None ):
259385 """
260386 Set the parent run ID for a chain run. If there is no parent, the run is the root.
@@ -324,6 +450,40 @@ def _get_trace_id(self, run_id: UUID):
324450 trace_id = uuid .uuid4 ()
325451 return trace_id
326452
453+ def _end_trace (
454+ self , trace_id : UUID , inputs : Dict [str , Any ], outputs : Optional [Dict [str , Any ]]
455+ ):
456+ event_properties = {
457+ "$ai_trace_id" : trace_id ,
458+ "$ai_input_state" : with_privacy_mode (
459+ self ._client , self ._privacy_mode , inputs
460+ ),
461+ ** self ._properties ,
462+ }
463+ if outputs is not None :
464+ event_properties ["$ai_output_state" ] = with_privacy_mode (
465+ self ._client , self ._privacy_mode , outputs
466+ )
467+ if self ._distinct_id is None :
468+ event_properties ["$process_person_profile" ] = False
469+ self ._client .capture (
470+ distinct_id = self ._distinct_id or trace_id ,
471+ event = "$ai_trace" ,
472+ properties = event_properties ,
473+ groups = self ._groups ,
474+ )
475+
476+ def _log_debug_event (
477+ self ,
478+ event_name : str ,
479+ run_id : UUID ,
480+ parent_run_id : Optional [UUID ] = None ,
481+ ** kwargs ,
482+ ):
483+ log .debug (
484+ f"Event: { event_name } , run_id: { str (run_id )[:5 ]} , parent_run_id: { str (parent_run_id )[:5 ]} , kwargs: { kwargs } "
485+ )
486+
327487
328488def _extract_raw_esponse (last_response ):
329489 """Extract the response from the last response of the LLM call."""
0 commit comments