diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 2e3b9e04e4..17ceedcc20 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -536,7 +536,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam: } -def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage: +def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage: if isinstance(message, BetaMessage): response_usage = message.usage elif isinstance(message, BetaRawMessageStartEvent): @@ -544,12 +544,7 @@ def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.Reques elif isinstance(message, BetaRawMessageDeltaEvent): response_usage = message.usage else: - # No usage information provided in: - # - RawMessageStopEvent - # - RawContentBlockStartEvent - # - RawContentBlockDeltaEvent - # - RawContentBlockStopEvent - return usage.RequestUsage() + assert_never(message) # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by # `response_tokens` @@ -586,10 +581,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: current_block: BetaContentBlock | None = None async for event in self._response: - self._usage += _map_usage(event) - if isinstance(event, BetaRawMessageStartEvent): - pass + self._usage = _map_usage(event) elif isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block @@ -652,7 +645,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass elif isinstance(event, BetaRawMessageDeltaEvent): - pass + self._usage = _map_usage(event) elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch current_block = None diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 330810746b..add88de736 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -608,7 +608,7 @@ async def test_stream_structured(allow_model_requests: None): BetaRawMessageDeltaEvent( type='message_delta', delta=Delta(stop_reason='end_turn'), - usage=BetaMessageDeltaUsage(output_tokens=5), + usage=BetaMessageDeltaUsage(input_tokens=20, output_tokens=5), ), # Mark message as complete BetaRawMessageStopEvent(type='message_stop'), @@ -1291,12 +1291,11 @@ def anth_msg(usage: BetaUsage) -> BetaMessage: snapshot(RequestUsage(output_tokens=5, details={'output_tokens': 5})), id='RawMessageDeltaEvent', ), - pytest.param( - lambda: BetaRawMessageStopEvent(type='message_stop'), snapshot(RequestUsage()), id='RawMessageStopEvent' - ), ], ) -def test_usage(message_callback: Callable[[], BetaMessage | BetaRawMessageStreamEvent], usage: RunUsage): +def test_usage( + message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage +): assert _map_usage(message_callback()) == usage