Skip to content

Commit 7361dbb

Browse files
committed
Fix Anthropic streaming usage counting
1 parent 01b6e51 commit 7361dbb

File tree

2 files changed

+7
-17
lines changed

2 files changed

+7
-17
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -536,20 +536,13 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
536536
}
537537

538538

539-
def _map_usage(message: BetaMessage | BetaRawMessageStreamEvent) -> usage.RequestUsage:
539+
def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent) -> usage.RequestUsage:
540540
if isinstance(message, BetaMessage):
541541
response_usage = message.usage
542542
elif isinstance(message, BetaRawMessageStartEvent):
543543
response_usage = message.message.usage
544544
elif isinstance(message, BetaRawMessageDeltaEvent):
545545
response_usage = message.usage
546-
else:
547-
# No usage information provided in:
548-
# - RawMessageStopEvent
549-
# - RawContentBlockStartEvent
550-
# - RawContentBlockDeltaEvent
551-
# - RawContentBlockStopEvent
552-
return usage.RequestUsage()
553546

554547
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
555548
# `response_tokens`
@@ -586,10 +579,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
586579
current_block: BetaContentBlock | None = None
587580

588581
async for event in self._response:
589-
self._usage += _map_usage(event)
590-
591582
if isinstance(event, BetaRawMessageStartEvent):
592-
pass
583+
self._usage = _map_usage(event)
593584

594585
elif isinstance(event, BetaRawContentBlockStartEvent):
595586
current_block = event.content_block
@@ -652,7 +643,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
652643
pass
653644

654645
elif isinstance(event, BetaRawMessageDeltaEvent):
655-
pass
646+
self._usage = _map_usage(event)
656647

657648
elif isinstance(event, BetaRawContentBlockStopEvent | BetaRawMessageStopEvent): # pragma: no branch
658649
current_block = None

tests/models/test_anthropic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ async def test_stream_structured(allow_model_requests: None):
608608
BetaRawMessageDeltaEvent(
609609
type='message_delta',
610610
delta=Delta(stop_reason='end_turn'),
611-
usage=BetaMessageDeltaUsage(output_tokens=5),
611+
usage=BetaMessageDeltaUsage(input_tokens=20, output_tokens=5),
612612
),
613613
# Mark message as complete
614614
BetaRawMessageStopEvent(type='message_stop'),
@@ -1291,12 +1291,11 @@ def anth_msg(usage: BetaUsage) -> BetaMessage:
12911291
snapshot(RequestUsage(output_tokens=5, details={'output_tokens': 5})),
12921292
id='RawMessageDeltaEvent',
12931293
),
1294-
pytest.param(
1295-
lambda: BetaRawMessageStopEvent(type='message_stop'), snapshot(RequestUsage()), id='RawMessageStopEvent'
1296-
),
12971294
],
12981295
)
1299-
def test_usage(message_callback: Callable[[], BetaMessage | BetaRawMessageStreamEvent], usage: RunUsage):
1296+
def test_usage(
1297+
message_callback: Callable[[], BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent], usage: RunUsage
1298+
):
13001299
assert _map_usage(message_callback()) == usage
13011300

13021301

0 commit comments

Comments
 (0)