diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c284421..446992f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 4.9.0 - 2025-06-13 + +- feat(ai): track reasoning and cache tokens in the LangChain callback + ## 4.8.0 - 2025-06-10 - fix: export scoped, rather than tracked, decorator diff --git a/posthog/ai/langchain/callbacks.py b/posthog/ai/langchain/callbacks.py index c3106055..763d6880 100644 --- a/posthog/ai/langchain/callbacks.py +++ b/posthog/ai/langchain/callbacks.py @@ -14,7 +14,6 @@ List, Optional, Sequence, - Tuple, Union, cast, ) @@ -569,9 +568,14 @@ def _capture_generation( event_properties["$ai_is_error"] = True else: # Add usage - input_tokens, output_tokens = _parse_usage(output) - event_properties["$ai_input_tokens"] = input_tokens - event_properties["$ai_output_tokens"] = output_tokens + usage = _parse_usage(output) + event_properties["$ai_input_tokens"] = usage.input_tokens + event_properties["$ai_output_tokens"] = usage.output_tokens + event_properties["$ai_cache_creation_input_tokens"] = ( + usage.cache_write_tokens + ) + event_properties["$ai_cache_read_input_tokens"] = usage.cache_read_tokens + event_properties["$ai_reasoning_tokens"] = usage.reasoning_tokens # Generation results generation_result = output.generations[-1] @@ -647,9 +651,18 @@ def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: return message_dict +@dataclass +class ModelUsage: + input_tokens: Optional[int] + output_tokens: Optional[int] + cache_write_tokens: Optional[int] + cache_read_tokens: Optional[int] + reasoning_tokens: Optional[int] + + def _parse_usage_model( - usage: Union[BaseModel, Dict], -) -> Tuple[Union[int, None], Union[int, None]]: + usage: Union[BaseModel, dict], +) -> ModelUsage: if isinstance(usage, BaseModel): usage = usage.__dict__ @@ -657,15 +670,23 @@ def _parse_usage_model( # https://pypi.org/project/langchain-anthropic/ (works also for Bedrock-Anthropic) ("input_tokens", "input"), ("output_tokens", "output"), + ("cache_creation_input_tokens", "cache_write"), + ("cache_read_input_tokens", "cache_read"), # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/get-token-count ("prompt_token_count", "input"), ("candidates_token_count", "output"), + ("cached_content_token_count", "cache_read"), + ("thoughts_token_count", "reasoning"), # Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/monitoring-cw.html#runtime-cloudwatch-metrics ("inputTokenCount", "input"), ("outputTokenCount", "output"), + ("cacheCreationInputTokenCount", "cache_write"), + ("cacheReadInputTokenCount", "cache_read"), # Bedrock Anthropic ("prompt_tokens", "input"), ("completion_tokens", "output"), + ("cache_creation_input_tokens", "cache_write"), + ("cache_read_input_tokens", "cache_read"), # langchain-ibm https://pypi.org/project/langchain-ibm/ ("input_token_count", "input"), ("generated_token_count", "output"), @@ -683,13 +704,45 @@ def _parse_usage_model( parsed_usage[type_key] = final_count - return parsed_usage.get("input"), parsed_usage.get("output") + # Caching (OpenAI & langchain 0.3.9+) + if "input_token_details" in usage and isinstance( + usage["input_token_details"], dict + ): + parsed_usage["cache_write"] = usage["input_token_details"].get("cache_creation") + parsed_usage["cache_read"] = usage["input_token_details"].get("cache_read") + + # Reasoning (OpenAI & langchain 0.3.9+) + if "output_token_details" in usage and isinstance( + usage["output_token_details"], dict + ): + parsed_usage["reasoning"] = usage["output_token_details"].get("reasoning") + + field_mapping = { + "input": "input_tokens", + "output": "output_tokens", + "cache_write": "cache_write_tokens", + "cache_read": "cache_read_tokens", + "reasoning": "reasoning_tokens", + } + return ModelUsage( + **{ + dataclass_key: parsed_usage.get(mapped_key) or 0 + for mapped_key, dataclass_key in field_mapping.items() + }, + ) -def _parse_usage(response: LLMResult): +def _parse_usage(response: LLMResult) -> ModelUsage: # langchain-anthropic uses the usage field llm_usage_keys = ["token_usage", "usage"] - llm_usage: Tuple[Union[int, None], Union[int, None]] = (None, None) + llm_usage: ModelUsage = ModelUsage( + input_tokens=None, + output_tokens=None, + cache_write_tokens=None, + cache_read_tokens=None, + reasoning_tokens=None, + ) + if response.llm_output is not None: for key in llm_usage_keys: if response.llm_output.get(key): diff --git a/posthog/test/ai/langchain/test_callbacks.py b/posthog/test/ai/langchain/test_callbacks.py index 3e9a01b4..fdfdfc6c 100644 --- a/posthog/test/ai/langchain/test_callbacks.py +++ b/posthog/test/ai/langchain/test_callbacks.py @@ -1378,11 +1378,11 @@ def get_weather(city: Literal["nyc", "sf"]): ) graph.invoke(inputs, config={"callbacks": [cb]}) calls = [call[1] for call in mock_client.capture.call_args_list] - assert len(calls) == 21 + assert len(calls) == 15 for call in calls: assert call["properties"]["$ai_trace_id"] == "test-trace-id" assert len([call for call in calls if call["event"] == "$ai_generation"]) == 2 - assert len([call for call in calls if call["event"] == "$ai_span"]) == 18 + assert len([call for call in calls if call["event"] == "$ai_span"]) == 12 assert len([call for call in calls if call["event"] == "$ai_trace"]) == 1 @@ -1435,11 +1435,13 @@ def span_3(_): assert mock_client.capture.call_count == 3 - span2, span1, trace = [ - call[1]["properties"] for call in mock_client.capture.call_args_list - ] - assert span2["$ai_parent_id"] == span1["$ai_span_id"] - assert span1["$ai_parent_id"] == trace["$ai_trace_id"] + calls = mock_client.capture.call_args_list + span_props_2 = calls[0][1]["properties"] + span_props_1 = calls[1][1]["properties"] + trace_props = calls[2][1]["properties"] + + assert span_props_2["$ai_parent_id"] == span_props_1["$ai_span_id"] + assert span_props_1["$ai_parent_id"] == trace_props["$ai_trace_id"] def test_captures_error_with_details_in_span(mock_client): @@ -1478,3 +1480,250 @@ def span(_): == "ValueError" ) assert mock_client.capture.call_args_list[1][1]["properties"]["$ai_is_error"] + + +def test_openai_reasoning_tokens(mock_client): + """Test that OpenAI reasoning tokens are captured correctly.""" + prompt = ChatPromptTemplate.from_messages( + [("user", "Think step by step about this problem")] + ) + + # Mock response with reasoning tokens in output_token_details + model = FakeMessagesListChatModel( + responses=[ + AIMessage( + content="Let me think through this step by step...", + usage_metadata={ + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35, + "output_token_details": {"reasoning": 15}, # 15 reasoning tokens + }, + ) + ] + ) + + callbacks = [CallbackHandler(mock_client)] + chain = prompt | model + result = chain.invoke({}, config={"callbacks": callbacks}) + + assert result.content == "Let me think through this step by step..." + assert mock_client.capture.call_count == 3 + + generation_args = mock_client.capture.call_args_list[1][1] + generation_props = generation_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert generation_props["$ai_input_tokens"] == 10 + assert generation_props["$ai_output_tokens"] == 25 + assert generation_props["$ai_reasoning_tokens"] == 15 + + +def test_anthropic_cache_write_and_read_tokens(mock_client): + """Test that Anthropic cache creation and read tokens are captured correctly.""" + prompt = ChatPromptTemplate.from_messages([("user", "Analyze this large document")]) + + # First call with cache creation + model_write = FakeMessagesListChatModel( + responses=[ + AIMessage( + content="I've analyzed the document and cached the context.", + usage_metadata={ + "total_tokens": 1050, + "input_tokens": 1000, + "output_tokens": 50, + "cache_creation_input_tokens": 800, # Anthropic cache write + }, + ) + ] + ) + + callbacks = [CallbackHandler(mock_client)] + chain = prompt | model_write + result = chain.invoke({}, config={"callbacks": callbacks}) + + assert result.content == "I've analyzed the document and cached the context." + assert mock_client.capture.call_count == 3 + + generation_args = mock_client.capture.call_args_list[1][1] + generation_props = generation_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert generation_props["$ai_input_tokens"] == 1000 + assert generation_props["$ai_output_tokens"] == 50 + assert generation_props["$ai_cache_creation_input_tokens"] == 800 + assert generation_props["$ai_cache_read_input_tokens"] == 0 + assert generation_props["$ai_reasoning_tokens"] == 0 + + # Reset mock for second call + mock_client.reset_mock() + + # Second call with cache read + model_read = FakeMessagesListChatModel( + responses=[ + AIMessage( + content="Using cached analysis to provide quick response.", + usage_metadata={ + "input_tokens": 200, + "output_tokens": 30, + "total_tokens": 1030, + "cache_read_input_tokens": 800, # Anthropic cache read + }, + ) + ] + ) + + chain = prompt | model_read + result = chain.invoke({}, config={"callbacks": callbacks}) + + assert result.content == "Using cached analysis to provide quick response." + assert mock_client.capture.call_count == 3 + + generation_args = mock_client.capture.call_args_list[1][1] + generation_props = generation_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert generation_props["$ai_input_tokens"] == 200 + assert generation_props["$ai_output_tokens"] == 30 + assert generation_props["$ai_cache_creation_input_tokens"] == 0 + assert generation_props["$ai_cache_read_input_tokens"] == 800 + assert generation_props["$ai_reasoning_tokens"] == 0 + + +def test_openai_cache_read_tokens(mock_client): + """Test that OpenAI cache read tokens are captured correctly.""" + prompt = ChatPromptTemplate.from_messages( + [("user", "Use the cached prompt for this request")] + ) + + # Mock response with cache read tokens in input_token_details + model = FakeMessagesListChatModel( + responses=[ + AIMessage( + content="Response using cached prompt context.", + usage_metadata={ + "input_tokens": 150, + "output_tokens": 40, + "total_tokens": 190, + "input_token_details": { + "cache_read": 100, # 100 tokens read from cache + "cache_creation": 0, + }, + }, + ) + ] + ) + + callbacks = [CallbackHandler(mock_client)] + chain = prompt | model + result = chain.invoke({}, config={"callbacks": callbacks}) + + assert result.content == "Response using cached prompt context." + assert mock_client.capture.call_count == 3 + + generation_args = mock_client.capture.call_args_list[1][1] + generation_props = generation_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert generation_props["$ai_input_tokens"] == 150 + assert generation_props["$ai_output_tokens"] == 40 + assert generation_props["$ai_cache_read_input_tokens"] == 100 + assert generation_props["$ai_cache_creation_input_tokens"] == 0 + assert generation_props["$ai_reasoning_tokens"] == 0 + + +def test_openai_cache_creation_tokens(mock_client): + """Test that OpenAI cache creation tokens are captured correctly.""" + prompt = ChatPromptTemplate.from_messages( + [("user", "Create a cache for this large prompt context")] + ) + + # Mock response with cache creation tokens in input_token_details + model = FakeMessagesListChatModel( + responses=[ + AIMessage( + content="Created cache for the prompt context.", + usage_metadata={ + "input_tokens": 2000, + "output_tokens": 25, + "total_tokens": 2025, + "input_token_details": { + "cache_creation": 1500, # 1500 tokens written to cache + "cache_read": 0, + }, + }, + ) + ] + ) + + callbacks = [CallbackHandler(mock_client)] + chain = prompt | model + result = chain.invoke({}, config={"callbacks": callbacks}) + + assert result.content == "Created cache for the prompt context." + assert mock_client.capture.call_count == 3 + + generation_args = mock_client.capture.call_args_list[1][1] + generation_props = generation_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert generation_props["$ai_input_tokens"] == 2000 + assert generation_props["$ai_output_tokens"] == 25 + assert generation_props["$ai_cache_creation_input_tokens"] == 1500 + assert generation_props["$ai_cache_read_input_tokens"] == 0 + assert generation_props["$ai_reasoning_tokens"] == 0 + + +def test_combined_reasoning_and_cache_tokens(mock_client): + """Test that both reasoning tokens and cache tokens can be captured together.""" + prompt = ChatPromptTemplate.from_messages( + [("user", "Think through this cached problem")] + ) + + # Mock response with both reasoning and cache tokens + model = FakeMessagesListChatModel( + responses=[ + AIMessage( + content="Let me reason through this using cached context...", + usage_metadata={ + "input_tokens": 500, + "output_tokens": 100, + "total_tokens": 600, + "input_token_details": {"cache_read": 300, "cache_creation": 0}, + "output_token_details": {"reasoning": 60}, # 60 reasoning tokens + }, + ) + ] + ) + + callbacks = [CallbackHandler(mock_client)] + chain = prompt | model + result = chain.invoke({}, config={"callbacks": callbacks}) + + assert result.content == "Let me reason through this using cached context..." + assert mock_client.capture.call_count == 3 + + generation_args = mock_client.capture.call_args_list[1][1] + generation_props = generation_args["properties"] + + assert generation_args["event"] == "$ai_generation" + assert generation_props["$ai_input_tokens"] == 500 + assert generation_props["$ai_output_tokens"] == 100 + assert generation_props["$ai_cache_read_input_tokens"] == 300 + assert generation_props["$ai_cache_creation_input_tokens"] == 0 + assert generation_props["$ai_reasoning_tokens"] == 60 + + +@pytest.mark.skipif(not OPENAI_API_KEY, reason="OPENAI_API_KEY is not set") +def test_openai_reasoning_tokens(mock_client): + model = ChatOpenAI( + api_key=OPENAI_API_KEY, model="o4-mini", max_completion_tokens=10 + ) + cb = CallbackHandler( + mock_client, trace_id="test-trace-id", distinct_id="test-distinct-id" + ) + model.invoke("what is the weather in sf", config={"callbacks": [cb]}) + call = mock_client.capture.call_args_list[0][1] + assert call["properties"]["$ai_reasoning_tokens"] is not None + assert call["properties"]["$ai_input_tokens"] is not None + assert call["properties"]["$ai_output_tokens"] is not None diff --git a/posthog/version.py b/posthog/version.py index 424f4a85..f19455b9 100644 --- a/posthog/version.py +++ b/posthog/version.py @@ -1,4 +1,4 @@ -VERSION = "4.8.0" +VERSION = "4.9.0" if __name__ == "__main__": print(VERSION, end="") # noqa: T201 diff --git a/pyproject.toml b/pyproject.toml index 383810ba..00485529 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,10 +68,11 @@ test = [ "django", "openai", "anthropic", - "langgraph", - "langchain-community>=0.2.0", - "langchain-openai>=0.2.0", - "langchain-anthropic>=0.2.0", + "langgraph>=0.4.8", + "langchain-core>=0.3.65", + "langchain-community>=0.3.25", + "langchain-openai>=0.3.22", + "langchain-anthropic>=0.3.15", "google-genai", "pydantic", "parameterized>=0.8.1",