Skip to content

Commit e81d61d

Browse files
committed
Fix input capture
1 parent e780da0 commit e81d61d

File tree

2 files changed

+143
-80
lines changed

2 files changed

+143
-80
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
except ImportError:
44
raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")
55

6-
import json
76
import logging
87
import time
98
import uuid
@@ -59,14 +58,25 @@ class CallbackHandler(BaseCallbackHandler):
5958

6059
_client: Client
6160
"""PostHog client instance."""
61+
6262
_distinct_id: Optional[Union[str, int, float, UUID]]
6363
"""Distinct ID of the user to associate the trace with."""
64+
6465
_trace_id: Optional[Union[str, int, float, UUID]]
6566
"""Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""
67+
68+
_trace_input: Optional[Any]
69+
"""The input at the start of the trace. Any JSON object."""
70+
71+
_trace_name: Optional[str]
72+
"""Name of the trace, exposed in the UI."""
73+
6674
_properties: Optional[Dict[str, Any]]
6775
"""Global properties to be sent with every event."""
76+
6877
_runs: RunStorage
6978
"""Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
79+
7080
_parent_tree: Dict[UUID, UUID]
7181
"""
7282
A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
@@ -95,6 +105,8 @@ def __init__(
95105
self._client = client or default_client
96106
self._distinct_id = distinct_id
97107
self._trace_id = trace_id
108+
self._trace_name = None
109+
self._trace_input = None
98110
self._properties = properties or {}
99111
self._privacy_mode = privacy_mode
100112
self._groups = groups or {}
@@ -113,7 +125,9 @@ def on_chain_start(
113125
):
114126
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
115127
self._set_parent_of_run(run_id, parent_run_id)
116-
self._set_run_metadata(serialized, run_id, inputs, metadata, **kwargs)
128+
if parent_run_id is None and self._trace_name is None:
129+
self._trace_name = self._get_langchain_run_name(serialized, **kwargs)
130+
self._trace_input = inputs
117131

118132
def on_chat_model_start(
119133
self,
@@ -151,7 +165,7 @@ def on_llm_new_token(
151165
**kwargs: Any,
152166
) -> Any:
153167
"""Run on new LLM token. Only available when streaming is enabled."""
154-
self.log.debug(f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}")
168+
self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token)
155169

156170
def on_tool_start(
157171
self,
@@ -160,7 +174,6 @@ def on_tool_start(
160174
*,
161175
run_id: UUID,
162176
parent_run_id: Optional[UUID] = None,
163-
tags: Optional[List[str]] = None,
164177
metadata: Optional[Dict[str, Any]] = None,
165178
**kwargs: Any,
166179
) -> Any:
@@ -192,19 +205,13 @@ def on_chain_end(
192205
*,
193206
run_id: UUID,
194207
parent_run_id: Optional[UUID] = None,
195-
tags: Optional[List[str]] = None,
196208
**kwargs: Any,
197209
):
198210
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
199211
self._pop_parent_of_run(run_id)
200-
run_metadata = self._pop_run_metadata(run_id)
201212

202213
if parent_run_id is None:
203-
self._end_trace(
204-
self._get_trace_id(run_id),
205-
inputs=run_metadata.get("messages") if run_metadata else None,
206-
outputs=outputs,
207-
)
214+
self._capture_trace(run_id, outputs=outputs)
208215

209216
def on_chain_error(
210217
self,
@@ -216,22 +223,16 @@ def on_chain_error(
216223
):
217224
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
218225
self._pop_parent_of_run(run_id)
219-
run_metadata = self._pop_run_metadata(run_id)
220226

221227
if parent_run_id is None:
222-
self._end_trace(
223-
self._get_trace_id(run_id),
224-
inputs=run_metadata.get("messages") if run_metadata else None,
225-
outputs=None,
226-
)
228+
self._capture_trace(run_id, outputs=None)
227229

228230
def on_llm_end(
229231
self,
230232
response: LLMResult,
231233
*,
232234
run_id: UUID,
233235
parent_run_id: Optional[UUID] = None,
234-
tags: Optional[List[str]] = None,
235236
**kwargs: Any,
236237
):
237238
"""
@@ -284,7 +285,6 @@ def on_llm_error(
284285
*,
285286
run_id: UUID,
286287
parent_run_id: Optional[UUID] = None,
287-
tags: Optional[List[str]] = None,
288288
**kwargs: Any,
289289
):
290290
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
@@ -322,7 +322,6 @@ def on_retriever_start(
322322
*,
323323
run_id: UUID,
324324
parent_run_id: Optional[UUID] = None,
325-
tags: Optional[List[str]] = None,
326325
metadata: Optional[Dict[str, Any]] = None,
327326
**kwargs: Any,
328327
) -> Any:
@@ -429,10 +428,41 @@ def _get_trace_id(self, run_id: UUID):
429428
trace_id = uuid.uuid4()
430429
return trace_id
431430

432-
def _end_trace(self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]]):
431+
def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str:
432+
"""Retrieve the name of a serialized LangChain runnable.
433+
434+
The prioritization for the determination of the run name is as follows:
435+
- The value assigned to the "name" key in `kwargs`.
436+
- The value assigned to the "name" key in `serialized`.
437+
- The last entry of the value assigned to the "id" key in `serialized`.
438+
- "<unknown>".
439+
440+
Args:
441+
serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
442+
**kwargs (Any): Additional keyword arguments, potentially including the 'name' override.
443+
444+
Returns:
445+
str: The determined name of the Langchain runnable.
446+
"""
447+
if "name" in kwargs and kwargs["name"] is not None:
448+
return kwargs["name"]
449+
450+
try:
451+
return serialized["name"]
452+
except (KeyError, TypeError):
453+
pass
454+
455+
try:
456+
return serialized["id"][-1]
457+
except (KeyError, TypeError):
458+
pass
459+
460+
def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
461+
trace_id = self._get_trace_id(run_id)
433462
event_properties = {
463+
"$ai_trace_name": self._trace_name,
434464
"$ai_trace_id": trace_id,
435-
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, inputs),
465+
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input),
436466
**self._properties,
437467
}
438468
if outputs is not None:

0 commit comments

Comments
 (0)