Skip to content

Commit 29867c7

Browse files
Kludexalexmojaki
andauthored
Add price() method to ModelResponse (#2584)
Co-authored-by: Alex Hall <[email protected]>
1 parent 984ca11 commit 29867c7

31 files changed

+361
-34
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def usage(self) -> RequestUsage:
5555
def model_name(self) -> str:
5656
return self.response.model_name or '' # pragma: no cover
5757

58+
@property
59+
def provider_name(self) -> str:
60+
return self.response.provider_name or '' # pragma: no cover
61+
5862
@property
5963
def timestamp(self) -> datetime:
6064
return self.response.timestamp # pragma: no cover

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pydantic
1212
import pydantic_core
13+
from genai_prices import calc_price, types as genai_types
1314
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
1415
from typing_extensions import TypeAlias, deprecated
1516

@@ -848,6 +849,9 @@ class ModelResponse:
848849
kind: Literal['response'] = 'response'
849850
"""Message type identifier, this is available on all parts as a discriminator."""
850851

852+
provider_name: str | None = None
853+
"""The name of the LLM provider that generated the response."""
854+
851855
provider_details: dict[str, Any] | None = field(default=None)
852856
"""Additional provider-specific details in a serializable format.
853857
@@ -858,6 +862,19 @@ class ModelResponse:
858862
provider_request_id: str | None = None
859863
"""request ID as specified by the model provider. This can be used to track the specific request to the model."""
860864

865+
def price(self) -> genai_types.PriceCalculation:
866+
"""Calculate the price of the usage.
867+
868+
Uses [`genai-prices`](https://github.com/pydantic/genai-prices).
869+
"""
870+
assert self.model_name, 'Model name is required to calculate price'
871+
return calc_price(
872+
self.usage,
873+
self.model_name,
874+
provider_id=self.provider_name,
875+
genai_request_timestamp=self.timestamp,
876+
)
877+
861878
def otel_events(self, settings: InstrumentationSettings) -> list[Event]:
862879
"""Return OpenTelemetry events for the response."""
863880
result: list[Event] = []

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ def get(self) -> ModelResponse:
598598
model_name=self.model_name,
599599
timestamp=self.timestamp,
600600
usage=self.usage(),
601+
provider_name=self.provider_name,
601602
)
602603

603604
def usage(self) -> RequestUsage:
@@ -610,6 +611,12 @@ def model_name(self) -> str:
610611
"""Get the model name of the response."""
611612
raise NotImplementedError()
612613

614+
@property
615+
@abstractmethod
616+
def provider_name(self) -> str | None:
617+
"""Get the provider name."""
618+
raise NotImplementedError()
619+
613620
@property
614621
@abstractmethod
615622
def timestamp(self) -> datetime:

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,11 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
326326
)
327327

328328
return ModelResponse(
329-
items, usage=_map_usage(response), model_name=response.model, provider_request_id=response.id
329+
items,
330+
usage=_map_usage(response),
331+
model_name=response.model,
332+
provider_request_id=response.id,
333+
provider_name=self._provider.name,
330334
)
331335

332336
async def _process_streamed_response(
@@ -344,6 +348,7 @@ async def _process_streamed_response(
344348
_model_name=self._model_name,
345349
_response=peekable_response,
346350
_timestamp=timestamp,
351+
_provider_name=self._provider.name,
347352
)
348353

349354
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[BetaToolParam]:
@@ -574,6 +579,7 @@ class AnthropicStreamedResponse(StreamedResponse):
574579
_model_name: AnthropicModelName
575580
_response: AsyncIterable[BetaRawMessageStreamEvent]
576581
_timestamp: datetime
582+
_provider_name: str
577583

578584
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
579585
current_block: BetaContentBlock | None = None
@@ -655,6 +661,11 @@ def model_name(self) -> AnthropicModelName:
655661
"""Get the model name of the response."""
656662
return self._model_name
657663

664+
@property
665+
def provider_name(self) -> str:
666+
"""Get the provider name."""
667+
return self._provider_name
668+
658669
@property
659670
def timestamp(self) -> datetime:
660671
"""Get the timestamp of the response."""

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
240240

241241
@staticmethod
242242
def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
243-
tool_spec: ToolSpecificationTypeDef = {
244-
'name': f.name,
245-
'inputSchema': {'json': f.parameters_json_schema},
246-
}
243+
tool_spec: ToolSpecificationTypeDef = {'name': f.name, 'inputSchema': {'json': f.parameters_json_schema}}
247244

248245
if f.description: # pragma: no branch
249246
tool_spec['description'] = f.description
@@ -275,6 +272,7 @@ async def request_stream(
275272
model_request_parameters=model_request_parameters,
276273
_model_name=self.model_name,
277274
_event_stream=response,
275+
_provider_name=self._provider.name,
278276
)
279277

280278
async def _process_response(self, response: ConverseResponseTypeDef) -> ModelResponse:
@@ -304,7 +302,9 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
304302
output_tokens=response['usage']['outputTokens'],
305303
)
306304
vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
307-
return ModelResponse(items, usage=u, model_name=self.model_name, provider_request_id=vendor_id)
305+
return ModelResponse(
306+
items, usage=u, model_name=self.model_name, provider_request_id=vendor_id, provider_name=self._provider.name
307+
)
308308

309309
@overload
310310
async def _messages_create(
@@ -594,6 +594,7 @@ class BedrockStreamedResponse(StreamedResponse):
594594

595595
_model_name: BedrockModelName
596596
_event_stream: EventStream[ConverseStreamOutputTypeDef]
597+
_provider_name: str
597598
_timestamp: datetime = field(default_factory=_utils.now_utc)
598599

599600
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
@@ -660,15 +661,20 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
660661
if maybe_event: # pragma: no branch
661662
yield maybe_event
662663

663-
@property
664-
def timestamp(self) -> datetime:
665-
return self._timestamp
666-
667664
@property
668665
def model_name(self) -> str:
669666
"""Get the model name of the response."""
670667
return self._model_name
671668

669+
@property
670+
def provider_name(self) -> str:
671+
"""Get the provider name."""
672+
return self._provider_name
673+
674+
@property
675+
def timestamp(self) -> datetime:
676+
return self._timestamp
677+
672678
def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> usage.RequestUsage:
673679
return usage.RequestUsage(
674680
input_tokens=metadata['usage']['inputTokens'],

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def _process_response(self, response: V2ChatResponse) -> ModelResponse:
205205
tool_call_id=c.id or _generate_tool_call_id(),
206206
)
207207
)
208-
return ModelResponse(parts=parts, usage=_map_usage(response), model_name=self._model_name)
208+
return ModelResponse(
209+
parts=parts, usage=_map_usage(response), model_name=self._model_name, provider_name=self._provider.name
210+
)
209211

210212
def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
211213
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,11 @@ def model_name(self) -> str:
304304
"""Get the model name of the response."""
305305
return self._model_name
306306

307+
@property
308+
def provider_name(self) -> None:
309+
"""Get the provider name."""
310+
return None
311+
307312
@property
308313
def timestamp(self) -> datetime:
309314
"""Get the timestamp of the response."""

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ async def _process_streamed_response(
305305
_model_name=self._model_name,
306306
_content=content,
307307
_stream=aiter_bytes,
308+
_provider_name=self._provider.name,
308309
)
309310

310311
async def _message_to_gemini_content(
@@ -425,6 +426,7 @@ class GeminiStreamedResponse(StreamedResponse):
425426
_model_name: GeminiModelName
426427
_content: bytearray
427428
_stream: AsyncIterator[bytes]
429+
_provider_name: str
428430
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
429431

430432
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
@@ -495,6 +497,11 @@ def model_name(self) -> GeminiModelName:
495497
"""Get the model name of the response."""
496498
return self._model_name
497499

500+
@property
501+
def provider_name(self) -> str:
502+
"""Get the provider name."""
503+
return self._provider_name
504+
498505
@property
499506
def timestamp(self) -> datetime:
500507
"""Get the timestamp of the response."""

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
395395
return _process_response_from_parts(
396396
parts,
397397
response.model_version or self._model_name,
398+
self._provider.name,
398399
usage,
399400
vendor_id=vendor_id,
400401
vendor_details=vendor_details,
@@ -414,6 +415,7 @@ async def _process_streamed_response(
414415
_model_name=self._model_name,
415416
_response=peekable_response,
416417
_timestamp=first_chunk.create_time or _utils.now_utc(),
418+
_provider_name=self._provider.name,
417419
)
418420

419421
async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
@@ -523,6 +525,7 @@ class GeminiStreamedResponse(StreamedResponse):
523525
_model_name: GoogleModelName
524526
_response: AsyncIterator[GenerateContentResponse]
525527
_timestamp: datetime
528+
_provider_name: str
526529

527530
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
528531
async for chunk in self._response:
@@ -564,6 +567,11 @@ def model_name(self) -> GoogleModelName:
564567
"""Get the model name of the response."""
565568
return self._model_name
566569

570+
@property
571+
def provider_name(self) -> str:
572+
"""Get the provider name."""
573+
return self._provider_name
574+
567575
@property
568576
def timestamp(self) -> datetime:
569577
"""Get the timestamp of the response."""
@@ -599,6 +607,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
599607
def _process_response_from_parts(
600608
parts: list[Part],
601609
model_name: GoogleModelName,
610+
provider_name: str,
602611
usage: usage.RequestUsage,
603612
vendor_id: str | None,
604613
vendor_details: dict[str, Any] | None = None,
@@ -636,7 +645,12 @@ def _process_response_from_parts(
636645
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
637646
)
638647
return ModelResponse(
639-
parts=items, model_name=model_name, usage=usage, provider_request_id=vendor_id, provider_details=vendor_details
648+
parts=items,
649+
model_name=model_name,
650+
usage=usage,
651+
provider_request_id=vendor_id,
652+
provider_details=vendor_details,
653+
provider_name=provider_name,
640654
)
641655

642656

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
290290
model_name=response.model,
291291
timestamp=timestamp,
292292
provider_request_id=response.id,
293+
provider_name=self._provider.name,
293294
)
294295

295296
async def _process_streamed_response(
@@ -309,6 +310,7 @@ async def _process_streamed_response(
309310
_model_name=self._model_name,
310311
_model_profile=self.profile,
311312
_timestamp=number_to_datetime(first_chunk.created),
313+
_provider_name=self._provider.name,
312314
)
313315

314316
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
@@ -444,6 +446,7 @@ class GroqStreamedResponse(StreamedResponse):
444446
_model_profile: ModelProfile
445447
_response: AsyncIterable[chat.ChatCompletionChunk]
446448
_timestamp: datetime
449+
_provider_name: str
447450

448451
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
449452
async for chunk in self._response:
@@ -482,6 +485,11 @@ def model_name(self) -> GroqModelName:
482485
"""Get the model name of the response."""
483486
return self._model_name
484487

488+
@property
489+
def provider_name(self) -> str:
490+
"""Get the provider name."""
491+
return self._provider_name
492+
485493
@property
486494
def timestamp(self) -> datetime:
487495
"""Get the timestamp of the response."""

0 commit comments

Comments
 (0)