77from datetime import datetime
88from typing import Any , Literal , cast , overload
99
10- from genai_prices import extract_usage
1110from pydantic import TypeAdapter
1211from typing_extensions import assert_never
1312
@@ -352,7 +351,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
352351
353352 return ModelResponse (
354353 parts = items ,
355- usage = _map_usage (response , self ._provider .name , self ._model_name ),
354+ usage = _map_usage (response , self ._provider .name , self ._provider . base_url , self . _model_name ),
356355 model_name = response .model ,
357356 provider_response_id = response .id ,
358357 provider_name = self ._provider .name ,
@@ -376,6 +375,7 @@ async def _process_streamed_response(
376375 _response = peekable_response ,
377376 _timestamp = _utils .now_utc (),
378377 _provider_name = self ._provider .name ,
378+ _provider_url = self ._provider .base_url ,
379379 )
380380
381381 def _get_tools (self , model_request_parameters : ModelRequestParameters ) -> list [BetaToolUnionParam ]:
@@ -620,6 +620,7 @@ def _map_tool_definition(f: ToolDefinition) -> BetaToolParam:
620620def _map_usage (
621621 message : BetaMessage | BetaRawMessageStartEvent | BetaRawMessageDeltaEvent ,
622622 provider : str ,
623+ provider_url : str ,
623624 model : str ,
624625 existing_usage : usage .RequestUsage | None = None ,
625626) -> usage .RequestUsage :
@@ -638,10 +639,11 @@ def _map_usage(
638639 key : value for key , value in response_usage .model_dump ().items () if isinstance (value , int )
639640 }
640641
641- extracted_usage = extract_usage (dict (model = model , usage = details ), provider_id = provider )
642-
643- return usage .RequestUsage (
644- ** {key : value for key , value in extracted_usage .usage .__dict__ .items () if isinstance (value , int )},
642+ return usage .RequestUsage .extract (
643+ dict (model = model , usage = details ),
644+ provider = provider ,
645+ provider_url = provider_url ,
646+ provider_fallback = 'anthropic' ,
645647 details = details ,
646648 )
647649
@@ -654,13 +656,14 @@ class AnthropicStreamedResponse(StreamedResponse):
654656 _response : AsyncIterable [BetaRawMessageStreamEvent ]
655657 _timestamp : datetime
656658 _provider_name : str
659+ _provider_url : str
657660
658661 async def _get_event_iterator (self ) -> AsyncIterator [ModelResponseStreamEvent ]: # noqa: C901
659662 current_block : BetaContentBlock | None = None
660663
661664 async for event in self ._response :
662665 if isinstance (event , BetaRawMessageStartEvent ):
663- self ._usage = _map_usage (event , self ._provider_name , self ._model_name )
666+ self ._usage = _map_usage (event , self ._provider_name , self ._provider_url , self . _model_name )
664667 self .provider_response_id = event .message .id
665668
666669 elif isinstance (event , BetaRawContentBlockStartEvent ):
@@ -741,7 +744,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
741744 pass
742745
743746 elif isinstance (event , BetaRawMessageDeltaEvent ):
744- self ._usage = _map_usage (event , self ._provider_name , self ._model_name , self ._usage )
747+ self ._usage = _map_usage (event , self ._provider_name , self ._provider_url , self . _model_name , self ._usage )
745748 if raw_finish_reason := event .delta .stop_reason : # pragma: no branch
746749 self .provider_details = {'finish_reason' : raw_finish_reason }
747750 self .finish_reason = _FINISH_REASON_MAP .get (raw_finish_reason )
0 commit comments