Skip to content

Commit b3a7f72

Browse files
committed
fix(llma): continuation of DRY refactoring
1 parent 43af2c3 commit b3a7f72

File tree

13 files changed

+885
-512
lines changed

13 files changed

+885
-512
lines changed

posthog/ai/anthropic/anthropic.py

Lines changed: 47 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
merge_system_prompt,
1818
with_privacy_mode,
1919
)
20+
from posthog.ai.anthropic.anthropic_converter import (
21+
format_anthropic_streaming_content,
22+
extract_anthropic_usage_from_event,
23+
handle_anthropic_content_block_start,
24+
handle_anthropic_text_delta,
25+
handle_anthropic_tool_delta,
26+
finalize_anthropic_tool_input,
27+
)
2028
from posthog.client import Client as PostHogClient
2129
from posthog import setup
2230

@@ -62,6 +70,7 @@ def create(
6270
posthog_groups: Optional group analytics properties
6371
**kwargs: Arguments passed to Anthropic's messages.create
6472
"""
73+
6574
if posthog_trace_id is None:
6675
posthog_trace_id = str(uuid.uuid4())
6776

@@ -132,79 +141,45 @@ def generator():
132141
nonlocal content_blocks
133142
nonlocal tools_in_progress
134143
nonlocal current_text_block
144+
135145
try:
136146
for event in response:
137-
# Handle usage stats from message_start event
138-
if hasattr(event, "type") and event.type == "message_start":
139-
if hasattr(event, "message") and hasattr(event.message, "usage"):
140-
usage_stats["input_tokens"] = getattr(event.message.usage, "input_tokens", 0)
141-
usage_stats["cache_creation_input_tokens"] = getattr(event.message.usage, "cache_creation_input_tokens", 0)
142-
usage_stats["cache_read_input_tokens"] = getattr(event.message.usage, "cache_read_input_tokens", 0)
143-
144-
# Handle usage stats from message_delta event
145-
if hasattr(event, "usage") and event.usage:
146-
usage_stats["output_tokens"] = getattr(event.usage, "output_tokens", 0)
147+
# Extract usage stats from event
148+
event_usage = extract_anthropic_usage_from_event(event)
149+
usage_stats.update(event_usage)
147150

148151
# Handle content block start events
149152
if hasattr(event, "type") and event.type == "content_block_start":
150-
if hasattr(event, "content_block"):
151-
block = event.content_block
152-
if hasattr(block, "type"):
153-
if block.type == "text":
154-
current_text_block = {
155-
"type": "text",
156-
"text": ""
157-
}
158-
content_blocks.append(current_text_block)
159-
elif block.type == "tool_use":
160-
tool_block = {
161-
"type": "function",
162-
"id": getattr(block, "id", ""),
163-
"function": {
164-
"name": getattr(block, "name", ""),
165-
"arguments": {}
166-
}
167-
}
168-
content_blocks.append(tool_block)
169-
tools_in_progress[block.id] = {
170-
"block": tool_block,
171-
"input_string": ""
172-
}
173-
current_text_block = None
153+
block, tool = handle_anthropic_content_block_start(event)
154+
155+
if block:
156+
content_blocks.append(block)
157+
158+
if block.get("type") == "text":
159+
current_text_block = block
160+
else:
161+
current_text_block = None
162+
163+
if tool:
164+
tools_in_progress[tool["block"]["id"]] = tool
174165

175166
# Handle text delta events
176-
if hasattr(event, "delta"):
177-
if hasattr(event.delta, "text"):
178-
delta_text = event.delta.text or ""
179-
accumulated_content += delta_text
180-
if current_text_block is not None:
181-
current_text_block["text"] += delta_text
167+
delta_text = handle_anthropic_text_delta(event, current_text_block)
168+
169+
if delta_text:
170+
accumulated_content += delta_text
182171

183172
# Handle tool input delta events
184-
if hasattr(event, "type") and event.type == "content_block_delta":
185-
if hasattr(event, "delta") and hasattr(event.delta, "type") and event.delta.type == "input_json_delta":
186-
if hasattr(event, "index") and event.index < len(content_blocks):
187-
block = content_blocks[event.index]
188-
if block.get("type") == "function" and block.get("id") in tools_in_progress:
189-
tool = tools_in_progress[block["id"]]
190-
partial_json = getattr(event.delta, "partial_json", "")
191-
tool["input_string"] += partial_json
173+
handle_anthropic_tool_delta(
174+
event, content_blocks, tools_in_progress
175+
)
192176

193177
# Handle content block stop events
194178
if hasattr(event, "type") and event.type == "content_block_stop":
195179
current_text_block = None
196-
# Parse accumulated tool input
197-
if hasattr(event, "index") and event.index < len(content_blocks):
198-
block = content_blocks[event.index]
199-
if block.get("type") == "function" and block.get("id") in tools_in_progress:
200-
tool = tools_in_progress[block["id"]]
201-
try:
202-
import json
203-
block["function"]["arguments"] = json.loads(tool["input_string"])
204-
except (json.JSONDecodeError, Exception):
205-
# Keep empty dict if parsing fails
206-
pass
207-
del tools_in_progress[block["id"]]
180+
finalize_anthropic_tool_input(
181+
event, content_blocks, tools_in_progress
182+
)
208183

209184
yield event
210185

@@ -243,19 +218,20 @@ def _capture_streaming_event(
243218
if posthog_trace_id is None:
244219
posthog_trace_id = str(uuid.uuid4())
245220

246-
# Format output to match non-streaming version
221+
# Format output using converter
222+
formatted_content = format_anthropic_streaming_content(content_blocks)
247223
formatted_output = []
248-
if content_blocks:
249-
formatted_output = [{
250-
"role": "assistant",
251-
"content": content_blocks
252-
}]
224+
225+
if formatted_content:
226+
formatted_output = [{"role": "assistant", "content": formatted_content}]
253227
else:
254228
# Fallback to accumulated content if no blocks
255-
formatted_output = [{
256-
"role": "assistant",
257-
"content": [{"type": "text", "text": accumulated_content}]
258-
}]
229+
formatted_output = [
230+
{
231+
"role": "assistant",
232+
"content": [{"type": "text", "text": accumulated_content}],
233+
}
234+
]
259235

260236
event_properties = {
261237
"$ai_provider": "anthropic",
@@ -288,6 +264,7 @@ def _capture_streaming_event(
288264

289265
# Add tools if available
290266
available_tools = extract_available_tool_calls("anthropic", kwargs)
267+
291268
if available_tools:
292269
event_properties["$ai_tools"] = available_tools
293270

posthog/ai/anthropic/anthropic_async.py

Lines changed: 48 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
merge_system_prompt,
1919
with_privacy_mode,
2020
)
21+
from posthog.ai.anthropic.anthropic_converter import (
22+
format_anthropic_streaming_content,
23+
extract_anthropic_usage_from_event,
24+
handle_anthropic_content_block_start,
25+
handle_anthropic_text_delta,
26+
handle_anthropic_tool_delta,
27+
finalize_anthropic_tool_input,
28+
)
2129
from posthog.client import Client as PostHogClient
2230

2331

@@ -62,6 +70,7 @@ async def create(
6270
posthog_groups: Optional group analytics properties
6371
**kwargs: Arguments passed to Anthropic's messages.create
6472
"""
73+
6574
if posthog_trace_id is None:
6675
posthog_trace_id = str(uuid.uuid4())
6776

@@ -124,87 +133,53 @@ async def _create_streaming(
124133
content_blocks: List[Dict[str, Any]] = []
125134
tools_in_progress: Dict[str, Dict[str, Any]] = {}
126135
current_text_block: Optional[Dict[str, Any]] = None
127-
response = await super().create(**kwargs)
136+
response = super().create(**kwargs)
128137

129138
async def generator():
130139
nonlocal usage_stats
131140
nonlocal accumulated_content
132141
nonlocal content_blocks
133142
nonlocal tools_in_progress
134143
nonlocal current_text_block
144+
135145
try:
136146
async for event in response:
137-
# Handle usage stats from message_start event
138-
if hasattr(event, "type") and event.type == "message_start":
139-
if hasattr(event, "message") and hasattr(event.message, "usage"):
140-
usage_stats["input_tokens"] = getattr(event.message.usage, "input_tokens", 0)
141-
usage_stats["cache_creation_input_tokens"] = getattr(event.message.usage, "cache_creation_input_tokens", 0)
142-
usage_stats["cache_read_input_tokens"] = getattr(event.message.usage, "cache_read_input_tokens", 0)
143-
144-
# Handle usage stats from message_delta event
145-
if hasattr(event, "usage") and event.usage:
146-
usage_stats["output_tokens"] = getattr(event.usage, "output_tokens", 0)
147+
# Extract usage stats from event
148+
event_usage = extract_anthropic_usage_from_event(event)
149+
usage_stats.update(event_usage)
147150

148151
# Handle content block start events
149152
if hasattr(event, "type") and event.type == "content_block_start":
150-
if hasattr(event, "content_block"):
151-
block = event.content_block
152-
if hasattr(block, "type"):
153-
if block.type == "text":
154-
current_text_block = {
155-
"type": "text",
156-
"text": ""
157-
}
158-
content_blocks.append(current_text_block)
159-
elif block.type == "tool_use":
160-
tool_block = {
161-
"type": "function",
162-
"id": getattr(block, "id", ""),
163-
"function": {
164-
"name": getattr(block, "name", ""),
165-
"arguments": {}
166-
}
167-
}
168-
content_blocks.append(tool_block)
169-
tools_in_progress[block.id] = {
170-
"block": tool_block,
171-
"input_string": ""
172-
}
173-
current_text_block = None
153+
block, tool = handle_anthropic_content_block_start(event)
154+
155+
if block:
156+
content_blocks.append(block)
157+
158+
if block.get("type") == "text":
159+
current_text_block = block
160+
else:
161+
current_text_block = None
162+
163+
if tool:
164+
tools_in_progress[tool["block"]["id"]] = tool
174165

175166
# Handle text delta events
176-
if hasattr(event, "delta"):
177-
if hasattr(event.delta, "text"):
178-
delta_text = event.delta.text or ""
179-
accumulated_content += delta_text
180-
if current_text_block is not None:
181-
current_text_block["text"] += delta_text
167+
delta_text = handle_anthropic_text_delta(event, current_text_block)
168+
169+
if delta_text:
170+
accumulated_content += delta_text
182171

183172
# Handle tool input delta events
184-
if hasattr(event, "type") and event.type == "content_block_delta":
185-
if hasattr(event, "delta") and hasattr(event.delta, "type") and event.delta.type == "input_json_delta":
186-
if hasattr(event, "index") and event.index < len(content_blocks):
187-
block = content_blocks[event.index]
188-
if block.get("type") == "function" and block.get("id") in tools_in_progress:
189-
tool = tools_in_progress[block["id"]]
190-
partial_json = getattr(event.delta, "partial_json", "")
191-
tool["input_string"] += partial_json
173+
handle_anthropic_tool_delta(
174+
event, content_blocks, tools_in_progress
175+
)
192176

193177
# Handle content block stop events
194178
if hasattr(event, "type") and event.type == "content_block_stop":
195179
current_text_block = None
196-
# Parse accumulated tool input
197-
if hasattr(event, "index") and event.index < len(content_blocks):
198-
block = content_blocks[event.index]
199-
if block.get("type") == "function" and block.get("id") in tools_in_progress:
200-
tool = tools_in_progress[block["id"]]
201-
try:
202-
import json
203-
block["function"]["arguments"] = json.loads(tool["input_string"])
204-
except (json.JSONDecodeError, Exception):
205-
# Keep empty dict if parsing fails
206-
pass
207-
del tools_in_progress[block["id"]]
180+
finalize_anthropic_tool_input(
181+
event, content_blocks, tools_in_progress
182+
)
208183

209184
yield event
210185

@@ -243,19 +218,20 @@ async def _capture_streaming_event(
243218
if posthog_trace_id is None:
244219
posthog_trace_id = str(uuid.uuid4())
245220

246-
# Format output to match non-streaming version
221+
# Format output using converter
222+
formatted_content = format_anthropic_streaming_content(content_blocks)
247223
formatted_output = []
248-
if content_blocks:
249-
formatted_output = [{
250-
"role": "assistant",
251-
"content": content_blocks
252-
}]
224+
225+
if formatted_content:
226+
formatted_output = [{"role": "assistant", "content": formatted_content}]
253227
else:
254228
# Fallback to accumulated content if no blocks
255-
formatted_output = [{
256-
"role": "assistant",
257-
"content": [{"type": "text", "text": accumulated_content}]
258-
}]
229+
formatted_output = [
230+
{
231+
"role": "assistant",
232+
"content": [{"type": "text", "text": accumulated_content}],
233+
}
234+
]
259235

260236
event_properties = {
261237
"$ai_provider": "anthropic",
@@ -288,6 +264,7 @@ async def _capture_streaming_event(
288264

289265
# Add tools if available
290266
available_tools = extract_available_tool_calls("anthropic", kwargs)
267+
291268
if available_tools:
292269
event_properties["$ai_tools"] = available_tools
293270

0 commit comments

Comments
 (0)