Skip to content

Commit e797f11

Browse files
committed
pass actual model
1 parent 0b62da4 commit e797f11

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
352352

353353
return ModelResponse(
354354
parts=items,
355-
usage=_map_usage(response, self._provider.name),
355+
usage=_map_usage(response, self._provider.name, self._model_name),
356356
model_name=response.model,
357357
provider_response_id=response.id,
358358
provider_name=self._provider.name,
@@ -619,7 +619,8 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
619619

620620
def _map_usage(
621621
message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent,
622-
provider: str = 'anthropic',
622+
provider: str,
623+
model: str,
623624
existing_usage: usage.RequestUsage | None = None,
624625
) -> usage.RequestUsage:
625626
if isinstance(message, BetaMessage):
@@ -637,10 +638,7 @@ def _map_usage(
637638
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
638639
}
639640

640-
# `extract_usage` expects a mapping with a `model` and `usage` key.
641-
# Not all the actual types of messages here have a model so we just make a dummy message.
642-
# We only care about the numbers of tokens etc.
643-
extracted_usage = extract_usage(dict(model='claude-sonnet-4-5', usage=details), provider_id=provider)
641+
extracted_usage = extract_usage(dict(model=model, usage=details), provider_id=provider)
644642

645643
return usage.RequestUsage(
646644
**{key: value for key, value in extracted_usage.usage.__dict__.items() if isinstance(value, int)},
@@ -662,7 +660,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
662660

663661
async for event in self._response:
664662
if isinstance(event, BetaRawMessageStartEvent):
665-
self._usage = _map_usage(event, self._provider_name)
663+
self._usage = _map_usage(event, self._provider_name, self._model_name)
666664
self.provider_response_id = event.message.id
667665

668666
elif isinstance(event, BetaRawContentBlockStartEvent):
@@ -743,7 +741,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
743741
pass
744742

745743
elif isinstance(event, BetaRawMessageDeltaEvent):
746-
self._usage = _map_usage(event, self._provider_name, self._usage)
744+
self._usage = _map_usage(event, self._provider_name, self._model_name, self._usage)
747745
if raw_finish_reason := event.delta.stop_reason: # pragma: no branch
748746
self.provider_details = {'finish_reason': raw_finish_reason}
749747
self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason)

tests/models/test_anthropic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,14 +1730,14 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
17301730
def test_usage(
17311731
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
17321732
):
1733-
assert _map_usage(message_callback()) == usage
1733+
assert _map_usage(message_callback(), 'anthropic', 'claude-sonnet-4-5') == usage
17341734

17351735

17361736
def test_streaming_usage():
17371737
start = BetaRawMessageStartEvent(message=anth_msg(BetaUsage(input_tokens=1, output_tokens=1)), type='message_start')
1738-
initial_usage = _map_usage(start)
1738+
initial_usage = _map_usage(start, 'anthropic', 'claude-sonnet-4-5')
17391739
delta = BetaRawMessageDeltaEvent(delta=Delta(), usage=BetaMessageDeltaUsage(output_tokens=5), type='message_delta')
1740-
final_usage = _map_usage(delta, existing_usage=initial_usage)
1740+
final_usage = _map_usage(delta, 'anthropic', 'claude-sonnet-4-5', existing_usage=initial_usage)
17411741
assert final_usage == snapshot(
17421742
RequestUsage(input_tokens=1, output_tokens=5, details={'input_tokens': 1, 'output_tokens': 5})
17431743
)

0 commit comments

Comments
 (0)