Skip to content

Commit ae72f5b

Browse files
authored
Merge branch 'main' into feature/native-mcp-support
2 parents b08309a + d2a39f6 commit ae72f5b

File tree

4 files changed

+53
-12
lines changed

4 files changed

+53
-12
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from datetime import datetime
88
from typing import Any, Literal, TypeAlias, cast, overload
99

10-
from genai_prices import extract_usage
1110
from pydantic import TypeAdapter
1211
from typing_extensions import assert_never
1312

@@ -364,7 +363,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
364363

365364
return ModelResponse(
366365
parts=items,
367-
usage=_map_usage(response, self._provider.name, self._model_name),
366+
usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name),
368367
model_name=response.model,
369368
provider_response_id=response.id,
370369
provider_name=self._provider.name,
@@ -388,6 +387,7 @@ async def _process_streamed_response(
388387
_response=peekable_response,
389388
_timestamp=_utils.now_utc(),
390389
_provider_name=self._provider.name,
390+
_provider_url=self._provider.base_url,
391391
)
392392

393393
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
@@ -657,6 +657,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
657657
def _map_usage(
658658
message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent,
659659
provider: str,
660+
provider_url: str,
660661
model: str,
661662
existing_usage: usage.RequestUsage | None = None,
662663
) -> usage.RequestUsage:
@@ -675,10 +676,11 @@ def _map_usage(
675676
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
676677
}
677678

678-
extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider)
679-
680-
return usage.RequestUsage(
681-
**{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)},
679+
return usage.RequestUsage.extract(
680+
dict(model=model, usage=details),
681+
provider=provider,
682+
provider_url=provider_url,
683+
provider_fallback='anthropic',
682684
details=details,
683685
)
684686

@@ -691,13 +693,14 @@ class AnthropicStreamedResponse(StreamedResponse):
691693
_response: AsyncIterable[BetaRawMessageStreamEvent]
692694
_timestamp: datetime
693695
_provider_name: str
696+
_provider_url: str
694697

695698
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
696699
current_block: BetaContentBlock | None = None
697700

698701
async for event in self._response:
699702
if isinstance(event, BetaRawMessageStartEvent):
700-
self._usage = _map_usage(event, self._provider_name, self._model_name)
703+
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
701704
self.provider_response_id = event.message.id
702705

703706
elif isinstance(event, BetaRawContentBlockStartEvent):
@@ -788,7 +791,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
788791
pass
789792

790793
elif isinstance(event, BetaRawMessageDeltaEvent):
791-
self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage)
794+
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage)
792795
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
793796
self.provider_details = {'finish_reason': raw_finish_reason}
794797
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import dataclasses
44
from copy import copy
55
from dataclasses import dataclass, fields
6-
from typing import Annotated
6+
from typing import Annotated, Any
77

8+
from genai_prices.data_snapshot import get_snapshot
89
from pydantic import AliasChoices, BeforeValidator, Field
910
from typing_extensions import deprecated, overload
1011

@@ -120,6 +121,39 @@ def __add__(self, other: RequestUsage) -> RequestUsage:
120121
new_usage.incr(other)
121122
return new_usage
122123

124+
@classmethod
125+
def extract(
126+
cls,
127+
data: Any,
128+
*,
129+
provider: str,
130+
provider_url: str,
131+
provider_fallback: str,
132+
api_flavor: str | None = None,
133+
details: dict[str, Any] | None = None,
134+
) -> RequestUsage:
135+
"""Extract usage information from the response data using genai-prices.
136+
137+
Args:
138+
data: The response data from the model API.
139+
provider: The actual provider ID
140+
provider_url: The provider base_url
141+
provider_fallback: The fallback provider ID to use if the actual provider is not found in genai-prices.
142+
For example, an OpenAI model should set this to "openai" in case it has an obscure provider ID.
143+
api_flavor: The API flavor to use when extracting usage information,
144+
e.g. 'chat' or 'responses' for OpenAI.
145+
details: Becomes the `details` field on the returned `RequestUsage` for convenience.
146+
"""
147+
details = details or {}
148+
for provider_id, provider_api_url in [(None, provider_url), (provider, None), (provider_fallback, None)]:
149+
try:
150+
provider_obj = get_snapshot().find_provider(None, provider_id, provider_api_url)
151+
_model_ref, extracted_usage = provider_obj.extract_usage(data, api_flavor=api_flavor)
152+
return cls(**{k: v for k, v in extracted_usage.__dict__.items() if v is not None}, details=details)
153+
except Exception:
154+
pass
155+
return cls(details=details)
156+
123157

124158
@dataclass(repr=False, kw_only=True)
125159
class RunUsage(UsageBase):

tests/models/test_anthropic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,14 +1730,14 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
17301730
def test_usage(
17311731
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
17321732
):
1733-
assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage
1733+
assert _map_usage(message_callback(), 'anthropic', '', 'unknown') == usage
17341734

17351735

17361736
def test_streaming_usage():
17371737
start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start')
1738-
initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5')
1738+
initial_usage = _map_usage(start, 'anthropic', '', 'unknown')
17391739
delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta')
1740-
final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage)
1740+
final_usage = _map_usage(delta, 'anthropic', '', 'unknown', existing_usage=initial_usage)
17411741
assert final_usage == snapshot(
17421742
RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5})
17431743
)

tests/test_usage_limits.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,7 @@ async def tool_c() -> str:
417417

418418
# Only the first batch of 5 tools should have executed
419419
assert len(executed_tools) == 5
420+
421+
422+
def test_usage_unknown_provider():
423+
assert RequestUsage.extract({}, provider='unknown', provider_url='', provider_fallback='') == RequestUsage()

0 commit comments

Comments
 (0)