diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 85b2c52803..be261e836e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -7,6 +7,7 @@ from datetime import datetime from typing import Any, Literal, cast, overload +from genai_prices import extract_usage from pydantic import TypeAdapter from typing_extensions import assert_never @@ -351,7 +352,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: return ModelResponse( parts=items, - usage=_map_usage(response), + usage=_map_usage(response, self._provider.name, self._model_name), model_name=response.model, provider_response_id=response.id, provider_name=self._provider.name, @@ -616,7 +617,12 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam: } -def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage: +def _map_usage( + message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent, + provider: str, + model: str, + existing_usage: usage.RequestUsage | None = None, +) -> usage.RequestUsage: if isinstance(message, BetaMessage): response_usage = message.usage elif isinstance(message, BetaRawMessageStartEvent): @@ -626,24 +632,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageD else: assert_never(message) - # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by - # `response_tokens` - details: dict[str, int] = { + # In streaming, usage appears in different events. + # The values are cumulative, meaning new values should replace existing ones entirely. + details: dict[str, int] = (existing_usage.details if existing_usage else {}) | { key: value for key, value in response_usage.model_dump().items() if isinstance(value, int) } - # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get` - # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens - # This approach maintains request_tokens as the count of all input tokens, with cached counts as details - cache_write_tokens = details.get('cache_creation_input_tokens', 0) - cache_read_tokens = details.get('cache_read_input_tokens', 0) - request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens + extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider) return usage.RequestUsage( - input_tokens=request_tokens, - cache_read_tokens=cache_read_tokens, - cache_write_tokens=cache_write_tokens, - output_tokens=response_usage.output_tokens, + **{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)}, details=details, ) @@ -662,7 +660,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for event in self._response: if isinstance(event, BetaRawMessageStartEvent): - self._usage = _map_usage(event) + self._usage = _map_usage(event, self._provider_name, self._model_name) self.provider_response_id = event.message.id elif isinstance(event, BetaRawContentBlockStartEvent): @@ -743,7 +741,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass elif isinstance(event, BetaRawMessageDeltaEvent): - self._usage = _map_usage(event) + self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage) if raw_finish_reason := event.delta.stop_reason: # pragma: no branch self.provider_details = {'finish_reason': raw_finish_reason} self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index de6e164d4b..f73ffcdab4 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -60,7 +60,7 @@ dependencies = [ "exceptiongroup; python_version < '3.11'", "opentelemetry-api>=1.28.0", "typing-inspection>=0.4.0", - "genai-prices>=0.0.28", + "genai-prices>=0.0.30", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 8d63d89f60..a3c25d97fb 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1439,7 +1439,16 @@ async def test_anthropic_model_thinking_part_stream(allow_model_requests: None, ), TextPart(content=IsStr()), ], - usage=RequestUsage(output_tokens=419, details={'output_tokens': 419}), + usage=RequestUsage( + input_tokens=42, + output_tokens=419, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 42, + 'output_tokens': 419, + }, + ), model_name='claude-3-7-sonnet-20250219', timestamp=IsDatetime(), provider_name='anthropic', @@ -1716,21 +1725,22 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: snapshot(RequestUsage(input_tokens=1, output_tokens=1, details={'input_tokens': 1, 'output_tokens': 1})), id='RawMessageStartEvent', ), - pytest.param( - lambda: BetaRawMessageDeltaEvent( - delta=Delta(), - usage=BetaMessageDeltaUsage(output_tokens=5), - type='message_delta', - ), - snapshot(RequestUsage(output_tokens=5, details={'output_tokens': 5})), - id='RawMessageDeltaEvent', - ), ], ) def test_usage( message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage ): - assert _map_usage(message_callback()) == usage + assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage + + +def test_streaming_usage(): + start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start') + initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5') + delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta') + final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage) + assert final_usage == snapshot( + RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5}) + ) async def test_anthropic_model_empty_message_on_history(allow_model_requests: None, anthropic_api_key: str): diff --git a/uv.lock b/uv.lock index 548ccfddce..f496e39ff8 100644 --- a/uv.lock +++ b/uv.lock @@ -1198,16 +1198,16 @@ http = [ [[package]] name = "genai-prices" -version = "0.0.28" +version = "0.0.30" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "eval-type-backport", marker = "python_full_version < '3.11'" }, { name = "httpx" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/89/c8/245848db2b8b1b393bffed5397e77fe9ea3cba53974639f2b80cf804388e/genai_prices-0.0.28.tar.gz", hash = "sha256:9928aa1c4f0535bec5b01ddfd3dc163888ee90cdc6153501de9f004f00193953", size = 45584, upload-time = "2025-09-27T03:17:08.775Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/ec/247749351ad7d49770c1856860fbd681e93aba02f7d8f40e12636a79bf3b/genai_prices-0.0.30.tar.gz", hash = "sha256:3d176d4c0366a58b0480c6929d1c101c989efa5ffeaae3cb53780811d08ac2fa", size = 45722, upload-time = "2025-10-07T18:56:21.6Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/9b/6f161c4fd141c0249621687cdcdd91734e6d0e11af74ded8b036de39e9ce/genai_prices-0.0.28-py3-none-any.whl", hash = "sha256:5b879bd02e901ffc6740d42eaffb9f5dfeb8c99f4a6b9e680b0046c84fb2de1b", size = 48187, upload-time = "2025-09-27T03:17:07.472Z" }, + { url = "https://files.pythonhosted.org/packages/3e/29/364374f38ec219f0831c062d72190cf45f133392bc1c74aa2ff9c0db4788/genai_prices-0.0.30-py3-none-any.whl", hash = "sha256:8fa326931957c95cf30b13cc663973ae8e51e619b787b1c4e2d3e58860f70adb", size = 48337, upload-time = "2025-10-07T18:56:20.231Z" }, ] [[package]] @@ -3370,7 +3370,7 @@ requires-dist = [ { name = "ddgs", marker = "extra == 'duckduckgo'", specifier = ">=9.0.0" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, - { name = "genai-prices", specifier = ">=0.0.28" }, + { name = "genai-prices", specifier = ">=0.0.30" }, { name = "google-auth", marker = "extra == 'vertexai'", specifier = ">=2.36.0" }, { name = "google-genai", marker = "extra == 'google'", specifier = ">=1.31.0" }, { name = "griffe", specifier = ">=1.3.2" },