@@ -352,7 +352,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
352352
353353 return ModelResponse (
354354 parts = items ,
355- usage = _map_usage (response , self ._provider .name ),
355+ usage = _map_usage (response , self ._provider .name , self . _model_name ),
356356 model_name = response .model ,
357357 provider_response_id = response .id ,
358358 provider_name = self ._provider .name ,
@@ -619,7 +619,8 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
619619
620620def _map_usage (
621621 message : BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent ,
622- provider : str = 'anthropic' ,
622+ provider : str ,
623+ model : str ,
623624 existing_usage : usage .RequestUsage | None = None ,
624625) -> usage .RequestUsage :
625626 if isinstance (message , BetaMessage ):
@@ -637,10 +638,7 @@ def _map_usage(
637638 key : value for key , value in response_usage .model_dump ().items () if isinstance (value , int )
638639 }
639640
640- # `extract_usage` expects a mapping with a `model` and `usage` key.
641- # Not all the actual types of messages here have a model so we just make a dummy message.
642- # We only care about the numbers of tokens etc.
643- extracted_usage = extract_usage (dict (model = 'claude-sonnet-4-5' , usage = details ), provider_id = provider )
641+ extracted_usage = extract_usage (dict (model = model , usage = details ), provider_id = provider )
644642
645643 return usage .RequestUsage (
646644 ** {key : value for key , value in extracted_usage .usage .__dict__ .items () if isinstance (value , int )},
@@ -662,7 +660,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
662660
663661 async for event in self ._response :
664662 if isinstance (event , BetaRawMessageStartEvent ):
665- self ._usage = _map_usage (event , self ._provider_name )
663+ self ._usage = _map_usage (event , self ._provider_name , self . _model_name )
666664 self .provider_response_id = event .message .id
667665
668666 elif isinstance (event , BetaRawContentBlockStartEvent ):
@@ -743,7 +741,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
743741 pass
744742
745743 elif isinstance (event , BetaRawMessageDeltaEvent ):
746- self ._usage = _map_usage (event , self ._provider_name , self ._usage )
744+ self ._usage = _map_usage (event , self ._provider_name , self ._model_name , self . _usage )
747745 if raw_finish_reason := event .delta .stop_reason : # pragma: no branch
748746 self .provider_details = {'finish_reason' : raw_finish_reason }
749747 self .finish_reason = _FINISH_REASON_MAP .get (raw_finish_reason )
0 commit comments