77from datetime import datetime
88from typing import Any , Literal , cast , overload
99
10+ from genai_prices import extract_usage
1011from pydantic import TypeAdapter
1112from typing_extensions import assert_never
1213
@@ -351,7 +352,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
351352
352353 return ModelResponse (
353354 parts = items ,
354- usage = _map_usage (response ),
355+ usage = _map_usage (response , self . _provider . name , self . _model_name ),
355356 model_name = response .model ,
356357 provider_response_id = response .id ,
357358 provider_name = self ._provider .name ,
@@ -616,7 +617,12 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
616617 }
617618
618619
619- def _map_usage (message : BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent ) -> usage .RequestUsage :
620+ def _map_usage (
621+ message : BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent ,
622+ provider : str ,
623+ model : str ,
624+ existing_usage : usage .RequestUsage | None = None ,
625+ ) -> usage .RequestUsage :
620626 if isinstance (message , BetaMessage ):
621627 response_usage = message .usage
622628 elif isinstance (message , BetaRawMessageStartEvent ):
@@ -626,24 +632,16 @@ def _map_usage(message: BetaMessage | BetaRawMessageStartEvent | BetaRawMessageD
626632 else :
627633 assert_never (message )
628634
629- # Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
630- # `response_tokens`
631- details : dict [str , int ] = {
635+ # In streaming, usage appears in different events.
636+ # The values are cumulative, meaning new values should replace existing ones entirely.
637+ details : dict [str , int ] = ( existing_usage . details if existing_usage else {}) | {
632638 key : value for key , value in response_usage .model_dump ().items () if isinstance (value , int )
633639 }
634640
635- # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
636- # Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
637- # This approach maintains request_tokens as the count of all input tokens, with cached counts as details
638- cache_write_tokens = details .get ('cache_creation_input_tokens' , 0 )
639- cache_read_tokens = details .get ('cache_read_input_tokens' , 0 )
640- request_tokens = details .get ('input_tokens' , 0 ) + cache_write_tokens + cache_read_tokens
641+ extracted_usage = extract_usage (dict (model = model , usage = details ), provider_id = provider )
641642
642643 return usage .RequestUsage (
643- input_tokens = request_tokens ,
644- cache_read_tokens = cache_read_tokens ,
645- cache_write_tokens = cache_write_tokens ,
646- output_tokens = response_usage .output_tokens ,
644+ ** {key : value for key , value in extracted_usage .usage .__dict__ .items () if isinstance (value , int )},
647645 details = details ,
648646 )
649647
@@ -662,7 +660,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
662660
663661 async for event in self ._response :
664662 if isinstance (event , BetaRawMessageStartEvent ):
665- self ._usage = _map_usage (event )
663+ self ._usage = _map_usage (event , self . _provider_name , self . _model_name )
666664 self .provider_response_id = event .message .id
667665
668666 elif isinstance (event , BetaRawContentBlockStartEvent ):
@@ -743,7 +741,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
743741 pass
744742
745743 elif isinstance (event , BetaRawMessageDeltaEvent ):
746- self ._usage = _map_usage (event )
744+ self ._usage = _map_usage (event , self . _provider_name , self . _model_name , self . _usage )
747745 if raw_finish_reason := event .delta .stop_reason : # pragma: no branch
748746 self .provider_details = {'finish_reason' : raw_finish_reason }
749747 self .finish_reason = _FINISH_REASON_MAP .get (raw_finish_reason )
0 commit comments