Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from typing import Any, Literal, cast, overload

from genai_prices import extract_usage
from pydantic import TypeAdapter
from typing_extensions import assert_never

Expand Down Expand Up @@ -351,7 +352,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:

return ModelResponse(
parts=items,
usage=_map_usage(response),
usage=_map_usage(response, self._provider.name, self._model_name),
model_name=response.model,
provider_response_id=response.id,
provider_name=self._provider.name,
Expand Down Expand Up @@ -616,7 +617,12 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
}


def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
def _map_usage(
message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent,
provider: str,
model: str,
existing_usage: usage.RequestUsage | None = None,
) -> usage.RequestUsage:
if isinstance(message, BetaMessage):
response_usage = message.usage
elif isinstance(message, BetaRawMessageStartEvent):
Expand All @@ -626,24 +632,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageD
else:
assert_never(message)

# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
# `response_tokens`
details: dict[str, int] = {
# In streaming, usage appears in different events.
# The values are cumulative, meaning new values should replace existing ones entirely.
details: dict[str, int] = (existing_usage.details if existing_usage else {}) | {
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
}

# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
cache_write_tokens = details.get('cache_creation_input_tokens', 0)
cache_read_tokens = details.get('cache_read_input_tokens', 0)
request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens
extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider)

return usage.RequestUsage(
input_tokens=request_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
output_tokens=response_usage.output_tokens,
**{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)},
details=details,
)

Expand All @@ -662,7 +660,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

async for event in self._response:
if isinstance(event, BetaRawMessageStartEvent):
self._usage = _map_usage(event)
self._usage = _map_usage(event, self._provider_name, self._model_name)
self.provider_response_id = event.message.id

elif isinstance(event, BetaRawContentBlockStartEvent):
Expand Down Expand Up @@ -743,7 +741,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
pass

elif isinstance(event, BetaRawMessageDeltaEvent):
self._usage = _map_usage(event)
self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage)
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
self.provider_details = {'finish_reason': raw_finish_reason}
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dependencies = [
"exceptiongroup; python_version < '3.11'",
"opentelemetry-api>=1.28.0",
"typing-inspection>=0.4.0",
"genai-prices>=0.0.28",
"genai-prices>=0.0.30",
]

[tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies]
Expand Down
32 changes: 21 additions & 11 deletions tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,16 @@ async def test_anthropic_model_thinking_part_stream(allow_model_requests: None,
),
TextPart(content=IsStr()),
],
usage=RequestUsage(output_tokens=419, details={'output_tokens': 419}),
usage=RequestUsage(
input_tokens=42,
output_tokens=419,
details={
'cache_creation_input_tokens': 0,
'cache_read_input_tokens': 0,
'input_tokens': 42,
'output_tokens': 419,
},
),
model_name='claude-3-7-sonnet-20250219',
timestamp=IsDatetime(),
provider_name='anthropic',
Expand Down Expand Up @@ -1716,21 +1725,22 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
snapshot(RequestUsage(input_tokens=1, output_tokens=1, details={'input_tokens': 1, 'output_tokens': 1})),
id='RawMessageStartEvent',
),
pytest.param(
lambda: BetaRawMessageDeltaEvent(
delta=Delta(),
usage=BetaMessageDeltaUsage(output_tokens=5),
type='message_delta',
),
snapshot(RequestUsage(output_tokens=5, details={'output_tokens': 5})),
id='RawMessageDeltaEvent',
),
],
)
def test_usage(
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
):
assert _map_usage(message_callback()) == usage
assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage


def test_streaming_usage():
start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start')
initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5')
delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta')
final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage)
assert final_usage == snapshot(
RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5})
)


async def test_anthropic_model_empty_message_on_history(allow_model_requests: None, anthropic_api_key: str):
Expand Down
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.