Skip to content

Commit 0384b8c

Browse files
committed
Format callbacks.py
1 parent 80f0b3e commit 0384b8c

File tree

1 file changed

+49
-15
lines changed

1 file changed

+49
-15
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
try:
22
import langchain # noqa: F401
33
except ImportError:
4-
raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")
4+
raise ModuleNotFoundError(
5+
"Please install LangChain to use this feature: 'pip install langchain'"
6+
)
57

68
import logging
79
import time
@@ -19,7 +21,14 @@
1921
from uuid import UUID
2022

2123
from langchain.callbacks.base import BaseCallbackHandler
22-
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage
24+
from langchain_core.messages import (
25+
AIMessage,
26+
BaseMessage,
27+
FunctionMessage,
28+
HumanMessage,
29+
SystemMessage,
30+
ToolMessage,
31+
)
2332
from langchain_core.outputs import ChatGeneration, LLMResult
2433
from pydantic import BaseModel
2534

@@ -111,7 +120,9 @@ def on_chat_model_start(
111120
**kwargs,
112121
):
113122
self._set_parent_of_run(run_id, parent_run_id)
114-
input = [_convert_message_to_dict(message) for row in messages for message in row]
123+
input = [
124+
_convert_message_to_dict(message) for row in messages for message in row
125+
]
115126
self._set_run_metadata(serialized, run_id, input, **kwargs)
116127

117128
def on_llm_start(
@@ -161,17 +172,24 @@ def on_llm_end(
161172
generation_result = response.generations[-1]
162173
if isinstance(generation_result[-1], ChatGeneration):
163174
output = [
164-
_convert_message_to_dict(cast(ChatGeneration, generation).message) for generation in generation_result
175+
_convert_message_to_dict(cast(ChatGeneration, generation).message)
176+
for generation in generation_result
165177
]
166178
else:
167-
output = [_extract_raw_esponse(generation) for generation in generation_result]
179+
output = [
180+
_extract_raw_esponse(generation) for generation in generation_result
181+
]
168182

169183
event_properties = {
170184
"$ai_provider": run.get("provider"),
171185
"$ai_model": run.get("model"),
172186
"$ai_model_parameters": run.get("model_params"),
173-
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
174-
"$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output),
187+
"$ai_input": with_privacy_mode(
188+
self._client, self._privacy_mode, run.get("messages")
189+
),
190+
"$ai_output_choices": with_privacy_mode(
191+
self._client, self._privacy_mode, output
192+
),
175193
"$ai_http_status": 200,
176194
"$ai_input_tokens": input_tokens,
177195
"$ai_output_tokens": output_tokens,
@@ -219,7 +237,9 @@ def on_llm_error(
219237
"$ai_provider": run.get("provider"),
220238
"$ai_model": run.get("model"),
221239
"$ai_model_parameters": run.get("model_params"),
222-
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
240+
"$ai_input": with_privacy_mode(
241+
self._client, self._privacy_mode, run.get("messages")
242+
),
223243
"$ai_http_status": _get_http_status(error),
224244
"$ai_latency": latency,
225245
"$ai_trace_id": trace_id,
@@ -339,7 +359,9 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
339359
return message_dict
340360

341361

342-
def _parse_usage_model(usage: Union[BaseModel, Dict]) -> Tuple[Union[int, None], Union[int, None]]:
362+
def _parse_usage_model(
363+
usage: Union[BaseModel, Dict],
364+
) -> Tuple[Union[int, None], Union[int, None]]:
343365
if isinstance(usage, BaseModel):
344366
usage = usage.__dict__
345367

@@ -363,7 +385,9 @@ def _parse_usage_model(usage: Union[BaseModel, Dict]) -> Tuple[Union[int, None],
363385
if model_key in usage:
364386
captured_count = usage[model_key]
365387
final_count = (
366-
sum(captured_count) if isinstance(captured_count, list) else captured_count
388+
sum(captured_count)
389+
if isinstance(captured_count, list)
390+
else captured_count
367391
) # For Bedrock, the token count is a list when streamed
368392

369393
parsed_usage[type_key] = final_count
@@ -384,8 +408,12 @@ def _parse_usage(response: LLMResult):
384408
if hasattr(response, "generations"):
385409
for generation in response.generations:
386410
for generation_chunk in generation:
387-
if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info):
388-
llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"])
411+
if generation_chunk.generation_info and (
412+
"usage_metadata" in generation_chunk.generation_info
413+
):
414+
llm_usage = _parse_usage_model(
415+
generation_chunk.generation_info["usage_metadata"]
416+
)
389417
break
390418

391419
message_chunk = getattr(generation_chunk, "message", {})
@@ -397,13 +425,19 @@ def _parse_usage(response: LLMResult):
397425
else None
398426
)
399427
bedrock_titan_usage = (
400-
response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan
428+
response_metadata.get(
429+
"amazon-bedrock-invocationMetrics", None
430+
) # for Bedrock-Titan
401431
if isinstance(response_metadata, dict)
402432
else None
403433
)
404-
ollama_usage = getattr(message_chunk, "usage_metadata", None) # for Ollama
434+
ollama_usage = getattr(
435+
message_chunk, "usage_metadata", None
436+
) # for Ollama
405437

406-
chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
438+
chunk_usage = (
439+
bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
440+
)
407441
if chunk_usage:
408442
llm_usage = _parse_usage_model(chunk_usage)
409443
break

0 commit comments

Comments
 (0)