77from datetime import datetime
88from typing import Any , Literal , TypeAlias , cast , overload
99
10- from genai_prices import extract_usage
1110from pydantic import TypeAdapter
1211from typing_extensions import assert_never
1312
@@ -364,7 +363,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
364363
365364 return ModelResponse (
366365 parts = items ,
367- usage = _map_usage (response , self ._provider .name , self ._model_name ),
366+ usage = _map_usage (response , self ._provider .name , self ._provider . base_url , self . _model_name ),
368367 model_name = response .model ,
369368 provider_response_id = response .id ,
370369 provider_name = self ._provider .name ,
@@ -388,6 +387,7 @@ async def _process_streamed_response(
388387 _response = peekable_response ,
389388 _timestamp = _utils .now_utc (),
390389 _provider_name = self ._provider .name ,
390+ _provider_url = self ._provider .base_url ,
391391 )
392392
393393 def _get_tools (self , model_request_parameters : ModelRequestParameters ) -> list [BetaToolUnionParam ]:
@@ -657,6 +657,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
657657def _map_usage (
658658 message : BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent ,
659659 provider : str ,
660+ provider_url : str ,
660661 model : str ,
661662 existing_usage : usage .RequestUsage | None = None ,
662663) -> usage .RequestUsage :
@@ -675,10 +676,11 @@ def _map_usage(
675676 key : value for key , value in response_usage .model_dump ().items () if isinstance (value , int )
676677 }
677678
678- extracted_usage = extract_usage (dict (model = model , usage = details ), provider_id = provider )
679-
680- return usage .RequestUsage (
681- ** {key : value for key , value in extracted_usage .usage .__dict__ .items () if isinstance (value , int )},
679+ return usage .RequestUsage .extract (
680+ dict (model = model , usage = details ),
681+ provider = provider ,
682+ provider_url = provider_url ,
683+ provider_fallback = 'anthropic' ,
682684 details = details ,
683685 )
684686
@@ -691,13 +693,14 @@ class AnthropicStreamedResponse(StreamedResponse):
691693 _response : AsyncIterable [BetaRawMessageStreamEvent ]
692694 _timestamp : datetime
693695 _provider_name : str
696+ _provider_url : str
694697
695698 async def _get_event_iterator (self ) -> AsyncIterator [ModelResponseStreamEvent ]: # noqa: C901
696699 current_block : BetaContentBlock | None = None
697700
698701 async for event in self ._response :
699702 if isinstance (event , BetaRawMessageStartEvent ):
700- self ._usage = _map_usage (event , self ._provider_name , self ._model_name )
703+ self ._usage = _map_usage (event , self ._provider_name , self ._provider_url , self . _model_name )
701704 self .provider_response_id = event .message .id
702705
703706 elif isinstance (event , BetaRawContentBlockStartEvent ):
@@ -788,7 +791,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
788791 pass
789792
790793 elif isinstance (event , BetaRawMessageDeltaEvent ):
791- self ._usage = _map_usage (event , self ._provider_name , self ._model_name , self ._usage )
794+ self ._usage = _map_usage (event , self ._provider_name , self ._provider_url , self . _model_name , self ._usage )
792795 if raw_finish_reason := event .delta .stop_reason : # pragma: no branch
793796 self .provider_details = {'finish_reason' : raw_finish_reason }
794797 self .finish_reason = _FINISH_REASON_MAP .get (raw_finish_reason )
0 commit comments