Skip to content

Commit d50e24f

Browse files
committed
add usage
1 parent 18c5ac1 commit d50e24f

File tree

4 files changed

+499
-58
lines changed

4 files changed

+499
-58
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55
import json
66
import warnings
7-
from collections.abc import AsyncIterable, AsyncIterator, Sequence
7+
from collections.abc import AsyncIterable, AsyncIterator, Iterable, Sequence
88
from contextlib import asynccontextmanager
99
from dataclasses import dataclass, field, replace
1010
from datetime import datetime
@@ -62,8 +62,9 @@
6262
ChatCompletionContentPartInputAudioParam,
6363
ChatCompletionContentPartParam,
6464
ChatCompletionContentPartTextParam,
65+
chat_completion,
66+
chat_completion_chunk,
6567
)
66-
from openai.types.chat.chat_completion_chunk import Choice
6768
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
6869
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
6970
from openai.types.chat.chat_completion_content_part_param import File, FileFile
@@ -543,28 +544,7 @@ def _process_provider_details(self, response: chat.ChatCompletion) -> dict[str,
543544
544545
This method may be overridden by subclasses of `OpenAIChatModel` to apply custom mappings.
545546
"""
546-
choice = response.choices[0]
547-
provider_details: dict[str, Any] = {}
548-
549-
# Add logprobs to vendor_details if available
550-
if choice.logprobs is not None and choice.logprobs.content:
551-
# Convert logprobs to a serializable format
552-
provider_details['logprobs'] = [
553-
{
554-
'token': lp.token,
555-
'bytes': lp.bytes,
556-
'logprob': lp.logprob,
557-
'top_logprobs': [
558-
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
559-
],
560-
}
561-
for lp in choice.logprobs.content
562-
]
563-
564-
raw_finish_reason = choice.finish_reason
565-
provider_details['finish_reason'] = raw_finish_reason
566-
567-
return provider_details
547+
return _map_provider_details(response.choices[0])
568548

569549
def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse:
570550
"""Process a non-streamed response, and prepare a message to return."""
@@ -618,7 +598,7 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
618598

619599
return ModelResponse(
620600
parts=items,
621-
usage=_map_usage(response, self._provider.name, self._provider.base_url, self._model_name),
601+
usage=self._map_usage(response),
622602
model_name=response.model,
623603
timestamp=timestamp,
624604
provider_details=self._process_provider_details(response),
@@ -680,6 +660,9 @@ def _streamed_response_cls(self) -> type[OpenAIStreamedResponse]:
680660
"""
681661
return OpenAIStreamedResponse
682662

663+
def _map_usage(self, response: chat.ChatCompletion) -> usage.RequestUsage:
664+
return _map_usage(response, self._provider.name, self._provider.base_url, self._model_name)
665+
683666
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
684667
return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()]
685668

@@ -1767,7 +1750,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
17671750
for event in self._map_part_delta(choice):
17681751
yield event
17691752

1770-
def _validate_response(self):
1753+
def _validate_response(self) -> AsyncIterable[ChatCompletionChunk]:
17711754
"""Hook that validates incoming chunks.
17721755
17731756
This method may be overridden by subclasses of `OpenAIStreamedResponse` to apply custom chunk validations.
@@ -1776,7 +1759,7 @@ def _validate_response(self):
17761759
"""
17771760
return self._response
17781761

1779-
def _map_part_delta(self, choice: Choice):
1762+
def _map_part_delta(self, choice: chat_completion_chunk.Choice) -> Iterable[ModelResponseStreamEvent]:
17801763
"""Hook that determines the sequence of mappings that will be called to produce events.
17811764
17821765
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
@@ -1785,7 +1768,7 @@ def _map_part_delta(self, choice: Choice):
17851768
self._map_thinking_delta(choice), self._map_text_delta(choice), self._map_tool_call_delta(choice)
17861769
)
17871770

1788-
def _map_thinking_delta(self, choice: Choice):
1771+
def _map_thinking_delta(self, choice: chat_completion_chunk.Choice) -> Iterable[ModelResponseStreamEvent]:
17891772
"""Hook that maps thinking delta content to events.
17901773
17911774
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
@@ -1811,7 +1794,7 @@ def _map_thinking_delta(self, choice: Choice):
18111794
provider_name=self.provider_name,
18121795
)
18131796

1814-
def _map_text_delta(self, choice: Choice):
1797+
def _map_text_delta(self, choice: chat_completion_chunk.Choice) -> Iterable[ModelResponseStreamEvent]:
18151798
"""Hook that maps text delta content to events.
18161799
18171800
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
@@ -1831,7 +1814,7 @@ def _map_text_delta(self, choice: Choice):
18311814
maybe_event.part.provider_name = self.provider_name
18321815
yield maybe_event
18331816

1834-
def _map_tool_call_delta(self, choice: Choice):
1817+
def _map_tool_call_delta(self, choice: chat_completion_chunk.Choice) -> Iterable[ModelResponseStreamEvent]:
18351818
"""Hook that maps tool call delta content to events.
18361819
18371820
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the mapping.
@@ -1851,11 +1834,9 @@ def _map_provider_details(self, chunk: ChatCompletionChunk) -> dict[str, str] |
18511834
18521835
This method may be overridden by subclasses of `OpenAIStreamResponse` to customize the provider details.
18531836
"""
1854-
choice = chunk.choices[0]
1855-
if raw_finish_reason := choice.finish_reason:
1856-
return {'finish_reason': raw_finish_reason}
1837+
return _map_provider_details(chunk.choices[0])
18571838

1858-
def _map_usage(self, response: ChatCompletionChunk):
1839+
def _map_usage(self, response: ChatCompletionChunk) -> usage.RequestUsage:
18591840
return _map_usage(response, self._provider_name, self._provider_url, self._model_name)
18601841

18611842
def _map_finish_reason(
@@ -2177,7 +2158,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
21772158
UserWarning,
21782159
)
21792160

2180-
def _map_usage(self, response: responses.Response):
2161+
def _map_usage(self, response: responses.Response) -> usage.RequestUsage:
21812162
return _map_usage(response, self._provider_name, self._provider_url, self._model_name)
21822163

21832164
@property
@@ -2237,6 +2218,32 @@ def _map_usage(
22372218
)
22382219

22392220

2221+
def _map_provider_details(
2222+
choice: chat_completion_chunk.Choice | chat_completion.Choice,
2223+
) -> dict[str, Any]:
2224+
provider_details: dict[str, Any] = {}
2225+
2226+
# Add logprobs to vendor_details if available
2227+
if choice.logprobs is not None and choice.logprobs.content:
2228+
# Convert logprobs to a serializable format
2229+
provider_details['logprobs'] = [
2230+
{
2231+
'token': lp.token,
2232+
'bytes': lp.bytes,
2233+
'logprob': lp.logprob,
2234+
'top_logprobs': [
2235+
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
2236+
],
2237+
}
2238+
for lp in choice.logprobs.content
2239+
]
2240+
2241+
if raw_finish_reason := choice.finish_reason:
2242+
provider_details['finish_reason'] = raw_finish_reason
2243+
2244+
return provider_details
2245+
2246+
22402247
def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
22412248
# When reasoning, the Responses API requires the `ResponseFunctionToolCall` to be returned with both the `call_id` and `id` fields.
22422249
# Before our `ToolCallPart` gained the `id` field alongside `tool_call_id` field, we combined the two fields into a single string stored on `tool_call_id`.

pydantic_ai_slim/pydantic_ai/models/openrouter.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
try:
2626
from openai import APIError
27-
from openai.types import chat
27+
from openai.types import chat, completion_usage
2828
from openai.types.chat import chat_completion, chat_completion_chunk
2929

3030
from .openai import OpenAIChatModel, OpenAIChatModelSettings, OpenAIStreamedResponse
@@ -220,6 +220,12 @@ class WebPlugin(TypedDict, total=False):
220220
OpenRouterPlugin = WebPlugin
221221

222222

223+
class OpenRouterUsageConfig(TypedDict, total=False):
224+
"""Configuration for OpenRouter usage."""
225+
226+
include: bool
227+
228+
223229
class OpenRouterModelSettings(ModelSettings, total=False):
224230
"""Settings used for an OpenRouter model request."""
225231

@@ -254,6 +260,16 @@ class OpenRouterModelSettings(ModelSettings, total=False):
254260
"""
255261

256262
openrouter_plugins: list[OpenRouterPlugin]
263+
"""To enable plugins in the request.
264+
265+
Plugins are tools that can be used to extend the functionality of the model. [See more](https://openrouter.ai/docs/features/web-search)
266+
"""
267+
268+
openrouter_usage: OpenRouterUsageConfig
269+
"""To control the usage of the model.
270+
271+
The usage config object consolidates settings for enabling detailed usage information. [See more](https://openrouter.ai/docs/use-cases/usage-accounting)
272+
"""
257273

258274

259275
class OpenRouterError(BaseModel):
@@ -357,6 +373,30 @@ class OpenRouterChoice(chat_completion.Choice):
357373
"""A wrapped chat completion message with OpenRouter specific attributes."""
358374

359375

376+
class OpenRouterCostDetails(BaseModel):
377+
"""OpenRouter specific cost details."""
378+
379+
upstream_inference_cost: int | None = None
380+
381+
382+
class OpenRouterCompletionTokenDetails(completion_usage.CompletionTokensDetails):
383+
"""Wraps OpenAI completion token details with OpenRouter specific attributes."""
384+
385+
image_tokens: int | None = None
386+
387+
388+
class OpenRouterUsage(completion_usage.CompletionUsage):
389+
"""Wraps OpenAI completion usage with OpenRouter specific attributes."""
390+
391+
cost: float | None = None
392+
393+
cost_details: OpenRouterCostDetails | None = None
394+
395+
is_byok: bool | None = None
396+
397+
completion_tokens_details: OpenRouterCompletionTokenDetails | None = None # type: ignore[reportIncompatibleVariableOverride]
398+
399+
360400
class OpenRouterChatCompletion(chat.ChatCompletion):
361401
"""Wraps OpenAI chat completion with OpenRouter specific attributes."""
362402

@@ -369,6 +409,9 @@ class OpenRouterChatCompletion(chat.ChatCompletion):
369409
error: OpenRouterError | None = None
370410
"""OpenRouter specific error attribute."""
371411

412+
usage: OpenRouterUsage | None = None # type: ignore[reportIncompatibleVariableOverride]
413+
"""OpenRouter specific usage attribute."""
414+
372415

373416
def _openrouter_settings_to_openai_settings(model_settings: OpenRouterModelSettings) -> OpenAIChatModelSettings:
374417
"""Transforms a 'OpenRouterModelSettings' object into an 'OpenAIChatModelSettings' object.
@@ -389,6 +432,8 @@ def _openrouter_settings_to_openai_settings(model_settings: OpenRouterModelSetti
389432
extra_body['preset'] = preset
390433
if transforms := model_settings.pop('openrouter_transforms', None):
391434
extra_body['transforms'] = transforms
435+
if usage := model_settings.pop('openrouter_usage', None):
436+
extra_body['usage'] = usage
392437

393438
model_settings['extra_body'] = extra_body
394439

@@ -401,30 +446,40 @@ def _map_usage(
401446
provider_url: str,
402447
model: str,
403448
) -> RequestUsage:
449+
assert isinstance(response, OpenRouterChatCompletion) or isinstance(response, OpenRouterChatCompletionChunk)
450+
builder = RequestUsage()
451+
404452
response_usage = response.usage
405453
if response_usage is None:
406-
return RequestUsage()
407-
408-
usage_data = response_usage.model_dump(exclude_none=True)
409-
details = {
410-
k: v
411-
for k, v in usage_data.items()
412-
if k not in {'prompt_tokens', 'completion_tokens', 'input_tokens', 'output_tokens', 'total_tokens'}
413-
if isinstance(v, int)
414-
}
415-
response_data = dict(model=model, usage=usage_data)
416-
417-
if response_usage.completion_tokens_details is not None: # pragma: lax no cover
418-
details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
419-
420-
return RequestUsage.extract(
421-
response_data,
422-
provider=provider,
423-
provider_url=provider_url,
424-
provider_fallback='openai',
425-
api_flavor='chat',
426-
details=details,
427-
)
454+
return builder
455+
456+
builder.input_tokens = response_usage.prompt_tokens
457+
builder.output_tokens = response_usage.completion_tokens
458+
459+
if prompt_token_details := response_usage.prompt_tokens_details:
460+
if cached_tokens := prompt_token_details.cached_tokens:
461+
builder.cache_read_tokens = cached_tokens
462+
463+
if audio_tokens := prompt_token_details.audio_tokens: # pragma: lax no cover
464+
builder.input_audio_tokens = audio_tokens
465+
466+
if video_tokens := prompt_token_details.video_tokens: # pragma: lax no cover
467+
builder.details['input_video_tokens'] = video_tokens
468+
469+
if completion_token_details := response_usage.completion_tokens_details:
470+
if reasoning_tokens := completion_token_details.reasoning_tokens:
471+
builder.details['reasoning_tokens'] = reasoning_tokens
472+
473+
if image_tokens := completion_token_details.image_tokens: # pragma: lax no cover
474+
builder.details['output_image_tokens'] = image_tokens
475+
476+
if (is_byok := response_usage.is_byok) is not None:
477+
builder.details['is_byok'] = is_byok
478+
479+
if cost := response_usage.cost:
480+
builder.details['cost'] = int(cost * 1000000) # convert to microcost
481+
482+
return builder
428483

429484

430485
class OpenRouterModel(OpenAIChatModel):
@@ -524,6 +579,10 @@ def _map_model_response(self, message: ModelResponse) -> chat.ChatCompletionMess
524579
def _streamed_response_cls(self):
525580
return OpenRouterStreamedResponse
526581

582+
@override
583+
def _map_usage(self, response: chat.ChatCompletion) -> RequestUsage:
584+
return _map_usage(response, self._provider.name, self._provider.base_url, self._model_name)
585+
527586
@override
528587
def _map_finish_reason( # type: ignore[reportIncompatibleMethodOverride]
529588
self, key: Literal['stop', 'length', 'tool_calls', 'content_filter', 'error']
@@ -566,6 +625,9 @@ class OpenRouterChatCompletionChunk(chat.ChatCompletionChunk):
566625
choices: list[OpenRouterChunkChoice] # type: ignore[reportIncompatibleVariableOverride]
567626
"""A list of chat completion chunk choices modified with OpenRouter specific attributes."""
568627

628+
usage: OpenRouterUsage | None = None # type: ignore[reportIncompatibleVariableOverride]
629+
"""Usage statistics for the completion request."""
630+
569631

570632
@dataclass
571633
class OpenRouterStreamedResponse(OpenAIStreamedResponse):

0 commit comments

Comments
 (0)