Skip to content

Commit 8369915

Browse files
committed
Format callbacks.py
1 parent 90c0417 commit 8369915

File tree

1 file changed

+101
-30
lines changed

1 file changed

+101
-30
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 101 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
try:
22
import langchain # noqa: F401
33
except ImportError:
4-
raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")
4+
raise ModuleNotFoundError(
5+
"Please install LangChain to use this feature: 'pip install langchain'"
6+
)
57

68
import logging
79
import time
@@ -21,7 +23,14 @@
2123
from langchain.callbacks.base import BaseCallbackHandler
2224
from langchain.schema.agent import AgentAction, AgentFinish
2325
from langchain_core.documents import Document
24-
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage
26+
from langchain_core.messages import (
27+
AIMessage,
28+
BaseMessage,
29+
FunctionMessage,
30+
HumanMessage,
31+
SystemMessage,
32+
ToolMessage,
33+
)
2534
from langchain_core.outputs import ChatGeneration, LLMResult
2635
from pydantic import BaseModel
2736

@@ -63,6 +72,7 @@ class GenerationMetadata(SpanMetadata):
6372
tools: Optional[List[Dict[str, Any]]] = None
6473
"""Tools provided to the model."""
6574

75+
6676
RunMetadata = Union[SpanMetadata, GenerationMetadata]
6777
RunMetadataStorage = Dict[UUID, RunMetadata]
6878

@@ -142,7 +152,9 @@ def on_chain_start(
142152
):
143153
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
144154
self._set_parent_of_run(run_id, parent_run_id)
145-
self._set_trace_or_span_metadata(serialized, inputs, run_id, parent_run_id, **kwargs)
155+
self._set_trace_or_span_metadata(
156+
serialized, inputs, run_id, parent_run_id, **kwargs
157+
)
146158

147159
def on_chain_end(
148160
self,
@@ -175,9 +187,13 @@ def on_chat_model_start(
175187
parent_run_id: Optional[UUID] = None,
176188
**kwargs,
177189
):
178-
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
190+
self._log_debug_event(
191+
"on_chat_model_start", run_id, parent_run_id, messages=messages
192+
)
179193
self._set_parent_of_run(run_id, parent_run_id)
180-
input = [_convert_message_to_dict(message) for row in messages for message in row]
194+
input = [
195+
_convert_message_to_dict(message) for row in messages for message in row
196+
]
181197
self._set_llm_metadata(serialized, run_id, input, **kwargs)
182198

183199
def on_llm_start(
@@ -215,7 +231,9 @@ def on_llm_end(
215231
"""
216232
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
217233
"""
218-
self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs)
234+
self._log_debug_event(
235+
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
236+
)
219237
self._pop_run_and_capture_generation(run_id, parent_run_id, response)
220238

221239
def on_llm_error(
@@ -239,9 +257,13 @@ def on_tool_start(
239257
metadata: Optional[Dict[str, Any]] = None,
240258
**kwargs: Any,
241259
) -> Any:
242-
self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)
260+
self._log_debug_event(
261+
"on_tool_start", run_id, parent_run_id, input_str=input_str
262+
)
243263
self._set_parent_of_run(run_id, parent_run_id)
244-
self._set_trace_or_span_metadata(serialized, input_str, run_id, parent_run_id, **kwargs)
264+
self._set_trace_or_span_metadata(
265+
serialized, input_str, run_id, parent_run_id, **kwargs
266+
)
245267

246268
def on_tool_end(
247269
self,
@@ -278,7 +300,9 @@ def on_retriever_start(
278300
) -> Any:
279301
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
280302
self._set_parent_of_run(run_id, parent_run_id)
281-
self._set_trace_or_span_metadata(serialized, query, run_id, parent_run_id, **kwargs)
303+
self._set_trace_or_span_metadata(
304+
serialized, query, run_id, parent_run_id, **kwargs
305+
)
282306

283307
def on_retriever_end(
284308
self,
@@ -288,7 +312,9 @@ def on_retriever_end(
288312
parent_run_id: Optional[UUID] = None,
289313
**kwargs: Any,
290314
):
291-
self._log_debug_event("on_retriever_end", run_id, parent_run_id, documents=documents)
315+
self._log_debug_event(
316+
"on_retriever_end", run_id, parent_run_id, documents=documents
317+
)
292318
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents)
293319

294320
def on_retriever_error(
@@ -363,7 +389,9 @@ def _set_trace_or_span_metadata(
363389
):
364390
default_name = "trace" if parent_run_id is None else "span"
365391
run_name = _get_langchain_run_name(serialized, **kwargs) or default_name
366-
self._runs[run_id] = SpanMetadata(name=run_name, input=input, start_time=time.time(), end_time=None)
392+
self._runs[run_id] = SpanMetadata(
393+
name=run_name, input=input, start_time=time.time(), end_time=None
394+
)
367395

368396
def _set_llm_metadata(
369397
self,
@@ -375,7 +403,9 @@ def _set_llm_metadata(
375403
**kwargs,
376404
):
377405
run_name = _get_langchain_run_name(serialized, **kwargs) or "generation"
378-
generation = GenerationMetadata(name=run_name, input=messages, start_time=time.time(), end_time=None)
406+
generation = GenerationMetadata(
407+
name=run_name, input=messages, start_time=time.time(), end_time=None
408+
)
379409
if isinstance(invocation_params, dict):
380410
generation.model_params = get_model_params(invocation_params)
381411
if tools := invocation_params.get("tools"):
@@ -409,25 +439,35 @@ def _get_trace_id(self, run_id: UUID):
409439
return run_id
410440
return trace_id
411441

412-
def _get_parent_run_id(self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]):
442+
def _get_parent_run_id(
443+
self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]
444+
):
413445
"""
414446
Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set.
415447
"""
416448
if parent_run_id is not None and parent_run_id not in self._parent_tree:
417449
return trace_id
418450
return parent_run_id
419451

420-
def _pop_run_and_capture_trace_or_span(self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any):
452+
def _pop_run_and_capture_trace_or_span(
453+
self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any
454+
):
421455
trace_id = self._get_trace_id(run_id)
422456
self._pop_parent_of_run(run_id)
423457
run = self._pop_run_metadata(run_id)
424458
if not run:
425459
return
426460
if isinstance(run, GenerationMetadata):
427-
log.warning(f"Run {run_id} is a generation, but attempted to be captured as a trace or span.")
461+
log.warning(
462+
f"Run {run_id} is a generation, but attempted to be captured as a trace or span."
463+
)
428464
return
429465
self._capture_trace_or_span(
430-
trace_id, run_id, run, outputs, self._get_parent_run_id(trace_id, run_id, parent_run_id)
466+
trace_id,
467+
run_id,
468+
run,
469+
outputs,
470+
self._get_parent_run_id(trace_id, run_id, parent_run_id),
431471
)
432472

433473
def _capture_trace_or_span(
@@ -441,7 +481,9 @@ def _capture_trace_or_span(
441481
event_name = "$ai_trace" if parent_run_id is None else "$ai_span"
442482
event_properties = {
443483
"$ai_trace_id": trace_id,
444-
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.input),
484+
"$ai_input_state": with_privacy_mode(
485+
self._client, self._privacy_mode, run.input
486+
),
445487
"$ai_latency": run.latency,
446488
"$ai_span_name": run.name,
447489
"$ai_span_id": run_id,
@@ -455,7 +497,9 @@ def _capture_trace_or_span(
455497
event_properties["$ai_error"] = _stringify_exception(outputs)
456498
event_properties["$ai_is_error"] = True
457499
elif outputs is not None:
458-
event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)
500+
event_properties["$ai_output_state"] = with_privacy_mode(
501+
self._client, self._privacy_mode, outputs
502+
)
459503

460504
if self._distinct_id is None:
461505
event_properties["$process_person_profile"] = False
@@ -468,18 +512,27 @@ def _capture_trace_or_span(
468512
)
469513

470514
def _pop_run_and_capture_generation(
471-
self, run_id: UUID, parent_run_id: Optional[UUID], response: Union[LLMResult, BaseException]
515+
self,
516+
run_id: UUID,
517+
parent_run_id: Optional[UUID],
518+
response: Union[LLMResult, BaseException],
472519
):
473520
trace_id = self._get_trace_id(run_id)
474521
self._pop_parent_of_run(run_id)
475522
run = self._pop_run_metadata(run_id)
476523
if not run:
477524
return
478525
if not isinstance(run, GenerationMetadata):
479-
log.warning(f"Run {run_id} is not a generation, but attempted to be captured as a generation.")
526+
log.warning(
527+
f"Run {run_id} is not a generation, but attempted to be captured as a generation."
528+
)
480529
return
481530
self._capture_generation(
482-
trace_id, run_id, run, response, self._get_parent_run_id(trace_id, run_id, parent_run_id)
531+
trace_id,
532+
run_id,
533+
run,
534+
response,
535+
self._get_parent_run_id(trace_id, run_id, parent_run_id),
483536
)
484537

485538
def _capture_generation(
@@ -524,8 +577,12 @@ def _capture_generation(
524577
for generation in generation_result
525578
]
526579
else:
527-
completions = [_extract_raw_esponse(generation) for generation in generation_result]
528-
event_properties["$ai_output_choices"] = with_privacy_mode(self._client, self._privacy_mode, completions)
580+
completions = [
581+
_extract_raw_esponse(generation) for generation in generation_result
582+
]
583+
event_properties["$ai_output_choices"] = with_privacy_mode(
584+
self._client, self._privacy_mode, completions
585+
)
529586

530587
if self._properties:
531588
event_properties.update(self._properties)
@@ -615,7 +672,9 @@ def _parse_usage_model(
615672
if model_key in usage:
616673
captured_count = usage[model_key]
617674
final_count = (
618-
sum(captured_count) if isinstance(captured_count, list) else captured_count
675+
sum(captured_count)
676+
if isinstance(captured_count, list)
677+
else captured_count
619678
) # For Bedrock, the token count is a list when streamed
620679

621680
parsed_usage[type_key] = final_count
@@ -640,8 +699,12 @@ def _parse_usage(response: LLMResult):
640699
break
641700

642701
for generation_chunk in generation:
643-
if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info):
644-
llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"])
702+
if generation_chunk.generation_info and (
703+
"usage_metadata" in generation_chunk.generation_info
704+
):
705+
llm_usage = _parse_usage_model(
706+
generation_chunk.generation_info["usage_metadata"]
707+
)
645708
break
646709

647710
message_chunk = getattr(generation_chunk, "message", {})
@@ -653,13 +716,19 @@ def _parse_usage(response: LLMResult):
653716
else None
654717
)
655718
bedrock_titan_usage = (
656-
response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan
719+
response_metadata.get(
720+
"amazon-bedrock-invocationMetrics", None
721+
) # for Bedrock-Titan
657722
if isinstance(response_metadata, dict)
658723
else None
659724
)
660-
ollama_usage = getattr(message_chunk, "usage_metadata", None) # for Ollama
725+
ollama_usage = getattr(
726+
message_chunk, "usage_metadata", None
727+
) # for Ollama
661728

662-
chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
729+
chunk_usage = (
730+
bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
731+
)
663732
if chunk_usage:
664733
llm_usage = _parse_usage_model(chunk_usage)
665734
break
@@ -675,7 +744,9 @@ def _get_http_status(error: BaseException) -> int:
675744
return status_code
676745

677746

678-
def _get_langchain_run_name(serialized: Optional[Dict[str, Any]], **kwargs: Any) -> Optional[str]:
747+
def _get_langchain_run_name(
748+
serialized: Optional[Dict[str, Any]], **kwargs: Any
749+
) -> Optional[str]:
679750
"""Retrieve the name of a serialized LangChain runnable.
680751
681752
The prioritization for the determination of the run name is as follows:

0 commit comments

Comments
 (0)