Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datetime import datetime
from typing import Any, Literal, cast, overload

from genai_prices import extract_usage
from pydantic import TypeAdapter
from typing_extensions import assert_never

Expand Down Expand Up @@ -352,7 +351,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:

return ModelResponse(
parts=items,
usage=_map_usage(response, self._provider.name, self._model_name),
usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name),
model_name=response.model,
provider_response_id=response.id,
provider_name=self._provider.name,
Expand All @@ -376,6 +375,7 @@ async def _process_streamed_response(
_response=peekable_response,
_timestamp=_utils.now_utc(),
_provider_name=self._provider.name,
_provider_url=self._provider.base_url,
)

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
Expand Down Expand Up @@ -620,6 +620,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
def _map_usage(
message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent,
provider: str,
provider_url: str,
model: str,
existing_usage: usage.RequestUsage | None = None,
) -> usage.RequestUsage:
Expand All @@ -638,10 +639,11 @@ def _map_usage(
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
}

extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider)

return usage.RequestUsage(
**{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)},
return usage.RequestUsage.extract(
dict(model=model, usage=details),
provider=provider,
provider_url=provider_url,
provider_fallback='anthropic',
details=details,
)

Expand All @@ -654,13 +656,14 @@ class AnthropicStreamedResponse(StreamedResponse):
_response: AsyncIterable[BetaRawMessageStreamEvent]
_timestamp: datetime
_provider_name: str
_provider_url: str

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
current_block: BetaContentBlock | None = None

async for event in self._response:
if isinstance(event, BetaRawMessageStartEvent):
self._usage = _map_usage(event, self._provider_name, self._model_name)
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
self.provider_response_id = event.message.id

elif isinstance(event, BetaRawContentBlockStartEvent):
Expand Down Expand Up @@ -741,7 +744,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
pass

elif isinstance(event, BetaRawMessageDeltaEvent):
self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage)
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage)
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
self.provider_details = {'finish_reason': raw_finish_reason}
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
Expand Down
36 changes: 35 additions & 1 deletion pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import dataclasses
from copy import copy
from dataclasses import dataclass, fields
from typing import Annotated
from typing import Annotated, Any

from genai_prices.data_snapshot import get_snapshot
from pydantic import AliasChoices, BeforeValidator, Field
from typing_extensions import deprecated, overload

Expand Down Expand Up @@ -120,6 +121,39 @@ def __add__(self, other: RequestUsage) -> RequestUsage:
new_usage.incr(other)
return new_usage

@classmethod
def extract(
cls,
data: Any,
*,
provider: str,
provider_url: str,
provider_fallback: str,
api_flavor: str | None = None,
details: dict[str, Any] | None = None,
) -> RequestUsage:
"""Extract usage information from the response data using genai-prices.

Args:
data: The response data from the model API.
provider: The actual provider ID
provider_url: The provider base_url
provider_fallback: The fallback provider ID to use if the actual provider is not found in genai-prices.
For example, an OpenAI model should set this to "openai" in case it has an obscure provider ID.
api_flavor: The API flavor to use when extracting usage information,
e.g. 'chat' or 'responses' for OpenAI.
details: Becomes the `details` field on the returned `RequestUsage` for convenience.
"""
details = details or {}
for provider_id, provider_api_url in [(None, provider_url), (provider, None), (provider_fallback, None)]:
try:
provider_obj = get_snapshot().find_provider(None, provider_id, provider_api_url)
_model_ref, extracted_usage = provider_obj.extract_usage(data, api_flavor=api_flavor)
return cls(**{k: v for k, v in extracted_usage.__dict__.items() if v is not None}, details=details)
except Exception:
pass
return cls(details=details)


@dataclass(repr=False, kw_only=True)
class RunUsage(UsageBase):
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,14 +1730,14 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
def test_usage(
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
):
assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage
assert _map_usage(message_callback(), 'anthropic', '', 'unknown') == usage


def test_streaming_usage():
start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start')
initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5')
initial_usage = _map_usage(start, 'anthropic', '', 'unknown')
delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta')
final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage)
final_usage = _map_usage(delta, 'anthropic', '', 'unknown', existing_usage=initial_usage)
assert final_usage == snapshot(
RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5})
)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_usage_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,7 @@ async def tool_c() -> str:

# Only the first batch of 5 tools should have executed
assert len(executed_tools) == 5


def test_usage_unknown_provider():
assert RequestUsage.extract({}, provider='unknown', provider_url='', provider_fallback='') == RequestUsage()