Skip to content

Commit 8d20b12

Browse files
committed
Ensure that old ModelResponses stored in a DB can still be deserialized
1 parent 8149de4 commit 8d20b12

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,14 +1014,19 @@ class ModelResponse:
10141014
provider_name: str | None = None
10151015
"""The name of the LLM provider that generated the response."""
10161016

1017-
provider_details: dict[str, Any] | None = field(default=None)
1017+
provider_details: Annotated[
1018+
dict[str, Any] | None,
1019+
pydantic.Field(validation_alias=pydantic.AliasChoices('provider_details', 'vendor_details')),
1020+
] = None
10181021
"""Additional provider-specific details in a serializable format.
10191022
10201023
This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
10211024
For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
10221025
"""
10231026

1024-
provider_response_id: str | None = None
1027+
provider_response_id: Annotated[
1028+
str | None, pydantic.Field(validation_alias=pydantic.AliasChoices('provider_response_id', 'vendor_id'))
1029+
] = None
10251030
"""request ID as specified by the model provider. This can be used to track the specific request to the model."""
10261031

10271032
def price(self) -> genai_types.PriceCalculation:

pydantic_ai_slim/pydantic_ai/usage.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import dataclasses
44
from copy import copy
55
from dataclasses import dataclass, fields
6+
from typing import Annotated
67

8+
from pydantic import AliasChoices, BeforeValidator, Field
79
from typing_extensions import deprecated, overload
810

911
from . import _utils
@@ -14,15 +16,15 @@
1416

1517
@dataclass(repr=False, kw_only=True)
1618
class UsageBase:
17-
input_tokens: int = 0
19+
input_tokens: Annotated[int, Field(validation_alias=AliasChoices('input_tokens', 'request_tokens'))] = 0
1820
"""Number of input/prompt tokens."""
1921

2022
cache_write_tokens: int = 0
2123
"""Number of tokens written to the cache."""
2224
cache_read_tokens: int = 0
2325
"""Number of tokens read from the cache."""
2426

25-
output_tokens: int = 0
27+
output_tokens: Annotated[int, Field(validation_alias=AliasChoices('output_tokens', 'response_tokens'))] = 0
2628
"""Number of output/completion tokens."""
2729

2830
input_audio_tokens: int = 0
@@ -32,7 +34,7 @@ class UsageBase:
3234
output_audio_tokens: int = 0
3335
"""Number of audio output tokens."""
3436

35-
details: dict[str, int] = dataclasses.field(default_factory=dict)
37+
details: Annotated[dict[str, int], BeforeValidator(lambda d: d or {})] = dataclasses.field(default_factory=dict)
3638
"""Any extra details returned by the model."""
3739

3840
@property

tests/test_messages.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
11
import sys
2+
from datetime import datetime, timezone
23

34
import pytest
5+
from inline_snapshot import snapshot
6+
7+
from pydantic_ai.messages import (
8+
AudioUrl,
9+
BinaryContent,
10+
DocumentUrl,
11+
ImageUrl,
12+
ModelMessagesTypeAdapter,
13+
ModelRequest,
14+
ModelResponse,
15+
RequestUsage,
16+
TextPart,
17+
ThinkingPartDelta,
18+
UserPromptPart,
19+
VideoUrl,
20+
)
421

5-
from pydantic_ai.messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, ThinkingPartDelta, VideoUrl
22+
from .conftest import IsNow
623

724

825
def test_image_url():
@@ -325,3 +342,63 @@ def test_thinking_part_delta_apply_to_thinking_part_delta():
325342
result = content_delta.apply(original_delta)
326343
assert isinstance(result, ThinkingPartDelta)
327344
assert result.content_delta == 'new_content'
345+
346+
347+
def test_pre_usage_refactor_messages_deserializable():
348+
# https://github.com/pydantic/pydantic-ai/pull/2378 changed the `ModelResponse` fields,
349+
# but we as tell people to store those in the DB we want to be very careful not to break deserialization.
350+
data = [
351+
{
352+
'parts': [
353+
{
354+
'content': 'What is the capital of Mexico?',
355+
'timestamp': datetime.now(tz=timezone.utc),
356+
'part_kind': 'user-prompt',
357+
}
358+
],
359+
'instructions': None,
360+
'kind': 'request',
361+
},
362+
{
363+
'parts': [{'content': 'Mexico City.', 'part_kind': 'text'}],
364+
'usage': {
365+
'requests': 1,
366+
'request_tokens': 13,
367+
'response_tokens': 76,
368+
'total_tokens': 89,
369+
'details': None,
370+
},
371+
'model_name': 'gpt-5-2025-08-07',
372+
'timestamp': datetime.now(tz=timezone.utc),
373+
'kind': 'response',
374+
'vendor_details': {
375+
'finish_reason': 'STOP',
376+
},
377+
'vendor_id': 'chatcmpl-CBpEXeCfDAW4HRcKQwbqsRDn7u7C5',
378+
},
379+
]
380+
messages = ModelMessagesTypeAdapter.validate_python(data)
381+
assert messages == snapshot(
382+
[
383+
ModelRequest(
384+
parts=[
385+
UserPromptPart(
386+
content='What is the capital of Mexico?',
387+
timestamp=IsNow(tz=timezone.utc),
388+
)
389+
]
390+
),
391+
ModelResponse(
392+
parts=[TextPart(content='Mexico City.')],
393+
usage=RequestUsage(
394+
input_tokens=13,
395+
output_tokens=76,
396+
details={},
397+
),
398+
model_name='gpt-5-2025-08-07',
399+
timestamp=IsNow(tz=timezone.utc),
400+
provider_details={'finish_reason': 'STOP'},
401+
provider_response_id='chatcmpl-CBpEXeCfDAW4HRcKQwbqsRDn7u7C5',
402+
),
403+
]
404+
)

0 commit comments

Comments
 (0)