Skip to content

Commit 3f4d142

Browse files
committed
Reformat callbacks.py
1 parent 6623d18 commit 3f4d142

File tree

1 file changed

+26
-78
lines changed

1 file changed

+26
-78
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 26 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
try:
22
import langchain # noqa: F401
33
except ImportError:
4-
raise ModuleNotFoundError(
5-
"Please install LangChain to use this feature: 'pip install langchain'"
6-
)
4+
raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")
75

86
import logging
97
import time
@@ -152,9 +150,7 @@ def on_chain_start(
152150
):
153151
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
154152
self._set_parent_of_run(run_id, parent_run_id)
155-
self._set_trace_or_span_metadata(
156-
serialized, inputs, run_id, parent_run_id, **kwargs
157-
)
153+
self._set_trace_or_span_metadata(serialized, inputs, run_id, parent_run_id, **kwargs)
158154

159155
def on_chain_end(
160156
self,
@@ -187,13 +183,9 @@ def on_chat_model_start(
187183
parent_run_id: Optional[UUID] = None,
188184
**kwargs,
189185
):
190-
self._log_debug_event(
191-
"on_chat_model_start", run_id, parent_run_id, messages=messages
192-
)
186+
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
193187
self._set_parent_of_run(run_id, parent_run_id)
194-
input = [
195-
_convert_message_to_dict(message) for row in messages for message in row
196-
]
188+
input = [_convert_message_to_dict(message) for row in messages for message in row]
197189
self._set_llm_metadata(serialized, run_id, input, **kwargs)
198190

199191
def on_llm_start(
@@ -231,9 +223,7 @@ def on_llm_end(
231223
"""
232224
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
233225
"""
234-
self._log_debug_event(
235-
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
236-
)
226+
self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs)
237227
self._pop_run_and_capture_generation(run_id, parent_run_id, response)
238228

239229
def on_llm_error(
@@ -257,13 +247,9 @@ def on_tool_start(
257247
metadata: Optional[Dict[str, Any]] = None,
258248
**kwargs: Any,
259249
) -> Any:
260-
self._log_debug_event(
261-
"on_tool_start", run_id, parent_run_id, input_str=input_str
262-
)
250+
self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)
263251
self._set_parent_of_run(run_id, parent_run_id)
264-
self._set_trace_or_span_metadata(
265-
serialized, input_str, run_id, parent_run_id, **kwargs
266-
)
252+
self._set_trace_or_span_metadata(serialized, input_str, run_id, parent_run_id, **kwargs)
267253

268254
def on_tool_end(
269255
self,
@@ -300,9 +286,7 @@ def on_retriever_start(
300286
) -> Any:
301287
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
302288
self._set_parent_of_run(run_id, parent_run_id)
303-
self._set_trace_or_span_metadata(
304-
serialized, query, run_id, parent_run_id, **kwargs
305-
)
289+
self._set_trace_or_span_metadata(serialized, query, run_id, parent_run_id, **kwargs)
306290

307291
def on_retriever_end(
308292
self,
@@ -312,9 +296,7 @@ def on_retriever_end(
312296
parent_run_id: Optional[UUID] = None,
313297
**kwargs: Any,
314298
):
315-
self._log_debug_event(
316-
"on_retriever_end", run_id, parent_run_id, documents=documents
317-
)
299+
self._log_debug_event("on_retriever_end", run_id, parent_run_id, documents=documents)
318300
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents)
319301

320302
def on_retriever_error(
@@ -389,9 +371,7 @@ def _set_trace_or_span_metadata(
389371
):
390372
default_name = "trace" if parent_run_id is None else "span"
391373
run_name = _get_langchain_run_name(serialized, **kwargs) or default_name
392-
self._runs[run_id] = SpanMetadata(
393-
name=run_name, input=input, start_time=time.time(), end_time=None
394-
)
374+
self._runs[run_id] = SpanMetadata(name=run_name, input=input, start_time=time.time(), end_time=None)
395375

396376
def _set_llm_metadata(
397377
self,
@@ -403,9 +383,7 @@ def _set_llm_metadata(
403383
**kwargs,
404384
):
405385
run_name = _get_langchain_run_name(serialized, **kwargs) or "generation"
406-
generation = GenerationMetadata(
407-
name=run_name, input=messages, start_time=time.time(), end_time=None
408-
)
386+
generation = GenerationMetadata(name=run_name, input=messages, start_time=time.time(), end_time=None)
409387
if isinstance(invocation_params, dict):
410388
generation.model_params = get_model_params(invocation_params)
411389
if tools := invocation_params.get("tools"):
@@ -439,28 +417,22 @@ def _get_trace_id(self, run_id: UUID):
439417
return run_id
440418
return trace_id
441419

442-
def _get_parent_run_id(
443-
self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]
444-
):
420+
def _get_parent_run_id(self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]):
445421
"""
446422
Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set.
447423
"""
448424
if parent_run_id is not None and parent_run_id not in self._parent_tree:
449425
return trace_id
450426
return parent_run_id
451427

452-
def _pop_run_and_capture_trace_or_span(
453-
self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any
454-
):
428+
def _pop_run_and_capture_trace_or_span(self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any):
455429
trace_id = self._get_trace_id(run_id)
456430
self._pop_parent_of_run(run_id)
457431
run = self._pop_run_metadata(run_id)
458432
if not run:
459433
return
460434
if isinstance(run, GenerationMetadata):
461-
log.warning(
462-
f"Run {run_id} is a generation, but attempted to be captured as a trace or span."
463-
)
435+
log.warning(f"Run {run_id} is a generation, but attempted to be captured as a trace or span.")
464436
return
465437
self._capture_trace_or_span(
466438
trace_id,
@@ -481,9 +453,7 @@ def _capture_trace_or_span(
481453
event_name = "$ai_trace" if parent_run_id is None else "$ai_span"
482454
event_properties = {
483455
"$ai_trace_id": trace_id,
484-
"$ai_input_state": with_privacy_mode(
485-
self._client, self._privacy_mode, run.input
486-
),
456+
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.input),
487457
"$ai_latency": run.latency,
488458
"$ai_span_name": run.name,
489459
"$ai_span_id": run_id,
@@ -497,9 +467,7 @@ def _capture_trace_or_span(
497467
event_properties["$ai_error"] = _stringify_exception(outputs)
498468
event_properties["$ai_is_error"] = True
499469
elif outputs is not None:
500-
event_properties["$ai_output_state"] = with_privacy_mode(
501-
self._client, self._privacy_mode, outputs
502-
)
470+
event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)
503471

504472
if self._distinct_id is None:
505473
event_properties["$process_person_profile"] = False
@@ -523,9 +491,7 @@ def _pop_run_and_capture_generation(
523491
if not run:
524492
return
525493
if not isinstance(run, GenerationMetadata):
526-
log.warning(
527-
f"Run {run_id} is not a generation, but attempted to be captured as a generation."
528-
)
494+
log.warning(f"Run {run_id} is not a generation, but attempted to be captured as a generation.")
529495
return
530496
self._capture_generation(
531497
trace_id,
@@ -577,12 +543,8 @@ def _capture_generation(
577543
for generation in generation_result
578544
]
579545
else:
580-
completions = [
581-
_extract_raw_esponse(generation) for generation in generation_result
582-
]
583-
event_properties["$ai_output_choices"] = with_privacy_mode(
584-
self._client, self._privacy_mode, completions
585-
)
546+
completions = [_extract_raw_esponse(generation) for generation in generation_result]
547+
event_properties["$ai_output_choices"] = with_privacy_mode(self._client, self._privacy_mode, completions)
586548

587549
if self._properties:
588550
event_properties.update(self._properties)
@@ -672,9 +634,7 @@ def _parse_usage_model(
672634
if model_key in usage:
673635
captured_count = usage[model_key]
674636
final_count = (
675-
sum(captured_count)
676-
if isinstance(captured_count, list)
677-
else captured_count
637+
sum(captured_count) if isinstance(captured_count, list) else captured_count
678638
) # For Bedrock, the token count is a list when streamed
679639

680640
parsed_usage[type_key] = final_count
@@ -699,12 +659,8 @@ def _parse_usage(response: LLMResult):
699659
break
700660

701661
for generation_chunk in generation:
702-
if generation_chunk.generation_info and (
703-
"usage_metadata" in generation_chunk.generation_info
704-
):
705-
llm_usage = _parse_usage_model(
706-
generation_chunk.generation_info["usage_metadata"]
707-
)
662+
if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info):
663+
llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"])
708664
break
709665

710666
message_chunk = getattr(generation_chunk, "message", {})
@@ -716,19 +672,13 @@ def _parse_usage(response: LLMResult):
716672
else None
717673
)
718674
bedrock_titan_usage = (
719-
response_metadata.get(
720-
"amazon-bedrock-invocationMetrics", None
721-
) # for Bedrock-Titan
675+
response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan
722676
if isinstance(response_metadata, dict)
723677
else None
724678
)
725-
ollama_usage = getattr(
726-
message_chunk, "usage_metadata", None
727-
) # for Ollama
679+
ollama_usage = getattr(message_chunk, "usage_metadata", None) # for Ollama
728680

729-
chunk_usage = (
730-
bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
731-
)
681+
chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
732682
if chunk_usage:
733683
llm_usage = _parse_usage_model(chunk_usage)
734684
break
@@ -744,9 +694,7 @@ def _get_http_status(error: BaseException) -> int:
744694
return status_code
745695

746696

747-
def _get_langchain_run_name(
748-
serialized: Optional[Dict[str, Any]], **kwargs: Any
749-
) -> Optional[str]:
697+
def _get_langchain_run_name(serialized: Optional[Dict[str, Any]], **kwargs: Any) -> Optional[str]:
750698
"""Retrieve the name of a serialized LangChain runnable.
751699
752700
The prioritization for the determination of the run name is as follows:

0 commit comments

Comments
 (0)