Skip to content

Commit e4c4884

Browse files
committed
fix(llma): fix types
1 parent 686da52 commit e4c4884

File tree

10 files changed

+128
-104
lines changed

10 files changed

+128
-104
lines changed

posthog/ai/anthropic/anthropic.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import uuid
1111
from typing import Any, Dict, List, Optional
1212

13+
from posthog.ai.types import StreamingContentBlock, ToolInProgress
1314
from posthog.ai.utils import (
1415
call_llm_and_track_usage,
16+
merge_usage_stats,
1517
)
1618
from posthog.ai.anthropic.anthropic_converter import (
1719
extract_anthropic_usage_from_event,
@@ -126,9 +128,9 @@ def _create_streaming(
126128
start_time = time.time()
127129
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
128130
accumulated_content = ""
129-
content_blocks: List[Dict[str, Any]] = []
130-
tools_in_progress: Dict[str, Dict[str, Any]] = {}
131-
current_text_block: Optional[Dict[str, Any]] = None
131+
content_blocks: List[StreamingContentBlock] = []
132+
tools_in_progress: Dict[str, ToolInProgress] = {}
133+
current_text_block: Optional[StreamingContentBlock] = None
132134
response = super().create(**kwargs)
133135

134136
def generator():
@@ -142,7 +144,7 @@ def generator():
142144
for event in response:
143145
# Extract usage stats from event
144146
event_usage = extract_anthropic_usage_from_event(event)
145-
usage_stats.update(event_usage)
147+
merge_usage_stats(usage_stats, event_usage)
146148

147149
# Handle content block start events
148150
if hasattr(event, "type") and event.type == "content_block_start":
@@ -157,7 +159,9 @@ def generator():
157159
current_text_block = None
158160

159161
if tool:
160-
tools_in_progress[tool["block"]["id"]] = tool
162+
tool_id = tool["block"].get("id")
163+
if tool_id:
164+
tools_in_progress[tool_id] = tool
161165

162166
# Handle text delta events
163167
delta_text = handle_anthropic_text_delta(event, current_text_block)
@@ -208,7 +212,7 @@ def _capture_streaming_event(
208212
kwargs: Dict[str, Any],
209213
usage_stats: Dict[str, int],
210214
latency: float,
211-
content_blocks: List[Dict[str, Any]],
215+
content_blocks: List[StreamingContentBlock],
212216
accumulated_content: str,
213217
):
214218
from posthog.ai.types import StreamingEventData
@@ -225,7 +229,7 @@ def _capture_streaming_event(
225229

226230
event_data = StreamingEventData(
227231
provider="anthropic",
228-
model=kwargs.get("model"),
232+
model=kwargs.get("model", "unknown"),
229233
base_url=str(self._client.base_url),
230234
kwargs=kwargs,
231235
formatted_input=sanitized_input,

posthog/ai/anthropic/anthropic_async.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from typing import Any, Dict, List, Optional
1212

1313
from posthog import setup
14+
from posthog.ai.types import StreamingContentBlock, ToolInProgress
1415
from posthog.ai.utils import (
1516
call_llm_and_track_usage_async,
1617
extract_available_tool_calls,
1718
get_model_params,
1819
merge_system_prompt,
20+
merge_usage_stats,
1921
with_privacy_mode,
2022
)
2123
from posthog.ai.anthropic.anthropic_converter import (
@@ -131,9 +133,9 @@ async def _create_streaming(
131133
start_time = time.time()
132134
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
133135
accumulated_content = ""
134-
content_blocks: List[Dict[str, Any]] = []
135-
tools_in_progress: Dict[str, Dict[str, Any]] = {}
136-
current_text_block: Optional[Dict[str, Any]] = None
136+
content_blocks: List[StreamingContentBlock] = []
137+
tools_in_progress: Dict[str, ToolInProgress] = {}
138+
current_text_block: Optional[StreamingContentBlock] = None
137139
response = super().create(**kwargs)
138140

139141
async def generator():
@@ -147,7 +149,7 @@ async def generator():
147149
async for event in response:
148150
# Extract usage stats from event
149151
event_usage = extract_anthropic_usage_from_event(event)
150-
usage_stats.update(event_usage)
152+
merge_usage_stats(usage_stats, event_usage)
151153

152154
# Handle content block start events
153155
if hasattr(event, "type") and event.type == "content_block_start":
@@ -162,7 +164,9 @@ async def generator():
162164
current_text_block = None
163165

164166
if tool:
165-
tools_in_progress[tool["block"]["id"]] = tool
167+
tool_id = tool["block"].get("id")
168+
if tool_id:
169+
tools_in_progress[tool_id] = tool
166170

167171
# Handle text delta events
168172
delta_text = handle_anthropic_text_delta(event, current_text_block)
@@ -213,7 +217,7 @@ async def _capture_streaming_event(
213217
kwargs: Dict[str, Any],
214218
usage_stats: Dict[str, int],
215219
latency: float,
216-
content_blocks: List[Dict[str, Any]],
220+
content_blocks: List[StreamingContentBlock],
217221
accumulated_content: str,
218222
):
219223
if posthog_trace_id is None:

posthog/ai/anthropic/anthropic_converter.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def format_anthropic_response(response: Any) -> List[FormattedMessage]:
3131
List of formatted messages with role and content
3232
"""
3333

34-
output = []
34+
output: List[FormattedMessage] = []
3535

3636
if response is None:
3737
return output
@@ -127,7 +127,7 @@ def extract_anthropic_tools(kwargs: Dict[str, Any]) -> Optional[Any]:
127127

128128

129129
def format_anthropic_streaming_content(
130-
content_blocks: List[Dict[str, Any]],
130+
content_blocks: List[StreamingContentBlock],
131131
) -> List[FormattedContentItem]:
132132
"""
133133
Format content blocks from Anthropic streaming response.
@@ -145,19 +145,17 @@ def format_anthropic_streaming_content(
145145

146146
for block in content_blocks:
147147
if block.get("type") == "text":
148-
text_content: FormattedTextContent = {
148+
formatted.append({
149149
"type": "text",
150-
"text": block.get("text", ""),
151-
}
152-
formatted.append(text_content)
150+
"text": block.get("text") or "",
151+
})
153152

154153
elif block.get("type") == "function":
155-
function_call: FormattedFunctionCall = {
154+
formatted.append({
156155
"type": "function",
157156
"id": block.get("id"),
158-
"function": block.get("function", {}),
159-
}
160-
formatted.append(function_call)
157+
"function": block.get("function") or {},
158+
})
161159

162160
return formatted
163161

@@ -222,13 +220,13 @@ def handle_anthropic_content_block_start(
222220
return content_block, None
223221

224222
elif block.type == "tool_use":
225-
content_block: StreamingContentBlock = {
223+
tool_block: StreamingContentBlock = {
226224
"type": "function",
227225
"id": getattr(block, "id", ""),
228226
"function": {"name": getattr(block, "name", ""), "arguments": {}},
229227
}
230-
tool_in_progress: ToolInProgress = {"block": content_block, "input_string": ""}
231-
return content_block, tool_in_progress
228+
tool_in_progress: ToolInProgress = {"block": tool_block, "input_string": ""}
229+
return tool_block, tool_in_progress
232230

233231
return None, None
234232

@@ -251,7 +249,11 @@ def handle_anthropic_text_delta(
251249
delta_text = event.delta.text or ""
252250

253251
if current_block is not None and current_block.get("type") == "text":
254-
current_block["text"] = current_block.get("text", "") + delta_text
252+
text_val = current_block.get("text")
253+
if text_val is not None:
254+
current_block["text"] = text_val + delta_text
255+
else:
256+
current_block["text"] = delta_text
255257

256258
return delta_text
257259

posthog/ai/gemini/gemini.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from posthog.ai.utils import (
1515
call_llm_and_track_usage,
1616
capture_streaming_event,
17+
merge_usage_stats,
1718
)
1819
from posthog.ai.gemini.gemini_converter import (
1920
format_gemini_input,
@@ -308,7 +309,8 @@ def generator():
308309
chunk_usage = extract_gemini_usage_from_chunk(chunk)
309310

310311
if chunk_usage:
311-
usage_stats.update(chunk_usage)
312+
# Gemini reports cumulative totals, not incremental values
313+
merge_usage_stats(usage_stats, chunk_usage, mode="cumulative")
312314

313315
# Extract content from chunk (now returns content blocks)
314316
content_block = extract_gemini_content_from_chunk(chunk)

posthog/ai/gemini/gemini_converter.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
from posthog.ai.types import (
1111
FormattedContentItem,
12-
FormattedFunctionCall,
1312
FormattedMessage,
14-
FormattedTextContent,
1513
StreamingUsageStats,
1614
TokenUsage,
1715
)
@@ -164,7 +162,7 @@ def format_gemini_response(response: Any) -> List[FormattedMessage]:
164162
List of formatted messages with role and content
165163
"""
166164

167-
output = []
165+
output: List[FormattedMessage] = []
168166

169167
if response is None:
170168
return output
@@ -177,43 +175,38 @@ def format_gemini_response(response: Any) -> List[FormattedMessage]:
177175
if hasattr(candidate.content, "parts") and candidate.content.parts:
178176
for part in candidate.content.parts:
179177
if hasattr(part, "text") and part.text:
180-
text_content: FormattedTextContent = {
178+
content.append({
181179
"type": "text",
182180
"text": part.text,
183-
}
184-
content.append(text_content)
181+
})
185182

186183
elif hasattr(part, "function_call") and part.function_call:
187184
function_call = part.function_call
188-
func_content: FormattedFunctionCall = {
185+
content.append({
189186
"type": "function",
190187
"function": {
191188
"name": function_call.name,
192189
"arguments": function_call.args,
193190
},
194-
}
195-
content.append(func_content)
191+
})
196192

197193
if content:
198-
message: FormattedMessage = {
194+
output.append({
199195
"role": "assistant",
200196
"content": content,
201-
}
202-
output.append(message)
197+
})
203198

204199
elif hasattr(candidate, "text") and candidate.text:
205-
message: FormattedMessage = {
200+
output.append({
206201
"role": "assistant",
207202
"content": [{"type": "text", "text": candidate.text}],
208-
}
209-
output.append(message)
203+
})
210204

211205
elif hasattr(response, "text") and response.text:
212-
message: FormattedMessage = {
206+
output.append({
213207
"role": "assistant",
214208
"content": [{"type": "text", "text": response.text}],
215-
}
216-
output.append(message)
209+
})
217210

218211
return output
219212

@@ -258,7 +251,7 @@ def format_gemini_input(contents: Any) -> List[FormattedMessage]:
258251

259252
# Handle list input
260253
if isinstance(contents, list):
261-
formatted = []
254+
formatted: List[FormattedMessage] = []
262255

263256
for item in contents:
264257
if isinstance(item, str):
@@ -383,27 +376,24 @@ def format_gemini_streaming_output(
383376
elif item.get("type") == "function":
384377
# If we have accumulated text, add it first
385378
if text_parts:
386-
text_content: FormattedTextContent = {
379+
content.append({
387380
"type": "text",
388381
"text": "".join(text_parts),
389-
}
390-
content.append(text_content)
382+
})
391383
text_parts = []
392384

393385
# Add the function call
394-
func_content: FormattedFunctionCall = {
386+
content.append({
395387
"type": "function",
396388
"function": item.get("function", {}),
397-
}
398-
content.append(func_content)
389+
})
399390

400391
# Add any remaining text
401392
if text_parts:
402-
text_content: FormattedTextContent = {
393+
content.append({
403394
"type": "text",
404395
"text": "".join(text_parts),
405-
}
406-
content.append(text_content)
396+
})
407397

408398
# If we have content, return it
409399
if content:

posthog/ai/openai/openai.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from posthog.ai.utils import (
1313
call_llm_and_track_usage,
1414
extract_available_tool_calls,
15+
merge_usage_stats,
1516
with_privacy_mode,
1617
)
1718
from posthog.ai.openai.openai_converter import (
@@ -133,7 +134,7 @@ def generator():
133134
chunk_usage = extract_openai_usage_from_chunk(chunk, "responses")
134135

135136
if chunk_usage:
136-
usage_stats.update(chunk_usage)
137+
merge_usage_stats(usage_stats, chunk_usage)
137138

138139
# Extract content from chunk
139140
content = extract_openai_content_from_chunk(chunk, "responses")
@@ -189,7 +190,7 @@ def _capture_streaming_event(
189190

190191
event_data = StreamingEventData(
191192
provider="openai",
192-
model=kwargs.get("model"),
193+
model=kwargs.get("model", "unknown"),
193194
base_url=str(self._client.base_url),
194195
kwargs=kwargs,
195196
formatted_input=sanitized_input,
@@ -334,7 +335,7 @@ def generator():
334335
chunk_usage = extract_openai_usage_from_chunk(chunk, "chat")
335336

336337
if chunk_usage:
337-
usage_stats.update(chunk_usage)
338+
merge_usage_stats(usage_stats, chunk_usage)
338339

339340
# Extract content from chunk
340341
content = extract_openai_content_from_chunk(chunk, "chat")
@@ -406,7 +407,7 @@ def _capture_streaming_event(
406407

407408
event_data = StreamingEventData(
408409
provider="openai",
409-
model=kwargs.get("model"),
410+
model=kwargs.get("model", "unknown"),
410411
base_url=str(self._client.base_url),
411412
kwargs=kwargs,
412413
formatted_input=sanitized_input,

posthog/ai/openai/openai_async.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
call_llm_and_track_usage_async,
1515
extract_available_tool_calls,
1616
get_model_params,
17+
merge_usage_stats,
1718
with_privacy_mode,
1819
)
1920
from posthog.ai.openai.openai_converter import (
@@ -137,7 +138,7 @@ async def async_generator():
137138
chunk_usage = extract_openai_usage_from_chunk(chunk, "responses")
138139

139140
if chunk_usage:
140-
usage_stats.update(chunk_usage)
141+
merge_usage_stats(usage_stats, chunk_usage)
141142

142143
# Extract content from chunk
143144
content = extract_openai_content_from_chunk(chunk, "responses")
@@ -354,7 +355,7 @@ async def async_generator():
354355
# Extract usage stats from chunk
355356
chunk_usage = extract_openai_usage_from_chunk(chunk, "chat")
356357
if chunk_usage:
357-
usage_stats.update(chunk_usage)
358+
merge_usage_stats(usage_stats, chunk_usage)
358359

359360
# Extract content from chunk
360361
content = extract_openai_content_from_chunk(chunk, "chat")

0 commit comments

Comments
 (0)