88
99import time
1010import uuid
11- from typing import Any , Dict , Optional
11+ from typing import Any , Dict , List , Optional
1212
1313from posthog import setup
1414from posthog .ai .utils import (
1515 call_llm_and_track_usage_async ,
16+ extract_available_tool_calls ,
1617 get_model_params ,
1718 merge_system_prompt ,
1819 with_privacy_mode ,
@@ -119,34 +120,97 @@ async def _create_streaming(
119120 ):
120121 start_time = time .time ()
121122 usage_stats : Dict [str , int ] = {"input_tokens" : 0 , "output_tokens" : 0 }
122- accumulated_content = []
123+ accumulated_content = ""
124+ content_blocks : List [Dict [str , Any ]] = []
125+ tools_in_progress : Dict [str , Dict [str , Any ]] = {}
126+ current_text_block : Optional [Dict [str , Any ]] = None
123127 response = await super ().create (** kwargs )
124128
125129 async def generator ():
126130 nonlocal usage_stats
127- nonlocal accumulated_content # noqa: F824
131+ nonlocal accumulated_content
132+ nonlocal content_blocks
133+ nonlocal tools_in_progress
134+ nonlocal current_text_block
128135 try :
129136 async for event in response :
137+ # Handle usage stats from message_start event
138+ if hasattr (event , "type" ) and event .type == "message_start" :
139+ if hasattr (event , "message" ) and hasattr (event .message , "usage" ):
140+ usage_stats ["input_tokens" ] = getattr (event .message .usage , "input_tokens" , 0 )
141+ usage_stats ["cache_creation_input_tokens" ] = getattr (event .message .usage , "cache_creation_input_tokens" , 0 )
142+ usage_stats ["cache_read_input_tokens" ] = getattr (event .message .usage , "cache_read_input_tokens" , 0 )
143+
144+ # Handle usage stats from message_delta event
130145 if hasattr (event , "usage" ) and event .usage :
131- usage_stats = {
132- k : getattr (event .usage , k , 0 )
133- for k in [
134- "input_tokens" ,
135- "output_tokens" ,
136- "cache_read_input_tokens" ,
137- "cache_creation_input_tokens" ,
138- ]
139- }
140-
141- if hasattr (event , "content" ) and event .content :
142- accumulated_content .append (event .content )
146+ usage_stats ["output_tokens" ] = getattr (event .usage , "output_tokens" , 0 )
147+
148+ # Handle content block start events
149+ if hasattr (event , "type" ) and event .type == "content_block_start" :
150+ if hasattr (event , "content_block" ):
151+ block = event .content_block
152+ if hasattr (block , "type" ):
153+ if block .type == "text" :
154+ current_text_block = {
155+ "type" : "text" ,
156+ "text" : ""
157+ }
158+ content_blocks .append (current_text_block )
159+ elif block .type == "tool_use" :
160+ tool_block = {
161+ "type" : "function" ,
162+ "id" : getattr (block , "id" , "" ),
163+ "function" : {
164+ "name" : getattr (block , "name" , "" ),
165+ "arguments" : {}
166+ }
167+ }
168+ content_blocks .append (tool_block )
169+ tools_in_progress [block .id ] = {
170+ "block" : tool_block ,
171+ "input_string" : ""
172+ }
173+ current_text_block = None
174+
175+ # Handle text delta events
176+ if hasattr (event , "delta" ):
177+ if hasattr (event .delta , "text" ):
178+ delta_text = event .delta .text or ""
179+ accumulated_content += delta_text
180+ if current_text_block is not None :
181+ current_text_block ["text" ] += delta_text
182+
183+ # Handle tool input delta events
184+ if hasattr (event , "type" ) and event .type == "content_block_delta" :
185+ if hasattr (event , "delta" ) and hasattr (event .delta , "type" ) and event .delta .type == "input_json_delta" :
186+ if hasattr (event , "index" ) and event .index < len (content_blocks ):
187+ block = content_blocks [event .index ]
188+ if block .get ("type" ) == "function" and block .get ("id" ) in tools_in_progress :
189+ tool = tools_in_progress [block ["id" ]]
190+ partial_json = getattr (event .delta , "partial_json" , "" )
191+ tool ["input_string" ] += partial_json
192+
193+ # Handle content block stop events
194+ if hasattr (event , "type" ) and event .type == "content_block_stop" :
195+ current_text_block = None
196+ # Parse accumulated tool input
197+ if hasattr (event , "index" ) and event .index < len (content_blocks ):
198+ block = content_blocks [event .index ]
199+ if block .get ("type" ) == "function" and block .get ("id" ) in tools_in_progress :
200+ tool = tools_in_progress [block ["id" ]]
201+ try :
202+ import json
203+ block ["function" ]["arguments" ] = json .loads (tool ["input_string" ])
204+ except (json .JSONDecodeError , Exception ):
205+ # Keep empty dict if parsing fails
206+ pass
207+ del tools_in_progress [block ["id" ]]
143208
144209 yield event
145210
146211 finally :
147212 end_time = time .time ()
148213 latency = end_time - start_time
149- output = "" .join (accumulated_content )
150214
151215 await self ._capture_streaming_event (
152216 posthog_distinct_id ,
@@ -157,7 +221,8 @@ async def generator():
157221 kwargs ,
158222 usage_stats ,
159223 latency ,
160- output ,
224+ content_blocks ,
225+ accumulated_content ,
161226 )
162227
163228 return generator ()
@@ -172,11 +237,26 @@ async def _capture_streaming_event(
172237 kwargs : Dict [str , Any ],
173238 usage_stats : Dict [str , int ],
174239 latency : float ,
175- output : str ,
240+ content_blocks : List [Dict [str , Any ]],
241+ accumulated_content : str ,
176242 ):
177243 if posthog_trace_id is None :
178244 posthog_trace_id = str (uuid .uuid4 ())
179245
246+ # Format output to match non-streaming version
247+ formatted_output = []
248+ if content_blocks :
249+ formatted_output = [{
250+ "role" : "assistant" ,
251+ "content" : content_blocks
252+ }]
253+ else :
254+ # Fallback to accumulated content if no blocks
255+ formatted_output = [{
256+ "role" : "assistant" ,
257+ "content" : [{"type" : "text" , "text" : accumulated_content }]
258+ }]
259+
180260 event_properties = {
181261 "$ai_provider" : "anthropic" ,
182262 "$ai_model" : kwargs .get ("model" ),
@@ -189,7 +269,7 @@ async def _capture_streaming_event(
189269 "$ai_output_choices" : with_privacy_mode (
190270 self ._client ._ph_client ,
191271 posthog_privacy_mode ,
192- [{ "content" : output , "role" : "assistant" }] ,
272+ formatted_output ,
193273 ),
194274 "$ai_http_status" : 200 ,
195275 "$ai_input_tokens" : usage_stats .get ("input_tokens" , 0 ),
@@ -206,6 +286,11 @@ async def _capture_streaming_event(
206286 ** (posthog_properties or {}),
207287 }
208288
289+ # Add tools if available
290+ available_tools = extract_available_tool_calls ("anthropic" , kwargs )
291+ if available_tools :
292+ event_properties ["$ai_tools" ] = available_tools
293+
209294 if posthog_distinct_id is None :
210295 event_properties ["$process_person_profile" ] = False
211296
0 commit comments