Skip to content

Commit 3eaa11e

Browse files
authored
Fix token usage for anthropic streaming (#3111)
1 parent f28c6fc commit 3eaa11e

File tree

4 files changed

+41
-33
lines changed

4 files changed

+41
-33
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datetime import datetime
88
from typing import Any, Literal, cast, overload
99

10+
from genai_prices import extract_usage
1011
from pydantic import TypeAdapter
1112
from typing_extensions import assert_never
1213

@@ -351,7 +352,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
351352

352353
return ModelResponse(
353354
parts=items,
354-
usage=_map_usage(response),
355+
usage=_map_usage(response, self._provider.name, self._model_name),
355356
model_name=response.model,
356357
provider_response_id=response.id,
357358
provider_name=self._provider.name,
@@ -616,7 +617,12 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
616617
}
617618

618619

619-
def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
620+
def _map_usage(
621+
message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent,
622+
provider: str,
623+
model: str,
624+
existing_usage: usage.RequestUsage | None = None,
625+
) -> usage.RequestUsage:
620626
if isinstance(message, BetaMessage):
621627
response_usage = message.usage
622628
elif isinstance(message, BetaRawMessageStartEvent):
@@ -626,24 +632,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageD
626632
else:
627633
assert_never(message)
628634

629-
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
630-
# `response_tokens`
631-
details: dict[str, int] = {
635+
# In streaming, usage appears in different events.
636+
# The values are cumulative, meaning new values should replace existing ones entirely.
637+
details: dict[str, int] = (existing_usage.details if existing_usage else {}) | {
632638
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
633639
}
634640

635-
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
636-
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
637-
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
638-
cache_write_tokens = details.get('cache_creation_input_tokens', 0)
639-
cache_read_tokens = details.get('cache_read_input_tokens', 0)
640-
request_tokens = details.get('input_tokens', 0) + cache_write_tokens + cache_read_tokens
641+
extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider)
641642

642643
return usage.RequestUsage(
643-
input_tokens=request_tokens,
644-
cache_read_tokens=cache_read_tokens,
645-
cache_write_tokens=cache_write_tokens,
646-
output_tokens=response_usage.output_tokens,
644+
**{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)},
647645
details=details,
648646
)
649647

@@ -662,7 +660,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
662660

663661
async for event in self._response:
664662
if isinstance(event, BetaRawMessageStartEvent):
665-
self._usage = _map_usage(event)
663+
self._usage = _map_usage(event, self._provider_name, self._model_name)
666664
self.provider_response_id = event.message.id
667665

668666
elif isinstance(event, BetaRawContentBlockStartEvent):
@@ -743,7 +741,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
743741
pass
744742

745743
elif isinstance(event, BetaRawMessageDeltaEvent):
746-
self._usage = _map_usage(event)
744+
self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage)
747745
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
748746
self.provider_details = {'finish_reason': raw_finish_reason}
749747
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ dependencies = [
6060
"exceptiongroup; python_version < '3.11'",
6161
"opentelemetry-api>=1.28.0",
6262
"typing-inspection>=0.4.0",
63-
"genai-prices>=0.0.28",
63+
"genai-prices>=0.0.30",
6464
]
6565

6666
[tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies]

tests/models/test_anthropic.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,16 @@ async def test_anthropic_model_thinking_part_stream(allow_model_requests: None,
14391439
),
14401440
TextPart(content=IsStr()),
14411441
],
1442-
usage=RequestUsage(output_tokens=419, details={'output_tokens': 419}),
1442+
usage=RequestUsage(
1443+
input_tokens=42,
1444+
output_tokens=419,
1445+
details={
1446+
'cache_creation_input_tokens': 0,
1447+
'cache_read_input_tokens': 0,
1448+
'input_tokens': 42,
1449+
'output_tokens': 419,
1450+
},
1451+
),
14431452
model_name='claude-3-7-sonnet-20250219',
14441453
timestamp=IsDatetime(),
14451454
provider_name='anthropic',
@@ -1716,21 +1725,22 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
17161725
snapshot(RequestUsage(input_tokens=1, output_tokens=1, details={'input_tokens': 1, 'output_tokens': 1})),
17171726
id='RawMessageStartEvent',
17181727
),
1719-
pytest.param(
1720-
lambda: BetaRawMessageDeltaEvent(
1721-
delta=Delta(),
1722-
usage=BetaMessageDeltaUsage(output_tokens=5),
1723-
type='message_delta',
1724-
),
1725-
snapshot(RequestUsage(output_tokens=5, details={'output_tokens': 5})),
1726-
id='RawMessageDeltaEvent',
1727-
),
17281728
],
17291729
)
17301730
def test_usage(
17311731
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
17321732
):
1733-
assert _map_usage(message_callback()) == usage
1733+
assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage
1734+
1735+
1736+
def test_streaming_usage():
1737+
start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start')
1738+
initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5')
1739+
delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta')
1740+
final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage)
1741+
assert final_usage == snapshot(
1742+
RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5})
1743+
)
17341744

17351745

17361746
async def test_anthropic_model_empty_message_on_history(allow_model_requests: None, anthropic_api_key: str):

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)