Skip to content

Commit 10fd481

Browse files
committed
Support serialization of lists and improve type hints for model binding and result processing
1 parent 1dab2cb commit 10fd481

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

sqlmodel/sql/sqltypes.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Any, cast
1+
from typing import Any, List, Type, TypeVar, cast, get_args
22

33
from pydantic import BaseModel
44
from sqlalchemy import types
55
from sqlalchemy.dialects.postgresql import JSONB # for Postgres JSONB
66
from sqlalchemy.engine.interfaces import Dialect
77

8+
BaseModelType = TypeVar("BaseModelType", bound=BaseModel)
9+
810

911
class AutoString(types.TypeDecorator): # type: ignore
1012
impl = types.String
@@ -24,20 +26,35 @@ class PydanticJSONB(types.TypeDecorator): # type: ignore
2426
impl = JSONB # use JSONB type in Postgres (fallback to JSON for others)
2527
cache_ok = True # allow SQLAlchemy to cache results
2628

27-
def __init__(self, model_class, *args, **kwargs):
29+
def __init__(
30+
self,
31+
model_class: Type[BaseModelType] | Type[list[BaseModelType]],
32+
*args,
33+
**kwargs,
34+
):
2835
super().__init__(*args, **kwargs)
2936
self.model_class = model_class # Pydantic model class to use
3037

31-
def process_bind_param(self, value, dialect):
32-
# Called when storing to DB: convert Pydantic model to a dict (JSON-serializable)
38+
def process_bind_param(self, value: Any, dialect) -> dict | list[dict] | None: # noqa: ANN401, ARG002, ANN001
3339
if value is None:
3440
return None
3541
if isinstance(value, BaseModel):
36-
return value.model_dump()
37-
return value # assume it's already a dict
38-
39-
def process_result_value(self, value, dialect):
42+
return value.model_dump(mode="json")
43+
if isinstance(value, list):
44+
return [
45+
m.model_dump(mode="json") if isinstance(m, BaseModel) else m
46+
for m in value
47+
]
48+
return value
49+
50+
def process_result_value(
51+
self, value: Any, dialect
52+
) -> BaseModelType | List[BaseModelType] | None: # noqa: ANN401, ARG002, ANN001
4053
# Called when loading from DB: convert dict to Pydantic model instance
4154
if value is None:
4255
return None
43-
return self.model_class.parse_obj(value) # instantiate Pydantic model
56+
if isinstance(value, dict):
57+
return self.model_class.model_validate(value) # type: ignore
58+
if isinstance(value, list):
59+
return [get_args(self.model_class)[0].model_validate(v) for v in value] # type: ignore
60+
return value

0 commit comments

Comments
 (0)