Skip to content

Commit d2a39f6

Browse files
authored
More robust anthropic usage extraction (#3115)
1 parent 3eaa11e commit d2a39f6

File tree

4 files changed

+53
-12
lines changed

4 files changed

+53
-12
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from datetime import datetime
88
from typing import Any, Literal, cast, overload
99

10-
from genai_prices import extract_usage
1110
from pydantic import TypeAdapter
1211
from typing_extensions import assert_never
1312

@@ -352,7 +351,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
352351

353352
return ModelResponse(
354353
parts=items,
355-
usage=_map_usage(response, self._provider.name, self._model_name),
354+
usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name),
356355
model_name=response.model,
357356
provider_response_id=response.id,
358357
provider_name=self._provider.name,
@@ -376,6 +375,7 @@ async def _process_streamed_response(
376375
_response=peekable_response,
377376
_timestamp=_utils.now_utc(),
378377
_provider_name=self._provider.name,
378+
_provider_url=self._provider.base_url,
379379
)
380380

381381
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
@@ -620,6 +620,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
620620
def _map_usage(
621621
message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent,
622622
provider: str,
623+
provider_url: str,
623624
model: str,
624625
existing_usage: usage.RequestUsage | None = None,
625626
) -> usage.RequestUsage:
@@ -638,10 +639,11 @@ def _map_usage(
638639
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
639640
}
640641

641-
extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider)
642-
643-
return usage.RequestUsage(
644-
**{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)},
642+
return usage.RequestUsage.extract(
643+
dict(model=model, usage=details),
644+
provider=provider,
645+
provider_url=provider_url,
646+
provider_fallback='anthropic',
645647
details=details,
646648
)
647649

@@ -654,13 +656,14 @@ class AnthropicStreamedResponse(StreamedResponse):
654656
_response: AsyncIterable[BetaRawMessageStreamEvent]
655657
_timestamp: datetime
656658
_provider_name: str
659+
_provider_url: str
657660

658661
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
659662
current_block: BetaContentBlock | None = None
660663

661664
async for event in self._response:
662665
if isinstance(event, BetaRawMessageStartEvent):
663-
self._usage = _map_usage(event, self._provider_name, self._model_name)
666+
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
664667
self.provider_response_id = event.message.id
665668

666669
elif isinstance(event, BetaRawContentBlockStartEvent):
@@ -741,7 +744,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
741744
pass
742745

743746
elif isinstance(event, BetaRawMessageDeltaEvent):
744-
self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage)
747+
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage)
745748
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
746749
self.provider_details = {'finish_reason': raw_finish_reason}
747750
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import dataclasses
44
from copy import copy
55
from dataclasses import dataclass, fields
6-
from typing import Annotated
6+
from typing import Annotated, Any
77

8+
from genai_prices.data_snapshot import get_snapshot
89
from pydantic import AliasChoices, BeforeValidator, Field
910
from typing_extensions import deprecated, overload
1011

@@ -120,6 +121,39 @@ def __add__(self, other: RequestUsage) -> RequestUsage:
120121
new_usage.incr(other)
121122
return new_usage
122123

124+
@classmethod
125+
def extract(
126+
cls,
127+
data: Any,
128+
*,
129+
provider: str,
130+
provider_url: str,
131+
provider_fallback: str,
132+
api_flavor: str | None = None,
133+
details: dict[str, Any] | None = None,
134+
) -> RequestUsage:
135+
"""Extract usage information from the response data using genai-prices.
136+
137+
Args:
138+
data: The response data from the model API.
139+
provider: The actual provider ID
140+
provider_url: The provider base_url
141+
provider_fallback: The fallback provider ID to use if the actual provider is not found in genai-prices.
142+
For example, an OpenAI model should set this to "openai" in case it has an obscure provider ID.
143+
api_flavor: The API flavor to use when extracting usage information,
144+
e.g. 'chat' or 'responses' for OpenAI.
145+
details: Becomes the `details` field on the returned `RequestUsage` for convenience.
146+
"""
147+
details = details or {}
148+
for provider_id, provider_api_url in [(None, provider_url), (provider, None), (provider_fallback, None)]:
149+
try:
150+
provider_obj = get_snapshot().find_provider(None, provider_id, provider_api_url)
151+
_model_ref, extracted_usage = provider_obj.extract_usage(data, api_flavor=api_flavor)
152+
return cls(**{k: v for k, v in extracted_usage.__dict__.items() if v is not None}, details=details)
153+
except Exception:
154+
pass
155+
return cls(details=details)
156+
123157

124158
@dataclass(repr=False, kw_only=True)
125159
class RunUsage(UsageBase):

tests/models/test_anthropic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,14 +1730,14 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
17301730
def test_usage(
17311731
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
17321732
):
1733-
assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage
1733+
assert _map_usage(message_callback(), 'anthropic', '', 'unknown') == usage
17341734

17351735

17361736
def test_streaming_usage():
17371737
start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start')
1738-
initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5')
1738+
initial_usage = _map_usage(start, 'anthropic', '', 'unknown')
17391739
delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta')
1740-
final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage)
1740+
final_usage = _map_usage(delta, 'anthropic', '', 'unknown', existing_usage=initial_usage)
17411741
assert final_usage == snapshot(
17421742
RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5})
17431743
)

tests/test_usage_limits.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,7 @@ async def tool_c() -> str:
417417

418418
# Only the first batch of 5 tools should have executed
419419
assert len(executed_tools) == 5
420+
421+
422+
def test_usage_unknown_provider():
423+
assert RequestUsage.extract({}, provider='unknown', provider_url='', provider_fallback='') == RequestUsage()

0 commit comments

Comments
 (0)