1- from typing import Any , List , Type , TypeVar , cast , get_args
1+ from typing import Any , Dict , List , Type , TypeVar , cast , get_args
22
33from pydantic import BaseModel
44from sqlalchemy import types
@@ -28,7 +28,9 @@ class PydanticJSONB(types.TypeDecorator): # type: ignore
2828
2929 def __init__ (
3030 self ,
31- model_class : Type [BaseModelType ] | Type [list [BaseModelType ]],
31+ model_class : Type [BaseModelType ]
32+ | Type [list [BaseModelType ]]
33+ | Type [Dict [str , BaseModelType ]],
3234 * args ,
3335 ** kwargs ,
3436 ):
@@ -45,15 +47,23 @@ def process_bind_param(self, value: Any, dialect) -> dict | list[dict] | None:
4547 m .model_dump (mode = "json" ) if isinstance (m , BaseModel ) else m
4648 for m in value
4749 ]
50+ if isinstance (value , dict ):
51+ return {
52+ k : v .model_dump (mode = "json" ) if isinstance (v , BaseModel ) else v
53+ for k , v in value .items ()
54+ }
4855 return value
4956
5057 def process_result_value (
5158 self , value : Any , dialect
52- ) -> BaseModelType | List [BaseModelType ] | None : # noqa: ANN401, ARG002, ANN001
53- # Called when loading from DB: convert dict to Pydantic model instance
59+ ) -> BaseModelType | List [BaseModelType ] | Dict [str , BaseModelType ] | None : # noqa: ANN401, ARG002, ANN001
5460 if value is None :
5561 return None
5662 if isinstance (value , dict ):
63+ # If model_class is dict, handle key-value pairs
64+ if isinstance (self .model_class , dict ):
65+ return {k : self .model_class .model_validate (v ) for k , v in value .items ()}
66+ # Regular case: the whole dict represents a single model
5767 return self .model_class .model_validate (value ) # type: ignore
5868 if isinstance (value , list ):
5969 return [get_args (self .model_class )[0 ].model_validate (v ) for v in value ] # type: ignore
0 commit comments