diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index be261e836e..497a03a4f0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -7,7 +7,6 @@ from datetime import datetime from typing import Any, Literal, cast, overload -from genai_prices import extract_usage from pydantic import TypeAdapter from typing_extensions import assert_never @@ -352,7 +351,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: return ModelResponse( parts=items, - usage=_map_usage(response, self._provider.name, self._model_name), + usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name), model_name=response.model, provider_response_id=response.id, provider_name=self._provider.name, @@ -376,6 +375,7 @@ async def _process_streamed_response( _response=peekable_response, _timestamp=_utils.now_utc(), _provider_name=self._provider.name, + _provider_url=self._provider.base_url, ) def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]: @@ -620,6 +620,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam: def _map_usage( message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent, provider: str, + provider_url: str, model: str, existing_usage: usage.RequestUsage | None = None, ) -> usage.RequestUsage: @@ -638,10 +639,11 @@ def _map_usage( key: value for key, value in response_usage.model_dump().items() if isinstance(value, int) } - extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider) - - return usage.RequestUsage( - **{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)}, + return usage.RequestUsage.extract( + dict(model=model, usage=details), + provider=provider, + provider_url=provider_url, + provider_fallback='anthropic', details=details, ) @@ -654,13 +656,14 @@ class AnthropicStreamedResponse(StreamedResponse): _response: AsyncIterable[BetaRawMessageStreamEvent] _timestamp: datetime _provider_name: str + _provider_url: str async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 current_block: BetaContentBlock | None = None async for event in self._response: if isinstance(event, BetaRawMessageStartEvent): - self._usage = _map_usage(event, self._provider_name, self._model_name) + self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name) self.provider_response_id = event.message.id elif isinstance(event, BetaRawContentBlockStartEvent): @@ -741,7 +744,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass elif isinstance(event, BetaRawMessageDeltaEvent): - self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage) + self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage) if raw_finish_reason := event.delta.stop_reason: # pragma: no branch self.provider_details = {'finish_reason': raw_finish_reason} self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 8eae608263..71556e8e94 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -3,8 +3,9 @@ import dataclasses from copy import copy from dataclasses import dataclass, fields -from typing import Annotated +from typing import Annotated, Any +from genai_prices.data_snapshot import get_snapshot from pydantic import AliasChoices, BeforeValidator, Field from typing_extensions import deprecated, overload @@ -120,6 +121,39 @@ def __add__(self, other: RequestUsage) -> RequestUsage: new_usage.incr(other) return new_usage + @classmethod + def extract( + cls, + data: Any, + *, + provider: str, + provider_url: str, + provider_fallback: str, + api_flavor: str | None = None, + details: dict[str, Any] | None = None, + ) -> RequestUsage: + """Extract usage information from the response data using genai-prices. + + Args: + data: The response data from the model API. + provider: The actual provider ID + provider_url: The provider base_url + provider_fallback: The fallback provider ID to use if the actual provider is not found in genai-prices. + For example, an OpenAI model should set this to "openai" in case it has an obscure provider ID. + api_flavor: The API flavor to use when extracting usage information, + e.g. 'chat' or 'responses' for OpenAI. + details: Becomes the `details` field on the returned `RequestUsage` for convenience. + """ + details = details or {} + for provider_id, provider_api_url in [(None, provider_url), (provider, None), (provider_fallback, None)]: + try: + provider_obj = get_snapshot().find_provider(None, provider_id, provider_api_url) + _model_ref, extracted_usage = provider_obj.extract_usage(data, api_flavor=api_flavor) + return cls(**{k: v for k, v in extracted_usage.__dict__.items() if v is not None}, details=details) + except Exception: + pass + return cls(details=details) + @dataclass(repr=False, kw_only=True) class RunUsage(UsageBase): diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index a3c25d97fb..5fe850bf45 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1730,14 +1730,14 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: def test_usage( message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage ): - assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage + assert _map_usage(message_callback(), 'anthropic', '', 'unknown') == usage def test_streaming_usage(): start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start') - initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5') + initial_usage = _map_usage(start, 'anthropic', '', 'unknown') delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta') - final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage) + final_usage = _map_usage(delta, 'anthropic', '', 'unknown', existing_usage=initial_usage) assert final_usage == snapshot( RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5}) ) diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 2eeeab15b6..d3c40ed375 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -417,3 +417,7 @@ async def tool_c() -> str: # Only the first batch of 5 tools should have executed assert len(executed_tools) == 5 + + +def test_usage_unknown_provider(): + assert RequestUsage.extract({}, provider='unknown', provider_url='', provider_fallback='') == RequestUsage()