diff --git a/CHANGELOG.md b/CHANGELOG.md index 333daad4..275f2b64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 6.7.2 - 2025-09-03 + +- fix: tool call results in streaming providers + # 6.7.1 - 2025-09-01 - fix: Add base64 inline image sanitization diff --git a/mypy-baseline.txt b/mypy-baseline.txt index 7289247f..0e32db71 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -36,10 +36,5 @@ posthog/client.py:0: error: "None" has no attribute "start" [attr-defined] posthog/client.py:0: error: "None" has no attribute "get" [attr-defined] posthog/client.py:0: error: Statement is unreachable [unreachable] posthog/client.py:0: error: Statement is unreachable [unreachable] -posthog/ai/utils.py:0: error: Need type annotation for "output" (hint: "output: list[] = ...") [var-annotated] -posthog/ai/utils.py:0: error: Function "builtins.any" is not valid as a type [valid-type] -posthog/ai/utils.py:0: note: Perhaps you meant "typing.Any" instead of "any"? -posthog/ai/utils.py:0: error: Function "builtins.any" is not valid as a type [valid-type] -posthog/ai/utils.py:0: note: Perhaps you meant "typing.Any" instead of "any"? posthog/client.py:0: error: Name "urlparse" already defined (possibly by an import) [no-redef] posthog/client.py:0: error: Name "parse_qs" already defined (possibly by an import) [no-redef] diff --git a/posthog/ai/anthropic/__init__.py b/posthog/ai/anthropic/__init__.py index 82fabcb1..3648625f 100644 --- a/posthog/ai/anthropic/__init__.py +++ b/posthog/ai/anthropic/__init__.py @@ -6,6 +6,12 @@ AsyncAnthropicBedrock, AsyncAnthropicVertex, ) +from .anthropic_converter import ( + format_anthropic_response, + format_anthropic_input, + extract_anthropic_tools, + format_anthropic_streaming_content, +) __all__ = [ "Anthropic", @@ -14,4 +20,8 @@ "AsyncAnthropicBedrock", "AnthropicVertex", "AsyncAnthropicVertex", + "format_anthropic_response", + "format_anthropic_input", + "extract_anthropic_tools", + "format_anthropic_streaming_content", ] diff --git a/posthog/ai/anthropic/anthropic.py b/posthog/ai/anthropic/anthropic.py index ffb34dee..80000f43 100644 --- a/posthog/ai/anthropic/anthropic.py +++ b/posthog/ai/anthropic/anthropic.py @@ -8,13 +8,19 @@ import time import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional +from posthog.ai.types import StreamingContentBlock, ToolInProgress from posthog.ai.utils import ( call_llm_and_track_usage, - get_model_params, - merge_system_prompt, - with_privacy_mode, + merge_usage_stats, +) +from posthog.ai.anthropic.anthropic_converter import ( + extract_anthropic_usage_from_event, + handle_anthropic_content_block_start, + handle_anthropic_text_delta, + handle_anthropic_tool_delta, + finalize_anthropic_tool_input, ) from posthog.ai.sanitization import sanitize_anthropic from posthog.client import Client as PostHogClient @@ -62,6 +68,7 @@ def create( posthog_groups: Optional group analytics properties **kwargs: Arguments passed to Anthropic's messages.create """ + if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) @@ -120,34 +127,65 @@ def _create_streaming( ): start_time = time.time() usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0} - accumulated_content = [] + accumulated_content = "" + content_blocks: List[StreamingContentBlock] = [] + tools_in_progress: Dict[str, ToolInProgress] = {} + current_text_block: Optional[StreamingContentBlock] = None response = super().create(**kwargs) def generator(): nonlocal usage_stats - nonlocal accumulated_content # noqa: F824 + nonlocal accumulated_content + nonlocal content_blocks + nonlocal tools_in_progress + nonlocal current_text_block + try: for event in response: - if hasattr(event, "usage") and event.usage: - usage_stats = { - k: getattr(event.usage, k, 0) - for k in [ - "input_tokens", - "output_tokens", - "cache_read_input_tokens", - "cache_creation_input_tokens", - ] - } - - if hasattr(event, "content") and event.content: - accumulated_content.append(event.content) + # Extract usage stats from event + event_usage = extract_anthropic_usage_from_event(event) + merge_usage_stats(usage_stats, event_usage) + + # Handle content block start events + if hasattr(event, "type") and event.type == "content_block_start": + block, tool = handle_anthropic_content_block_start(event) + + if block: + content_blocks.append(block) + + if block.get("type") == "text": + current_text_block = block + else: + current_text_block = None + + if tool: + tool_id = tool["block"].get("id") + if tool_id: + tools_in_progress[tool_id] = tool + + # Handle text delta events + delta_text = handle_anthropic_text_delta(event, current_text_block) + + if delta_text: + accumulated_content += delta_text + + # Handle tool input delta events + handle_anthropic_tool_delta( + event, content_blocks, tools_in_progress + ) + + # Handle content block stop events + if hasattr(event, "type") and event.type == "content_block_stop": + current_text_block = None + finalize_anthropic_tool_input( + event, content_blocks, tools_in_progress + ) yield event finally: end_time = time.time() latency = end_time - start_time - output = "".join(accumulated_content) self._capture_streaming_event( posthog_distinct_id, @@ -158,7 +196,8 @@ def generator(): kwargs, usage_stats, latency, - output, + content_blocks, + accumulated_content, ) return generator() @@ -173,47 +212,38 @@ def _capture_streaming_event( kwargs: Dict[str, Any], usage_stats: Dict[str, int], latency: float, - output: str, + content_blocks: List[StreamingContentBlock], + accumulated_content: str, ): - if posthog_trace_id is None: - posthog_trace_id = str(uuid.uuid4()) - - event_properties = { - "$ai_provider": "anthropic", - "$ai_model": kwargs.get("model"), - "$ai_model_parameters": get_model_params(kwargs), - "$ai_input": with_privacy_mode( - self._client._ph_client, - posthog_privacy_mode, - sanitize_anthropic(merge_system_prompt(kwargs, "anthropic")), - ), - "$ai_output_choices": with_privacy_mode( - self._client._ph_client, - posthog_privacy_mode, - [{"content": output, "role": "assistant"}], - ), - "$ai_http_status": 200, - "$ai_input_tokens": usage_stats.get("input_tokens", 0), - "$ai_output_tokens": usage_stats.get("output_tokens", 0), - "$ai_cache_read_input_tokens": usage_stats.get( - "cache_read_input_tokens", 0 - ), - "$ai_cache_creation_input_tokens": usage_stats.get( - "cache_creation_input_tokens", 0 + from posthog.ai.types import StreamingEventData + from posthog.ai.anthropic.anthropic_converter import ( + standardize_anthropic_usage, + format_anthropic_streaming_input, + format_anthropic_streaming_output_complete, + ) + from posthog.ai.utils import capture_streaming_event + + # Prepare standardized event data + formatted_input = format_anthropic_streaming_input(kwargs) + sanitized_input = sanitize_anthropic(formatted_input) + + event_data = StreamingEventData( + provider="anthropic", + model=kwargs.get("model", "unknown"), + base_url=str(self._client.base_url), + kwargs=kwargs, + formatted_input=sanitized_input, + formatted_output=format_anthropic_streaming_output_complete( + content_blocks, accumulated_content ), - "$ai_latency": latency, - "$ai_trace_id": posthog_trace_id, - "$ai_base_url": str(self._client.base_url), - **(posthog_properties or {}), - } - - if posthog_distinct_id is None: - event_properties["$process_person_profile"] = False - - if hasattr(self._client._ph_client, "capture"): - self._client._ph_client.capture( - distinct_id=posthog_distinct_id or posthog_trace_id, - event="$ai_generation", - properties=event_properties, - groups=posthog_groups, - ) + usage_stats=standardize_anthropic_usage(usage_stats), + latency=latency, + distinct_id=posthog_distinct_id, + trace_id=posthog_trace_id, + properties=posthog_properties, + privacy_mode=posthog_privacy_mode, + groups=posthog_groups, + ) + + # Use the common capture function + capture_streaming_event(self._client._ph_client, event_data) diff --git a/posthog/ai/anthropic/anthropic_async.py b/posthog/ai/anthropic/anthropic_async.py index afb8dc58..34233333 100644 --- a/posthog/ai/anthropic/anthropic_async.py +++ b/posthog/ai/anthropic/anthropic_async.py @@ -8,15 +8,26 @@ import time import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from posthog import setup +from posthog.ai.types import StreamingContentBlock, ToolInProgress from posthog.ai.utils import ( call_llm_and_track_usage_async, + extract_available_tool_calls, get_model_params, merge_system_prompt, + merge_usage_stats, with_privacy_mode, ) +from posthog.ai.anthropic.anthropic_converter import ( + format_anthropic_streaming_content, + extract_anthropic_usage_from_event, + handle_anthropic_content_block_start, + handle_anthropic_text_delta, + handle_anthropic_tool_delta, + finalize_anthropic_tool_input, +) from posthog.ai.sanitization import sanitize_anthropic from posthog.client import Client as PostHogClient @@ -62,6 +73,7 @@ async def create( posthog_groups: Optional group analytics properties **kwargs: Arguments passed to Anthropic's messages.create """ + if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) @@ -120,34 +132,65 @@ async def _create_streaming( ): start_time = time.time() usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0} - accumulated_content = [] + accumulated_content = "" + content_blocks: List[StreamingContentBlock] = [] + tools_in_progress: Dict[str, ToolInProgress] = {} + current_text_block: Optional[StreamingContentBlock] = None response = await super().create(**kwargs) async def generator(): nonlocal usage_stats - nonlocal accumulated_content # noqa: F824 + nonlocal accumulated_content + nonlocal content_blocks + nonlocal tools_in_progress + nonlocal current_text_block + try: async for event in response: - if hasattr(event, "usage") and event.usage: - usage_stats = { - k: getattr(event.usage, k, 0) - for k in [ - "input_tokens", - "output_tokens", - "cache_read_input_tokens", - "cache_creation_input_tokens", - ] - } - - if hasattr(event, "content") and event.content: - accumulated_content.append(event.content) + # Extract usage stats from event + event_usage = extract_anthropic_usage_from_event(event) + merge_usage_stats(usage_stats, event_usage) + + # Handle content block start events + if hasattr(event, "type") and event.type == "content_block_start": + block, tool = handle_anthropic_content_block_start(event) + + if block: + content_blocks.append(block) + + if block.get("type") == "text": + current_text_block = block + else: + current_text_block = None + + if tool: + tool_id = tool["block"].get("id") + if tool_id: + tools_in_progress[tool_id] = tool + + # Handle text delta events + delta_text = handle_anthropic_text_delta(event, current_text_block) + + if delta_text: + accumulated_content += delta_text + + # Handle tool input delta events + handle_anthropic_tool_delta( + event, content_blocks, tools_in_progress + ) + + # Handle content block stop events + if hasattr(event, "type") and event.type == "content_block_stop": + current_text_block = None + finalize_anthropic_tool_input( + event, content_blocks, tools_in_progress + ) yield event finally: end_time = time.time() latency = end_time - start_time - output = "".join(accumulated_content) await self._capture_streaming_event( posthog_distinct_id, @@ -158,7 +201,8 @@ async def generator(): kwargs, usage_stats, latency, - output, + content_blocks, + accumulated_content, ) return generator() @@ -173,11 +217,27 @@ async def _capture_streaming_event( kwargs: Dict[str, Any], usage_stats: Dict[str, int], latency: float, - output: str, + content_blocks: List[StreamingContentBlock], + accumulated_content: str, ): if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) + # Format output using converter + formatted_content = format_anthropic_streaming_content(content_blocks) + formatted_output = [] + + if formatted_content: + formatted_output = [{"role": "assistant", "content": formatted_content}] + else: + # Fallback to accumulated content if no blocks + formatted_output = [ + { + "role": "assistant", + "content": [{"type": "text", "text": accumulated_content}], + } + ] + event_properties = { "$ai_provider": "anthropic", "$ai_model": kwargs.get("model"), @@ -190,7 +250,7 @@ async def _capture_streaming_event( "$ai_output_choices": with_privacy_mode( self._client._ph_client, posthog_privacy_mode, - [{"content": output, "role": "assistant"}], + formatted_output, ), "$ai_http_status": 200, "$ai_input_tokens": usage_stats.get("input_tokens", 0), @@ -207,6 +267,12 @@ async def _capture_streaming_event( **(posthog_properties or {}), } + # Add tools if available + available_tools = extract_available_tool_calls("anthropic", kwargs) + + if available_tools: + event_properties["$ai_tools"] = available_tools + if posthog_distinct_id is None: event_properties["$process_person_profile"] = False diff --git a/posthog/ai/anthropic/anthropic_converter.py b/posthog/ai/anthropic/anthropic_converter.py new file mode 100644 index 00000000..7ed96268 --- /dev/null +++ b/posthog/ai/anthropic/anthropic_converter.py @@ -0,0 +1,393 @@ +""" +Anthropic-specific conversion utilities. + +This module handles the conversion of Anthropic API responses and inputs +into standardized formats for PostHog tracking. +""" + +import json +from typing import Any, Dict, List, Optional, Tuple + +from posthog.ai.types import ( + FormattedContentItem, + FormattedFunctionCall, + FormattedMessage, + FormattedTextContent, + StreamingContentBlock, + StreamingUsageStats, + TokenUsage, + ToolInProgress, +) + + +def format_anthropic_response(response: Any) -> List[FormattedMessage]: + """ + Format an Anthropic response into standardized message format. + + Args: + response: The response object from Anthropic API + + Returns: + List of formatted messages with role and content + """ + + output: List[FormattedMessage] = [] + + if response is None: + return output + + content: List[FormattedContentItem] = [] + + # Process content blocks from the response + if hasattr(response, "content"): + for choice in response.content: + if ( + hasattr(choice, "type") + and choice.type == "text" + and hasattr(choice, "text") + and choice.text + ): + text_content: FormattedTextContent = { + "type": "text", + "text": choice.text, + } + content.append(text_content) + + elif ( + hasattr(choice, "type") + and choice.type == "tool_use" + and hasattr(choice, "name") + and hasattr(choice, "id") + ): + function_call: FormattedFunctionCall = { + "type": "function", + "id": choice.id, + "function": { + "name": choice.name, + "arguments": getattr(choice, "input", {}), + }, + } + content.append(function_call) + + if content: + message: FormattedMessage = { + "role": "assistant", + "content": content, + } + output.append(message) + + return output + + +def format_anthropic_input( + messages: List[Dict[str, Any]], system: Optional[str] = None +) -> List[FormattedMessage]: + """ + Format Anthropic input messages with optional system prompt. + + Args: + messages: List of message dictionaries + system: Optional system prompt to prepend + + Returns: + List of formatted messages + """ + + formatted_messages: List[FormattedMessage] = [] + + # Add system message if provided + if system is not None: + formatted_messages.append({"role": "system", "content": system}) + + # Add user messages + if messages: + for msg in messages: + # Messages are already in the correct format, just ensure type safety + formatted_msg: FormattedMessage = { + "role": msg.get("role", "user"), + "content": msg.get("content", ""), + } + formatted_messages.append(formatted_msg) + + return formatted_messages + + +def extract_anthropic_tools(kwargs: Dict[str, Any]) -> Optional[Any]: + """ + Extract tool definitions from Anthropic API kwargs. + + Args: + kwargs: Keyword arguments passed to Anthropic API + + Returns: + Tool definitions if present, None otherwise + """ + + return kwargs.get("tools", None) + + +def format_anthropic_streaming_content( + content_blocks: List[StreamingContentBlock], +) -> List[FormattedContentItem]: + """ + Format content blocks from Anthropic streaming response. + + Used by streaming handlers to format accumulated content blocks. + + Args: + content_blocks: List of content block dictionaries from streaming + + Returns: + List of formatted content items + """ + + formatted: List[FormattedContentItem] = [] + + for block in content_blocks: + if block.get("type") == "text": + formatted.append( + { + "type": "text", + "text": block.get("text") or "", + } + ) + + elif block.get("type") == "function": + formatted.append( + { + "type": "function", + "id": block.get("id"), + "function": block.get("function") or {}, + } + ) + + return formatted + + +def extract_anthropic_usage_from_event(event: Any) -> StreamingUsageStats: + """ + Extract usage statistics from an Anthropic streaming event. + + Args: + event: Streaming event from Anthropic API + + Returns: + Dictionary of usage statistics + """ + + usage: StreamingUsageStats = {} + + # Handle usage stats from message_start event + if hasattr(event, "type") and event.type == "message_start": + if hasattr(event, "message") and hasattr(event.message, "usage"): + usage["input_tokens"] = getattr(event.message.usage, "input_tokens", 0) + usage["cache_creation_input_tokens"] = getattr( + event.message.usage, "cache_creation_input_tokens", 0 + ) + usage["cache_read_input_tokens"] = getattr( + event.message.usage, "cache_read_input_tokens", 0 + ) + + # Handle usage stats from message_delta event + if hasattr(event, "usage") and event.usage: + usage["output_tokens"] = getattr(event.usage, "output_tokens", 0) + + return usage + + +def handle_anthropic_content_block_start( + event: Any, +) -> Tuple[Optional[StreamingContentBlock], Optional[ToolInProgress]]: + """ + Handle content block start event from Anthropic streaming. + + Args: + event: Content block start event + + Returns: + Tuple of (content_block, tool_in_progress) + """ + + if not (hasattr(event, "type") and event.type == "content_block_start"): + return None, None + + if not hasattr(event, "content_block"): + return None, None + + block = event.content_block + + if not hasattr(block, "type"): + return None, None + + if block.type == "text": + content_block: StreamingContentBlock = {"type": "text", "text": ""} + return content_block, None + + elif block.type == "tool_use": + tool_block: StreamingContentBlock = { + "type": "function", + "id": getattr(block, "id", ""), + "function": {"name": getattr(block, "name", ""), "arguments": {}}, + } + tool_in_progress: ToolInProgress = {"block": tool_block, "input_string": ""} + return tool_block, tool_in_progress + + return None, None + + +def handle_anthropic_text_delta( + event: Any, current_block: Optional[StreamingContentBlock] +) -> Optional[str]: + """ + Handle text delta event from Anthropic streaming. + + Args: + event: Delta event + current_block: Current text block being accumulated + + Returns: + Text delta if present + """ + + if hasattr(event, "delta") and hasattr(event.delta, "text"): + delta_text = event.delta.text or "" + + if current_block is not None and current_block.get("type") == "text": + text_val = current_block.get("text") + if text_val is not None: + current_block["text"] = text_val + delta_text + else: + current_block["text"] = delta_text + + return delta_text + + return None + + +def handle_anthropic_tool_delta( + event: Any, + content_blocks: List[StreamingContentBlock], + tools_in_progress: Dict[str, ToolInProgress], +) -> None: + """ + Handle tool input delta event from Anthropic streaming. + + Args: + event: Tool delta event + content_blocks: List of content blocks + tools_in_progress: Dictionary tracking tools being accumulated + """ + + if not (hasattr(event, "type") and event.type == "content_block_delta"): + return + + if not ( + hasattr(event, "delta") + and hasattr(event.delta, "type") + and event.delta.type == "input_json_delta" + ): + return + + if hasattr(event, "index") and event.index < len(content_blocks): + block = content_blocks[event.index] + + if block.get("type") == "function" and block.get("id") in tools_in_progress: + tool = tools_in_progress[block["id"]] + partial_json = getattr(event.delta, "partial_json", "") + tool["input_string"] += partial_json + + +def finalize_anthropic_tool_input( + event: Any, + content_blocks: List[StreamingContentBlock], + tools_in_progress: Dict[str, ToolInProgress], +) -> None: + """ + Finalize tool input when content block stops. + + Args: + event: Content block stop event + content_blocks: List of content blocks + tools_in_progress: Dictionary tracking tools being accumulated + """ + + if not (hasattr(event, "type") and event.type == "content_block_stop"): + return + + if hasattr(event, "index") and event.index < len(content_blocks): + block = content_blocks[event.index] + + if block.get("type") == "function" and block.get("id") in tools_in_progress: + tool = tools_in_progress[block["id"]] + + try: + block["function"]["arguments"] = json.loads(tool["input_string"]) + except (json.JSONDecodeError, Exception): + # Keep empty dict if parsing fails + pass + + del tools_in_progress[block["id"]] + + +def standardize_anthropic_usage(usage: Dict[str, Any]) -> TokenUsage: + """ + Standardize Anthropic usage statistics to common TokenUsage format. + + Anthropic already uses standard field names, so this mainly structures the data. + + Args: + usage: Raw usage statistics from Anthropic + + Returns: + Standardized TokenUsage dict + """ + return TokenUsage( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + cache_read_input_tokens=usage.get("cache_read_input_tokens"), + cache_creation_input_tokens=usage.get("cache_creation_input_tokens"), + ) + + +def format_anthropic_streaming_input(kwargs: Dict[str, Any]) -> Any: + """ + Format Anthropic streaming input using system prompt merging. + + Args: + kwargs: Keyword arguments passed to Anthropic API + + Returns: + Formatted input ready for PostHog tracking + """ + from posthog.ai.utils import merge_system_prompt + + return merge_system_prompt(kwargs, "anthropic") + + +def format_anthropic_streaming_output_complete( + content_blocks: List[StreamingContentBlock], accumulated_content: str +) -> List[FormattedMessage]: + """ + Format complete Anthropic streaming output. + + Combines existing logic for formatting content blocks with fallback to accumulated content. + + Args: + content_blocks: List of content blocks accumulated during streaming + accumulated_content: Raw accumulated text content as fallback + + Returns: + Formatted messages ready for PostHog tracking + """ + formatted_content = format_anthropic_streaming_content(content_blocks) + + if formatted_content: + return [{"role": "assistant", "content": formatted_content}] + else: + # Fallback to accumulated content if no blocks + return [ + { + "role": "assistant", + "content": [{"type": "text", "text": accumulated_content}], + } + ] diff --git a/posthog/ai/gemini/__init__.py b/posthog/ai/gemini/__init__.py index c1d71e10..eb17989d 100644 --- a/posthog/ai/gemini/__init__.py +++ b/posthog/ai/gemini/__init__.py @@ -1,4 +1,9 @@ from .gemini import Client +from .gemini_converter import ( + format_gemini_input, + format_gemini_response, + extract_gemini_tools, +) # Create a genai-like module for perfect drop-in replacement @@ -8,4 +13,10 @@ class _GenAI: genai = _GenAI() -__all__ = ["Client", "genai"] +__all__ = [ + "Client", + "genai", + "format_gemini_input", + "format_gemini_response", + "extract_gemini_tools", +] diff --git a/posthog/ai/gemini/gemini.py b/posthog/ai/gemini/gemini.py index 9de0e0a9..0be9c673 100644 --- a/posthog/ai/gemini/gemini.py +++ b/posthog/ai/gemini/gemini.py @@ -13,8 +13,14 @@ from posthog import setup from posthog.ai.utils import ( call_llm_and_track_usage, - get_model_params, - with_privacy_mode, + capture_streaming_event, + merge_usage_stats, +) +from posthog.ai.gemini.gemini_converter import ( + format_gemini_input, + extract_gemini_usage_from_chunk, + extract_gemini_content_from_chunk, + format_gemini_streaming_output, ) from posthog.ai.sanitization import sanitize_gemini from posthog.client import Client as PostHogClient @@ -72,6 +78,7 @@ def __init__( posthog_groups: Default groups for all calls (can be overridden per call) **kwargs: Additional arguments (for future compatibility) """ + self._ph_client = posthog_client or setup() if self._ph_client is None: @@ -133,6 +140,7 @@ def __init__( posthog_groups: Default groups for all calls **kwargs: Additional arguments (for future compatibility) """ + self._ph_client = posthog_client or setup() if self._ph_client is None: @@ -150,14 +158,19 @@ def __init__( # Add Vertex AI parameters if provided if vertexai is not None: client_args["vertexai"] = vertexai + if credentials is not None: client_args["credentials"] = credentials + if project is not None: client_args["project"] = project + if location is not None: client_args["location"] = location + if debug_config is not None: client_args["debug_config"] = debug_config + if http_options is not None: client_args["http_options"] = http_options @@ -175,6 +188,7 @@ def __init__( raise ValueError( "API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable" ) + client_args["api_key"] = api_key self._client = genai.Client(**client_args) @@ -189,6 +203,7 @@ def _merge_posthog_params( call_groups: Optional[Dict[str, Any]], ): """Merge call-level PostHog parameters with client defaults.""" + # Use call-level values if provided, otherwise fall back to defaults distinct_id = ( call_distinct_id @@ -204,6 +219,7 @@ def _merge_posthog_params( # Merge properties: default properties + call properties (call properties override) properties = dict(self._default_properties) + if call_properties: properties.update(call_properties) @@ -239,6 +255,7 @@ def generate_content( posthog_groups: Group analytics properties (overrides client default) **kwargs: Arguments passed to Gemini's generate_content """ + # Merge PostHog parameters distinct_id, trace_id, properties, privacy_mode, groups = ( self._merge_posthog_params( @@ -288,25 +305,24 @@ def generator(): nonlocal accumulated_content # noqa: F824 try: for chunk in response: - if hasattr(chunk, "usage_metadata") and chunk.usage_metadata: - usage_stats = { - "input_tokens": getattr( - chunk.usage_metadata, "prompt_token_count", 0 - ), - "output_tokens": getattr( - chunk.usage_metadata, "candidates_token_count", 0 - ), - } - - if hasattr(chunk, "text") and chunk.text: - accumulated_content.append(chunk.text) + # Extract usage stats from chunk + chunk_usage = extract_gemini_usage_from_chunk(chunk) + + if chunk_usage: + # Gemini reports cumulative totals, not incremental values + merge_usage_stats(usage_stats, chunk_usage, mode="cumulative") + + # Extract content from chunk (now returns content blocks) + content_block = extract_gemini_content_from_chunk(chunk) + + if content_block is not None: + accumulated_content.append(content_block) yield chunk finally: end_time = time.time() latency = end_time - start_time - output = "".join(accumulated_content) self._capture_streaming_event( model, @@ -319,7 +335,7 @@ def generator(): kwargs, usage_stats, latency, - output, + accumulated_content, ) return generator() @@ -336,61 +352,38 @@ def _capture_streaming_event( kwargs: Dict[str, Any], usage_stats: Dict[str, int], latency: float, - output: str, + output: Any, ): - if trace_id is None: - trace_id = str(uuid.uuid4()) - - event_properties = { - "$ai_provider": "gemini", - "$ai_model": model, - "$ai_model_parameters": get_model_params(kwargs), - "$ai_input": with_privacy_mode( - self._ph_client, - privacy_mode, - sanitize_gemini(self._format_input(contents)), - ), - "$ai_output_choices": with_privacy_mode( - self._ph_client, - privacy_mode, - [{"content": output, "role": "assistant"}], - ), - "$ai_http_status": 200, - "$ai_input_tokens": usage_stats.get("input_tokens", 0), - "$ai_output_tokens": usage_stats.get("output_tokens", 0), - "$ai_latency": latency, - "$ai_trace_id": trace_id, - "$ai_base_url": self._base_url, - **(properties or {}), - } - - if distinct_id is None: - event_properties["$process_person_profile"] = False - - if hasattr(self._ph_client, "capture"): - self._ph_client.capture( - distinct_id=distinct_id, - event="$ai_generation", - properties=event_properties, - groups=groups, - ) + from posthog.ai.types import StreamingEventData + from posthog.ai.gemini.gemini_converter import standardize_gemini_usage + + # Prepare standardized event data + formatted_input = self._format_input(contents) + sanitized_input = sanitize_gemini(formatted_input) + + event_data = StreamingEventData( + provider="gemini", + model=model, + base_url=self._base_url, + kwargs=kwargs, + formatted_input=sanitized_input, + formatted_output=format_gemini_streaming_output(output), + usage_stats=standardize_gemini_usage(usage_stats), + latency=latency, + distinct_id=distinct_id, + trace_id=trace_id, + properties=properties, + privacy_mode=privacy_mode, + groups=groups, + ) + + # Use the common capture function + capture_streaming_event(self._ph_client, event_data) def _format_input(self, contents): """Format input contents for PostHog tracking""" - if isinstance(contents, str): - return [{"role": "user", "content": contents}] - elif isinstance(contents, list): - formatted = [] - for item in contents: - if isinstance(item, str): - formatted.append({"role": "user", "content": item}) - elif hasattr(item, "text"): - formatted.append({"role": "user", "content": item.text}) - else: - formatted.append({"role": "user", "content": str(item)}) - return formatted - else: - return [{"role": "user", "content": str(contents)}] + + return format_gemini_input(contents) def generate_content_stream( self, diff --git a/posthog/ai/gemini/gemini_converter.py b/posthog/ai/gemini/gemini_converter.py new file mode 100644 index 00000000..b5296de1 --- /dev/null +++ b/posthog/ai/gemini/gemini_converter.py @@ -0,0 +1,438 @@ +""" +Gemini-specific conversion utilities. + +This module handles the conversion of Gemini API responses and inputs +into standardized formats for PostHog tracking. +""" + +from typing import Any, Dict, List, Optional, TypedDict, Union + +from posthog.ai.types import ( + FormattedContentItem, + FormattedMessage, + StreamingUsageStats, + TokenUsage, +) + + +class GeminiPart(TypedDict, total=False): + """Represents a part in a Gemini message.""" + + text: str + + +class GeminiMessage(TypedDict, total=False): + """Represents a Gemini message with various possible fields.""" + + role: str + parts: List[Union[GeminiPart, Dict[str, Any]]] + content: Union[str, List[Any]] + text: str + + +def _extract_text_from_parts(parts: List[Any]) -> str: + """ + Extract and concatenate text from a parts array. + + Args: + parts: List of parts that may contain text content + + Returns: + Concatenated text from all parts + """ + + content_parts = [] + + for part in parts: + if isinstance(part, dict) and "text" in part: + content_parts.append(part["text"]) + + elif isinstance(part, str): + content_parts.append(part) + + elif hasattr(part, "text"): + # Get the text attribute value + text_value = getattr(part, "text", "") + content_parts.append(text_value if text_value else str(part)) + + else: + content_parts.append(str(part)) + + return "".join(content_parts) + + +def _format_dict_message(item: Dict[str, Any]) -> FormattedMessage: + """ + Format a dictionary message into standardized format. + + Args: + item: Dictionary containing message data + + Returns: + Formatted message with role and content + """ + + # Handle dict format with parts array (Gemini-specific format) + if "parts" in item and isinstance(item["parts"], list): + content = _extract_text_from_parts(item["parts"]) + return {"role": item.get("role", "user"), "content": content} + + # Handle dict with content field + if "content" in item: + content = item["content"] + + if isinstance(content, list): + # If content is a list, extract text from it + content = _extract_text_from_parts(content) + + elif not isinstance(content, str): + content = str(content) + + return {"role": item.get("role", "user"), "content": content} + + # Handle dict with text field + if "text" in item: + return {"role": item.get("role", "user"), "content": item["text"]} + + # Fallback to string representation + return {"role": "user", "content": str(item)} + + +def _format_object_message(item: Any) -> FormattedMessage: + """ + Format an object (with attributes) into standardized format. + + Args: + item: Object that may have text or parts attributes + + Returns: + Formatted message with role and content + """ + + # Handle object with parts attribute + if hasattr(item, "parts") and hasattr(item.parts, "__iter__"): + content = _extract_text_from_parts(item.parts) + role = getattr(item, "role", "user") if hasattr(item, "role") else "user" + + # Ensure role is a string + if not isinstance(role, str): + role = "user" + + return {"role": role, "content": content} + + # Handle object with text attribute + if hasattr(item, "text"): + role = getattr(item, "role", "user") if hasattr(item, "role") else "user" + + # Ensure role is a string + if not isinstance(role, str): + role = "user" + + return {"role": role, "content": item.text} + + # Handle object with content attribute + if hasattr(item, "content"): + role = getattr(item, "role", "user") if hasattr(item, "role") else "user" + + # Ensure role is a string + if not isinstance(role, str): + role = "user" + + content = item.content + + if isinstance(content, list): + content = _extract_text_from_parts(content) + + elif not isinstance(content, str): + content = str(content) + return {"role": role, "content": content} + + # Fallback to string representation + return {"role": "user", "content": str(item)} + + +def format_gemini_response(response: Any) -> List[FormattedMessage]: + """ + Format a Gemini response into standardized message format. + + Args: + response: The response object from Gemini API + + Returns: + List of formatted messages with role and content + """ + + output: List[FormattedMessage] = [] + + if response is None: + return output + + if hasattr(response, "candidates") and response.candidates: + for candidate in response.candidates: + if hasattr(candidate, "content") and candidate.content: + content: List[FormattedContentItem] = [] + + if hasattr(candidate.content, "parts") and candidate.content.parts: + for part in candidate.content.parts: + if hasattr(part, "text") and part.text: + content.append( + { + "type": "text", + "text": part.text, + } + ) + + elif hasattr(part, "function_call") and part.function_call: + function_call = part.function_call + content.append( + { + "type": "function", + "function": { + "name": function_call.name, + "arguments": function_call.args, + }, + } + ) + + if content: + output.append( + { + "role": "assistant", + "content": content, + } + ) + + elif hasattr(candidate, "text") and candidate.text: + output.append( + { + "role": "assistant", + "content": [{"type": "text", "text": candidate.text}], + } + ) + + elif hasattr(response, "text") and response.text: + output.append( + { + "role": "assistant", + "content": [{"type": "text", "text": response.text}], + } + ) + + return output + + +def extract_gemini_tools(kwargs: Dict[str, Any]) -> Optional[Any]: + """ + Extract tool definitions from Gemini API kwargs. + + Args: + kwargs: Keyword arguments passed to Gemini API + + Returns: + Tool definitions if present, None otherwise + """ + + if "config" in kwargs and hasattr(kwargs["config"], "tools"): + return kwargs["config"].tools + + return None + + +def format_gemini_input(contents: Any) -> List[FormattedMessage]: + """ + Format Gemini input contents into standardized message format for PostHog tracking. + + This function handles various input formats: + - String inputs + - List of strings, dicts, or objects + - Single dict or object + - Gemini-specific format with parts array + + Args: + contents: Input contents in various possible formats + + Returns: + List of formatted messages with role and content fields + """ + + # Handle string input + if isinstance(contents, str): + return [{"role": "user", "content": contents}] + + # Handle list input + if isinstance(contents, list): + formatted: List[FormattedMessage] = [] + + for item in contents: + if isinstance(item, str): + formatted.append({"role": "user", "content": item}) + + elif isinstance(item, dict): + formatted.append(_format_dict_message(item)) + + else: + formatted.append(_format_object_message(item)) + + return formatted + + # Handle single dict input + if isinstance(contents, dict): + return [_format_dict_message(contents)] + + # Handle single object input + return [_format_object_message(contents)] + + +def extract_gemini_usage_from_chunk(chunk: Any) -> StreamingUsageStats: + """ + Extract usage statistics from a Gemini streaming chunk. + + Args: + chunk: Streaming chunk from Gemini API + + Returns: + Dictionary of usage statistics + """ + + usage: StreamingUsageStats = {} + + if not hasattr(chunk, "usage_metadata") or not chunk.usage_metadata: + return usage + + # Gemini uses prompt_token_count and candidates_token_count + usage["input_tokens"] = getattr(chunk.usage_metadata, "prompt_token_count", 0) + usage["output_tokens"] = getattr(chunk.usage_metadata, "candidates_token_count", 0) + + # Calculate total if both values are defined (including 0) + if "input_tokens" in usage and "output_tokens" in usage: + usage["total_tokens"] = usage["input_tokens"] + usage["output_tokens"] + + return usage + + +def extract_gemini_content_from_chunk(chunk: Any) -> Optional[Dict[str, Any]]: + """ + Extract content (text or function call) from a Gemini streaming chunk. + + Args: + chunk: Streaming chunk from Gemini API + + Returns: + Content block dictionary if present, None otherwise + """ + + # Check for text content + if hasattr(chunk, "text") and chunk.text: + return {"type": "text", "text": chunk.text} + + # Check for function calls in candidates + if hasattr(chunk, "candidates") and chunk.candidates: + for candidate in chunk.candidates: + if hasattr(candidate, "content") and candidate.content: + if hasattr(candidate.content, "parts") and candidate.content.parts: + for part in candidate.content.parts: + # Check for function_call part + if hasattr(part, "function_call") and part.function_call: + function_call = part.function_call + return { + "type": "function", + "function": { + "name": function_call.name, + "arguments": function_call.args, + }, + } + # Also check for text in parts + elif hasattr(part, "text") and part.text: + return {"type": "text", "text": part.text} + + return None + + +def format_gemini_streaming_output( + accumulated_content: Union[str, List[Any]], +) -> List[FormattedMessage]: + """ + Format the final output from Gemini streaming. + + Args: + accumulated_content: Accumulated content from streaming (string, list of strings, or list of content blocks) + + Returns: + List of formatted messages + """ + + # Handle legacy string input (backward compatibility) + if isinstance(accumulated_content, str): + return [ + { + "role": "assistant", + "content": [{"type": "text", "text": accumulated_content}], + } + ] + + # Handle list input + if isinstance(accumulated_content, list): + content: List[FormattedContentItem] = [] + text_parts = [] + + for item in accumulated_content: + if isinstance(item, str): + # Legacy support: accumulate strings + text_parts.append(item) + elif isinstance(item, dict): + # New format: content blocks + if item.get("type") == "text": + text_parts.append(item.get("text", "")) + elif item.get("type") == "function": + # If we have accumulated text, add it first + if text_parts: + content.append( + { + "type": "text", + "text": "".join(text_parts), + } + ) + text_parts = [] + + # Add the function call + content.append( + { + "type": "function", + "function": item.get("function", {}), + } + ) + + # Add any remaining text + if text_parts: + content.append( + { + "type": "text", + "text": "".join(text_parts), + } + ) + + # If we have content, return it + if content: + return [{"role": "assistant", "content": content}] + + # Fallback for empty or unexpected input + return [{"role": "assistant", "content": [{"type": "text", "text": ""}]}] + + +def standardize_gemini_usage(usage: Dict[str, Any]) -> TokenUsage: + """ + Standardize Gemini usage statistics to common TokenUsage format. + + Gemini already uses standard field names (input_tokens/output_tokens). + + Args: + usage: Raw usage statistics from Gemini + + Returns: + Standardized TokenUsage dict + """ + return TokenUsage( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + # Gemini doesn't currently support cache or reasoning tokens + ) diff --git a/posthog/ai/openai/__init__.py b/posthog/ai/openai/__init__.py index bf73e0a3..88281e1a 100644 --- a/posthog/ai/openai/__init__.py +++ b/posthog/ai/openai/__init__.py @@ -1,5 +1,20 @@ from .openai import OpenAI from .openai_async import AsyncOpenAI from .openai_providers import AsyncAzureOpenAI, AzureOpenAI +from .openai_converter import ( + format_openai_response, + format_openai_input, + extract_openai_tools, + format_openai_streaming_content, +) -__all__ = ["OpenAI", "AsyncOpenAI", "AzureOpenAI", "AsyncAzureOpenAI"] +__all__ = [ + "OpenAI", + "AsyncOpenAI", + "AzureOpenAI", + "AsyncAzureOpenAI", + "format_openai_response", + "format_openai_input", + "extract_openai_tools", + "format_openai_streaming_content", +] diff --git a/posthog/ai/openai/openai.py b/posthog/ai/openai/openai.py index e7feccd0..bdca1f6c 100644 --- a/posthog/ai/openai/openai.py +++ b/posthog/ai/openai/openai.py @@ -12,9 +12,15 @@ from posthog.ai.utils import ( call_llm_and_track_usage, extract_available_tool_calls, - get_model_params, + merge_usage_stats, with_privacy_mode, ) +from posthog.ai.openai.openai_converter import ( + extract_openai_usage_from_chunk, + extract_openai_content_from_chunk, + extract_openai_tool_calls_from_chunk, + accumulate_openai_tool_calls, +) from posthog.ai.sanitization import sanitize_openai, sanitize_openai_response from posthog.client import Client as PostHogClient from posthog import setup @@ -34,6 +40,7 @@ def __init__(self, posthog_client: Optional[PostHogClient] = None, **kwargs): posthog_client: If provided, events will be captured via this client instead of the global `posthog`. **openai_config: Any additional keyword args to set on openai (e.g. organization="xxx"). """ + super().__init__(**kwargs) self._ph_client = posthog_client or setup() @@ -123,35 +130,17 @@ def generator(): try: for chunk in response: - if hasattr(chunk, "type") and chunk.type == "response.completed": - res = chunk.response - if res.output and len(res.output) > 0: - final_content.append(res.output[0]) - - if hasattr(chunk, "usage") and chunk.usage: - usage_stats = { - k: getattr(chunk.usage, k, 0) - for k in [ - "input_tokens", - "output_tokens", - "total_tokens", - ] - } - - # Add support for cached tokens - if hasattr(chunk.usage, "output_tokens_details") and hasattr( - chunk.usage.output_tokens_details, "reasoning_tokens" - ): - usage_stats["reasoning_tokens"] = ( - chunk.usage.output_tokens_details.reasoning_tokens - ) - - if hasattr(chunk.usage, "input_tokens_details") and hasattr( - chunk.usage.input_tokens_details, "cached_tokens" - ): - usage_stats["cache_read_input_tokens"] = ( - chunk.usage.input_tokens_details.cached_tokens - ) + # Extract usage stats from chunk + chunk_usage = extract_openai_usage_from_chunk(chunk, "responses") + + if chunk_usage: + merge_usage_stats(usage_stats, chunk_usage) + + # Extract content from chunk + content = extract_openai_content_from_chunk(chunk, "responses") + + if content is not None: + final_content.append(content) yield chunk @@ -169,7 +158,7 @@ def generator(): usage_stats, latency, output, - extract_available_tool_calls("openai", kwargs), + None, # Responses API doesn't have tools ) return generator() @@ -187,49 +176,36 @@ def _capture_streaming_event( output: Any, available_tool_calls: Optional[List[Dict[str, Any]]] = None, ): - if posthog_trace_id is None: - posthog_trace_id = str(uuid.uuid4()) - - event_properties = { - "$ai_provider": "openai", - "$ai_model": kwargs.get("model"), - "$ai_model_parameters": get_model_params(kwargs), - "$ai_input": with_privacy_mode( - self._client._ph_client, - posthog_privacy_mode, - sanitize_openai_response(kwargs.get("input")), - ), - "$ai_output_choices": with_privacy_mode( - self._client._ph_client, - posthog_privacy_mode, - output, - ), - "$ai_http_status": 200, - "$ai_input_tokens": usage_stats.get("input_tokens", 0), - "$ai_output_tokens": usage_stats.get("output_tokens", 0), - "$ai_cache_read_input_tokens": usage_stats.get( - "cache_read_input_tokens", 0 - ), - "$ai_reasoning_tokens": usage_stats.get("reasoning_tokens", 0), - "$ai_latency": latency, - "$ai_trace_id": posthog_trace_id, - "$ai_base_url": str(self._client.base_url), - **(posthog_properties or {}), - } - - if available_tool_calls: - event_properties["$ai_tools"] = available_tool_calls - - if posthog_distinct_id is None: - event_properties["$process_person_profile"] = False + from posthog.ai.types import StreamingEventData + from posthog.ai.openai.openai_converter import ( + standardize_openai_usage, + format_openai_streaming_input, + format_openai_streaming_output, + ) + from posthog.ai.utils import capture_streaming_event + + # Prepare standardized event data + formatted_input = format_openai_streaming_input(kwargs, "responses") + sanitized_input = sanitize_openai_response(formatted_input) + + event_data = StreamingEventData( + provider="openai", + model=kwargs.get("model", "unknown"), + base_url=str(self._client.base_url), + kwargs=kwargs, + formatted_input=sanitized_input, + formatted_output=format_openai_streaming_output(output, "responses"), + usage_stats=standardize_openai_usage(usage_stats, "responses"), + latency=latency, + distinct_id=posthog_distinct_id, + trace_id=posthog_trace_id, + properties=posthog_properties, + privacy_mode=posthog_privacy_mode, + groups=posthog_groups, + ) - if hasattr(self._client._ph_client, "capture"): - self._client._ph_client.capture( - distinct_id=posthog_distinct_id or posthog_trace_id, - event="$ai_generation", - properties=event_properties, - groups=posthog_groups, - ) + # Use the common capture function + capture_streaming_event(self._client._ph_client, event_data) def parse( self, @@ -342,6 +318,7 @@ def _create_streaming( start_time = time.time() usage_stats: Dict[str, int] = {} accumulated_content = [] + accumulated_tool_calls: Dict[int, Dict[str, Any]] = {} if "stream_options" not in kwargs: kwargs["stream_options"] = {} kwargs["stream_options"]["include_usage"] = True @@ -350,50 +327,42 @@ def _create_streaming( def generator(): nonlocal usage_stats nonlocal accumulated_content # noqa: F824 + nonlocal accumulated_tool_calls try: for chunk in response: - if hasattr(chunk, "usage") and chunk.usage: - usage_stats = { - k: getattr(chunk.usage, k, 0) - for k in [ - "prompt_tokens", - "completion_tokens", - "total_tokens", - ] - } - - # Add support for cached tokens - if hasattr(chunk.usage, "prompt_tokens_details") and hasattr( - chunk.usage.prompt_tokens_details, "cached_tokens" - ): - usage_stats["cache_read_input_tokens"] = ( - chunk.usage.prompt_tokens_details.cached_tokens - ) - - if hasattr(chunk.usage, "output_tokens_details") and hasattr( - chunk.usage.output_tokens_details, "reasoning_tokens" - ): - usage_stats["reasoning_tokens"] = ( - chunk.usage.output_tokens_details.reasoning_tokens - ) - - if ( - hasattr(chunk, "choices") - and chunk.choices - and len(chunk.choices) > 0 - ): - if chunk.choices[0].delta and chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - if content: - accumulated_content.append(content) + # Extract usage stats from chunk + chunk_usage = extract_openai_usage_from_chunk(chunk, "chat") + + if chunk_usage: + merge_usage_stats(usage_stats, chunk_usage) + + # Extract content from chunk + content = extract_openai_content_from_chunk(chunk, "chat") + + if content is not None: + accumulated_content.append(content) + + # Extract and accumulate tool calls from chunk + chunk_tool_calls = extract_openai_tool_calls_from_chunk(chunk) + if chunk_tool_calls: + accumulate_openai_tool_calls( + accumulated_tool_calls, chunk_tool_calls + ) yield chunk finally: end_time = time.time() latency = end_time - start_time - output = "".join(accumulated_content) + + # Convert accumulated tool calls dict to list + tool_calls_list = ( + list(accumulated_tool_calls.values()) + if accumulated_tool_calls + else None + ) + self._capture_streaming_event( posthog_distinct_id, posthog_trace_id, @@ -403,7 +372,8 @@ def generator(): kwargs, usage_stats, latency, - output, + accumulated_content, + tool_calls_list, extract_available_tool_calls("openai", kwargs), ) @@ -420,51 +390,39 @@ def _capture_streaming_event( usage_stats: Dict[str, int], latency: float, output: Any, + tool_calls: Optional[List[Dict[str, Any]]] = None, available_tool_calls: Optional[List[Dict[str, Any]]] = None, ): - if posthog_trace_id is None: - posthog_trace_id = str(uuid.uuid4()) - - event_properties = { - "$ai_provider": "openai", - "$ai_model": kwargs.get("model"), - "$ai_model_parameters": get_model_params(kwargs), - "$ai_input": with_privacy_mode( - self._client._ph_client, - posthog_privacy_mode, - sanitize_openai(kwargs.get("messages")), - ), - "$ai_output_choices": with_privacy_mode( - self._client._ph_client, - posthog_privacy_mode, - [{"content": output, "role": "assistant"}], - ), - "$ai_http_status": 200, - "$ai_input_tokens": usage_stats.get("prompt_tokens", 0), - "$ai_output_tokens": usage_stats.get("completion_tokens", 0), - "$ai_cache_read_input_tokens": usage_stats.get( - "cache_read_input_tokens", 0 - ), - "$ai_reasoning_tokens": usage_stats.get("reasoning_tokens", 0), - "$ai_latency": latency, - "$ai_trace_id": posthog_trace_id, - "$ai_base_url": str(self._client.base_url), - **(posthog_properties or {}), - } - - if available_tool_calls: - event_properties["$ai_tools"] = available_tool_calls - - if posthog_distinct_id is None: - event_properties["$process_person_profile"] = False + from posthog.ai.types import StreamingEventData + from posthog.ai.openai.openai_converter import ( + standardize_openai_usage, + format_openai_streaming_input, + format_openai_streaming_output, + ) + from posthog.ai.utils import capture_streaming_event + + # Prepare standardized event data + formatted_input = format_openai_streaming_input(kwargs, "chat") + sanitized_input = sanitize_openai(formatted_input) + + event_data = StreamingEventData( + provider="openai", + model=kwargs.get("model", "unknown"), + base_url=str(self._client.base_url), + kwargs=kwargs, + formatted_input=sanitized_input, + formatted_output=format_openai_streaming_output(output, "chat", tool_calls), + usage_stats=standardize_openai_usage(usage_stats, "chat"), + latency=latency, + distinct_id=posthog_distinct_id, + trace_id=posthog_trace_id, + properties=posthog_properties, + privacy_mode=posthog_privacy_mode, + groups=posthog_groups, + ) - if hasattr(self._client._ph_client, "capture"): - self._client._ph_client.capture( - distinct_id=posthog_distinct_id or posthog_trace_id, - event="$ai_generation", - properties=event_properties, - groups=posthog_groups, - ) + # Use the common capture function + capture_streaming_event(self._client._ph_client, event_data) class WrappedEmbeddings: @@ -501,6 +459,7 @@ def create( Returns: The response from OpenAI's embeddings.create call. """ + if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) diff --git a/posthog/ai/openai/openai_async.py b/posthog/ai/openai/openai_async.py index ae1a5352..57bc7d3d 100644 --- a/posthog/ai/openai/openai_async.py +++ b/posthog/ai/openai/openai_async.py @@ -14,8 +14,16 @@ call_llm_and_track_usage_async, extract_available_tool_calls, get_model_params, + merge_usage_stats, with_privacy_mode, ) +from posthog.ai.openai.openai_converter import ( + extract_openai_usage_from_chunk, + extract_openai_content_from_chunk, + extract_openai_tool_calls_from_chunk, + accumulate_openai_tool_calls, + format_openai_streaming_output, +) from posthog.ai.sanitization import sanitize_openai, sanitize_openai_response from posthog.client import Client as PostHogClient @@ -35,6 +43,7 @@ def __init__(self, posthog_client: Optional[PostHogClient] = None, **kwargs): of the global posthog. **openai_config: Any additional keyword args to set on openai (e.g. organization="xxx"). """ + super().__init__(**kwargs) self._ph_client = posthog_client or setup() @@ -67,6 +76,7 @@ def __init__(self, client: AsyncOpenAI, original_responses): def __getattr__(self, name): """Fallback to original responses object for any methods we don't explicitly handle.""" + return getattr(self._original, name) async def create( @@ -116,7 +126,7 @@ async def _create_streaming( start_time = time.time() usage_stats: Dict[str, int] = {} final_content = [] - response = await self._original.create(**kwargs) + response = self._original.create(**kwargs) async def async_generator(): nonlocal usage_stats @@ -124,35 +134,17 @@ async def async_generator(): try: async for chunk in response: - if hasattr(chunk, "type") and chunk.type == "response.completed": - res = chunk.response - if res.output and len(res.output) > 0: - final_content.append(res.output[0]) - - if hasattr(chunk, "usage") and chunk.usage: - usage_stats = { - k: getattr(chunk.usage, k, 0) - for k in [ - "input_tokens", - "output_tokens", - "total_tokens", - ] - } - - # Add support for cached tokens - if hasattr(chunk.usage, "output_tokens_details") and hasattr( - chunk.usage.output_tokens_details, "reasoning_tokens" - ): - usage_stats["reasoning_tokens"] = ( - chunk.usage.output_tokens_details.reasoning_tokens - ) - - if hasattr(chunk.usage, "input_tokens_details") and hasattr( - chunk.usage.input_tokens_details, "cached_tokens" - ): - usage_stats["cache_read_input_tokens"] = ( - chunk.usage.input_tokens_details.cached_tokens - ) + # Extract usage stats from chunk + chunk_usage = extract_openai_usage_from_chunk(chunk, "responses") + + if chunk_usage: + merge_usage_stats(usage_stats, chunk_usage) + + # Extract content from chunk + content = extract_openai_content_from_chunk(chunk, "responses") + + if content is not None: + final_content.append(content) yield chunk @@ -160,6 +152,7 @@ async def async_generator(): end_time = time.time() latency = end_time - start_time output = final_content + await self._capture_streaming_event( posthog_distinct_id, posthog_trace_id, @@ -203,7 +196,7 @@ async def _capture_streaming_event( "$ai_output_choices": with_privacy_mode( self._client._ph_client, posthog_privacy_mode, - output, + format_openai_streaming_output(output, "responses"), ), "$ai_http_status": 200, "$ai_input_tokens": usage_stats.get("input_tokens", 0), @@ -345,59 +338,50 @@ async def _create_streaming( start_time = time.time() usage_stats: Dict[str, int] = {} accumulated_content = [] + accumulated_tool_calls: Dict[int, Dict[str, Any]] = {} if "stream_options" not in kwargs: kwargs["stream_options"] = {} kwargs["stream_options"]["include_usage"] = True - response = await self._original.create(**kwargs) + response = self._original.create(**kwargs) async def async_generator(): nonlocal usage_stats nonlocal accumulated_content # noqa: F824 + nonlocal accumulated_tool_calls try: async for chunk in response: - if hasattr(chunk, "usage") and chunk.usage: - usage_stats = { - k: getattr(chunk.usage, k, 0) - for k in [ - "prompt_tokens", - "completion_tokens", - "total_tokens", - ] - } - - # Add support for cached tokens - if hasattr(chunk.usage, "prompt_tokens_details") and hasattr( - chunk.usage.prompt_tokens_details, "cached_tokens" - ): - usage_stats["cache_read_input_tokens"] = ( - chunk.usage.prompt_tokens_details.cached_tokens - ) - - if hasattr(chunk.usage, "output_tokens_details") and hasattr( - chunk.usage.output_tokens_details, "reasoning_tokens" - ): - usage_stats["reasoning_tokens"] = ( - chunk.usage.output_tokens_details.reasoning_tokens - ) - - if ( - hasattr(chunk, "choices") - and chunk.choices - and len(chunk.choices) > 0 - ): - if chunk.choices[0].delta and chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - if content: - accumulated_content.append(content) + # Extract usage stats from chunk + chunk_usage = extract_openai_usage_from_chunk(chunk, "chat") + if chunk_usage: + merge_usage_stats(usage_stats, chunk_usage) + + # Extract content from chunk + content = extract_openai_content_from_chunk(chunk, "chat") + if content is not None: + accumulated_content.append(content) + + # Extract and accumulate tool calls from chunk + chunk_tool_calls = extract_openai_tool_calls_from_chunk(chunk) + if chunk_tool_calls: + accumulate_openai_tool_calls( + accumulated_tool_calls, chunk_tool_calls + ) yield chunk finally: end_time = time.time() latency = end_time - start_time - output = "".join(accumulated_content) + + # Convert accumulated tool calls dict to list + tool_calls_list = ( + list(accumulated_tool_calls.values()) + if accumulated_tool_calls + else None + ) + await self._capture_streaming_event( posthog_distinct_id, posthog_trace_id, @@ -407,7 +391,8 @@ async def async_generator(): kwargs, usage_stats, latency, - output, + accumulated_content, + tool_calls_list, extract_available_tool_calls("openai", kwargs), ) @@ -424,6 +409,7 @@ async def _capture_streaming_event( usage_stats: Dict[str, int], latency: float, output: Any, + tool_calls: Optional[List[Dict[str, Any]]] = None, available_tool_calls: Optional[List[Dict[str, Any]]] = None, ): if posthog_trace_id is None: @@ -441,7 +427,7 @@ async def _capture_streaming_event( "$ai_output_choices": with_privacy_mode( self._client._ph_client, posthog_privacy_mode, - [{"content": output, "role": "assistant"}], + format_openai_streaming_output(output, "chat", tool_calls), ), "$ai_http_status": 200, "$ai_input_tokens": usage_stats.get("prompt_tokens", 0), @@ -480,6 +466,7 @@ def __init__(self, client: AsyncOpenAI, original_embeddings): def __getattr__(self, name): """Fallback to original embeddings object for any methods we don't explicitly handle.""" + return getattr(self._original, name) async def create( @@ -505,15 +492,17 @@ async def create( Returns: The response from OpenAI's embeddings.create call. """ + if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) start_time = time.time() - response = await self._original.create(**kwargs) + response = self._original.create(**kwargs) end_time = time.time() # Extract usage statistics if available usage_stats = {} + if hasattr(response, "usage") and response.usage: usage_stats = { "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), @@ -563,6 +552,7 @@ def __init__(self, client: AsyncOpenAI, original_beta): def __getattr__(self, name): """Fallback to original beta object for any methods we don't explicitly handle.""" + return getattr(self._original, name) @property @@ -579,6 +569,7 @@ def __init__(self, client: AsyncOpenAI, original_beta_chat): def __getattr__(self, name): """Fallback to original beta chat object for any methods we don't explicitly handle.""" + return getattr(self._original, name) @property @@ -595,6 +586,7 @@ def __init__(self, client: AsyncOpenAI, original_beta_completions): def __getattr__(self, name): """Fallback to original beta completions object for any methods we don't explicitly handle.""" + return getattr(self._original, name) async def parse( diff --git a/posthog/ai/openai/openai_converter.py b/posthog/ai/openai/openai_converter.py new file mode 100644 index 00000000..2429270b --- /dev/null +++ b/posthog/ai/openai/openai_converter.py @@ -0,0 +1,585 @@ +""" +OpenAI-specific conversion utilities. + +This module handles the conversion of OpenAI API responses and inputs +into standardized formats for PostHog tracking. It supports both +Chat Completions API and Responses API formats. +""" + +from typing import Any, Dict, List, Optional + +from posthog.ai.types import ( + FormattedContentItem, + FormattedFunctionCall, + FormattedImageContent, + FormattedMessage, + FormattedTextContent, + StreamingUsageStats, + TokenUsage, +) + + +def format_openai_response(response: Any) -> List[FormattedMessage]: + """ + Format an OpenAI response into standardized message format. + + Handles both Chat Completions API and Responses API formats. + + Args: + response: The response object from OpenAI API + + Returns: + List of formatted messages with role and content + """ + + output: List[FormattedMessage] = [] + + if response is None: + return output + + # Handle Chat Completions response format + if hasattr(response, "choices"): + content: List[FormattedContentItem] = [] + role = "assistant" + + for choice in response.choices: + if hasattr(choice, "message") and choice.message: + if choice.message.role: + role = choice.message.role + + if choice.message.content: + content.append( + { + "type": "text", + "text": choice.message.content, + } + ) + + if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: + for tool_call in choice.message.tool_calls: + content.append( + { + "type": "function", + "id": tool_call.id, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + ) + + if content: + output.append( + { + "role": role, + "content": content, + } + ) + + # Handle Responses API format + if hasattr(response, "output"): + content = [] + role = "assistant" + + for item in response.output: + if item.type == "message": + role = item.role + + if hasattr(item, "content") and isinstance(item.content, list): + for content_item in item.content: + if ( + hasattr(content_item, "type") + and content_item.type == "output_text" + and hasattr(content_item, "text") + ): + content.append( + { + "type": "text", + "text": content_item.text, + } + ) + + elif hasattr(content_item, "text"): + content.append({"type": "text", "text": content_item.text}) + + elif ( + hasattr(content_item, "type") + and content_item.type == "input_image" + and hasattr(content_item, "image_url") + ): + image_content: FormattedImageContent = { + "type": "image", + "image": content_item.image_url, + } + content.append(image_content) + + elif hasattr(item, "content"): + text_content = {"type": "text", "text": str(item.content)} + content.append(text_content) + + elif hasattr(item, "type") and item.type == "function_call": + content.append( + { + "type": "function", + "id": getattr(item, "call_id", getattr(item, "id", "")), + "function": { + "name": item.name, + "arguments": getattr(item, "arguments", {}), + }, + } + ) + + if content: + output.append( + { + "role": role, + "content": content, + } + ) + + return output + + +def format_openai_input( + messages: Optional[List[Dict[str, Any]]] = None, input_data: Optional[Any] = None +) -> List[FormattedMessage]: + """ + Format OpenAI input messages. + + Handles both messages parameter (Chat Completions) and input parameter (Responses API). + + Args: + messages: List of message dictionaries for Chat Completions API + input_data: Input data for Responses API + + Returns: + List of formatted messages + """ + + formatted_messages: List[FormattedMessage] = [] + + # Handle Chat Completions API format + if messages is not None: + for msg in messages: + formatted_messages.append( + { + "role": msg.get("role", "user"), + "content": msg.get("content", ""), + } + ) + + # Handle Responses API format + if input_data is not None: + if isinstance(input_data, list): + for item in input_data: + role = "user" + content = "" + + if isinstance(item, dict): + role = item.get("role", "user") + content = item.get("content", "") + + elif isinstance(item, str): + content = item + + else: + content = str(item) + + formatted_messages.append({"role": role, "content": content}) + + elif isinstance(input_data, str): + formatted_messages.append({"role": "user", "content": input_data}) + + else: + formatted_messages.append({"role": "user", "content": str(input_data)}) + + return formatted_messages + + +def extract_openai_tools(kwargs: Dict[str, Any]) -> Optional[Any]: + """ + Extract tool definitions from OpenAI API kwargs. + + Args: + kwargs: Keyword arguments passed to OpenAI API + + Returns: + Tool definitions if present, None otherwise + """ + + # Check for tools parameter (newer API) + if "tools" in kwargs: + return kwargs["tools"] + + # Check for functions parameter (older API) + if "functions" in kwargs: + return kwargs["functions"] + + return None + + +def format_openai_streaming_content( + accumulated_content: str, tool_calls: Optional[List[Dict[str, Any]]] = None +) -> List[FormattedContentItem]: + """ + Format content from OpenAI streaming response. + + Used by streaming handlers to format accumulated content. + + Args: + accumulated_content: Accumulated text content from streaming + tool_calls: Optional list of tool calls accumulated during streaming + + Returns: + List of formatted content items + """ + formatted: List[FormattedContentItem] = [] + + # Add text content if present + if accumulated_content: + text_content: FormattedTextContent = { + "type": "text", + "text": accumulated_content, + } + formatted.append(text_content) + + # Add tool calls if present + if tool_calls: + for tool_call in tool_calls: + function_call: FormattedFunctionCall = { + "type": "function", + "id": tool_call.get("id"), + "function": tool_call.get("function", {}), + } + formatted.append(function_call) + + return formatted + + +def extract_openai_usage_from_chunk( + chunk: Any, provider_type: str = "chat" +) -> StreamingUsageStats: + """ + Extract usage statistics from an OpenAI streaming chunk. + + Handles both Chat Completions and Responses API formats. + + Args: + chunk: Streaming chunk from OpenAI API + provider_type: Either "chat" or "responses" to handle different API formats + + Returns: + Dictionary of usage statistics + """ + + usage: StreamingUsageStats = {} + + if provider_type == "chat": + if not hasattr(chunk, "usage") or not chunk.usage: + return usage + + # Chat Completions API uses prompt_tokens and completion_tokens + usage["prompt_tokens"] = getattr(chunk.usage, "prompt_tokens", 0) + usage["completion_tokens"] = getattr(chunk.usage, "completion_tokens", 0) + usage["total_tokens"] = getattr(chunk.usage, "total_tokens", 0) + + # Handle cached tokens + if hasattr(chunk.usage, "prompt_tokens_details") and hasattr( + chunk.usage.prompt_tokens_details, "cached_tokens" + ): + usage["cache_read_input_tokens"] = ( + chunk.usage.prompt_tokens_details.cached_tokens + ) + + # Handle reasoning tokens + if hasattr(chunk.usage, "completion_tokens_details") and hasattr( + chunk.usage.completion_tokens_details, "reasoning_tokens" + ): + usage["reasoning_tokens"] = ( + chunk.usage.completion_tokens_details.reasoning_tokens + ) + + elif provider_type == "responses": + # For Responses API, usage is only in chunk.response.usage for completed events + if hasattr(chunk, "type") and chunk.type == "response.completed": + if ( + hasattr(chunk, "response") + and hasattr(chunk.response, "usage") + and chunk.response.usage + ): + response_usage = chunk.response.usage + usage["input_tokens"] = getattr(response_usage, "input_tokens", 0) + usage["output_tokens"] = getattr(response_usage, "output_tokens", 0) + usage["total_tokens"] = getattr(response_usage, "total_tokens", 0) + + # Handle cached tokens + if hasattr(response_usage, "input_tokens_details") and hasattr( + response_usage.input_tokens_details, "cached_tokens" + ): + usage["cache_read_input_tokens"] = ( + response_usage.input_tokens_details.cached_tokens + ) + + # Handle reasoning tokens + if hasattr(response_usage, "output_tokens_details") and hasattr( + response_usage.output_tokens_details, "reasoning_tokens" + ): + usage["reasoning_tokens"] = ( + response_usage.output_tokens_details.reasoning_tokens + ) + + return usage + + +def extract_openai_content_from_chunk( + chunk: Any, provider_type: str = "chat" +) -> Optional[str]: + """ + Extract content from an OpenAI streaming chunk. + + Handles both Chat Completions and Responses API formats. + + Args: + chunk: Streaming chunk from OpenAI API + provider_type: Either "chat" or "responses" to handle different API formats + + Returns: + Text content if present, None otherwise + """ + + if provider_type == "chat": + # Chat Completions API format + if ( + hasattr(chunk, "choices") + and chunk.choices + and len(chunk.choices) > 0 + and chunk.choices[0].delta + and chunk.choices[0].delta.content + ): + return chunk.choices[0].delta.content + + elif provider_type == "responses": + # Responses API format + if hasattr(chunk, "type") and chunk.type == "response.completed": + if hasattr(chunk, "response") and chunk.response: + res = chunk.response + if res.output and len(res.output) > 0: + # Return the full output for responses + return res.output[0] + + return None + + +def extract_openai_tool_calls_from_chunk(chunk: Any) -> Optional[List[Dict[str, Any]]]: + """ + Extract tool calls from an OpenAI streaming chunk. + + Args: + chunk: Streaming chunk from OpenAI API + + Returns: + List of tool call deltas if present, None otherwise + """ + if ( + hasattr(chunk, "choices") + and chunk.choices + and len(chunk.choices) > 0 + and chunk.choices[0].delta + and hasattr(chunk.choices[0].delta, "tool_calls") + and chunk.choices[0].delta.tool_calls + ): + tool_calls = [] + for tool_call in chunk.choices[0].delta.tool_calls: + tc_dict = { + "index": getattr(tool_call, "index", None), + } + + if hasattr(tool_call, "id") and tool_call.id: + tc_dict["id"] = tool_call.id + + if hasattr(tool_call, "type") and tool_call.type: + tc_dict["type"] = tool_call.type + + if hasattr(tool_call, "function") and tool_call.function: + function_dict = {} + if hasattr(tool_call.function, "name") and tool_call.function.name: + function_dict["name"] = tool_call.function.name + if ( + hasattr(tool_call.function, "arguments") + and tool_call.function.arguments + ): + function_dict["arguments"] = tool_call.function.arguments + tc_dict["function"] = function_dict + + tool_calls.append(tc_dict) + return tool_calls + + return None + + +def accumulate_openai_tool_calls( + accumulated_tool_calls: Dict[int, Dict[str, Any]], + chunk_tool_calls: List[Dict[str, Any]], +) -> None: + """ + Accumulate tool calls from streaming chunks. + + OpenAI sends tool calls incrementally: + - First chunk has id, type, function.name and partial function.arguments + - Subsequent chunks have more function.arguments + + Args: + accumulated_tool_calls: Dictionary mapping index to accumulated tool call data + chunk_tool_calls: List of tool call deltas from current chunk + """ + for tool_call_delta in chunk_tool_calls: + index = tool_call_delta.get("index") + if index is None: + continue + + # Initialize tool call if first time seeing this index + if index not in accumulated_tool_calls: + accumulated_tool_calls[index] = { + "id": "", + "type": "function", + "function": { + "name": "", + "arguments": "", + }, + } + + # Update with new data from delta + tc = accumulated_tool_calls[index] + + if "id" in tool_call_delta and tool_call_delta["id"]: + tc["id"] = tool_call_delta["id"] + + if "type" in tool_call_delta and tool_call_delta["type"]: + tc["type"] = tool_call_delta["type"] + + if "function" in tool_call_delta: + func_delta = tool_call_delta["function"] + if "name" in func_delta and func_delta["name"]: + tc["function"]["name"] = func_delta["name"] + if "arguments" in func_delta and func_delta["arguments"]: + # Arguments are sent incrementally, concatenate them + tc["function"]["arguments"] += func_delta["arguments"] + + +def format_openai_streaming_output( + accumulated_content: Any, + provider_type: str = "chat", + tool_calls: Optional[List[Dict[str, Any]]] = None, +) -> List[FormattedMessage]: + """ + Format the final output from OpenAI streaming. + + Args: + accumulated_content: Accumulated content from streaming (string for chat, list for responses) + provider_type: Either "chat" or "responses" to handle different API formats + tool_calls: Optional list of accumulated tool calls + + Returns: + List of formatted messages + """ + + if provider_type == "chat": + content_items: List[FormattedContentItem] = [] + + # Add text content if present + if isinstance(accumulated_content, str) and accumulated_content: + content_items.append({"type": "text", "text": accumulated_content}) + elif isinstance(accumulated_content, list): + # If it's a list of strings, join them + text = "".join(str(item) for item in accumulated_content if item) + if text: + content_items.append({"type": "text", "text": text}) + + # Add tool calls if present + if tool_calls: + for tool_call in tool_calls: + if "function" in tool_call: + function_call: FormattedFunctionCall = { + "type": "function", + "id": tool_call.get("id", ""), + "function": tool_call["function"], + } + content_items.append(function_call) + + # Return formatted message with content + if content_items: + return [{"role": "assistant", "content": content_items}] + else: + # Empty response + return [{"role": "assistant", "content": []}] + + elif provider_type == "responses": + # Responses API: accumulated_content is a list of output items + if isinstance(accumulated_content, list) and accumulated_content: + # The output is already formatted, just return it + return accumulated_content + elif isinstance(accumulated_content, str): + return [ + { + "role": "assistant", + "content": [{"type": "text", "text": accumulated_content}], + } + ] + + # Fallback for any other format + return [ + { + "role": "assistant", + "content": [{"type": "text", "text": str(accumulated_content)}], + } + ] + + +def standardize_openai_usage( + usage: Dict[str, Any], api_type: str = "chat" +) -> TokenUsage: + """ + Standardize OpenAI usage statistics to common TokenUsage format. + + Args: + usage: Raw usage statistics from OpenAI + api_type: Either "chat" or "responses" to handle different field names + + Returns: + Standardized TokenUsage dict + """ + if api_type == "chat": + # Chat API uses prompt_tokens/completion_tokens + return TokenUsage( + input_tokens=usage.get("prompt_tokens", 0), + output_tokens=usage.get("completion_tokens", 0), + cache_read_input_tokens=usage.get("cache_read_input_tokens"), + reasoning_tokens=usage.get("reasoning_tokens"), + ) + else: # responses API + # Responses API uses input_tokens/output_tokens + return TokenUsage( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + cache_read_input_tokens=usage.get("cache_read_input_tokens"), + reasoning_tokens=usage.get("reasoning_tokens"), + ) + + +def format_openai_streaming_input( + kwargs: Dict[str, Any], api_type: str = "chat" +) -> Any: + """ + Format OpenAI streaming input based on API type. + + Args: + kwargs: Keyword arguments passed to OpenAI API + api_type: Either "chat" or "responses" + + Returns: + Formatted input ready for PostHog tracking + """ + if api_type == "chat": + return kwargs.get("messages") + else: # responses API + return kwargs.get("input") diff --git a/posthog/ai/types.py b/posthog/ai/types.py new file mode 100644 index 00000000..bc20e69c --- /dev/null +++ b/posthog/ai/types.py @@ -0,0 +1,142 @@ +""" +Common type definitions for PostHog AI SDK. + +These types are used for formatting messages and responses across different AI providers +(Anthropic, OpenAI, Gemini, etc.) to ensure consistency in tracking and data structure. +""" + +from typing import Any, Dict, List, Optional, TypedDict, Union + + +class FormattedTextContent(TypedDict): + """Formatted text content item.""" + + type: str # Literal["text"] + text: str + + +class FormattedFunctionCall(TypedDict, total=False): + """Formatted function/tool call content item.""" + + type: str # Literal["function"] + id: Optional[str] + function: Dict[str, Any] # Contains 'name' and 'arguments' + + +class FormattedImageContent(TypedDict): + """Formatted image content item.""" + + type: str # Literal["image"] + image: str + + +# Union type for all formatted content items +FormattedContentItem = Union[ + FormattedTextContent, + FormattedFunctionCall, + FormattedImageContent, + Dict[str, Any], # Fallback for unknown content types +] + + +class FormattedMessage(TypedDict): + """ + Standardized message format for PostHog tracking. + + Used across all providers to ensure consistent message structure + when sending events to PostHog. + """ + + role: str + content: Union[str, List[FormattedContentItem], Any] + + +class TokenUsage(TypedDict, total=False): + """ + Token usage information for AI model responses. + + Different providers may populate different fields. + """ + + input_tokens: int + output_tokens: int + cache_read_input_tokens: Optional[int] + cache_creation_input_tokens: Optional[int] + reasoning_tokens: Optional[int] + + +class ProviderResponse(TypedDict, total=False): + """ + Standardized provider response format. + + Used for consistent response formatting across all providers. + """ + + messages: List[FormattedMessage] + usage: TokenUsage + error: Optional[str] + + +class StreamingUsageStats(TypedDict, total=False): + """ + Usage statistics collected during streaming. + + Different providers populate different fields during streaming. + """ + + input_tokens: int + output_tokens: int + cache_read_input_tokens: Optional[int] + cache_creation_input_tokens: Optional[int] + reasoning_tokens: Optional[int] + # OpenAI-specific names + prompt_tokens: Optional[int] + completion_tokens: Optional[int] + total_tokens: Optional[int] + + +class StreamingContentBlock(TypedDict, total=False): + """ + Content block used during streaming to accumulate content. + + Used for tracking text and function calls as they stream in. + """ + + type: str # "text" or "function" + text: Optional[str] + id: Optional[str] + function: Optional[Dict[str, Any]] + + +class ToolInProgress(TypedDict): + """ + Tracks a tool/function call being accumulated during streaming. + + Used by Anthropic to accumulate JSON input for tools. + """ + + block: StreamingContentBlock + input_string: str + + +class StreamingEventData(TypedDict): + """ + Standardized data for streaming events across all providers. + + This type ensures consistent data structure when capturing streaming events, + with all provider-specific formatting already completed. + """ + + provider: str # "openai", "anthropic", "gemini" + model: str + base_url: str + kwargs: Dict[str, Any] # Original kwargs for tool extraction and special handling + formatted_input: Any # Provider-formatted input ready for tracking + formatted_output: Any # Provider-formatted output ready for tracking + usage_stats: TokenUsage # Standardized token counts + latency: float + distinct_id: Optional[str] + trace_id: Optional[str] + properties: Optional[Dict[str, Any]] + privacy_mode: bool + groups: Optional[Dict[str, Any]] diff --git a/posthog/ai/utils.py b/posthog/ai/utils.py index 5687ffb2..6daca1b6 100644 --- a/posthog/ai/utils.py +++ b/posthog/ai/utils.py @@ -1,10 +1,10 @@ import time import uuid -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Optional -from httpx import URL from posthog.client import Client as PostHogClient +from posthog.ai.types import StreamingEventData, StreamingUsageStats from posthog.ai.sanitization import ( sanitize_openai, sanitize_anthropic, @@ -13,6 +13,35 @@ ) +def merge_usage_stats( + target: Dict[str, int], source: StreamingUsageStats, mode: str = "incremental" +) -> None: + """ + Merge streaming usage statistics into target dict, handling None values. + + Supports two modes: + - "incremental": Add source values to target (for APIs that report new tokens) + - "cumulative": Replace target with source values (for APIs that report totals) + + Args: + target: Dictionary to update with usage stats + source: StreamingUsageStats that may contain None values + mode: Either "incremental" or "cumulative" + """ + if mode == "incremental": + # Add new values to existing totals + for key, value in source.items(): + if value is not None and isinstance(value, int): + target[key] = target.get(key, 0) + value + elif mode == "cumulative": + # Replace with latest values (already cumulative) + for key, value in source.items(): + if value is not None and isinstance(value, int): + target[key] = value + else: + raise ValueError(f"Invalid mode: {mode}. Must be 'incremental' or 'cumulative'") + + def get_model_params(kwargs: Dict[str, Any]) -> Dict[str, Any]: """ Extracts model parameters from the kwargs dictionary. @@ -109,275 +138,96 @@ def format_response(response, provider: str): """ Format a regular (non-streaming) response. """ - output = [] - if response is None: - return output if provider == "anthropic": - return format_response_anthropic(response) - elif provider == "openai": - return format_response_openai(response) - elif provider == "gemini": - return format_response_gemini(response) - return output - + from posthog.ai.anthropic.anthropic_converter import format_anthropic_response -def format_response_anthropic(response): - output = [] - content = [] - - for choice in response.content: - if ( - hasattr(choice, "type") - and choice.type == "text" - and hasattr(choice, "text") - and choice.text - ): - content.append({"type": "text", "text": choice.text}) - elif ( - hasattr(choice, "type") - and choice.type == "tool_use" - and hasattr(choice, "name") - and hasattr(choice, "id") - ): - tool_call = { - "type": "function", - "id": choice.id, - "function": { - "name": choice.name, - "arguments": getattr(choice, "input", {}), - }, - } - content.append(tool_call) - - if content: - message = { - "role": "assistant", - "content": content, - } - output.append(message) - - return output - - -def format_response_openai(response): - output = [] - - if hasattr(response, "choices"): - content = [] - role = "assistant" - - for choice in response.choices: - # Handle Chat Completions response format - if hasattr(choice, "message") and choice.message: - if choice.message.role: - role = choice.message.role - - if choice.message.content: - content.append({"type": "text", "text": choice.message.content}) - - if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: - for tool_call in choice.message.tool_calls: - content.append( - { - "type": "function", - "id": tool_call.id, - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - } - ) - - if content: - message = { - "role": role, - "content": content, - } - output.append(message) - - # Handle Responses API format - if hasattr(response, "output"): - content = [] - role = "assistant" - - for item in response.output: - if item.type == "message": - role = item.role - - if hasattr(item, "content") and isinstance(item.content, list): - for content_item in item.content: - if ( - hasattr(content_item, "type") - and content_item.type == "output_text" - and hasattr(content_item, "text") - ): - content.append({"type": "text", "text": content_item.text}) - elif hasattr(content_item, "text"): - content.append({"type": "text", "text": content_item.text}) - elif ( - hasattr(content_item, "type") - and content_item.type == "input_image" - and hasattr(content_item, "image_url") - ): - content.append( - { - "type": "image", - "image": content_item.image_url, - } - ) - elif hasattr(item, "content"): - content.append({"type": "text", "text": str(item.content)}) - - elif hasattr(item, "type") and item.type == "function_call": - content.append( - { - "type": "function", - "id": getattr(item, "call_id", getattr(item, "id", "")), - "function": { - "name": item.name, - "arguments": getattr(item, "arguments", {}), - }, - } - ) + return format_anthropic_response(response) + elif provider == "openai": + from posthog.ai.openai.openai_converter import format_openai_response - if content: - message = { - "role": role, - "content": content, - } - output.append(message) - - return output - - -def format_response_gemini(response): - output = [] - - if hasattr(response, "candidates") and response.candidates: - for candidate in response.candidates: - if hasattr(candidate, "content") and candidate.content: - content = [] - - if hasattr(candidate.content, "parts") and candidate.content.parts: - for part in candidate.content.parts: - if hasattr(part, "text") and part.text: - content.append({"type": "text", "text": part.text}) - elif hasattr(part, "function_call") and part.function_call: - function_call = part.function_call - content.append( - { - "type": "function", - "function": { - "name": function_call.name, - "arguments": function_call.args, - }, - } - ) - - if content: - message = { - "role": "assistant", - "content": content, - } - output.append(message) - - elif hasattr(candidate, "text") and candidate.text: - output.append( - { - "role": "assistant", - "content": [{"type": "text", "text": candidate.text}], - } - ) - elif hasattr(response, "text") and response.text: - output.append( - { - "role": "assistant", - "content": [{"type": "text", "text": response.text}], - } - ) + return format_openai_response(response) + elif provider == "gemini": + from posthog.ai.gemini.gemini_converter import format_gemini_response - return output + return format_gemini_response(response) + return [] def extract_available_tool_calls(provider: str, kwargs: Dict[str, Any]): + """ + Extract available tool calls for the given provider. + """ if provider == "anthropic": - if "tools" in kwargs: - return kwargs["tools"] + from posthog.ai.anthropic.anthropic_converter import extract_anthropic_tools - return None + return extract_anthropic_tools(kwargs) elif provider == "gemini": - if "config" in kwargs and hasattr(kwargs["config"], "tools"): - return kwargs["config"].tools + from posthog.ai.gemini.gemini_converter import extract_gemini_tools - return None + return extract_gemini_tools(kwargs) elif provider == "openai": - if "tools" in kwargs: - return kwargs["tools"] + from posthog.ai.openai.openai_converter import extract_openai_tools - return None + return extract_openai_tools(kwargs) def merge_system_prompt(kwargs: Dict[str, Any], provider: str): - messages: List[Dict[str, Any]] = [] + """ + Merge system prompts and format messages for the given provider. + """ if provider == "anthropic": + from posthog.ai.anthropic.anthropic_converter import format_anthropic_input + messages = kwargs.get("messages") or [] - if kwargs.get("system") is None: - return messages - return [{"role": "system", "content": kwargs.get("system")}] + messages + system = kwargs.get("system") + return format_anthropic_input(messages, system) elif provider == "gemini": - contents = kwargs.get("contents", []) - if isinstance(contents, str): - return [{"role": "user", "content": contents}] - elif isinstance(contents, list): - formatted = [] - for item in contents: - if isinstance(item, str): - formatted.append({"role": "user", "content": item}) - elif hasattr(item, "text"): - formatted.append({"role": "user", "content": item.text}) - else: - formatted.append({"role": "user", "content": str(item)}) - return formatted - else: - return [{"role": "user", "content": str(contents)}] - - # For OpenAI, handle both Chat Completions and Responses API - if kwargs.get("messages") is not None: - messages = list(kwargs.get("messages", [])) - - if kwargs.get("input") is not None: - input_data = kwargs.get("input") - if isinstance(input_data, list): - messages.extend(input_data) - else: - messages.append({"role": "user", "content": input_data}) - - # Check if system prompt is provided as a separate parameter - if kwargs.get("system") is not None: - has_system = any(msg.get("role") == "system" for msg in messages) - if not has_system: - messages = [{"role": "system", "content": kwargs.get("system")}] + messages - - # For Responses API, add instructions to the system prompt if provided - if kwargs.get("instructions") is not None: - # Find the system message if it exists - system_idx = next( - (i for i, msg in enumerate(messages) if msg.get("role") == "system"), None - ) + from posthog.ai.gemini.gemini_converter import format_gemini_input - if system_idx is not None: - # Append instructions to existing system message - system_content = messages[system_idx].get("content", "") - messages[system_idx]["content"] = ( - f"{system_content}\n\n{kwargs.get('instructions')}" + contents = kwargs.get("contents", []) + return format_gemini_input(contents) + elif provider == "openai": + # For OpenAI, handle both Chat Completions and Responses API + from posthog.ai.openai.openai_converter import format_openai_input + + messages_param = kwargs.get("messages") + input_param = kwargs.get("input") + + # Get base formatted messages + messages = format_openai_input(messages_param, input_param) + + # Check if system prompt is provided as a separate parameter + if kwargs.get("system") is not None: + has_system = any(msg.get("role") == "system" for msg in messages) + if not has_system: + messages = [ + {"role": "system", "content": kwargs.get("system")} + ] + messages + + # For Responses API, add instructions to the system prompt if provided + if kwargs.get("instructions") is not None: + # Find the system message if it exists + system_idx = next( + (i for i, msg in enumerate(messages) if msg.get("role") == "system"), + None, ) - else: - # Create a new system message with instructions - messages = [ - {"role": "system", "content": kwargs.get("instructions")} - ] + messages - return messages + if system_idx is not None: + # Append instructions to existing system message + system_content = messages[system_idx].get("content", "") + messages[system_idx]["content"] = ( + f"{system_content}\n\n{kwargs.get('instructions')}" + ) + else: + # Create a new system message with instructions + messages = [ + {"role": "system", "content": kwargs.get("instructions")} + ] + messages + + return messages + + # Default case - return empty list + return [] def call_llm_and_track_usage( @@ -388,7 +238,7 @@ def call_llm_and_track_usage( posthog_properties: Optional[Dict[str, Any]], posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], - base_url: URL, + base_url: str, call_method: Callable[..., Any], **kwargs: Any, ) -> Any: @@ -401,7 +251,7 @@ def call_llm_and_track_usage( error = None http_status = 200 usage: Dict[str, Any] = {} - error_params: Dict[str, any] = {} + error_params: Dict[str, Any] = {} try: response = call_method(**kwargs) @@ -509,7 +359,7 @@ async def call_llm_and_track_usage_async( posthog_properties: Optional[Dict[str, Any]], posthog_privacy_mode: bool, posthog_groups: Optional[Dict[str, Any]], - base_url: URL, + base_url: str, call_async_method: Callable[..., Any], **kwargs: Any, ) -> Any: @@ -518,7 +368,7 @@ async def call_llm_and_track_usage_async( error = None http_status = 200 usage: Dict[str, Any] = {} - error_params: Dict[str, any] = {} + error_params: Dict[str, Any] = {} try: response = await call_async_method(**kwargs) @@ -629,3 +479,105 @@ def with_privacy_mode(ph_client: PostHogClient, privacy_mode: bool, value: Any): if ph_client.privacy_mode or privacy_mode: return None return value + + +def capture_streaming_event( + ph_client: PostHogClient, + event_data: StreamingEventData, +): + """ + Unified streaming event capture for all LLM providers. + + This function handles the common logic for capturing streaming events across all providers. + All provider-specific formatting should be done BEFORE calling this function. + + The function handles: + - Building PostHog event properties + - Extracting and adding tools based on provider + - Applying privacy mode + - Adding special token fields (cache, reasoning) + - Provider-specific fields (e.g., OpenAI instructions) + - Sending the event to PostHog + + Args: + ph_client: PostHog client instance + event_data: Standardized streaming event data containing all necessary information + """ + trace_id = event_data.get("trace_id") or str(uuid.uuid4()) + + # Build base event properties + event_properties = { + "$ai_provider": event_data["provider"], + "$ai_model": event_data["model"], + "$ai_model_parameters": get_model_params(event_data["kwargs"]), + "$ai_input": with_privacy_mode( + ph_client, + event_data["privacy_mode"], + event_data["formatted_input"], + ), + "$ai_output_choices": with_privacy_mode( + ph_client, + event_data["privacy_mode"], + event_data["formatted_output"], + ), + "$ai_http_status": 200, + "$ai_input_tokens": event_data["usage_stats"].get("input_tokens", 0), + "$ai_output_tokens": event_data["usage_stats"].get("output_tokens", 0), + "$ai_latency": event_data["latency"], + "$ai_trace_id": trace_id, + "$ai_base_url": str(event_data["base_url"]), + **(event_data.get("properties") or {}), + } + + # Extract and add tools based on provider + available_tools = extract_available_tool_calls( + event_data["provider"], + event_data["kwargs"], + ) + if available_tools: + event_properties["$ai_tools"] = available_tools + + # Add optional token fields + # For Anthropic, always include cache fields even if 0 (backward compatibility) + # For others, only include if present and non-zero + if event_data["provider"] == "anthropic": + # Anthropic always includes cache fields + cache_read = event_data["usage_stats"].get("cache_read_input_tokens", 0) + cache_creation = event_data["usage_stats"].get("cache_creation_input_tokens", 0) + event_properties["$ai_cache_read_input_tokens"] = cache_read + event_properties["$ai_cache_creation_input_tokens"] = cache_creation + else: + # Other providers only include if non-zero + optional_token_fields = [ + "cache_read_input_tokens", + "cache_creation_input_tokens", + "reasoning_tokens", + ] + + for field in optional_token_fields: + value = event_data["usage_stats"].get(field) + if value is not None and isinstance(value, int) and value > 0: + event_properties[f"$ai_{field}"] = value + + # Handle provider-specific fields + if ( + event_data["provider"] == "openai" + and event_data["kwargs"].get("instructions") is not None + ): + event_properties["$ai_instructions"] = with_privacy_mode( + ph_client, + event_data["privacy_mode"], + event_data["kwargs"]["instructions"], + ) + + if event_data.get("distinct_id") is None: + event_properties["$process_person_profile"] = False + + # Send event to PostHog + if hasattr(ph_client, "capture"): + ph_client.capture( + distinct_id=event_data.get("distinct_id") or trace_id, + event="$ai_generation", + properties=event_properties, + groups=event_data.get("groups"), + ) diff --git a/posthog/test/ai/anthropic/test_anthropic.py b/posthog/test/ai/anthropic/test_anthropic.py index fde6f78b..fcb64c15 100644 --- a/posthog/test/ai/anthropic/test_anthropic.py +++ b/posthog/test/ai/anthropic/test_anthropic.py @@ -1,5 +1,4 @@ import os -import time from unittest.mock import patch import pytest @@ -21,6 +20,89 @@ ) +# ======================= +# Reusable Mock Helpers +# ======================= + + +class MockContent: + """Reusable mock content class for Anthropic responses.""" + + def __init__(self, text="Bar", content_type="text"): + self.type = content_type + self.text = text + + +class MockUsage: + """Reusable mock usage class for Anthropic responses.""" + + def __init__( + self, + input_tokens=18, + output_tokens=1, + cache_read_input_tokens=0, + cache_creation_input_tokens=0, + ): + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.cache_read_input_tokens = cache_read_input_tokens + self.cache_creation_input_tokens = cache_creation_input_tokens + + +class MockResponse: + """Reusable mock response class for Anthropic messages.""" + + def __init__( + self, + content_text="Bar", + model="claude-3-opus-20240229", + input_tokens=18, + output_tokens=1, + cache_read=0, + cache_creation=0, + ): + self.content = [MockContent(text=content_text)] + self.model = model + self.usage = MockUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_input_tokens=cache_read, + cache_creation_input_tokens=cache_creation, + ) + + +def create_mock_response(**kwargs): + """Factory function to create mock responses with custom parameters.""" + return MockResponse(**kwargs) + + +# Streaming mock helpers +class MockStreamEvent: + """Reusable mock event class for streaming responses.""" + + def __init__(self, event_type=None, **kwargs): + self.type = event_type + for key, value in kwargs.items(): + setattr(self, key, value) + + +class MockContentBlock: + """Reusable mock content block for streaming.""" + + def __init__(self, block_type, **kwargs): + self.type = block_type + for key, value in kwargs.items(): + setattr(self, key, value) + + +class MockDelta: + """Reusable mock delta for streaming events.""" + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + @pytest.fixture def mock_client(): with patch("posthog.client.Client") as mock_client: @@ -46,22 +128,77 @@ def mock_anthropic_response(): @pytest.fixture -def mock_anthropic_stream(): - class MockStreamEvent: - def __init__(self, content, usage=None): - self.content = content - self.usage = usage +def mock_anthropic_stream_with_tools(): + """Mock stream events for tool calls.""" + + class MockMessage: + def __init__(self): + self.usage = MockUsage( + input_tokens=50, + cache_creation_input_tokens=0, + cache_read_input_tokens=5, + ) def stream_generator(): - yield MockStreamEvent("A") - yield MockStreamEvent("B") - yield MockStreamEvent( - "C", - usage=Usage( - input_tokens=20, - output_tokens=10, - ), + # Message start with usage + event = MockStreamEvent("message_start") + event.message = MockMessage() + yield event + + # Text block start + event = MockStreamEvent("content_block_start") + event.content_block = MockContentBlock("text") + event.index = 0 + yield event + + # Text delta + event = MockStreamEvent("content_block_delta") + event.delta = MockDelta(text="I'll check the weather for you.") + event.index = 0 + yield event + + # Text block stop + event = MockStreamEvent("content_block_stop") + event.index = 0 + yield event + + # Tool use block start + event = MockStreamEvent("content_block_start") + event.content_block = MockContentBlock( + "tool_use", id="toolu_stream123", name="get_weather" + ) + event.index = 1 + yield event + + # Tool input delta 1 + event = MockStreamEvent("content_block_delta") + event.delta = MockDelta( + type="input_json_delta", partial_json='{"location": "San' + ) + event.index = 1 + yield event + + # Tool input delta 2 + event = MockStreamEvent("content_block_delta") + event.delta = MockDelta( + type="input_json_delta", partial_json=' Francisco", "unit": "celsius"}' ) + event.index = 1 + yield event + + # Tool block stop + event = MockStreamEvent("content_block_stop") + event.index = 1 + yield event + + # Message delta with final usage + event = MockStreamEvent("message_delta") + event.usage = MockUsage(output_tokens=25) + yield event + + # Message stop + event = MockStreamEvent("message_stop") + yield event return stream_generator() @@ -174,83 +311,6 @@ def test_basic_completion(mock_client, mock_anthropic_response): assert isinstance(props["$ai_latency"], float) -def test_streaming(mock_client, mock_anthropic_stream): - with patch( - "anthropic.resources.Messages.create", return_value=mock_anthropic_stream - ): - client = Anthropic(api_key="test-key", posthog_client=mock_client) - response = client.messages.create( - model="claude-3-opus-20240229", - messages=[{"role": "user", "content": "Hello"}], - stream=True, - posthog_distinct_id="test-id", - posthog_properties={"foo": "bar"}, - ) - - # Consume the stream - chunks = list(response) - assert len(chunks) == 3 - assert chunks[0].content == "A" - assert chunks[1].content == "B" - assert chunks[2].content == "C" - - # Wait a bit to ensure the capture is called - time.sleep(0.1) - assert mock_client.capture.call_count == 1 - - call_args = mock_client.capture.call_args[1] - props = call_args["properties"] - - assert call_args["distinct_id"] == "test-id" - assert call_args["event"] == "$ai_generation" - assert props["$ai_provider"] == "anthropic" - assert props["$ai_model"] == "claude-3-opus-20240229" - assert props["$ai_input"] == [{"role": "user", "content": "Hello"}] - assert props["$ai_output_choices"] == [{"role": "assistant", "content": "ABC"}] - assert props["$ai_input_tokens"] == 20 - assert props["$ai_output_tokens"] == 10 - assert isinstance(props["$ai_latency"], float) - assert props["foo"] == "bar" - - -def test_streaming_with_stream_endpoint(mock_client, mock_anthropic_stream): - with patch( - "anthropic.resources.Messages.create", return_value=mock_anthropic_stream - ): - client = Anthropic(api_key="test-key", posthog_client=mock_client) - response = client.messages.stream( - model="claude-3-opus-20240229", - messages=[{"role": "user", "content": "Hello"}], - posthog_distinct_id="test-id", - posthog_properties={"foo": "bar"}, - ) - - # Consume the stream - chunks = list(response) - assert len(chunks) == 3 - assert chunks[0].content == "A" - assert chunks[1].content == "B" - assert chunks[2].content == "C" - - # Wait a bit to ensure the capture is called - time.sleep(0.1) - assert mock_client.capture.call_count == 1 - - call_args = mock_client.capture.call_args[1] - props = call_args["properties"] - - assert call_args["distinct_id"] == "test-id" - assert call_args["event"] == "$ai_generation" - assert props["$ai_provider"] == "anthropic" - assert props["$ai_model"] == "claude-3-opus-20240229" - assert props["$ai_input"] == [{"role": "user", "content": "Hello"}] - assert props["$ai_output_choices"] == [{"role": "assistant", "content": "ABC"}] - assert props["$ai_input_tokens"] == 20 - assert props["$ai_output_tokens"] == 10 - assert isinstance(props["$ai_latency"], float) - assert props["foo"] == "bar" - - def test_groups(mock_client, mock_anthropic_response): with patch( "anthropic.resources.Messages.create", return_value=mock_anthropic_response @@ -315,16 +375,22 @@ def test_privacy_mode_global(mock_client, mock_anthropic_response): @pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") def test_basic_integration(mock_client): - client = Anthropic(posthog_client=mock_client) - client.messages.create( - model="claude-3-opus-20240229", - messages=[{"role": "user", "content": "Foo"}], - max_tokens=1, - temperature=0, - posthog_distinct_id="test-id", - posthog_properties={"foo": "bar"}, - system="You must always answer with 'Bar'.", - ) + """Test basic non-streaming integration.""" + + with patch( + "anthropic.resources.Messages.create", + return_value=create_mock_response(), + ): + client = Anthropic(posthog_client=mock_client) + client.messages.create( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "Foo"}], + max_tokens=1, + temperature=0, + posthog_distinct_id="test-id", + posthog_properties={"foo": "bar"}, + system="You must always answer with 'Bar'.", + ) assert mock_client.capture.call_count == 1 @@ -351,15 +417,27 @@ def test_basic_integration(mock_client): @pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") async def test_basic_async_integration(mock_client): - client = AsyncAnthropic(posthog_client=mock_client) - await client.messages.create( - model="claude-3-opus-20240229", - messages=[{"role": "user", "content": "You must always answer with 'Bar'."}], - max_tokens=1, - temperature=0, - posthog_distinct_id="test-id", - posthog_properties={"foo": "bar"}, - ) + """Test async non-streaming integration.""" + + # Make the mock async + async def mock_async_create(**kwargs): + return create_mock_response(input_tokens=16) + + with patch( + "anthropic.resources.messages.AsyncMessages.create", + side_effect=mock_async_create, + ): + client = AsyncAnthropic(posthog_client=mock_client) + await client.messages.create( + model="claude-3-opus-20240229", + messages=[ + {"role": "user", "content": "You must always answer with 'Bar'."} + ], + max_tokens=1, + temperature=0, + posthog_distinct_id="test-id", + posthog_properties={"foo": "bar"}, + ) assert mock_client.capture.call_count == 1 @@ -381,52 +459,51 @@ async def test_basic_async_integration(mock_client): assert isinstance(props["$ai_latency"], float) -def test_streaming_system_prompt(mock_client, mock_anthropic_stream): +@pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") +async def test_async_streaming_system_prompt(mock_client): + """Test async streaming with system prompt.""" + + # Create a simple mock async stream using reusable helpers + async def mock_async_stream(): + # Yield some events + yield MockStreamEvent(type="message_start") + yield MockStreamEvent(type="content_block_start") + yield MockStreamEvent(type="content_block_delta", text="Bar") + + # Final message with usage + final_msg = MockStreamEvent(type="message_delta") + final_msg.usage = MockUsage( + input_tokens=10, + output_tokens=5, + cache_read_input_tokens=0, + cache_creation_input_tokens=0, + ) + yield final_msg + + # Mock create to return a coroutine that yields the async generator + # This matches the actual behavior when stream=True with await + async def async_create_wrapper(**kwargs): + return mock_async_stream() + with patch( - "anthropic.resources.Messages.create", return_value=mock_anthropic_stream + "anthropic.resources.messages.AsyncMessages.create", + side_effect=async_create_wrapper, ): - client = Anthropic(api_key="test-key", posthog_client=mock_client) - response = client.messages.create( + client = AsyncAnthropic(posthog_client=mock_client) + response = await client.messages.create( model="claude-3-opus-20240229", - system="Foo", - messages=[{"role": "user", "content": "Bar"}], + system="You must always answer with 'Bar'.", + messages=[{"role": "user", "content": "Foo"}], stream=True, + max_tokens=1, ) - # Consume the stream - list(response) + # Consume the stream - async finally block completes before this returns + [c async for c in response] - # Wait a bit to ensure the capture is called - time.sleep(0.1) + # Capture happens in the async finally block before generator completes assert mock_client.capture.call_count == 1 - call_args = mock_client.capture.call_args[1] - props = call_args["properties"] - - assert props["$ai_input"] == [ - {"role": "system", "content": "Foo"}, - {"role": "user", "content": "Bar"}, - ] - - -@pytest.mark.skipif(not ANTHROPIC_API_KEY, reason="ANTHROPIC_API_KEY is not set") -async def test_async_streaming_system_prompt(mock_client, mock_anthropic_stream): - client = AsyncAnthropic(posthog_client=mock_client) - response = await client.messages.create( - model="claude-3-opus-20240229", - system="You must always answer with 'Bar'.", - messages=[{"role": "user", "content": "Foo"}], - stream=True, - max_tokens=1, - ) - - # Consume the stream - [c async for c in response] - - # Wait a bit to ensure the capture is called - time.sleep(0.1) - assert mock_client.capture.call_count == 1 - call_args = mock_client.capture.call_args[1] props = call_args["properties"] @@ -746,3 +823,219 @@ async def run_test(): assert props["$ai_input_tokens"] == 25 assert props["$ai_output_tokens"] == 15 assert props["$ai_http_status"] == 200 + + +def test_streaming_with_tool_calls(mock_client, mock_anthropic_stream_with_tools): + """Test that tool calls are properly captured in streaming mode.""" + with patch( + "anthropic.resources.Messages.create", + return_value=mock_anthropic_stream_with_tools, + ): + client = Anthropic(api_key="test-key", posthog_client=mock_client) + response = client.messages.create( + model="claude-3-5-sonnet-20241022", + system="You are a helpful weather assistant.", + messages=[ + {"role": "user", "content": "What's the weather in San Francisco?"} + ], + tools=[ + { + "name": "get_weather", + "description": "Get weather information", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string"}, + }, + "required": ["location"], + }, + } + ], + stream=True, + posthog_distinct_id="test-id", + ) + + # Consume the stream - this triggers the finally block synchronously + list(response) + + # Capture happens synchronously when generator is exhausted + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "anthropic" + assert props["$ai_model"] == "claude-3-5-sonnet-20241022" + + # Verify system prompt is included in input + assert props["$ai_input"] == [ + {"role": "system", "content": "You are a helpful weather assistant."}, + {"role": "user", "content": "What's the weather in San Francisco?"}, + ] + + # Verify that tools are captured in the properties + assert props["$ai_tools"] == [ + { + "name": "get_weather", + "description": "Get weather information", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string"}, + }, + "required": ["location"], + }, + } + ] + + # Verify output contains both text and tool call + output_choices = props["$ai_output_choices"] + assert len(output_choices) == 1 + + assistant_message = output_choices[0] + assert assistant_message["role"] == "assistant" + + content = assistant_message["content"] + assert isinstance(content, list) + assert len(content) == 2 + + # Verify text block + text_block = content[0] + assert text_block["type"] == "text" + assert text_block["text"] == "I'll check the weather for you." + + # Verify tool call block + tool_block = content[1] + assert tool_block["type"] == "function" + assert tool_block["id"] == "toolu_stream123" + assert tool_block["function"]["name"] == "get_weather" + assert tool_block["function"]["arguments"] == { + "location": "San Francisco", + "unit": "celsius", + } + + # Check token usage + assert props["$ai_input_tokens"] == 50 + assert props["$ai_output_tokens"] == 25 + assert props["$ai_cache_read_input_tokens"] == 5 + assert props["$ai_cache_creation_input_tokens"] == 0 + + +def test_async_streaming_with_tool_calls(mock_client, mock_anthropic_stream_with_tools): + """Test that tool calls are properly captured in async streaming mode.""" + import asyncio + + async def mock_async_generator(): + # Convert regular generator to async generator + for event in mock_anthropic_stream_with_tools: + yield event + + async def mock_async_create(**kwargs): + # Return the async generator (to be awaited by the implementation) + return mock_async_generator() + + with patch( + "anthropic.resources.AsyncMessages.create", + side_effect=mock_async_create, + ): + async_client = AsyncAnthropic(api_key="test-key", posthog_client=mock_client) + + async def run_test(): + response = await async_client.messages.create( + model="claude-3-5-sonnet-20241022", + system="You are a helpful weather assistant.", + messages=[ + {"role": "user", "content": "What's the weather in San Francisco?"} + ], + tools=[ + { + "name": "get_weather", + "description": "Get weather information", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string"}, + }, + "required": ["location"], + }, + } + ], + stream=True, + posthog_distinct_id="test-id", + ) + + # Consume the async stream + [event async for event in response] + + # asyncio.run() waits for all async operations to complete + asyncio.run(run_test()) + + # Capture completes before asyncio.run() returns + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "anthropic" + assert props["$ai_model"] == "claude-3-5-sonnet-20241022" + + # Verify system prompt is included in input + assert props["$ai_input"] == [ + {"role": "system", "content": "You are a helpful weather assistant."}, + {"role": "user", "content": "What's the weather in San Francisco?"}, + ] + + # Verify that tools are captured in the properties + assert props["$ai_tools"] == [ + { + "name": "get_weather", + "description": "Get weather information", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string"}, + }, + "required": ["location"], + }, + } + ] + + # Verify output contains both text and tool call + output_choices = props["$ai_output_choices"] + assert len(output_choices) == 1 + + assistant_message = output_choices[0] + assert assistant_message["role"] == "assistant" + + content = assistant_message["content"] + assert isinstance(content, list) + assert len(content) == 2 + + # Verify text block + text_block = content[0] + assert text_block["type"] == "text" + assert text_block["text"] == "I'll check the weather for you." + + # Verify tool call block + tool_block = content[1] + assert tool_block["type"] == "function" + assert tool_block["id"] == "toolu_stream123" + assert tool_block["function"]["name"] == "get_weather" + assert tool_block["function"]["arguments"] == { + "location": "San Francisco", + "unit": "celsius", + } + + # Check token usage + assert props["$ai_input_tokens"] == 50 + assert props["$ai_output_tokens"] == 25 + assert props["$ai_cache_read_input_tokens"] == 5 + assert props["$ai_cache_creation_input_tokens"] == 0 diff --git a/posthog/test/ai/gemini/test_gemini.py b/posthog/test/ai/gemini/test_gemini.py index 4dfcc9e0..f874ce4a 100644 --- a/posthog/test/ai/gemini/test_gemini.py +++ b/posthog/test/ai/gemini/test_gemini.py @@ -226,6 +226,87 @@ def mock_streaming_response(): assert isinstance(props["$ai_latency"], float) +def test_new_client_streaming_with_tools(mock_client, mock_google_genai_client): + """Test that tools are captured in streaming mode""" + + def mock_streaming_response(): + mock_chunk1 = MagicMock() + mock_chunk1.text = "I'll check " + mock_usage1 = MagicMock() + mock_usage1.prompt_token_count = 15 + mock_usage1.candidates_token_count = 5 + mock_chunk1.usage_metadata = mock_usage1 + + mock_chunk2 = MagicMock() + mock_chunk2.text = "the weather" + mock_usage2 = MagicMock() + mock_usage2.prompt_token_count = 15 + mock_usage2.candidates_token_count = 10 + mock_chunk2.usage_metadata = mock_usage2 + + yield mock_chunk1 + yield mock_chunk2 + + # Mock the generate_content_stream method + mock_google_genai_client.models.generate_content_stream.return_value = ( + mock_streaming_response() + ) + + client = Client(api_key="test-key", posthog_client=mock_client) + + # Create mock tools configuration + mock_tool = MagicMock() + mock_tool.function_declarations = [ + MagicMock( + name="get_current_weather", + description="Gets the current weather for a given location.", + parameters=MagicMock( + type="OBJECT", + properties={ + "location": MagicMock( + type="STRING", + description="The city and state, e.g. San Francisco, CA", + ) + }, + required=["location"], + ), + ) + ] + + mock_config = MagicMock() + mock_config.tools = [mock_tool] + + response = client.models.generate_content_stream( + model="gemini-2.0-flash", + contents=["What's the weather in SF?"], + config=mock_config, + posthog_distinct_id="test-id", + posthog_properties={"feature": "streaming_with_tools"}, + ) + + chunks = list(response) + assert len(chunks) == 2 + assert chunks[0].text == "I'll check " + assert chunks[1].text == "the weather" + + # Check that the streaming event was captured with tools + assert mock_client.capture.call_count == 1 + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "gemini" + assert props["$ai_model"] == "gemini-2.0-flash" + assert props["$ai_input_tokens"] == 15 + assert props["$ai_output_tokens"] == 10 + assert props["feature"] == "streaming_with_tools" + assert isinstance(props["$ai_latency"], float) + + # Verify that tools are captured in the $ai_tools property in streaming mode + assert props["$ai_tools"] == [mock_tool] + + def test_new_client_groups(mock_client, mock_google_genai_client, mock_gemini_response): """Test groups functionality with new Client API""" mock_google_genai_client.models.generate_content.return_value = mock_gemini_response @@ -302,12 +383,32 @@ def test_new_client_different_input_formats( props = call_args["properties"] assert props["$ai_input"] == [{"role": "user", "content": "Hello"}] - # Test list input + # Test Gemini-specific format with parts array (like in the screenshot) + mock_client.reset_mock() + client.models.generate_content( + model="gemini-2.0-flash", + contents=[{"role": "user", "parts": [{"text": "hey"}]}], + posthog_distinct_id="test-id", + ) + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] == [{"role": "user", "content": "hey"}] + + # Test multiple parts in the parts array + mock_client.reset_mock() + client.models.generate_content( + model="gemini-2.0-flash", + contents=[{"role": "user", "parts": [{"text": "Hello "}, {"text": "world"}]}], + posthog_distinct_id="test-id", + ) + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + assert props["$ai_input"] == [{"role": "user", "content": "Hello world"}] + + # Test list input with string mock_client.capture.reset_mock() - mock_part = MagicMock() - mock_part.text = "List item" client.models.generate_content( - model="gemini-2.0-flash", contents=[mock_part], posthog_distinct_id="test-id" + model="gemini-2.0-flash", contents=["List item"], posthog_distinct_id="test-id" ) call_args = mock_client.capture.call_args[1] props = call_args["properties"] diff --git a/posthog/test/ai/openai/test_openai.py b/posthog/test/ai/openai/test_openai.py index d16e5361..cc4d6aea 100644 --- a/posthog/test/ai/openai/test_openai.py +++ b/posthog/test/ai/openai/test_openai.py @@ -890,10 +890,29 @@ def test_streaming_with_tool_calls(mock_client): assert defined_tool["function"]["description"] == "Get weather" assert defined_tool["function"]["parameters"] == {} - # Check that the content was also accumulated + # Check that both text content and tool calls were accumulated + output_content = props["$ai_output_choices"][0]["content"] + + # Find text content and tool call in the output + text_content = None + tool_call_content = None + for item in output_content: + if item["type"] == "text": + text_content = item + elif item["type"] == "function": + tool_call_content = item + + # Verify text content + assert text_content is not None + assert text_content["text"] == "The weather in San Francisco is 15°C." + + # Verify tool call was captured + assert tool_call_content is not None + assert tool_call_content["id"] == "call_abc123" + assert tool_call_content["function"]["name"] == "get_weather" assert ( - props["$ai_output_choices"][0]["content"] - == "The weather in San Francisco is 15°C." + tool_call_content["function"]["arguments"] + == '{"location": "San Francisco", "unit": "celsius"}' ) # Check token usage @@ -1014,6 +1033,85 @@ def test_responses_parse(mock_client, mock_parsed_response): assert isinstance(props["$ai_latency"], float) +def test_responses_api_streaming_with_tokens(mock_client): + """Test that Responses API streaming properly captures token usage from response.usage.""" + from openai.types.responses import ResponseUsage + from unittest.mock import MagicMock + + # Create mock response chunks with usage data in the correct location + chunks = [] + + # First chunk - just content, no usage + chunk1 = MagicMock() + chunk1.type = "response.text.delta" + chunk1.text = "Test " + chunks.append(chunk1) + + # Second chunk - more content + chunk2 = MagicMock() + chunk2.type = "response.text.delta" + chunk2.text = "response" + chunks.append(chunk2) + + # Final chunk - completed event with usage in response.usage + chunk3 = MagicMock() + chunk3.type = "response.completed" + chunk3.response = MagicMock() + chunk3.response.usage = ResponseUsage( + input_tokens=25, + output_tokens=30, + total_tokens=55, + input_tokens_details={"prompt_tokens": 25, "cached_tokens": 0}, + output_tokens_details={"reasoning_tokens": 0}, + ) + chunk3.response.output = ["Test response"] + chunks.append(chunk3) + + captured_kwargs = {} + + def mock_streaming_response(**kwargs): + # Capture the kwargs to verify stream_options was NOT added + captured_kwargs.update(kwargs) + return iter(chunks) + + with patch( + "openai.resources.responses.Responses.create", + side_effect=mock_streaming_response, + ): + client = OpenAI(api_key="test-key", posthog_client=mock_client) + + # Consume the streaming response + response = client.responses.create( + model="gpt-4o-mini", + input=[{"role": "user", "content": "Test message"}], + stream=True, + posthog_distinct_id="test-id", + posthog_properties={"test": "streaming"}, + ) + + # Consume all chunks + list(response) + + # Verify stream_options was NOT added (Responses API doesn't support it) + assert "stream_options" not in captured_kwargs + + # Verify capture was called + assert mock_client.capture.call_count == 1 + + call_args = mock_client.capture.call_args[1] + props = call_args["properties"] + + # Verify tokens are captured correctly from response.usage (not 0) + assert call_args["distinct_id"] == "test-id" + assert call_args["event"] == "$ai_generation" + assert props["$ai_provider"] == "openai" + assert props["$ai_model"] == "gpt-4o-mini" + assert props["$ai_input_tokens"] == 25 # Should not be 0 + assert props["$ai_output_tokens"] == 30 # Should not be 0 + assert props["test"] == "streaming" + assert isinstance(props["$ai_latency"], float) + + def test_tool_definition(mock_client, mock_openai_response): """Test that tools defined in the create function are captured in $ai_tools property""" with patch( diff --git a/posthog/version.py b/posthog/version.py index e9674076..7cda3ab0 100644 --- a/posthog/version.py +++ b/posthog/version.py @@ -1,4 +1,4 @@ -VERSION = "6.7.1" +VERSION = "6.7.2" if __name__ == "__main__": print(VERSION, end="") # noqa: T201