Skip to content

Commit d3b7f2d

Browse files
Use defer_build with type adapters, not custom _LazyTypeAdapter (#253)
1 parent 04f8c46 commit d3b7f2d

File tree

3 files changed

+32
-47
lines changed

3 files changed

+32
-47
lines changed

pydantic_ai_slim/pydantic_ai/_pydantic.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from inspect import Parameter, signature
99
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
1010

11-
from pydantic import ConfigDict, TypeAdapter
11+
from pydantic import ConfigDict
1212
from pydantic._internal import _decorators, _generate_schema, _typing_extra
1313
from pydantic._internal._config import ConfigWrapper
1414
from pydantic.fields import FieldInfo
@@ -23,7 +23,7 @@
2323
from .tools import ObjectJsonSchema
2424

2525

26-
__all__ = 'function_schema', 'LazyTypeAdapter'
26+
__all__ = ('function_schema',)
2727

2828

2929
class FunctionSchema(TypedDict):
@@ -214,21 +214,3 @@ def _is_call_ctx(annotation: Any) -> bool:
214214
return annotation is RunContext or (
215215
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
216216
)
217-
218-
219-
if TYPE_CHECKING:
220-
LazyTypeAdapter = TypeAdapter
221-
else:
222-
223-
class LazyTypeAdapter:
224-
__slots__ = '_args', '_kwargs', '_type_adapter'
225-
226-
def __init__(self, *args, **kwargs):
227-
self._args = args
228-
self._kwargs = kwargs
229-
self._type_adapter = None
230-
231-
def __getattr__(self, item):
232-
if self._type_adapter is None:
233-
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
234-
return getattr(self._type_adapter, item)

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pydantic
88
import pydantic_core
99

10-
from . import _pydantic
1110
from ._utils import now_utc as _now_utc
1211

1312

@@ -49,7 +48,7 @@ class UserPrompt:
4948
"""Message type identifier, this type is available on all messages as a discriminator."""
5049

5150

52-
tool_return_ta: pydantic.TypeAdapter[Any] = _pydantic.LazyTypeAdapter(Any)
51+
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
5352

5453

5554
@dataclass
@@ -88,7 +87,7 @@ def model_response_object(self) -> dict[str, Any]:
8887
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
8988

9089

91-
ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
90+
ErrorDetailsTa = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
9291

9392

9493
@dataclass
@@ -229,5 +228,7 @@ def from_text(content: str, timestamp: datetime | None = None) -> ModelResponse:
229228
Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelResponse]
230229
"""Any message send to or returned by a model."""
231230

232-
MessagesTypeAdapter = _pydantic.LazyTypeAdapter(list[Annotated[Message, pydantic.Discriminator('message_kind')]])
231+
MessagesTypeAdapter = pydantic.TypeAdapter(
232+
list[Annotated[Message, pydantic.Discriminator('message_kind')]], config=pydantic.ConfigDict(defer_build=True)
233+
)
233234
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from datetime import datetime
1010
from typing import Annotated, Any, Literal, Protocol, Union
1111

12+
import pydantic
1213
import pydantic_core
1314
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
14-
from pydantic import Discriminator, Field, Tag
1515
from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never
1616

17-
from .. import UnexpectedModelBehavior, _pydantic, _utils, exceptions, result
17+
from .. import UnexpectedModelBehavior, _utils, exceptions, result
1818
from ..messages import (
1919
ArgsDict,
2020
Message,
@@ -386,6 +386,7 @@ def timestamp(self) -> datetime:
386386
# TypeAdapters take care of validation and serialization
387387

388388

389+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
389390
class _GeminiRequest(TypedDict):
390391
"""Schema for an API request to the Gemini API.
391392
@@ -457,7 +458,7 @@ class _GeminiTextPart(TypedDict):
457458

458459

459460
class _GeminiFunctionCallPart(TypedDict):
460-
function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
461+
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]
461462

462463

463464
def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
@@ -487,7 +488,7 @@ class _GeminiFunctionCall(TypedDict):
487488

488489

489490
class _GeminiFunctionResponsePart(TypedDict):
490-
function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
491+
function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]
491492

492493

493494
def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
@@ -517,11 +518,11 @@ def _part_discriminator(v: Any) -> str:
517518
# TODO discriminator
518519
_GeminiPartUnion = Annotated[
519520
Union[
520-
Annotated[_GeminiTextPart, Tag('text')],
521-
Annotated[_GeminiFunctionCallPart, Tag('function_call')],
522-
Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
521+
Annotated[_GeminiTextPart, pydantic.Tag('text')],
522+
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
523+
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
523524
],
524-
Discriminator(_part_discriminator),
525+
pydantic.Discriminator(_part_discriminator),
525526
]
526527

527528

@@ -531,7 +532,7 @@ class _GeminiTextContent(TypedDict):
531532

532533

533534
class _GeminiTools(TypedDict):
534-
function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
535+
function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]
535536

536537

537538
class _GeminiFunction(TypedDict):
@@ -572,6 +573,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
572573
allowed_function_names: list[str]
573574

574575

576+
@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
575577
class _GeminiResponse(TypedDict):
576578
"""Schema for the response from the Gemini API.
577579
@@ -581,8 +583,8 @@ class _GeminiResponse(TypedDict):
581583

582584
candidates: list[_GeminiCandidates]
583585
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
584-
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
585-
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
586+
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
587+
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
586588

587589

588590
# TODO: Delete the next three functions once we've reworked streams to be more flexible
@@ -618,14 +620,14 @@ class _GeminiCandidates(TypedDict):
618620
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
619621

620622
content: _GeminiContent
621-
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], Field(alias='finishReason')]]
623+
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
622624
"""
623625
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
624626
but let's wait until we see them and know what they mean to add them here.
625627
"""
626-
avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
628+
avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
627629
index: NotRequired[int]
628-
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
630+
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
629631

630632

631633
class _GeminiUsageMetaData(TypedDict, total=False):
@@ -634,10 +636,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
634636
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
635637
"""
636638

637-
prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
638-
candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
639-
total_token_count: Annotated[int, Field(alias='totalTokenCount')]
640-
cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
639+
prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
640+
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
641+
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
642+
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
641643

642644

643645
def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
@@ -671,15 +673,15 @@ class _GeminiSafetyRating(TypedDict):
671673
class _GeminiPromptFeedback(TypedDict):
672674
"""See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
673675

674-
block_reason: Annotated[str, Field(alias='blockReason')]
675-
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
676+
block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
677+
safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]
676678

677679

678-
_gemini_request_ta = _pydantic.LazyTypeAdapter(_GeminiRequest)
679-
_gemini_response_ta = _pydantic.LazyTypeAdapter(_GeminiResponse)
680+
_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
681+
_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)
680682

681683
# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
682-
_gemini_streamed_response_ta = _pydantic.LazyTypeAdapter(list[_GeminiResponse])
684+
_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))
683685

684686

685687
class _GeminiJsonSchema:

0 commit comments

Comments
 (0)