Skip to content

Commit e780da0

Browse files
committed
Fix formatting
1 parent 199dfd2 commit e780da0

File tree

2 files changed

+47
-142
lines changed

2 files changed

+47
-142
lines changed

posthog/ai/langchain/callbacks.py

Lines changed: 20 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
try:
22
import langchain # noqa: F401
33
except ImportError:
4-
raise ModuleNotFoundError(
5-
"Please install LangChain to use this feature: 'pip install langchain'"
6-
)
4+
raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")
75

86
import json
97
import logging
@@ -126,13 +124,9 @@ def on_chat_model_start(
126124
parent_run_id: Optional[UUID] = None,
127125
**kwargs,
128126
):
129-
self._log_debug_event(
130-
"on_chat_model_start", run_id, parent_run_id, messages=messages
131-
)
127+
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
132128
self._set_parent_of_run(run_id, parent_run_id)
133-
input = [
134-
_convert_message_to_dict(message) for row in messages for message in row
135-
]
129+
input = [_convert_message_to_dict(message) for row in messages for message in row]
136130
self._set_run_metadata(serialized, run_id, input, **kwargs)
137131

138132
def on_llm_start(
@@ -157,9 +151,7 @@ def on_llm_new_token(
157151
**kwargs: Any,
158152
) -> Any:
159153
"""Run on new LLM token. Only available when streaming is enabled."""
160-
self.log.debug(
161-
f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}"
162-
)
154+
self.log.debug(f"on llm new token: run_id: {run_id} parent_run_id: {parent_run_id}")
163155

164156
def on_tool_start(
165157
self,
@@ -172,9 +164,7 @@ def on_tool_start(
172164
metadata: Optional[Dict[str, Any]] = None,
173165
**kwargs: Any,
174166
) -> Any:
175-
self._log_debug_event(
176-
"on_tool_start", run_id, parent_run_id, input_str=input_str
177-
)
167+
self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)
178168

179169
def on_tool_end(
180170
self,
@@ -247,9 +237,7 @@ def on_llm_end(
247237
"""
248238
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
249239
"""
250-
self._log_debug_event(
251-
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
252-
)
240+
self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs)
253241
trace_id = self._get_trace_id(run_id)
254242
self._pop_parent_of_run(run_id)
255243
run = self._pop_run_metadata(run_id)
@@ -262,24 +250,17 @@ def on_llm_end(
262250
generation_result = response.generations[-1]
263251
if isinstance(generation_result[-1], ChatGeneration):
264252
output = [
265-
_convert_message_to_dict(cast(ChatGeneration, generation).message)
266-
for generation in generation_result
253+
_convert_message_to_dict(cast(ChatGeneration, generation).message) for generation in generation_result
267254
]
268255
else:
269-
output = [
270-
_extract_raw_esponse(generation) for generation in generation_result
271-
]
256+
output = [_extract_raw_esponse(generation) for generation in generation_result]
272257

273258
event_properties = {
274259
"$ai_provider": run.get("provider"),
275260
"$ai_model": run.get("model"),
276261
"$ai_model_parameters": run.get("model_params"),
277-
"$ai_input": with_privacy_mode(
278-
self._client, self._privacy_mode, run.get("messages")
279-
),
280-
"$ai_output_choices": with_privacy_mode(
281-
self._client, self._privacy_mode, output
282-
),
262+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
263+
"$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output),
283264
"$ai_http_status": 200,
284265
"$ai_input_tokens": input_tokens,
285266
"$ai_output_tokens": output_tokens,
@@ -318,9 +299,7 @@ def on_llm_error(
318299
"$ai_provider": run.get("provider"),
319300
"$ai_model": run.get("model"),
320301
"$ai_model_parameters": run.get("model_params"),
321-
"$ai_input": with_privacy_mode(
322-
self._client, self._privacy_mode, run.get("messages")
323-
),
302+
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
324303
"$ai_http_status": _get_http_status(error),
325304
"$ai_latency": latency,
326305
"$ai_trace_id": trace_id,
@@ -450,20 +429,14 @@ def _get_trace_id(self, run_id: UUID):
450429
trace_id = uuid.uuid4()
451430
return trace_id
452431

453-
def _end_trace(
454-
self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]]
455-
):
432+
def _end_trace(self, trace_id: UUID, inputs: Dict[str, Any], outputs: Optional[Dict[str, Any]]):
456433
event_properties = {
457434
"$ai_trace_id": trace_id,
458-
"$ai_input_state": with_privacy_mode(
459-
self._client, self._privacy_mode, inputs
460-
),
435+
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, inputs),
461436
**self._properties,
462437
}
463438
if outputs is not None:
464-
event_properties["$ai_output_state"] = with_privacy_mode(
465-
self._client, self._privacy_mode, outputs
466-
)
439+
event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)
467440
if self._distinct_id is None:
468441
event_properties["$process_person_profile"] = False
469442
self._client.capture(
@@ -545,9 +518,7 @@ def _parse_usage_model(
545518
if model_key in usage:
546519
captured_count = usage[model_key]
547520
final_count = (
548-
sum(captured_count)
549-
if isinstance(captured_count, list)
550-
else captured_count
521+
sum(captured_count) if isinstance(captured_count, list) else captured_count
551522
) # For Bedrock, the token count is a list when streamed
552523

553524
parsed_usage[type_key] = final_count
@@ -568,12 +539,8 @@ def _parse_usage(response: LLMResult):
568539
if hasattr(response, "generations"):
569540
for generation in response.generations:
570541
for generation_chunk in generation:
571-
if generation_chunk.generation_info and (
572-
"usage_metadata" in generation_chunk.generation_info
573-
):
574-
llm_usage = _parse_usage_model(
575-
generation_chunk.generation_info["usage_metadata"]
576-
)
542+
if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info):
543+
llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"])
577544
break
578545

579546
message_chunk = getattr(generation_chunk, "message", {})
@@ -585,19 +552,13 @@ def _parse_usage(response: LLMResult):
585552
else None
586553
)
587554
bedrock_titan_usage = (
588-
response_metadata.get(
589-
"amazon-bedrock-invocationMetrics", None
590-
) # for Bedrock-Titan
555+
response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan
591556
if isinstance(response_metadata, dict)
592557
else None
593558
)
594-
ollama_usage = getattr(
595-
message_chunk, "usage_metadata", None
596-
) # for Ollama
559+
ollama_usage = getattr(message_chunk, "usage_metadata", None) # for Ollama
597560

598-
chunk_usage = (
599-
bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
600-
)
561+
chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
601562
if chunk_usage:
602563
llm_usage = _parse_usage_model(chunk_usage)
603564
break

0 commit comments

Comments
 (0)