Skip to content

Commit 1f37cef

Browse files
committed
multiple provider lookups
1 parent 3a5f215 commit 1f37cef

File tree

4 files changed

+45
-15
lines changed

4 files changed

+45
-15
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
351351

352352
return ModelResponse(
353353
parts=items,
354-
usage=_map_usage(response, self._provider.name, self._model_name),
354+
usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name),
355355
model_name=response.model,
356356
provider_response_id=response.id,
357357
provider_name=self._provider.name,
@@ -375,6 +375,7 @@ async def _process_streamed_response(
375375
_response=peekable_response,
376376
_timestamp=_utils.now_utc(),
377377
_provider_name=self._provider.name,
378+
_provider_url=self._provider.base_url,
378379
)
379380

380381
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolUnionParam]:
@@ -619,6 +620,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
619620
def _map_usage(
620621
message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent,
621622
provider: str,
623+
provider_url: str,
622624
model: str,
623625
existing_usage: usage.RequestUsage | None = None,
624626
) -> usage.RequestUsage:
@@ -637,7 +639,13 @@ def _map_usage(
637639
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
638640
}
639641

640-
return usage.RequestUsage.extract(dict(model=model, usage=details), provider_id=provider, details=details)
642+
return usage.RequestUsage.extract(
643+
dict(model=model, usage=details),
644+
provider=provider,
645+
provider_url=provider_url,
646+
provider_fallback='anthropic',
647+
details=details,
648+
)
641649

642650

643651
@dataclass
@@ -648,13 +656,14 @@ class AnthropicStreamedResponse(StreamedResponse):
648656
_response: AsyncIterable[BetaRawMessageStreamEvent]
649657
_timestamp: datetime
650658
_provider_name: str
659+
_provider_url: str
651660

652661
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
653662
current_block: BetaContentBlock | None = None
654663

655664
async for event in self._response:
656665
if isinstance(event, BetaRawMessageStartEvent):
657-
self._usage = _map_usage(event, self._provider_name, self._model_name)
666+
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name)
658667
self.provider_response_id = event.message.id
659668

660669
elif isinstance(event, BetaRawContentBlockStartEvent):
@@ -735,7 +744,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
735744
pass
736745

737746
elif isinstance(event, BetaRawMessageDeltaEvent):
738-
self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage)
747+
self._usage = _map_usage(event, self._provider_name, self._provider_url, self._model_name, self._usage)
739748
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
740749
self.provider_details = {'finish_reason': raw_finish_reason}
741750
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,36 @@ def __add__(self, other: RequestUsage) -> RequestUsage:
123123

124124
@classmethod
125125
def extract(
126-
cls, data: Any, *, provider_id: str, api_flavor: str | None = None, details: dict[str, Any] | None = None
126+
cls,
127+
data: Any,
128+
*,
129+
provider: str,
130+
provider_url: str,
131+
provider_fallback: str,
132+
api_flavor: str | None = None,
133+
details: dict[str, Any] | None = None,
127134
) -> RequestUsage:
135+
"""Extract usage information from the response data using genai-prices.
136+
137+
Args:
138+
data: The response data from the model API.
139+
provider: The actual provider ID
140+
provider_url: The provider base_url
141+
provider_fallback: The fallback provider ID to use if the actual provider is not found in genai-prices.
142+
For example, an OpenAI model should set this to "openai" in case it has an obscure provider ID.
143+
api_flavor: The API flavor to use when extracting usage information,
144+
e.g. 'chat' or 'responses' for OpenAI.
145+
details: Becomes the `details` field on the returned `RequestUsage` for convenience.
146+
"""
128147
details = details or {}
129-
try:
130-
provider = get_snapshot().find_provider(None, provider_id, None)
131-
_model_ref, extracted_usage = provider.extract_usage(data, api_flavor=api_flavor)
132-
return cls(**{k: v for k, v in extracted_usage.__dict__.items() if v is not None}, details=details)
133-
except Exception:
134-
return cls(details=details)
148+
for provider_id, provider_api_url in [(None, provider_url), (provider, None), (provider_fallback, None)]:
149+
try:
150+
provider_obj = get_snapshot().find_provider(None, provider_id, provider_api_url)
151+
_model_ref, extracted_usage = provider_obj.extract_usage(data, api_flavor=api_flavor)
152+
return cls(**{k: v for k, v in extracted_usage.__dict__.items() if v is not None}, details=details)
153+
except Exception:
154+
pass
155+
return cls(details=details)
135156

136157

137158
@dataclass(repr=False, kw_only=True)

tests/models/test_anthropic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,14 +1730,14 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
17301730
def test_usage(
17311731
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
17321732
):
1733-
assert _map_usage(message_callback(), 'anthropic', 'unknown') == usage
1733+
assert _map_usage(message_callback(), 'anthropic', '', 'unknown') == usage
17341734

17351735

17361736
def test_streaming_usage():
17371737
start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start')
1738-
initial_usage = _map_usage(start, 'anthropic', 'unknown')
1738+
initial_usage = _map_usage(start, 'anthropic', '', 'unknown')
17391739
delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta')
1740-
final_usage = _map_usage(delta, 'anthropic', 'unknown', existing_usage=initial_usage)
1740+
final_usage = _map_usage(delta, 'anthropic', '', 'unknown', existing_usage=initial_usage)
17411741
assert final_usage == snapshot(
17421742
RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5})
17431743
)

tests/test_usage_limits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,4 +420,4 @@ async def tool_c() -> str:
420420

421421

422422
def test_usage_unknown_provider():
423-
assert RequestUsage.extract({}, provider_id='unknown') == RequestUsage()
423+
assert RequestUsage.extract({}, provider='unknown', provider_url='', provider_fallback='') == RequestUsage()

0 commit comments

Comments
 (0)