1- from typing import Any , cast
1+ from typing import Any , List , Type , TypeVar , cast , get_args
22
33from pydantic import BaseModel
44from sqlalchemy import types
55from sqlalchemy .dialects .postgresql import JSONB # for Postgres JSONB
66from sqlalchemy .engine .interfaces import Dialect
77
8+ BaseModelType = TypeVar ("BaseModelType" , bound = BaseModel )
9+
810
911class AutoString (types .TypeDecorator ): # type: ignore
1012 impl = types .String
@@ -24,20 +26,35 @@ class PydanticJSONB(types.TypeDecorator): # type: ignore
2426 impl = JSONB # use JSONB type in Postgres (fallback to JSON for others)
2527 cache_ok = True # allow SQLAlchemy to cache results
2628
27- def __init__ (self , model_class , * args , ** kwargs ):
29+ def __init__ (
30+ self ,
31+ model_class : Type [BaseModelType ] | Type [list [BaseModelType ]],
32+ * args ,
33+ ** kwargs ,
34+ ):
2835 super ().__init__ (* args , ** kwargs )
2936 self .model_class = model_class # Pydantic model class to use
3037
31- def process_bind_param (self , value , dialect ):
32- # Called when storing to DB: convert Pydantic model to a dict (JSON-serializable)
38+ def process_bind_param (self , value : Any , dialect ) -> dict | list [dict ] | None : # noqa: ANN401, ARG002, ANN001
3339 if value is None :
3440 return None
3541 if isinstance (value , BaseModel ):
36- return value .model_dump ()
37- return value # assume it's already a dict
38-
39- def process_result_value (self , value , dialect ):
42+ return value .model_dump (mode = "json" )
43+ if isinstance (value , list ):
44+ return [
45+ m .model_dump (mode = "json" ) if isinstance (m , BaseModel ) else m
46+ for m in value
47+ ]
48+ return value
49+
50+ def process_result_value (
51+ self , value : Any , dialect
52+ ) -> BaseModelType | List [BaseModelType ] | None : # noqa: ANN401, ARG002, ANN001
4053 # Called when loading from DB: convert dict to Pydantic model instance
4154 if value is None :
4255 return None
43- return self .model_class .parse_obj (value ) # instantiate Pydantic model
56+ if isinstance (value , dict ):
57+ return self .model_class .model_validate (value ) # type: ignore
58+ if isinstance (value , list ):
59+ return [get_args (self .model_class )[0 ].model_validate (v ) for v in value ] # type: ignore
60+ return value
0 commit comments