Skip to content

Commit d912939

Browse files
committed
Have PydanticType use TypeAdapter instead of the BaseModel validate -- so that it works on other types
1 parent 25661fa commit d912939

File tree

2 files changed

+9
-19
lines changed

2 files changed

+9
-19
lines changed

packages/db/src/db/models/message/message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Message(Base, kw_only=True):
6060
content: Mapped[str] = mapped_column(Text, nullable=False)
6161

6262
input_parts: Mapped[list[InputPart] | None] = mapped_column(
63-
ARRAY(PydanticType(InputPart)), # type:ignore[arg-type] # pyright: ignore[reportArgumentType]
63+
ARRAY(PydanticType(InputPart)),
6464
nullable=True,
6565
default=None,
6666
)

packages/db/src/db/models/pydantic_type.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TYPE_CHECKING, Any, final, override
22

3-
from pydantic import BaseModel
3+
from pydantic import TypeAdapter
44
from sqlalchemy import JSON, Dialect, TypeDecorator
55
from sqlalchemy.dialects.postgresql import JSONB
66

@@ -10,7 +10,7 @@
1010

1111
# Taken from https://gist.github.com/pdmtt/a6dc62f051c5597a8cdeeb8271c1e079?permalink_comment_id=5761533#gistcomment-5761533
1212
@final
13-
class PydanticType(TypeDecorator[BaseModel]):
13+
class PydanticType(TypeDecorator[Any]):
1414
"""Pydantic type.
1515
1616
SAVING:
@@ -38,9 +38,10 @@ class PydanticType(TypeDecorator[BaseModel]):
3838
impl = JSONB
3939
cache_ok = True
4040

41-
def __init__(self, pydantic_type: type[BaseModel]) -> None:
41+
def __init__(self, pydantic_type: Any) -> None:
4242
super().__init__()
4343
self.pydantic_type = pydantic_type
44+
self._adapter: TypeAdapter[Any] = TypeAdapter(pydantic_type)
4445

4546
@override
4647
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[JSONB | JSON]:
@@ -55,31 +56,20 @@ def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[JSONB | JSON]:
5556
@override
5657
def process_bind_param(
5758
self,
58-
value: BaseModel | None,
59+
value: Any | None,
5960
dialect: Dialect,
6061
) -> dict[str, Any] | None:
6162
if value is None:
6263
return None
63-
64-
if not isinstance(value, BaseModel): # dynamic typing.
65-
msg = f'The value "{value!r}" is not a pydantic model'
66-
raise TypeError(msg)
67-
68-
# Setting mode to "json" entails that you won't need to define a custom json
69-
# serializer ahead.
70-
return value.model_dump(mode="json")
64+
return self._adapter.dump_python(value, mode="json")
7165

7266
@override
7367
def process_result_value(
7468
self,
7569
value: dict[str, Any] | None,
7670
dialect: Dialect,
77-
) -> BaseModel | None:
78-
# We're assuming that the value will be a dictionary here.
79-
validate_on_load = True
80-
if validate_on_load:
81-
return self.pydantic_type.model_validate(value) if value else None
82-
return self.pydantic_type.model_construct(**value) if value else None
71+
) -> Any | None:
72+
return self._adapter.validate_python(value) if value else None
8373

8474
def __repr__(self) -> str:
8575
# Used by alembic

0 commit comments

Comments
 (0)