Skip to content

Commit 199dfd2

Browse files
committed
LangChain tracing, with LangGraph tests
1 parent 0384b8c commit 199dfd2

File tree

4 files changed

+533
-147
lines changed

4 files changed

+533
-147
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 173 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"Please install LangChain to use this feature: 'pip install langchain'"
66
)
77

8+
import json
89
import logging
910
import time
1011
import uuid
@@ -30,10 +31,12 @@
3031
ToolMessage,
3132
)
3233
from langchain_core.outputs import ChatGeneration, LLMResult
34+
from langchain.schema.agent import AgentAction, AgentFinish
3335
from pydantic import BaseModel
3436

3537
from posthog.ai.utils import get_model_params, with_privacy_mode
3638
from posthog.client import Client
39+
from posthog import default_client
3740

3841
log = logging.getLogger("posthog")
3942

@@ -53,7 +56,7 @@ class RunMetadata(TypedDict, total=False):
5356

5457
class CallbackHandler(BaseCallbackHandler):
5558
"""
56-
A callback handler for LangChain that sends events to PostHog LLM Observability.
59+
The PostHog LLM observability callback handler for LangChain.
5760
"""
5861

5962
_client: Client
@@ -74,7 +77,8 @@ class CallbackHandler(BaseCallbackHandler):
7477

7578
def __init__(
7679
self,
77-
client: Client,
80+
client: Optional[Client] = None,
81+
*,
7882
distinct_id: Optional[Union[str, int, float, UUID]] = None,
7983
trace_id: Optional[Union[str, int, float, UUID]] = None,
8084
properties: Optional[Dict[str, Any]] = None,
@@ -90,7 +94,7 @@ def __init__(
9094
privacy_mode: Whether to redact the input and output of the trace.
9195
groups: Optional additional PostHog groups to use for the trace.
9296
"""
93-
self._client = client
97+
self._client = client or default_client
9498
self._distinct_id = distinct_id
9599
self._trace_id = trace_id
96100
self._properties = properties or {}
@@ -106,9 +110,12 @@ def on_chain_start(
106110
*,
107111
run_id: UUID,
108112
parent_run_id: Optional[UUID] = None,
113+
metadata: Optional[Dict[str, Any]] = None,
109114
**kwargs,
110115
):
116+
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
111117
self._set_parent_of_run(run_id, parent_run_id)
118+
self._set_run_metadata(serialized, run_id, inputs, metadata, **kwargs)
112119

113120
def on_chat_model_start(
114121
self,
@@ -119,6 +126,9 @@ def on_chat_model_start(
119126
parent_run_id: Optional[UUID] = None,
120127
**kwargs,
121128
):
129+
self._log_debug_event(
130+
"on_chat_model_start", run_id, parent_run_id, messages=messages
131+
)
122132
self._set_parent_of_run(run_id, parent_run_id)
123133
input = [
124134
_convert_message_to_dict(message) for row in messages for message in row
@@ -134,9 +144,58 @@ def on_llm_start(
134144
parent_run_id: Optional[UUID] = None,
135145
**kwargs: Any,
136146
):
147+
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
137148
self._set_parent_of_run(run_id, parent_run_id)
138149
self._set_run_metadata(serialized, run_id, prompts, **kwargs)
139150

151+
def on_llm_new_token(
152+
self,
153+
token: str,
154+
*,
155+
run_id: UUID,
156+
parent_run_id: Optional[UUID] = None,
157+
**kwargs: Any,
158+
) -> Any:
159+
"""Run on new LLM token. Only available when streaming is enabled."""
160+
self.log.debug(
161+
f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}"
162+
)
163+
164+
def on_tool_start(
165+
self,
166+
serialized: Optional[Dict[str, Any]],
167+
input_str: str,
168+
*,
169+
run_id: UUID,
170+
parent_run_id: Optional[UUID] = None,
171+
tags: Optional[List[str]] = None,
172+
metadata: Optional[Dict[str, Any]] = None,
173+
**kwargs: Any,
174+
) -> Any:
175+
self._log_debug_event(
176+
"on_tool_start", run_id, parent_run_id, input_str=input_str
177+
)
178+
179+
def on_tool_end(
180+
self,
181+
output: str,
182+
*,
183+
run_id: UUID,
184+
parent_run_id: Optional[UUID] = None,
185+
**kwargs: Any,
186+
) -> Any:
187+
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
188+
189+
def on_tool_error(
190+
self,
191+
error: Union[Exception, KeyboardInterrupt],
192+
*,
193+
run_id: UUID,
194+
parent_run_id: Optional[UUID] = None,
195+
**kwargs: Any,
196+
) -> Any:
197+
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
198+
140199
def on_chain_end(
141200
self,
142201
outputs: Dict[str, Any],
@@ -146,7 +205,35 @@ def on_chain_end(
146205
tags: Optional[List[str]] = None,
147206
**kwargs: Any,
148207
):
208+
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
149209
self._pop_parent_of_run(run_id)
210+
run_metadata = self._pop_run_metadata(run_id)
211+
212+
if parent_run_id is None:
213+
self._end_trace(
214+
self._get_trace_id(run_id),
215+
inputs=run_metadata.get("messages") if run_metadata else None,
216+
outputs=outputs,
217+
)
218+
219+
def on_chain_error(
220+
self,
221+
error: BaseException,
222+
*,
223+
run_id: UUID,
224+
parent_run_id: Optional[UUID] = None,
225+
**kwargs: Any,
226+
):
227+
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
228+
self._pop_parent_of_run(run_id)
229+
run_metadata = self._pop_run_metadata(run_id)
230+
231+
if parent_run_id is None:
232+
self._end_trace(
233+
self._get_trace_id(run_id),
234+
inputs=run_metadata.get("messages") if run_metadata else None,
235+
outputs=None,
236+
)
150237

151238
def on_llm_end(
152239
self,
@@ -160,6 +247,9 @@ def on_llm_end(
160247
"""
161248
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
162249
"""
250+
self._log_debug_event(
251+
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
252+
)
163253
trace_id = self._get_trace_id(run_id)
164254
self._pop_parent_of_run(run_id)
165255
run = self._pop_run_metadata(run_id)
@@ -207,16 +297,6 @@ def on_llm_end(
207297
groups=self._groups,
208298
)
209299

210-
def on_chain_error(
211-
self,
212-
error: BaseException,
213-
*,
214-
run_id: UUID,
215-
parent_run_id: Optional[UUID] = None,
216-
**kwargs: Any,
217-
):
218-
self._pop_parent_of_run(run_id)
219-
220300
def on_llm_error(
221301
self,
222302
error: BaseException,
@@ -226,6 +306,7 @@ def on_llm_error(
226306
tags: Optional[List[str]] = None,
227307
**kwargs: Any,
228308
):
309+
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
229310
trace_id = self._get_trace_id(run_id)
230311
self._pop_parent_of_run(run_id)
231312
run = self._pop_run_metadata(run_id)
@@ -255,6 +336,51 @@ def on_llm_error(
255336
groups=self._groups,
256337
)
257338

339+
def on_retriever_start(
340+
self,
341+
serialized: Optional[Dict[str, Any]],
342+
query: str,
343+
*,
344+
run_id: UUID,
345+
parent_run_id: Optional[UUID] = None,
346+
tags: Optional[List[str]] = None,
347+
metadata: Optional[Dict[str, Any]] = None,
348+
**kwargs: Any,
349+
) -> Any:
350+
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
351+
352+
def on_retriever_error(
353+
self,
354+
error: Union[Exception, KeyboardInterrupt],
355+
*,
356+
run_id: UUID,
357+
parent_run_id: Optional[UUID] = None,
358+
**kwargs: Any,
359+
) -> Any:
360+
"""Run when Retriever errors."""
361+
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)
362+
363+
def on_agent_action(
364+
self,
365+
action: AgentAction,
366+
*,
367+
run_id: UUID,
368+
parent_run_id: Optional[UUID] = None,
369+
**kwargs: Any,
370+
) -> Any:
371+
"""Run on agent action."""
372+
self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action)
373+
374+
def on_agent_finish(
375+
self,
376+
finish: AgentFinish,
377+
*,
378+
run_id: UUID,
379+
parent_run_id: Optional[UUID] = None,
380+
**kwargs: Any,
381+
) -> Any:
382+
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)
383+
258384
def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
259385
"""
260386
Set the parent run ID for a chain run. If there is no parent, the run is the root.
@@ -324,6 +450,40 @@ def _get_trace_id(self, run_id: UUID):
324450
trace_id = uuid.uuid4()
325451
return trace_id
326452

453+
def _end_trace(
454+
self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]]
455+
):
456+
event_properties = {
457+
"$ai_trace_id": trace_id,
458+
"$ai_input_state": with_privacy_mode(
459+
self._client, self._privacy_mode, inputs
460+
),
461+
**self._properties,
462+
}
463+
if outputs is not None:
464+
event_properties["$ai_output_state"] = with_privacy_mode(
465+
self._client, self._privacy_mode, outputs
466+
)
467+
if self._distinct_id is None:
468+
event_properties["$process_person_profile"] = False
469+
self._client.capture(
470+
distinct_id=self._distinct_id or trace_id,
471+
event="$ai_trace",
472+
properties=event_properties,
473+
groups=self._groups,
474+
)
475+
476+
def _log_debug_event(
477+
self,
478+
event_name: str,
479+
run_id: UUID,
480+
parent_run_id: Optional[UUID] = None,
481+
**kwargs,
482+
):
483+
log.debug(
484+
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}"
485+
)
486+
327487

328488
def _extract_raw_esponse(last_response):
329489
"""Extract the response from the last response of the LLM call."""

posthog/test/ai/langchain/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
pytest.importorskip("langchain")
44
pytest.importorskip("langchain_community")
5+
pytest.importorskip("langgraph")

0 commit comments

Comments
 (0)