Skip to content

Commit 00b4e5a

Browse files
committed
fix(llma): tool calls in streaming Anthropic
1 parent 6e00d57 commit 00b4e5a

File tree

4 files changed

+516
-160
lines changed

4 files changed

+516
-160
lines changed

posthog/ai/anthropic/anthropic.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import time
1010
import uuid
11-
from typing import Any, Dict, Optional
11+
from typing import Any, Dict, List, Optional
1212

1313
from posthog.ai.utils import (
1414
call_llm_and_track_usage,
15+
extract_available_tool_calls,
1516
get_model_params,
1617
merge_system_prompt,
1718
with_privacy_mode,
@@ -119,34 +120,97 @@ def _create_streaming(
119120
):
120121
start_time = time.time()
121122
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
122-
accumulated_content = []
123+
accumulated_content = ""
124+
content_blocks: List[Dict[str, Any]] = []
125+
tools_in_progress: Dict[str, Dict[str, Any]] = {}
126+
current_text_block: Optional[Dict[str, Any]] = None
123127
response = super().create(**kwargs)
124128

125129
def generator():
126130
nonlocal usage_stats
127-
nonlocal accumulated_content # noqa: F824
131+
nonlocal accumulated_content
132+
nonlocal content_blocks
133+
nonlocal tools_in_progress
134+
nonlocal current_text_block
128135
try:
129136
for event in response:
137+
# Handle usage stats from message_start event
138+
if hasattr(event, "type") and event.type == "message_start":
139+
if hasattr(event, "message") and hasattr(event.message, "usage"):
140+
usage_stats["input_tokens"] = getattr(event.message.usage, "input_tokens", 0)
141+
usage_stats["cache_creation_input_tokens"] = getattr(event.message.usage, "cache_creation_input_tokens", 0)
142+
usage_stats["cache_read_input_tokens"] = getattr(event.message.usage, "cache_read_input_tokens", 0)
143+
144+
# Handle usage stats from message_delta event
130145
if hasattr(event, "usage") and event.usage:
131-
usage_stats = {
132-
k: getattr(event.usage, k, 0)
133-
for k in [
134-
"input_tokens",
135-
"output_tokens",
136-
"cache_read_input_tokens",
137-
"cache_creation_input_tokens",
138-
]
139-
}
140-
141-
if hasattr(event, "content") and event.content:
142-
accumulated_content.append(event.content)
146+
usage_stats["output_tokens"] = getattr(event.usage, "output_tokens", 0)
147+
148+
# Handle content block start events
149+
if hasattr(event, "type") and event.type == "content_block_start":
150+
if hasattr(event, "content_block"):
151+
block = event.content_block
152+
if hasattr(block, "type"):
153+
if block.type == "text":
154+
current_text_block = {
155+
"type": "text",
156+
"text": ""
157+
}
158+
content_blocks.append(current_text_block)
159+
elif block.type == "tool_use":
160+
tool_block = {
161+
"type": "function",
162+
"id": getattr(block, "id", ""),
163+
"function": {
164+
"name": getattr(block, "name", ""),
165+
"arguments": {}
166+
}
167+
}
168+
content_blocks.append(tool_block)
169+
tools_in_progress[block.id] = {
170+
"block": tool_block,
171+
"input_string": ""
172+
}
173+
current_text_block = None
174+
175+
# Handle text delta events
176+
if hasattr(event, "delta"):
177+
if hasattr(event.delta, "text"):
178+
delta_text = event.delta.text or ""
179+
accumulated_content += delta_text
180+
if current_text_block is not None:
181+
current_text_block["text"] += delta_text
182+
183+
# Handle tool input delta events
184+
if hasattr(event, "type") and event.type == "content_block_delta":
185+
if hasattr(event, "delta") and hasattr(event.delta, "type") and event.delta.type == "input_json_delta":
186+
if hasattr(event, "index") and event.index < len(content_blocks):
187+
block = content_blocks[event.index]
188+
if block.get("type") == "function" and block.get("id") in tools_in_progress:
189+
tool = tools_in_progress[block["id"]]
190+
partial_json = getattr(event.delta, "partial_json", "")
191+
tool["input_string"] += partial_json
192+
193+
# Handle content block stop events
194+
if hasattr(event, "type") and event.type == "content_block_stop":
195+
current_text_block = None
196+
# Parse accumulated tool input
197+
if hasattr(event, "index") and event.index < len(content_blocks):
198+
block = content_blocks[event.index]
199+
if block.get("type") == "function" and block.get("id") in tools_in_progress:
200+
tool = tools_in_progress[block["id"]]
201+
try:
202+
import json
203+
block["function"]["arguments"] = json.loads(tool["input_string"])
204+
except (json.JSONDecodeError, Exception):
205+
# Keep empty dict if parsing fails
206+
pass
207+
del tools_in_progress[block["id"]]
143208

144209
yield event
145210

146211
finally:
147212
end_time = time.time()
148213
latency = end_time - start_time
149-
output = "".join(accumulated_content)
150214

151215
self._capture_streaming_event(
152216
posthog_distinct_id,
@@ -157,7 +221,8 @@ def generator():
157221
kwargs,
158222
usage_stats,
159223
latency,
160-
output,
224+
content_blocks,
225+
accumulated_content,
161226
)
162227

163228
return generator()
@@ -172,11 +237,26 @@ def _capture_streaming_event(
172237
kwargs: Dict[str, Any],
173238
usage_stats: Dict[str, int],
174239
latency: float,
175-
output: str,
240+
content_blocks: List[Dict[str, Any]],
241+
accumulated_content: str,
176242
):
177243
if posthog_trace_id is None:
178244
posthog_trace_id = str(uuid.uuid4())
179245

246+
# Format output to match non-streaming version
247+
formatted_output = []
248+
if content_blocks:
249+
formatted_output = [{
250+
"role": "assistant",
251+
"content": content_blocks
252+
}]
253+
else:
254+
# Fallback to accumulated content if no blocks
255+
formatted_output = [{
256+
"role": "assistant",
257+
"content": [{"type": "text", "text": accumulated_content}]
258+
}]
259+
180260
event_properties = {
181261
"$ai_provider": "anthropic",
182262
"$ai_model": kwargs.get("model"),
@@ -189,7 +269,7 @@ def _capture_streaming_event(
189269
"$ai_output_choices": with_privacy_mode(
190270
self._client._ph_client,
191271
posthog_privacy_mode,
192-
[{"content": output, "role": "assistant"}],
272+
formatted_output,
193273
),
194274
"$ai_http_status": 200,
195275
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
@@ -206,6 +286,11 @@ def _capture_streaming_event(
206286
**(posthog_properties or {}),
207287
}
208288

289+
# Add tools if available
290+
available_tools = extract_available_tool_calls("anthropic", kwargs)
291+
if available_tools:
292+
event_properties["$ai_tools"] = available_tools
293+
209294
if posthog_distinct_id is None:
210295
event_properties["$process_person_profile"] = False
211296

posthog/ai/anthropic/anthropic_async.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88

99
import time
1010
import uuid
11-
from typing import Any, Dict, Optional
11+
from typing import Any, Dict, List, Optional
1212

1313
from posthog import setup
1414
from posthog.ai.utils import (
1515
call_llm_and_track_usage_async,
16+
extract_available_tool_calls,
1617
get_model_params,
1718
merge_system_prompt,
1819
with_privacy_mode,
@@ -119,34 +120,97 @@ async def _create_streaming(
119120
):
120121
start_time = time.time()
121122
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
122-
accumulated_content = []
123+
accumulated_content = ""
124+
content_blocks: List[Dict[str, Any]] = []
125+
tools_in_progress: Dict[str, Dict[str, Any]] = {}
126+
current_text_block: Optional[Dict[str, Any]] = None
123127
response = await super().create(**kwargs)
124128

125129
async def generator():
126130
nonlocal usage_stats
127-
nonlocal accumulated_content # noqa: F824
131+
nonlocal accumulated_content
132+
nonlocal content_blocks
133+
nonlocal tools_in_progress
134+
nonlocal current_text_block
128135
try:
129136
async for event in response:
137+
# Handle usage stats from message_start event
138+
if hasattr(event, "type") and event.type == "message_start":
139+
if hasattr(event, "message") and hasattr(event.message, "usage"):
140+
usage_stats["input_tokens"] = getattr(event.message.usage, "input_tokens", 0)
141+
usage_stats["cache_creation_input_tokens"] = getattr(event.message.usage, "cache_creation_input_tokens", 0)
142+
usage_stats["cache_read_input_tokens"] = getattr(event.message.usage, "cache_read_input_tokens", 0)
143+
144+
# Handle usage stats from message_delta event
130145
if hasattr(event, "usage") and event.usage:
131-
usage_stats = {
132-
k: getattr(event.usage, k, 0)
133-
for k in [
134-
"input_tokens",
135-
"output_tokens",
136-
"cache_read_input_tokens",
137-
"cache_creation_input_tokens",
138-
]
139-
}
140-
141-
if hasattr(event, "content") and event.content:
142-
accumulated_content.append(event.content)
146+
usage_stats["output_tokens"] = getattr(event.usage, "output_tokens", 0)
147+
148+
# Handle content block start events
149+
if hasattr(event, "type") and event.type == "content_block_start":
150+
if hasattr(event, "content_block"):
151+
block = event.content_block
152+
if hasattr(block, "type"):
153+
if block.type == "text":
154+
current_text_block = {
155+
"type": "text",
156+
"text": ""
157+
}
158+
content_blocks.append(current_text_block)
159+
elif block.type == "tool_use":
160+
tool_block = {
161+
"type": "function",
162+
"id": getattr(block, "id", ""),
163+
"function": {
164+
"name": getattr(block, "name", ""),
165+
"arguments": {}
166+
}
167+
}
168+
content_blocks.append(tool_block)
169+
tools_in_progress[block.id] = {
170+
"block": tool_block,
171+
"input_string": ""
172+
}
173+
current_text_block = None
174+
175+
# Handle text delta events
176+
if hasattr(event, "delta"):
177+
if hasattr(event.delta, "text"):
178+
delta_text = event.delta.text or ""
179+
accumulated_content += delta_text
180+
if current_text_block is not None:
181+
current_text_block["text"] += delta_text
182+
183+
# Handle tool input delta events
184+
if hasattr(event, "type") and event.type == "content_block_delta":
185+
if hasattr(event, "delta") and hasattr(event.delta, "type") and event.delta.type == "input_json_delta":
186+
if hasattr(event, "index") and event.index < len(content_blocks):
187+
block = content_blocks[event.index]
188+
if block.get("type") == "function" and block.get("id") in tools_in_progress:
189+
tool = tools_in_progress[block["id"]]
190+
partial_json = getattr(event.delta, "partial_json", "")
191+
tool["input_string"] += partial_json
192+
193+
# Handle content block stop events
194+
if hasattr(event, "type") and event.type == "content_block_stop":
195+
current_text_block = None
196+
# Parse accumulated tool input
197+
if hasattr(event, "index") and event.index < len(content_blocks):
198+
block = content_blocks[event.index]
199+
if block.get("type") == "function" and block.get("id") in tools_in_progress:
200+
tool = tools_in_progress[block["id"]]
201+
try:
202+
import json
203+
block["function"]["arguments"] = json.loads(tool["input_string"])
204+
except (json.JSONDecodeError, Exception):
205+
# Keep empty dict if parsing fails
206+
pass
207+
del tools_in_progress[block["id"]]
143208

144209
yield event
145210

146211
finally:
147212
end_time = time.time()
148213
latency = end_time - start_time
149-
output = "".join(accumulated_content)
150214

151215
await self._capture_streaming_event(
152216
posthog_distinct_id,
@@ -157,7 +221,8 @@ async def generator():
157221
kwargs,
158222
usage_stats,
159223
latency,
160-
output,
224+
content_blocks,
225+
accumulated_content,
161226
)
162227

163228
return generator()
@@ -172,11 +237,26 @@ async def _capture_streaming_event(
172237
kwargs: Dict[str, Any],
173238
usage_stats: Dict[str, int],
174239
latency: float,
175-
output: str,
240+
content_blocks: List[Dict[str, Any]],
241+
accumulated_content: str,
176242
):
177243
if posthog_trace_id is None:
178244
posthog_trace_id = str(uuid.uuid4())
179245

246+
# Format output to match non-streaming version
247+
formatted_output = []
248+
if content_blocks:
249+
formatted_output = [{
250+
"role": "assistant",
251+
"content": content_blocks
252+
}]
253+
else:
254+
# Fallback to accumulated content if no blocks
255+
formatted_output = [{
256+
"role": "assistant",
257+
"content": [{"type": "text", "text": accumulated_content}]
258+
}]
259+
180260
event_properties = {
181261
"$ai_provider": "anthropic",
182262
"$ai_model": kwargs.get("model"),
@@ -189,7 +269,7 @@ async def _capture_streaming_event(
189269
"$ai_output_choices": with_privacy_mode(
190270
self._client._ph_client,
191271
posthog_privacy_mode,
192-
[{"content": output, "role": "assistant"}],
272+
formatted_output,
193273
),
194274
"$ai_http_status": 200,
195275
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
@@ -206,6 +286,11 @@ async def _capture_streaming_event(
206286
**(posthog_properties or {}),
207287
}
208288

289+
# Add tools if available
290+
available_tools = extract_available_tool_calls("anthropic", kwargs)
291+
if available_tools:
292+
event_properties["$ai_tools"] = available_tools
293+
209294
if posthog_distinct_id is None:
210295
event_properties["$process_person_profile"] = False
211296

posthog/ai/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,12 @@ def format_response_anthropic(response):
156156
def format_response_openai(response):
157157
output = []
158158

159+
# Handle Chat Completions response format
159160
if hasattr(response, "choices"):
160161
content = []
161162
role = "assistant"
162163

163164
for choice in response.choices:
164-
# Handle Chat Completions response format
165165
if hasattr(choice, "message") and choice.message:
166166
if choice.message.role:
167167
role = choice.message.role

0 commit comments

Comments
 (0)