diff --git a/CHANGELOG.md b/CHANGELOG.md index 275f2b64..b4f67aaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 6.7.3 - 2025-09-04 + +- fix: missing usage tokens in Gemini + # 6.7.2 - 2025-09-03 - fix: tool call results in streaming providers diff --git a/posthog/ai/anthropic/anthropic.py b/posthog/ai/anthropic/anthropic.py index 80000f43..9cc03678 100644 --- a/posthog/ai/anthropic/anthropic.py +++ b/posthog/ai/anthropic/anthropic.py @@ -10,7 +10,7 @@ import uuid from typing import Any, Dict, List, Optional -from posthog.ai.types import StreamingContentBlock, ToolInProgress +from posthog.ai.types import StreamingContentBlock, TokenUsage, ToolInProgress from posthog.ai.utils import ( call_llm_and_track_usage, merge_usage_stats, @@ -126,7 +126,7 @@ def _create_streaming( **kwargs: Any, ): start_time = time.time() - usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0} + usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0) accumulated_content = "" content_blocks: List[StreamingContentBlock] = [] tools_in_progress: Dict[str, ToolInProgress] = {} @@ -210,14 +210,13 @@ def _capture_streaming_event( posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - usage_stats: Dict[str, int], + usage_stats: TokenUsage, latency: float, content_blocks: List[StreamingContentBlock], accumulated_content: str, ): from posthog.ai.types import StreamingEventData from posthog.ai.anthropic.anthropic_converter import ( - standardize_anthropic_usage, format_anthropic_streaming_input, format_anthropic_streaming_output_complete, ) @@ -236,7 +235,7 @@ def _capture_streaming_event( formatted_output=format_anthropic_streaming_output_complete( content_blocks, accumulated_content ), - usage_stats=standardize_anthropic_usage(usage_stats), + usage_stats=usage_stats, latency=latency, distinct_id=posthog_distinct_id, trace_id=posthog_trace_id, diff --git a/posthog/ai/anthropic/anthropic_async.py b/posthog/ai/anthropic/anthropic_async.py index 34233333..527b73f9 100644 --- a/posthog/ai/anthropic/anthropic_async.py +++ b/posthog/ai/anthropic/anthropic_async.py @@ -11,7 +11,7 @@ from typing import Any, Dict, List, Optional from posthog import setup -from posthog.ai.types import StreamingContentBlock, ToolInProgress +from posthog.ai.types import StreamingContentBlock, TokenUsage, ToolInProgress from posthog.ai.utils import ( call_llm_and_track_usage_async, extract_available_tool_calls, @@ -131,7 +131,7 @@ async def _create_streaming( **kwargs: Any, ): start_time = time.time() - usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0} + usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0) accumulated_content = "" content_blocks: List[StreamingContentBlock] = [] tools_in_progress: Dict[str, ToolInProgress] = {} @@ -215,7 +215,7 @@ async def _capture_streaming_event( posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - usage_stats: Dict[str, int], + usage_stats: TokenUsage, latency: float, content_blocks: List[StreamingContentBlock], accumulated_content: str, diff --git a/posthog/ai/anthropic/anthropic_converter.py b/posthog/ai/anthropic/anthropic_converter.py index 7ed96268..7d2e615f 100644 --- a/posthog/ai/anthropic/anthropic_converter.py +++ b/posthog/ai/anthropic/anthropic_converter.py @@ -14,7 +14,6 @@ FormattedMessage, FormattedTextContent, StreamingContentBlock, - StreamingUsageStats, TokenUsage, ToolInProgress, ) @@ -164,7 +163,38 @@ def format_anthropic_streaming_content( return formatted -def extract_anthropic_usage_from_event(event: Any) -> StreamingUsageStats: +def extract_anthropic_usage_from_response(response: Any) -> TokenUsage: + """ + Extract usage from a full Anthropic response (non-streaming). + + Args: + response: The complete response from Anthropic API + + Returns: + TokenUsage with standardized usage + """ + if not hasattr(response, "usage"): + return TokenUsage(input_tokens=0, output_tokens=0) + + result = TokenUsage( + input_tokens=getattr(response.usage, "input_tokens", 0), + output_tokens=getattr(response.usage, "output_tokens", 0), + ) + + if hasattr(response.usage, "cache_read_input_tokens"): + cache_read = response.usage.cache_read_input_tokens + if cache_read and cache_read > 0: + result["cache_read_input_tokens"] = cache_read + + if hasattr(response.usage, "cache_creation_input_tokens"): + cache_creation = response.usage.cache_creation_input_tokens + if cache_creation and cache_creation > 0: + result["cache_creation_input_tokens"] = cache_creation + + return result + + +def extract_anthropic_usage_from_event(event: Any) -> TokenUsage: """ Extract usage statistics from an Anthropic streaming event. @@ -175,7 +205,7 @@ def extract_anthropic_usage_from_event(event: Any) -> StreamingUsageStats: Dictionary of usage statistics """ - usage: StreamingUsageStats = {} + usage: TokenUsage = TokenUsage() # Handle usage stats from message_start event if hasattr(event, "type") and event.type == "message_start": @@ -329,26 +359,6 @@ def finalize_anthropic_tool_input( del tools_in_progress[block["id"]] -def standardize_anthropic_usage(usage: Dict[str, Any]) -> TokenUsage: - """ - Standardize Anthropic usage statistics to common TokenUsage format. - - Anthropic already uses standard field names, so this mainly structures the data. - - Args: - usage: Raw usage statistics from Anthropic - - Returns: - Standardized TokenUsage dict - """ - return TokenUsage( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - cache_read_input_tokens=usage.get("cache_read_input_tokens"), - cache_creation_input_tokens=usage.get("cache_creation_input_tokens"), - ) - - def format_anthropic_streaming_input(kwargs: Dict[str, Any]) -> Any: """ Format Anthropic streaming input using system prompt merging. diff --git a/posthog/ai/gemini/gemini.py b/posthog/ai/gemini/gemini.py index 0be9c673..ed630d6c 100644 --- a/posthog/ai/gemini/gemini.py +++ b/posthog/ai/gemini/gemini.py @@ -3,6 +3,8 @@ import uuid from typing import Any, Dict, Optional +from posthog.ai.types import TokenUsage + try: from google import genai except ImportError: @@ -294,7 +296,7 @@ def _generate_content_streaming( **kwargs: Any, ): start_time = time.time() - usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0} + usage_stats: TokenUsage = TokenUsage(input_tokens=0, output_tokens=0) accumulated_content = [] kwargs_without_stream = {"model": model, "contents": contents, **kwargs} @@ -350,12 +352,11 @@ def _capture_streaming_event( privacy_mode: bool, groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - usage_stats: Dict[str, int], + usage_stats: TokenUsage, latency: float, output: Any, ): from posthog.ai.types import StreamingEventData - from posthog.ai.gemini.gemini_converter import standardize_gemini_usage # Prepare standardized event data formatted_input = self._format_input(contents) @@ -368,7 +369,7 @@ def _capture_streaming_event( kwargs=kwargs, formatted_input=sanitized_input, formatted_output=format_gemini_streaming_output(output), - usage_stats=standardize_gemini_usage(usage_stats), + usage_stats=usage_stats, latency=latency, distinct_id=distinct_id, trace_id=trace_id, diff --git a/posthog/ai/gemini/gemini_converter.py b/posthog/ai/gemini/gemini_converter.py index b5296de1..4fd979c7 100644 --- a/posthog/ai/gemini/gemini_converter.py +++ b/posthog/ai/gemini/gemini_converter.py @@ -10,7 +10,6 @@ from posthog.ai.types import ( FormattedContentItem, FormattedMessage, - StreamingUsageStats, TokenUsage, ) @@ -283,7 +282,54 @@ def format_gemini_input(contents: Any) -> List[FormattedMessage]: return [_format_object_message(contents)] -def extract_gemini_usage_from_chunk(chunk: Any) -> StreamingUsageStats: +def _extract_usage_from_metadata(metadata: Any) -> TokenUsage: + """ + Common logic to extract usage from Gemini metadata. + Used by both streaming and non-streaming paths. + + Args: + metadata: usage_metadata from Gemini response or chunk + + Returns: + TokenUsage with standardized usage + """ + usage = TokenUsage( + input_tokens=getattr(metadata, "prompt_token_count", 0), + output_tokens=getattr(metadata, "candidates_token_count", 0), + ) + + # Add cache tokens if present (don't add if 0) + if hasattr(metadata, "cached_content_token_count"): + cache_tokens = metadata.cached_content_token_count + if cache_tokens and cache_tokens > 0: + usage["cache_read_input_tokens"] = cache_tokens + + # Add reasoning tokens if present (don't add if 0) + if hasattr(metadata, "thoughts_token_count"): + reasoning_tokens = metadata.thoughts_token_count + if reasoning_tokens and reasoning_tokens > 0: + usage["reasoning_tokens"] = reasoning_tokens + + return usage + + +def extract_gemini_usage_from_response(response: Any) -> TokenUsage: + """ + Extract usage statistics from a full Gemini response (non-streaming). + + Args: + response: The complete response from Gemini API + + Returns: + TokenUsage with standardized usage statistics + """ + if not hasattr(response, "usage_metadata") or not response.usage_metadata: + return TokenUsage(input_tokens=0, output_tokens=0) + + return _extract_usage_from_metadata(response.usage_metadata) + + +def extract_gemini_usage_from_chunk(chunk: Any) -> TokenUsage: """ Extract usage statistics from a Gemini streaming chunk. @@ -291,21 +337,16 @@ def extract_gemini_usage_from_chunk(chunk: Any) -> StreamingUsageStats: chunk: Streaming chunk from Gemini API Returns: - Dictionary of usage statistics + TokenUsage with standardized usage statistics """ - usage: StreamingUsageStats = {} + usage: TokenUsage = TokenUsage() if not hasattr(chunk, "usage_metadata") or not chunk.usage_metadata: return usage - # Gemini uses prompt_token_count and candidates_token_count - usage["input_tokens"] = getattr(chunk.usage_metadata, "prompt_token_count", 0) - usage["output_tokens"] = getattr(chunk.usage_metadata, "candidates_token_count", 0) - - # Calculate total if both values are defined (including 0) - if "input_tokens" in usage and "output_tokens" in usage: - usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"] + # Use the shared helper to extract usage + usage = _extract_usage_from_metadata(chunk.usage_metadata) return usage @@ -417,22 +458,3 @@ def format_gemini_streaming_output( # Fallback for empty or unexpected input return [{"role": "assistant", "content": [{"type": "text", "text": ""}]}] - - -def standardize_gemini_usage(usage: Dict[str, Any]) -> TokenUsage: - """ - Standardize Gemini usage statistics to common TokenUsage format. - - Gemini already uses standard field names (input_tokens/output_tokens). - - Args: - usage: Raw usage statistics from Gemini - - Returns: - Standardized TokenUsage dict - """ - return TokenUsage( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - # Gemini doesn't currently support cache or reasoning tokens - ) diff --git a/posthog/ai/openai/openai.py b/posthog/ai/openai/openai.py index bdca1f6c..11b3fe92 100644 --- a/posthog/ai/openai/openai.py +++ b/posthog/ai/openai/openai.py @@ -2,6 +2,8 @@ import uuid from typing import Any, Dict, List, Optional +from posthog.ai.types import TokenUsage + try: import openai except ImportError: @@ -120,7 +122,7 @@ def _create_streaming( **kwargs: Any, ): start_time = time.time() - usage_stats: Dict[str, int] = {} + usage_stats: TokenUsage = TokenUsage() final_content = [] response = self._original.create(**kwargs) @@ -171,14 +173,13 @@ def _capture_streaming_event( posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - usage_stats: Dict[str, int], + usage_stats: TokenUsage, latency: float, output: Any, available_tool_calls: Optional[List[Dict[str, Any]]] = None, ): from posthog.ai.types import StreamingEventData from posthog.ai.openai.openai_converter import ( - standardize_openai_usage, format_openai_streaming_input, format_openai_streaming_output, ) @@ -195,7 +196,7 @@ def _capture_streaming_event( kwargs=kwargs, formatted_input=sanitized_input, formatted_output=format_openai_streaming_output(output, "responses"), - usage_stats=standardize_openai_usage(usage_stats, "responses"), + usage_stats=usage_stats, latency=latency, distinct_id=posthog_distinct_id, trace_id=posthog_trace_id, @@ -316,7 +317,7 @@ def _create_streaming( **kwargs: Any, ): start_time = time.time() - usage_stats: Dict[str, int] = {} + usage_stats: TokenUsage = TokenUsage() accumulated_content = [] accumulated_tool_calls: Dict[int, Dict[str, Any]] = {} if "stream_options" not in kwargs: @@ -387,7 +388,7 @@ def _capture_streaming_event( posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - usage_stats: Dict[str, int], + usage_stats: TokenUsage, latency: float, output: Any, tool_calls: Optional[List[Dict[str, Any]]] = None, @@ -395,7 +396,6 @@ def _capture_streaming_event( ): from posthog.ai.types import StreamingEventData from posthog.ai.openai.openai_converter import ( - standardize_openai_usage, format_openai_streaming_input, format_openai_streaming_output, ) @@ -412,7 +412,7 @@ def _capture_streaming_event( kwargs=kwargs, formatted_input=sanitized_input, formatted_output=format_openai_streaming_output(output, "chat", tool_calls), - usage_stats=standardize_openai_usage(usage_stats, "chat"), + usage_stats=usage_stats, latency=latency, distinct_id=posthog_distinct_id, trace_id=posthog_trace_id, diff --git a/posthog/ai/openai/openai_async.py b/posthog/ai/openai/openai_async.py index 57bc7d3d..69cce4c8 100644 --- a/posthog/ai/openai/openai_async.py +++ b/posthog/ai/openai/openai_async.py @@ -2,6 +2,8 @@ import uuid from typing import Any, Dict, List, Optional +from posthog.ai.types import TokenUsage + try: import openai except ImportError: @@ -124,7 +126,7 @@ async def _create_streaming( **kwargs: Any, ): start_time = time.time() - usage_stats: Dict[str, int] = {} + usage_stats: TokenUsage = TokenUsage() final_content = [] response = self._original.create(**kwargs) @@ -176,7 +178,7 @@ async def _capture_streaming_event( posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - usage_stats: Dict[str, int], + usage_stats: TokenUsage, latency: float, output: Any, available_tool_calls: Optional[List[Dict[str, Any]]] = None, @@ -336,7 +338,7 @@ async def _create_streaming( **kwargs: Any, ): start_time = time.time() - usage_stats: Dict[str, int] = {} + usage_stats: TokenUsage = TokenUsage() accumulated_content = [] accumulated_tool_calls: Dict[int, Dict[str, Any]] = {} @@ -406,7 +408,7 @@ async def _capture_streaming_event( posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], kwargs: Dict[str, Any], - usage_stats: Dict[str, int], + usage_stats: TokenUsage, latency: float, output: Any, tool_calls: Optional[List[Dict[str, Any]]] = None, @@ -430,8 +432,8 @@ async def _capture_streaming_event( format_openai_streaming_output(output, "chat", tool_calls), ), "$ai_http_status": 200, - "$ai_input_tokens": usage_stats.get("prompt_tokens", 0), - "$ai_output_tokens": usage_stats.get("completion_tokens", 0), + "$ai_input_tokens": usage_stats.get("input_tokens", 0), + "$ai_output_tokens": usage_stats.get("output_tokens", 0), "$ai_cache_read_input_tokens": usage_stats.get( "cache_read_input_tokens", 0 ), @@ -501,13 +503,13 @@ async def create( end_time = time.time() # Extract usage statistics if available - usage_stats = {} + usage_stats: TokenUsage = TokenUsage() if hasattr(response, "usage") and response.usage: - usage_stats = { - "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), - "total_tokens": getattr(response.usage, "total_tokens", 0), - } + usage_stats = TokenUsage( + input_tokens=getattr(response.usage, "prompt_tokens", 0), + output_tokens=getattr(response.usage, "completion_tokens", 0), + ) latency = end_time - start_time @@ -521,7 +523,7 @@ async def create( sanitize_openai_response(kwargs.get("input")), ), "$ai_http_status": 200, - "$ai_input_tokens": usage_stats.get("prompt_tokens", 0), + "$ai_input_tokens": usage_stats.get("input_tokens", 0), "$ai_latency": latency, "$ai_trace_id": posthog_trace_id, "$ai_base_url": str(self._client.base_url), diff --git a/posthog/ai/openai/openai_converter.py b/posthog/ai/openai/openai_converter.py index 2429270b..6a56e838 100644 --- a/posthog/ai/openai/openai_converter.py +++ b/posthog/ai/openai/openai_converter.py @@ -14,7 +14,6 @@ FormattedImageContent, FormattedMessage, FormattedTextContent, - StreamingUsageStats, TokenUsage, ) @@ -256,9 +255,69 @@ def format_openai_streaming_content( return formatted +def extract_openai_usage_from_response(response: Any) -> TokenUsage: + """ + Extract usage statistics from a full OpenAI response (non-streaming). + Handles both Chat Completions and Responses API. + + Args: + response: The complete response from OpenAI API + + Returns: + TokenUsage with standardized usage statistics + """ + if not hasattr(response, "usage"): + return TokenUsage(input_tokens=0, output_tokens=0) + + cached_tokens = 0 + input_tokens = 0 + output_tokens = 0 + reasoning_tokens = 0 + + # Responses API format + if hasattr(response.usage, "input_tokens"): + input_tokens = response.usage.input_tokens + if hasattr(response.usage, "output_tokens"): + output_tokens = response.usage.output_tokens + if hasattr(response.usage, "input_tokens_details") and hasattr( + response.usage.input_tokens_details, "cached_tokens" + ): + cached_tokens = response.usage.input_tokens_details.cached_tokens + if hasattr(response.usage, "output_tokens_details") and hasattr( + response.usage.output_tokens_details, "reasoning_tokens" + ): + reasoning_tokens = response.usage.output_tokens_details.reasoning_tokens + + # Chat Completions format + if hasattr(response.usage, "prompt_tokens"): + input_tokens = response.usage.prompt_tokens + if hasattr(response.usage, "completion_tokens"): + output_tokens = response.usage.completion_tokens + if hasattr(response.usage, "prompt_tokens_details") and hasattr( + response.usage.prompt_tokens_details, "cached_tokens" + ): + cached_tokens = response.usage.prompt_tokens_details.cached_tokens + if hasattr(response.usage, "completion_tokens_details") and hasattr( + response.usage.completion_tokens_details, "reasoning_tokens" + ): + reasoning_tokens = response.usage.completion_tokens_details.reasoning_tokens + + result = TokenUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + + if cached_tokens > 0: + result["cache_read_input_tokens"] = cached_tokens + if reasoning_tokens > 0: + result["reasoning_tokens"] = reasoning_tokens + + return result + + def extract_openai_usage_from_chunk( chunk: Any, provider_type: str = "chat" -) -> StreamingUsageStats: +) -> TokenUsage: """ Extract usage statistics from an OpenAI streaming chunk. @@ -272,16 +331,16 @@ def extract_openai_usage_from_chunk( Dictionary of usage statistics """ - usage: StreamingUsageStats = {} + usage: TokenUsage = TokenUsage() if provider_type == "chat": if not hasattr(chunk, "usage") or not chunk.usage: return usage # Chat Completions API uses prompt_tokens and completion_tokens - usage["prompt_tokens"] = getattr(chunk.usage, "prompt_tokens", 0) - usage["completion_tokens"] = getattr(chunk.usage, "completion_tokens", 0) - usage["total_tokens"] = getattr(chunk.usage, "total_tokens", 0) + # Standardize to input_tokens and output_tokens + usage["input_tokens"] = getattr(chunk.usage, "prompt_tokens", 0) + usage["output_tokens"] = getattr(chunk.usage, "completion_tokens", 0) # Handle cached tokens if hasattr(chunk.usage, "prompt_tokens_details") and hasattr( @@ -310,7 +369,6 @@ def extract_openai_usage_from_chunk( response_usage = chunk.response.usage usage["input_tokens"] = getattr(response_usage, "input_tokens", 0) usage["output_tokens"] = getattr(response_usage, "output_tokens", 0) - usage["total_tokens"] = getattr(response_usage, "total_tokens", 0) # Handle cached tokens if hasattr(response_usage, "input_tokens_details") and hasattr( @@ -535,37 +593,6 @@ def format_openai_streaming_output( ] -def standardize_openai_usage( - usage: Dict[str, Any], api_type: str = "chat" -) -> TokenUsage: - """ - Standardize OpenAI usage statistics to common TokenUsage format. - - Args: - usage: Raw usage statistics from OpenAI - api_type: Either "chat" or "responses" to handle different field names - - Returns: - Standardized TokenUsage dict - """ - if api_type == "chat": - # Chat API uses prompt_tokens/completion_tokens - return TokenUsage( - input_tokens=usage.get("prompt_tokens", 0), - output_tokens=usage.get("completion_tokens", 0), - cache_read_input_tokens=usage.get("cache_read_input_tokens"), - reasoning_tokens=usage.get("reasoning_tokens"), - ) - else: # responses API - # Responses API uses input_tokens/output_tokens - return TokenUsage( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - cache_read_input_tokens=usage.get("cache_read_input_tokens"), - reasoning_tokens=usage.get("reasoning_tokens"), - ) - - def format_openai_streaming_input( kwargs: Dict[str, Any], api_type: str = "chat" ) -> Any: diff --git a/posthog/ai/types.py b/posthog/ai/types.py index bc20e69c..d90a0df8 100644 --- a/posthog/ai/types.py +++ b/posthog/ai/types.py @@ -77,24 +77,6 @@ class ProviderResponse(TypedDict, total=False): error: Optional[str] -class StreamingUsageStats(TypedDict, total=False): - """ - Usage statistics collected during streaming. - - Different providers populate different fields during streaming. - """ - - input_tokens: int - output_tokens: int - cache_read_input_tokens: Optional[int] - cache_creation_input_tokens: Optional[int] - reasoning_tokens: Optional[int] - # OpenAI-specific names - prompt_tokens: Optional[int] - completion_tokens: Optional[int] - total_tokens: Optional[int] - - class StreamingContentBlock(TypedDict, total=False): """ Content block used during streaming to accumulate content. @@ -133,7 +115,7 @@ class StreamingEventData(TypedDict): kwargs: Dict[str, Any] # Original kwargs for tool extraction and special handling formatted_input: Any # Provider-formatted input ready for tracking formatted_output: Any # Provider-formatted output ready for tracking - usage_stats: TokenUsage # Standardized token counts + usage_stats: TokenUsage latency: float distinct_id: Optional[str] trace_id: Optional[str] diff --git a/posthog/ai/utils.py b/posthog/ai/utils.py index 6daca1b6..f4392521 100644 --- a/posthog/ai/utils.py +++ b/posthog/ai/utils.py @@ -2,9 +2,8 @@ import uuid from typing import Any, Callable, Dict, Optional - from posthog.client import Client as PostHogClient -from posthog.ai.types import StreamingEventData, StreamingUsageStats +from posthog.ai.types import StreamingEventData, TokenUsage from posthog.ai.sanitization import ( sanitize_openai, sanitize_anthropic, @@ -14,7 +13,7 @@ def merge_usage_stats( - target: Dict[str, int], source: StreamingUsageStats, mode: str = "incremental" + target: TokenUsage, source: TokenUsage, mode: str = "incremental" ) -> None: """ Merge streaming usage statistics into target dict, handling None values. @@ -25,19 +24,49 @@ def merge_usage_stats( Args: target: Dictionary to update with usage stats - source: StreamingUsageStats that may contain None values + source: TokenUsage that may contain None values mode: Either "incremental" or "cumulative" """ if mode == "incremental": # Add new values to existing totals - for key, value in source.items(): - if value is not None and isinstance(value, int): - target[key] = target.get(key, 0) + value + source_input = source.get("input_tokens") + if source_input is not None: + current = target.get("input_tokens") or 0 + target["input_tokens"] = current + source_input + + source_output = source.get("output_tokens") + if source_output is not None: + current = target.get("output_tokens") or 0 + target["output_tokens"] = current + source_output + + source_cache_read = source.get("cache_read_input_tokens") + if source_cache_read is not None: + current = target.get("cache_read_input_tokens") or 0 + target["cache_read_input_tokens"] = current + source_cache_read + + source_cache_creation = source.get("cache_creation_input_tokens") + if source_cache_creation is not None: + current = target.get("cache_creation_input_tokens") or 0 + target["cache_creation_input_tokens"] = current + source_cache_creation + + source_reasoning = source.get("reasoning_tokens") + if source_reasoning is not None: + current = target.get("reasoning_tokens") or 0 + target["reasoning_tokens"] = current + source_reasoning elif mode == "cumulative": # Replace with latest values (already cumulative) - for key, value in source.items(): - if value is not None and isinstance(value, int): - target[key] = value + if source.get("input_tokens") is not None: + target["input_tokens"] = source["input_tokens"] + if source.get("output_tokens") is not None: + target["output_tokens"] = source["output_tokens"] + if source.get("cache_read_input_tokens") is not None: + target["cache_read_input_tokens"] = source["cache_read_input_tokens"] + if source.get("cache_creation_input_tokens") is not None: + target["cache_creation_input_tokens"] = source[ + "cache_creation_input_tokens" + ] + if source.get("reasoning_tokens") is not None: + target["reasoning_tokens"] = source["reasoning_tokens"] else: raise ValueError(f"Invalid mode: {mode}. Must be 'incremental' or 'cumulative'") @@ -64,74 +93,31 @@ def get_model_params(kwargs: Dict[str, Any]) -> Dict[str, Any]: return model_params -def get_usage(response, provider: str) -> Dict[str, Any]: +def get_usage(response, provider: str) -> TokenUsage: + """ + Extract usage statistics from response based on provider. + Delegates to provider-specific converter functions. + """ if provider == "anthropic": - return { - "input_tokens": response.usage.input_tokens, - "output_tokens": response.usage.output_tokens, - "cache_read_input_tokens": response.usage.cache_read_input_tokens, - "cache_creation_input_tokens": response.usage.cache_creation_input_tokens, - } + from posthog.ai.anthropic.anthropic_converter import ( + extract_anthropic_usage_from_response, + ) + + return extract_anthropic_usage_from_response(response) elif provider == "openai": - cached_tokens = 0 - input_tokens = 0 - output_tokens = 0 - reasoning_tokens = 0 - - # responses api - if hasattr(response.usage, "input_tokens"): - input_tokens = response.usage.input_tokens - if hasattr(response.usage, "output_tokens"): - output_tokens = response.usage.output_tokens - if hasattr(response.usage, "input_tokens_details") and hasattr( - response.usage.input_tokens_details, "cached_tokens" - ): - cached_tokens = response.usage.input_tokens_details.cached_tokens - if hasattr(response.usage, "output_tokens_details") and hasattr( - response.usage.output_tokens_details, "reasoning_tokens" - ): - reasoning_tokens = response.usage.output_tokens_details.reasoning_tokens - - # chat completions - if hasattr(response.usage, "prompt_tokens"): - input_tokens = response.usage.prompt_tokens - if hasattr(response.usage, "completion_tokens"): - output_tokens = response.usage.completion_tokens - if hasattr(response.usage, "prompt_tokens_details") and hasattr( - response.usage.prompt_tokens_details, "cached_tokens" - ): - cached_tokens = response.usage.prompt_tokens_details.cached_tokens + from posthog.ai.openai.openai_converter import ( + extract_openai_usage_from_response, + ) - return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "cache_read_input_tokens": cached_tokens, - "reasoning_tokens": reasoning_tokens, - } + return extract_openai_usage_from_response(response) elif provider == "gemini": - input_tokens = 0 - output_tokens = 0 + from posthog.ai.gemini.gemini_converter import ( + extract_gemini_usage_from_response, + ) - if hasattr(response, "usage_metadata") and response.usage_metadata: - input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0) - output_tokens = getattr( - response.usage_metadata, "candidates_token_count", 0 - ) + return extract_gemini_usage_from_response(response) - return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - "reasoning_tokens": 0, - } - return { - "input_tokens": 0, - "output_tokens": 0, - "cache_read_input_tokens": 0, - "cache_creation_input_tokens": 0, - "reasoning_tokens": 0, - } + return TokenUsage(input_tokens=0, output_tokens=0) def format_response(response, provider: str): @@ -169,6 +155,7 @@ def extract_available_tool_calls(provider: str, kwargs: Dict[str, Any]): from posthog.ai.openai.openai_converter import extract_openai_tools return extract_openai_tools(kwargs) + return None def merge_system_prompt(kwargs: Dict[str, Any], provider: str): @@ -187,9 +174,9 @@ def merge_system_prompt(kwargs: Dict[str, Any], provider: str): contents = kwargs.get("contents", []) return format_gemini_input(contents) elif provider == "openai": - # For OpenAI, handle both Chat Completions and Responses API from posthog.ai.openai.openai_converter import format_openai_input + # For OpenAI, handle both Chat Completions and Responses API messages_param = kwargs.get("messages") input_param = kwargs.get("input") @@ -250,7 +237,7 @@ def call_llm_and_track_usage( response = None error = None http_status = 200 - usage: Dict[str, Any] = {} + usage: TokenUsage = TokenUsage() error_params: Dict[str, Any] = {} try: @@ -305,27 +292,17 @@ def call_llm_and_track_usage( if available_tool_calls: event_properties["$ai_tools"] = available_tool_calls - if ( - usage.get("cache_read_input_tokens") is not None - and usage.get("cache_read_input_tokens", 0) > 0 - ): - event_properties["$ai_cache_read_input_tokens"] = usage.get( - "cache_read_input_tokens", 0 - ) + cache_read = usage.get("cache_read_input_tokens") + if cache_read is not None and cache_read > 0: + event_properties["$ai_cache_read_input_tokens"] = cache_read - if ( - usage.get("cache_creation_input_tokens") is not None - and usage.get("cache_creation_input_tokens", 0) > 0 - ): - event_properties["$ai_cache_creation_input_tokens"] = usage.get( - "cache_creation_input_tokens", 0 - ) + cache_creation = usage.get("cache_creation_input_tokens") + if cache_creation is not None and cache_creation > 0: + event_properties["$ai_cache_creation_input_tokens"] = cache_creation - if ( - usage.get("reasoning_tokens") is not None - and usage.get("reasoning_tokens", 0) > 0 - ): - event_properties["$ai_reasoning_tokens"] = usage.get("reasoning_tokens", 0) + reasoning = usage.get("reasoning_tokens") + if reasoning is not None and reasoning > 0: + event_properties["$ai_reasoning_tokens"] = reasoning if posthog_distinct_id is None: event_properties["$process_person_profile"] = False @@ -367,7 +344,7 @@ async def call_llm_and_track_usage_async( response = None error = None http_status = 200 - usage: Dict[str, Any] = {} + usage: TokenUsage = TokenUsage() error_params: Dict[str, Any] = {} try: @@ -422,21 +399,13 @@ async def call_llm_and_track_usage_async( if available_tool_calls: event_properties["$ai_tools"] = available_tool_calls - if ( - usage.get("cache_read_input_tokens") is not None - and usage.get("cache_read_input_tokens", 0) > 0 - ): - event_properties["$ai_cache_read_input_tokens"] = usage.get( - "cache_read_input_tokens", 0 - ) + cache_read = usage.get("cache_read_input_tokens") + if cache_read is not None and cache_read > 0: + event_properties["$ai_cache_read_input_tokens"] = cache_read - if ( - usage.get("cache_creation_input_tokens") is not None - and usage.get("cache_creation_input_tokens", 0) > 0 - ): - event_properties["$ai_cache_creation_input_tokens"] = usage.get( - "cache_creation_input_tokens", 0 - ) + cache_creation = usage.get("cache_creation_input_tokens") + if cache_creation is not None and cache_creation > 0: + event_properties["$ai_cache_creation_input_tokens"] = cache_creation if posthog_distinct_id is None: event_properties["$process_person_profile"] = False diff --git a/posthog/test/ai/anthropic/test_anthropic.py b/posthog/test/ai/anthropic/test_anthropic.py index fcb64c15..5f65a99e 100644 --- a/posthog/test/ai/anthropic/test_anthropic.py +++ b/posthog/test/ai/anthropic/test_anthropic.py @@ -12,8 +12,6 @@ except ImportError: ANTHROPIC_AVAILABLE = False -ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") - # Skip all tests if Anthropic is not available pytestmark = pytest.mark.skipif( not ANTHROPIC_AVAILABLE, reason="Anthropic package is not available" @@ -373,7 +371,6 @@ def test_privacy_mode_global(mock_client, mock_anthropic_response): assert props["$ai_output_choices"] is None -@pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") def test_basic_integration(mock_client): """Test basic non-streaming integration.""" @@ -415,7 +412,6 @@ def test_basic_integration(mock_client): assert isinstance(props["$ai_latency"], float) -@pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") async def test_basic_async_integration(mock_client): """Test async non-streaming integration.""" @@ -459,7 +455,6 @@ async def mock_async_create(**kwargs): assert isinstance(props["$ai_latency"], float) -@pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") async def test_async_streaming_system_prompt(mock_client): """Test async streaming with system prompt.""" diff --git a/posthog/test/ai/gemini/test_gemini.py b/posthog/test/ai/gemini/test_gemini.py index f874ce4a..e66164a9 100644 --- a/posthog/test/ai/gemini/test_gemini.py +++ b/posthog/test/ai/gemini/test_gemini.py @@ -31,6 +31,9 @@ def mock_gemini_response(): mock_usage = MagicMock() mock_usage.prompt_token_count = 20 mock_usage.candidates_token_count = 10 + # Ensure cache and reasoning tokens are not present (not MagicMock) + mock_usage.cached_content_token_count = 0 + mock_usage.thoughts_token_count = 0 mock_response.usage_metadata = mock_usage mock_candidate = MagicMock() @@ -64,6 +67,8 @@ def mock_gemini_response_with_function_calls(): mock_usage = MagicMock() mock_usage.prompt_token_count = 25 mock_usage.candidates_token_count = 15 + mock_usage.cached_content_token_count = 0 + mock_usage.thoughts_token_count = 0 mock_response.usage_metadata = mock_usage # Mock function call @@ -110,6 +115,8 @@ def mock_gemini_response_function_calls_only(): mock_usage = MagicMock() mock_usage.prompt_token_count = 30 mock_usage.candidates_token_count = 12 + mock_usage.cached_content_token_count = 0 + mock_usage.thoughts_token_count = 0 mock_response.usage_metadata = mock_usage # Mock function call @@ -180,6 +187,8 @@ def mock_streaming_response(): mock_usage1 = MagicMock() mock_usage1.prompt_token_count = 10 mock_usage1.candidates_token_count = 5 + mock_usage1.cached_content_token_count = 0 + mock_usage1.thoughts_token_count = 0 mock_chunk1.usage_metadata = mock_usage1 mock_chunk2 = MagicMock() @@ -187,6 +196,8 @@ def mock_streaming_response(): mock_usage2 = MagicMock() mock_usage2.prompt_token_count = 10 mock_usage2.candidates_token_count = 10 + mock_usage2.cached_content_token_count = 0 + mock_usage2.thoughts_token_count = 0 mock_chunk2.usage_metadata = mock_usage2 yield mock_chunk1 @@ -235,6 +246,8 @@ def mock_streaming_response(): mock_usage1 = MagicMock() mock_usage1.prompt_token_count = 15 mock_usage1.candidates_token_count = 5 + mock_usage1.cached_content_token_count = 0 + mock_usage1.thoughts_token_count = 0 mock_chunk1.usage_metadata = mock_usage1 mock_chunk2 = MagicMock() @@ -242,6 +255,8 @@ def mock_streaming_response(): mock_usage2 = MagicMock() mock_usage2.prompt_token_count = 15 mock_usage2.candidates_token_count = 10 + mock_usage2.cached_content_token_count = 0 + mock_usage2.thoughts_token_count = 0 mock_chunk2.usage_metadata = mock_usage2 yield mock_chunk1 @@ -730,3 +745,93 @@ def test_function_calls_only_no_content( assert props["$ai_input_tokens"] == 30 assert props["$ai_output_tokens"] == 12 assert props["$ai_http_status"] == 200 + + +def test_cache_and_reasoning_tokens(mock_client, mock_google_genai_client): + """Test that cache and reasoning tokens are properly extracted""" + # Create a mock response with cache and reasoning tokens + mock_response = MagicMock() + mock_response.text = "Test response with cache" + + mock_usage = MagicMock() + mock_usage.prompt_token_count = 100 + mock_usage.candidates_token_count = 50 + mock_usage.cached_content_token_count = 30 # Cache tokens + mock_usage.thoughts_token_count = 10 # Reasoning tokens + mock_response.usage_metadata = mock_usage + + # Mock candidates + mock_candidate = MagicMock() + mock_candidate.text = "Test response with cache" + mock_response.candidates = [mock_candidate] + + mock_google_genai_client.models.generate_content.return_value = mock_response + + client = Client(api_key="test-key", posthog_client=mock_client) + + response = client.models.generate_content( + model="gemini-2.5-pro", + contents="Test with cache", + posthog_distinct_id="test-id", + ) + + assert response == mock_response + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Check that all token types are present + assert props["$ai_input_tokens"] == 100 + assert props["$ai_output_tokens"] == 50 + assert props["$ai_cache_read_input_tokens"] == 30 + assert props["$ai_reasoning_tokens"] == 10 + + +def test_streaming_cache_and_reasoning_tokens(mock_client, mock_google_genai_client): + """Test that cache and reasoning tokens are properly extracted in streaming""" + # Create mock chunks with cache and reasoning tokens + chunk1 = MagicMock() + chunk1.text = "Hello " + chunk1_usage = MagicMock() + chunk1_usage.prompt_token_count = 100 + chunk1_usage.candidates_token_count = 5 + chunk1_usage.cached_content_token_count = 30 # Cache tokens + chunk1_usage.thoughts_token_count = 0 + chunk1.usage_metadata = chunk1_usage + + chunk2 = MagicMock() + chunk2.text = "world!" + chunk2_usage = MagicMock() + chunk2_usage.prompt_token_count = 100 + chunk2_usage.candidates_token_count = 10 + chunk2_usage.cached_content_token_count = 30 # Same cache tokens + chunk2_usage.thoughts_token_count = 5 # Reasoning tokens + chunk2.usage_metadata = chunk2_usage + + mock_stream = iter([chunk1, chunk2]) + mock_google_genai_client.models.generate_content_stream.return_value = mock_stream + + client = Client(api_key="test-key", posthog_client=mock_client) + + response = client.models.generate_content_stream( + model="gemini-2.5-pro", + contents="Test streaming with cache", + posthog_distinct_id="test-id", + ) + + # Consume the stream + result = list(response) + assert len(result) == 2 + + # Check PostHog capture was called + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Check that all token types are present (should use final chunk's usage) + assert props["$ai_input_tokens"] == 100 + assert props["$ai_output_tokens"] == 10 + assert props["$ai_cache_read_input_tokens"] == 30 + assert props["$ai_reasoning_tokens"] == 5 diff --git a/posthog/version.py b/posthog/version.py index 7cda3ab0..9a4f295e 100644 --- a/posthog/version.py +++ b/posthog/version.py @@ -1,4 +1,4 @@ -VERSION = "6.7.2" +VERSION = "6.7.3" if __name__ == "__main__": print(VERSION, end="") # noqa: T201