Skip to content

Commit 0b6ff2e

Browse files
authored
feat(llm-observability): LangChain tracing, with LangGraph tests (#169)
1 parent 80f0b3e commit 0b6ff2e

File tree

5 files changed

+539
-149
lines changed

5 files changed

+539
-149
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
3737
- name: Lint with flake8
3838
run: |
39-
flake8 posthog --ignore E501
39+
flake8 posthog --ignore E501,W503
4040
4141
- name: Check import order with isort
4242
run: |

posthog/ai/langchain/callbacks.py

Lines changed: 195 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from uuid import UUID
2020

2121
from langchain.callbacks.base import BaseCallbackHandler
22+
from langchain.schema.agent import AgentAction, AgentFinish
2223
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage
2324
from langchain_core.outputs import ChatGeneration, LLMResult
2425
from pydantic import BaseModel
2526

27+
from posthog import default_client
2628
from posthog.ai.utils import get_model_params, with_privacy_mode
2729
from posthog.client import Client
2830

@@ -44,19 +46,30 @@ class RunMetadata(TypedDict, total=False):
4446

4547
class CallbackHandler(BaseCallbackHandler):
4648
"""
47-
A callback handler for LangChain that sends events to PostHog LLM Observability.
49+
The PostHog LLM observability callback handler for LangChain.
4850
"""
4951

5052
_client: Client
5153
"""PostHog client instance."""
54+
5255
_distinct_id: Optional[Union[str, int, float, UUID]]
5356
"""Distinct ID of the user to associate the trace with."""
57+
5458
_trace_id: Optional[Union[str, int, float, UUID]]
5559
"""Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""
60+
61+
_trace_input: Optional[Any]
62+
"""The input at the start of the trace. Any JSON object."""
63+
64+
_trace_name: Optional[str]
65+
"""Name of the trace, exposed in the UI."""
66+
5667
_properties: Optional[Dict[str, Any]]
5768
"""Global properties to be sent with every event."""
69+
5870
_runs: RunStorage
5971
"""Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
72+
6073
_parent_tree: Dict[UUID, UUID]
6174
"""
6275
A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
@@ -65,7 +78,8 @@ class CallbackHandler(BaseCallbackHandler):
6578

6679
def __init__(
6780
self,
68-
client: Client,
81+
client: Optional[Client] = None,
82+
*,
6983
distinct_id: Optional[Union[str, int, float, UUID]] = None,
7084
trace_id: Optional[Union[str, int, float, UUID]] = None,
7185
properties: Optional[Dict[str, Any]] = None,
@@ -81,9 +95,11 @@ def __init__(
8195
privacy_mode: Whether to redact the input and output of the trace.
8296
groups: Optional additional PostHog groups to use for the trace.
8397
"""
84-
self._client = client
98+
self._client = client or default_client
8599
self._distinct_id = distinct_id
86100
self._trace_id = trace_id
101+
self._trace_name = None
102+
self._trace_input = None
87103
self._properties = properties or {}
88104
self._privacy_mode = privacy_mode
89105
self._groups = groups or {}
@@ -97,9 +113,14 @@ def on_chain_start(
97113
*,
98114
run_id: UUID,
99115
parent_run_id: Optional[UUID] = None,
116+
metadata: Optional[Dict[str, Any]] = None,
100117
**kwargs,
101118
):
119+
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
102120
self._set_parent_of_run(run_id, parent_run_id)
121+
if parent_run_id is None and self._trace_name is None:
122+
self._trace_name = self._get_langchain_run_name(serialized, **kwargs)
123+
self._trace_input = inputs
103124

104125
def on_chat_model_start(
105126
self,
@@ -110,6 +131,7 @@ def on_chat_model_start(
110131
parent_run_id: Optional[UUID] = None,
111132
**kwargs,
112133
):
134+
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
113135
self._set_parent_of_run(run_id, parent_run_id)
114136
input = [_convert_message_to_dict(message) for row in messages for message in row]
115137
self._set_run_metadata(serialized, run_id, input, **kwargs)
@@ -123,32 +145,93 @@ def on_llm_start(
123145
parent_run_id: Optional[UUID] = None,
124146
**kwargs: Any,
125147
):
148+
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
126149
self._set_parent_of_run(run_id, parent_run_id)
127150
self._set_run_metadata(serialized, run_id, prompts, **kwargs)
128151

152+
def on_llm_new_token(
153+
self,
154+
token: str,
155+
*,
156+
run_id: UUID,
157+
parent_run_id: Optional[UUID] = None,
158+
**kwargs: Any,
159+
) -> Any:
160+
"""Run on new LLM token. Only available when streaming is enabled."""
161+
self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token)
162+
163+
def on_tool_start(
164+
self,
165+
serialized: Optional[Dict[str, Any]],
166+
input_str: str,
167+
*,
168+
run_id: UUID,
169+
parent_run_id: Optional[UUID] = None,
170+
metadata: Optional[Dict[str, Any]] = None,
171+
**kwargs: Any,
172+
) -> Any:
173+
self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)
174+
175+
def on_tool_end(
176+
self,
177+
output: str,
178+
*,
179+
run_id: UUID,
180+
parent_run_id: Optional[UUID] = None,
181+
**kwargs: Any,
182+
) -> Any:
183+
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
184+
185+
def on_tool_error(
186+
self,
187+
error: Union[Exception, KeyboardInterrupt],
188+
*,
189+
run_id: UUID,
190+
parent_run_id: Optional[UUID] = None,
191+
**kwargs: Any,
192+
) -> Any:
193+
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
194+
129195
def on_chain_end(
130196
self,
131197
outputs: Dict[str, Any],
132198
*,
133199
run_id: UUID,
134200
parent_run_id: Optional[UUID] = None,
135-
tags: Optional[List[str]] = None,
136201
**kwargs: Any,
137202
):
203+
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
204+
self._pop_parent_of_run(run_id)
205+
206+
if parent_run_id is None:
207+
self._capture_trace(run_id, outputs=outputs)
208+
209+
def on_chain_error(
210+
self,
211+
error: BaseException,
212+
*,
213+
run_id: UUID,
214+
parent_run_id: Optional[UUID] = None,
215+
**kwargs: Any,
216+
):
217+
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
138218
self._pop_parent_of_run(run_id)
139219

220+
if parent_run_id is None:
221+
self._capture_trace(run_id, outputs=None)
222+
140223
def on_llm_end(
141224
self,
142225
response: LLMResult,
143226
*,
144227
run_id: UUID,
145228
parent_run_id: Optional[UUID] = None,
146-
tags: Optional[List[str]] = None,
147229
**kwargs: Any,
148230
):
149231
"""
150232
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
151233
"""
234+
self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs)
152235
trace_id = self._get_trace_id(run_id)
153236
self._pop_parent_of_run(run_id)
154237
run = self._pop_run_metadata(run_id)
@@ -189,25 +272,15 @@ def on_llm_end(
189272
groups=self._groups,
190273
)
191274

192-
def on_chain_error(
193-
self,
194-
error: BaseException,
195-
*,
196-
run_id: UUID,
197-
parent_run_id: Optional[UUID] = None,
198-
**kwargs: Any,
199-
):
200-
self._pop_parent_of_run(run_id)
201-
202275
def on_llm_error(
203276
self,
204277
error: BaseException,
205278
*,
206279
run_id: UUID,
207280
parent_run_id: Optional[UUID] = None,
208-
tags: Optional[List[str]] = None,
209281
**kwargs: Any,
210282
):
283+
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
211284
trace_id = self._get_trace_id(run_id)
212285
self._pop_parent_of_run(run_id)
213286
run = self._pop_run_metadata(run_id)
@@ -235,6 +308,50 @@ def on_llm_error(
235308
groups=self._groups,
236309
)
237310

311+
def on_retriever_start(
312+
self,
313+
serialized: Optional[Dict[str, Any]],
314+
query: str,
315+
*,
316+
run_id: UUID,
317+
parent_run_id: Optional[UUID] = None,
318+
metadata: Optional[Dict[str, Any]] = None,
319+
**kwargs: Any,
320+
) -> Any:
321+
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
322+
323+
def on_retriever_error(
324+
self,
325+
error: Union[Exception, KeyboardInterrupt],
326+
*,
327+
run_id: UUID,
328+
parent_run_id: Optional[UUID] = None,
329+
**kwargs: Any,
330+
) -> Any:
331+
"""Run when Retriever errors."""
332+
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)
333+
334+
def on_agent_action(
335+
self,
336+
action: AgentAction,
337+
*,
338+
run_id: UUID,
339+
parent_run_id: Optional[UUID] = None,
340+
**kwargs: Any,
341+
) -> Any:
342+
"""Run on agent action."""
343+
self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action)
344+
345+
def on_agent_finish(
346+
self,
347+
finish: AgentFinish,
348+
*,
349+
run_id: UUID,
350+
parent_run_id: Optional[UUID] = None,
351+
**kwargs: Any,
352+
) -> Any:
353+
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)
354+
238355
def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
239356
"""
240357
Set the parent run ID for a chain run. If there is no parent, the run is the root.
@@ -304,6 +421,65 @@ def _get_trace_id(self, run_id: UUID):
304421
trace_id = uuid.uuid4()
305422
return trace_id
306423

424+
def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str:
425+
"""Retrieve the name of a serialized LangChain runnable.
426+
427+
The prioritization for the determination of the run name is as follows:
428+
- The value assigned to the "name" key in `kwargs`.
429+
- The value assigned to the "name" key in `serialized`.
430+
- The last entry of the value assigned to the "id" key in `serialized`.
431+
- "<unknown>".
432+
433+
Args:
434+
serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
435+
**kwargs (Any): Additional keyword arguments, potentially including the 'name' override.
436+
437+
Returns:
438+
str: The determined name of the Langchain runnable.
439+
"""
440+
if "name" in kwargs and kwargs["name"] is not None:
441+
return kwargs["name"]
442+
443+
try:
444+
return serialized["name"]
445+
except (KeyError, TypeError):
446+
pass
447+
448+
try:
449+
return serialized["id"][-1]
450+
except (KeyError, TypeError):
451+
pass
452+
453+
def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
454+
trace_id = self._get_trace_id(run_id)
455+
event_properties = {
456+
"$ai_trace_name": self._trace_name,
457+
"$ai_trace_id": trace_id,
458+
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input),
459+
**self._properties,
460+
}
461+
if outputs is not None:
462+
event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)
463+
if self._distinct_id is None:
464+
event_properties["$process_person_profile"] = False
465+
self._client.capture(
466+
distinct_id=self._distinct_id or trace_id,
467+
event="$ai_trace",
468+
properties=event_properties,
469+
groups=self._groups,
470+
)
471+
472+
def _log_debug_event(
473+
self,
474+
event_name: str,
475+
run_id: UUID,
476+
parent_run_id: Optional[UUID] = None,
477+
**kwargs,
478+
):
479+
log.debug(
480+
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}"
481+
)
482+
307483

308484
def _extract_raw_esponse(last_response):
309485
"""Extract the response from the last response of the LLM call."""
@@ -339,7 +515,9 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
339515
return message_dict
340516

341517

342-
def _parse_usage_model(usage: Union[BaseModel, Dict]) -> Tuple[Union[int, None], Union[int, None]]:
518+
def _parse_usage_model(
519+
usage: Union[BaseModel, Dict],
520+
) -> Tuple[Union[int, None], Union[int, None]]:
343521
if isinstance(usage, BaseModel):
344522
usage = usage.__dict__
345523

posthog/test/ai/langchain/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
pytest.importorskip("langchain")
44
pytest.importorskip("langchain_community")
5+
pytest.importorskip("langgraph")

0 commit comments

Comments
 (0)