Skip to content

Commit f629591

Browse files
committed
test: langchain-openai for v2 and v3
1 parent 2d6513e commit f629591

File tree

5 files changed

+387
-77
lines changed

5 files changed

+387
-77
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ posthog-analytics
1515
.idea
1616
.python-version
1717
.coverage
18-
pyrightconfig.json
18+
pyrightconfig.json
19+
.env

posthog/ai/providers/langchain.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
try:
22
import langchain
33
except ImportError:
4-
raise ModuleNotFoundError(
5-
"Please install LangChain to use this feature: 'pip install langchain'"
6-
)
4+
raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")
75

86
import logging
97
import time
@@ -31,22 +29,11 @@
3129
PosthogProperties = dict[str, Any]
3230

3331

34-
class ModelParams(TypedDict, total=False):
35-
temperature: Optional[float]
36-
max_tokens: Optional[int]
37-
top_p: Optional[float]
38-
frequency_penalty: Optional[float]
39-
presence_penalty: Optional[float]
40-
n: Optional[int]
41-
stop: Optional[list[str]]
42-
stream: Optional[bool]
43-
44-
4532
class RunMetadata(TypedDict, total=False):
4633
messages: list[dict[str, Any]] | list[str]
4734
provider: str
4835
model: str
49-
model_params: ModelParams
36+
model_params: dict[str, Any]
5037
start_time: float
5138
end_time: float
5239

@@ -114,9 +101,7 @@ def on_chat_model_start(
114101
**kwargs,
115102
):
116103
self._set_parent_of_run(run_id, parent_run_id)
117-
input = [
118-
_convert_message_to_dict(message) for row in messages for message in row
119-
]
104+
input = [_convert_message_to_dict(message) for row in messages for message in row]
120105
self._set_run_metadata(run_id, input, **kwargs)
121106

122107
def on_llm_start(
@@ -166,13 +151,10 @@ def on_llm_end(
166151
generation_result = response.generations[-1]
167152
if isinstance(generation_result[-1], ChatGeneration):
168153
output = [
169-
_convert_message_to_dict(cast(ChatGeneration, generation).message)
170-
for generation in generation_result
154+
_convert_message_to_dict(cast(ChatGeneration, generation).message) for generation in generation_result
171155
]
172156
else:
173-
output = [
174-
_extract_raw_esponse(generation) for generation in generation_result
175-
]
157+
output = [_extract_raw_esponse(generation) for generation in generation_result]
176158

177159
event_properties = {
178160
"$ai_provider": run.get("provider"),
@@ -276,7 +258,7 @@ def _set_run_metadata(
276258
"start_time": time.time(),
277259
}
278260
if isinstance(invocation_params, dict):
279-
run["model_params"] = cast(ModelParams, get_model_params(invocation_params))
261+
run["model_params"] = get_model_params(invocation_params)
280262
if isinstance(metadata, dict):
281263
if model := metadata.get("ls_model_name"):
282264
run["model"] = model
@@ -361,9 +343,7 @@ def _parse_usage_model(usage: Union[BaseModel, dict]) -> tuple[int | None, int |
361343
if model_key in usage:
362344
captured_count = usage[model_key]
363345
final_count = (
364-
sum(captured_count)
365-
if isinstance(captured_count, list)
366-
else captured_count
346+
sum(captured_count) if isinstance(captured_count, list) else captured_count
367347
) # For Bedrock, the token count is a list when streamed
368348

369349
parsed_usage[type_key] = final_count
@@ -384,12 +364,8 @@ def _parse_usage(response: LLMResult):
384364
if hasattr(response, "generations"):
385365
for generation in response.generations:
386366
for generation_chunk in generation:
387-
if generation_chunk.generation_info and (
388-
"usage_metadata" in generation_chunk.generation_info
389-
):
390-
llm_usage = _parse_usage_model(
391-
generation_chunk.generation_info["usage_metadata"]
392-
)
367+
if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info):
368+
llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"])
393369
break
394370

395371
message_chunk = getattr(generation_chunk, "message", {})
@@ -402,9 +378,7 @@ def _parse_usage(response: LLMResult):
402378
else None
403379
)
404380
or (
405-
response_metadata.get(
406-
"amazon-bedrock-invocationMetrics", None
407-
) # for Bedrock-Titan
381+
response_metadata.get("amazon-bedrock-invocationMetrics", None) # for Bedrock-Titan
408382
if isinstance(response_metadata, dict)
409383
else None
410384
)

posthog/ai/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@ def get_model_params(kwargs: Dict[str, Any]) -> Dict[str, Any]:
1414
model_params = {}
1515
for param in [
1616
"temperature",
17-
"max_tokens",
17+
"max_tokens", # Deprecated field
18+
"max_completion_tokens",
1819
"top_p",
1920
"frequency_penalty",
2021
"presence_penalty",
2122
"n",
2223
"stop",
2324
"stream",
2425
]:
25-
if param in kwargs:
26+
if param in kwargs and kwargs[param] is not None:
2627
model_params[param] = kwargs[param]
2728
return model_params
2829

0 commit comments

Comments
 (0)