Skip to content

Commit c17aba1

Browse files
committed
feat: refactor to dataclasses
1 parent 45dc933 commit c17aba1

File tree

1 file changed

+96
-62
lines changed

1 file changed

+96
-62
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 96 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
List,
1313
Optional,
1414
Tuple,
15-
TypedDict,
1615
Union,
1716
cast,
1817
)
1918
from uuid import UUID
19+
from dataclasses import dataclass
2020

2121
from langchain.callbacks.base import BaseCallbackHandler
2222
from langchain.schema.agent import AgentAction, AgentFinish
@@ -31,26 +31,48 @@
3131
log = logging.getLogger("posthog")
3232

3333

34-
class RunMetadata(TypedDict, total=False):
35-
input: Any
36-
"""Input of the run: messages, prompt variables, etc."""
34+
@dataclass
35+
class RunMetadata:
3736
name: str
3837
"""Name of the run: chain name, model name, etc."""
39-
provider: str
38+
start_time: float
39+
"""Start time of the run."""
40+
end_time: Optional[float]
41+
"""End time of the run."""
42+
input: Optional[Any]
43+
"""Input of the run: messages, prompt variables, etc."""
44+
45+
@property
46+
def latency(self) -> float:
47+
if not self.end_time:
48+
return 0
49+
return self.end_time - self.start_time
50+
51+
52+
@dataclass
53+
class TraceMetadata(RunMetadata):
54+
pass
55+
56+
57+
@dataclass
58+
class GenerationMetadata(RunMetadata):
59+
provider: Optional[str] = None
4060
"""Provider of the run: OpenAI, Anthropic"""
41-
model: str
61+
model: Optional[str] = None
4262
"""Model used in the run"""
43-
model_params: Dict[str, Any]
63+
model_params: Optional[Dict[str, Any]] = None
4464
"""Model parameters of the run: temperature, max_tokens, etc."""
45-
base_url: str
65+
base_url: Optional[str] = None
4666
"""Base URL of the provider's API used in the run."""
47-
start_time: float
48-
"""Start time of the run."""
49-
end_time: float
50-
"""End time of the run."""
5167

5268

53-
RunStorage = Dict[UUID, RunMetadata]
69+
@dataclass
70+
class SpanMetadata(RunMetadata):
71+
pass
72+
73+
74+
RunMetadataUnion = Union[TraceMetadata, GenerationMetadata, SpanMetadata]
75+
RunMetadataStorage = Dict[UUID, RunMetadataUnion]
5476

5577

5678
class CallbackHandler(BaseCallbackHandler):
@@ -76,7 +98,7 @@ class CallbackHandler(BaseCallbackHandler):
7698
_properties: Optional[Dict[str, Any]]
7799
"""Global properties to be sent with every event."""
78100

79-
_runs: RunStorage
101+
_runs: RunMetadataStorage
80102
"""Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
81103

82104
_parent_tree: Dict[UUID, UUID]
@@ -107,8 +129,6 @@ def __init__(
107129
self._client = client or default_client
108130
self._distinct_id = distinct_id
109131
self._trace_id = trace_id
110-
self._trace_name = None
111-
self._trace_input = None
112132
self._properties = properties or {}
113133
self._privacy_mode = privacy_mode
114134
self._groups = groups or {}
@@ -127,8 +147,7 @@ def on_chain_start(
127147
):
128148
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
129149
self._set_parent_of_run(run_id, parent_run_id)
130-
if parent_run_id is None and self._trace_name is None:
131-
self._set_span_metadata(run_id, self._get_langchain_run_name(serialized, **kwargs), inputs)
150+
self._set_trace_or_span_metadata(serialized, inputs, run_id, parent_run_id, **kwargs)
132151

133152
def on_chat_model_start(
134153
self,
@@ -179,6 +198,7 @@ def on_tool_start(
179198
**kwargs: Any,
180199
) -> Any:
181200
self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)
201+
self._set_trace_or_span_metadata(serialized, input_str, run_id, parent_run_id, **kwargs)
182202

183203
def on_tool_end(
184204
self,
@@ -192,10 +212,11 @@ def on_tool_end(
192212

193213
def on_tool_error(
194214
self,
195-
error: Union[Exception, KeyboardInterrupt],
215+
error: BaseException,
196216
*,
197217
run_id: UUID,
198218
parent_run_id: Optional[UUID] = None,
219+
tags: Optional[list[str]] = None,
199220
**kwargs: Any,
200221
) -> Any:
201222
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
@@ -243,10 +264,9 @@ def on_llm_end(
243264
trace_id = self._get_trace_id(run_id)
244265
self._pop_parent_of_run(run_id)
245266
run = self._pop_run_metadata(run_id)
246-
if not run:
267+
if not run or not isinstance(run, GenerationMetadata):
247268
return
248269

249-
latency = run.get("end_time", 0) - run.get("start_time", 0)
250270
input_tokens, output_tokens = _parse_usage(response)
251271

252272
generation_result = response.generations[-1]
@@ -258,19 +278,20 @@ def on_llm_end(
258278
output = [_extract_raw_esponse(generation) for generation in generation_result]
259279

260280
event_properties = {
261-
"$ai_provider": run.get("provider"),
262-
"$ai_model": run.get("model"),
263-
"$ai_model_parameters": run.get("model_params"),
264-
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
281+
"$ai_provider": run.provider,
282+
"$ai_model": run.model,
283+
"$ai_model_parameters": run.model_params,
284+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.input),
265285
"$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output),
266286
"$ai_http_status": 200,
267287
"$ai_input_tokens": input_tokens,
268288
"$ai_output_tokens": output_tokens,
269-
"$ai_latency": latency,
289+
"$ai_latency": run.latency,
270290
"$ai_trace_id": trace_id,
271-
"$ai_base_url": run.get("base_url"),
272-
**self._properties,
291+
"$ai_base_url": run.base_url,
273292
}
293+
if self._properties:
294+
event_properties.update(self._properties)
274295
if self._distinct_id is None:
275296
event_properties["$process_person_profile"] = False
276297
self._client.capture(
@@ -292,21 +313,21 @@ def on_llm_error(
292313
trace_id = self._get_trace_id(run_id)
293314
self._pop_parent_of_run(run_id)
294315
run = self._pop_run_metadata(run_id)
295-
if not run:
316+
if not run or not isinstance(run, GenerationMetadata):
296317
return
297318

298-
latency = run.get("end_time", 0) - run.get("start_time", 0)
299319
event_properties = {
300-
"$ai_provider": run.get("provider"),
301-
"$ai_model": run.get("model"),
302-
"$ai_model_parameters": run.get("model_params"),
303-
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
320+
"$ai_provider": run.provider,
321+
"$ai_model": run.model,
322+
"$ai_model_parameters": run.model_params,
323+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.input),
304324
"$ai_http_status": _get_http_status(error),
305-
"$ai_latency": latency,
325+
"$ai_latency": run.latency,
306326
"$ai_trace_id": trace_id,
307-
"$ai_base_url": run.get("base_url"),
308-
**self._properties,
327+
"$ai_base_url": run.base_url,
309328
}
329+
if self._properties:
330+
event_properties.update(self._properties)
310331
if self._distinct_id is None:
311332
event_properties["$process_person_profile"] = False
312333
self._client.capture(
@@ -327,13 +348,15 @@ def on_retriever_start(
327348
**kwargs: Any,
328349
) -> Any:
329350
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
351+
self._set_trace_or_span_metadata(serialized, query, run_id, parent_run_id, **kwargs)
330352

331353
def on_retriever_error(
332354
self,
333-
error: Union[Exception, KeyboardInterrupt],
355+
error: BaseException,
334356
*,
335357
run_id: UUID,
336358
parent_run_id: Optional[UUID] = None,
359+
tags: Optional[list[str]] = None,
337360
**kwargs: Any,
338361
) -> Any:
339362
"""Run when Retriever errors."""
@@ -385,12 +408,23 @@ def _find_root_run(self, run_id: UUID) -> UUID:
385408
id = self._parent_tree[id]
386409
return id
387410

388-
def _set_span_metadata(self, run_id: UUID, name: str, input: Any):
389-
self._runs[run_id] = {
390-
"name": name,
391-
"input": input,
392-
"start_time": time.time(),
393-
}
411+
def _set_trace_or_span_metadata(
412+
self,
413+
serialized: Optional[Dict[str, Any]],
414+
input: Any,
415+
run_id: UUID,
416+
parent_run_id: Optional[UUID] = None,
417+
**kwargs,
418+
):
419+
run_name = self._get_langchain_run_name(serialized, **kwargs)
420+
if parent_run_id is None:
421+
self._runs[run_id] = TraceMetadata(
422+
name=run_name or "trace", input=input, start_time=time.time(), end_time=None
423+
)
424+
else:
425+
self._runs[run_id] = SpanMetadata(
426+
name=run_name or "span", input=input, start_time=time.time(), end_time=None
427+
)
394428

395429
def _set_llm_metadata(
396430
self,
@@ -401,33 +435,31 @@ def _set_llm_metadata(
401435
invocation_params: Optional[Dict[str, Any]] = None,
402436
**kwargs,
403437
):
404-
run: RunMetadata = {
405-
"input": messages,
406-
"start_time": time.time(),
407-
}
438+
run_name = self._get_langchain_run_name(serialized, **kwargs) or "generation"
439+
generation = GenerationMetadata(name=run_name, input=messages, start_time=time.time(), end_time=None)
408440
if isinstance(invocation_params, dict):
409-
run["model_params"] = get_model_params(invocation_params)
441+
generation.model_params = get_model_params(invocation_params)
410442
if isinstance(metadata, dict):
411443
if model := metadata.get("ls_model_name"):
412-
run["model"] = model
444+
generation.model = model
413445
if provider := metadata.get("ls_provider"):
414-
run["provider"] = provider
446+
generation.provider = provider
415447
try:
416448
base_url = serialized["kwargs"]["openai_api_base"]
417449
if base_url is not None:
418-
run["base_url"] = base_url
450+
generation.base_url = base_url
419451
except KeyError:
420452
pass
421-
self._runs[run_id] = run
453+
self._runs[run_id] = generation
422454

423-
def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]:
455+
def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadataUnion]:
424456
end_time = time.time()
425457
try:
426458
run = self._runs.pop(run_id)
427459
except KeyError:
428460
log.warning(f"No run metadata found for run {run_id}")
429461
return None
430-
run["end_time"] = end_time
462+
run.end_time = end_time
431463
return run
432464

433465
def _get_trace_id(self, run_id: UUID):
@@ -436,7 +468,7 @@ def _get_trace_id(self, run_id: UUID):
436468
trace_id = uuid.uuid4()
437469
return trace_id
438470

439-
def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str:
471+
def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> Optional[str]:
440472
"""Retrieve the name of a serialized LangChain runnable.
441473
442474
The prioritization for the determination of the run name is as follows:
@@ -454,29 +486,31 @@ def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs
454486
"""
455487
if "name" in kwargs and kwargs["name"] is not None:
456488
return kwargs["name"]
457-
489+
if serialized is None:
490+
return None
458491
try:
459492
return serialized["name"]
460493
except (KeyError, TypeError):
461494
pass
462-
463495
try:
464496
return serialized["id"][-1]
465497
except (KeyError, TypeError):
466498
pass
499+
return None
467500

468501
def _pop_trace_and_capture(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
469502
trace_id = self._get_trace_id(run_id)
470503
run = self._pop_run_metadata(run_id)
471504
if not run:
472505
return
473506
event_properties = {
474-
"$ai_trace_name": run.get("name"),
507+
"$ai_trace_name": run.name,
475508
"$ai_trace_id": trace_id,
476-
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
477-
"$ai_latency": run.get("end_time", 0) - run.get("start_time", 0),
478-
**self._properties,
509+
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.input),
510+
"$ai_latency": run.latency,
479511
}
512+
if self._properties:
513+
event_properties.update(self._properties)
480514
if outputs is not None:
481515
event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)
482516
if self._distinct_id is None:

0 commit comments

Comments
 (0)