88
99from __future__ import annotations
1010
11- from typing import TYPE_CHECKING , Any , Generic , TypeVar , cast
11+ import importlib .metadata
12+ from collections .abc import Iterable , Mapping , Sequence
13+ from typing import Any , Generic , TypeVar , Union , cast
1214
1315import marshmallow as ma
1416from sqlalchemy .ext .declarative import DeclarativeMeta
1719
1820from .fields import get_primary_keys
1921
20- if TYPE_CHECKING :
21- from collections . abc import Iterable , Mapping
22-
22+ _LoadDataV3 = Union [ Mapping [ str , Any ], Iterable [ Mapping [ str , Any ]]]
23+ _LoadDataV4 = Union [ Mapping [ str , Any ], Sequence [ Mapping [ str , Any ]]]
24+ _LoadDataT = TypeVar ( "_LoadDataT" , _LoadDataV3 , _LoadDataV4 )
2325_ModelType = TypeVar ("_ModelType" , bound = DeclarativeMeta )
2426
2527
28+ def _cast_data (data ):
29+ if int (importlib .metadata .version ("marshmallow" )[0 ]) >= 4 :
30+ return cast (_LoadDataV4 , data )
31+ return cast (_LoadDataV3 , data )
32+
33+
2634class LoadInstanceMixin :
2735 class Opts :
2836 model : type [DeclarativeMeta ] | None
@@ -114,7 +122,7 @@ def make_instance(self, data, **kwargs) -> _ModelType:
114122
115123 def load (
116124 self ,
117- data : Mapping [ str , Any ] | Iterable [ Mapping [ str , Any ]] ,
125+ data : _LoadDataT ,
118126 * ,
119127 session : Session | None = None ,
120128 instance : _ModelType | None = None ,
@@ -136,13 +144,13 @@ def load(
136144 raise ValueError ("Deserialization requires a session" )
137145 self .instance = instance or self .instance
138146 try :
139- return cast (ma .Schema , super ()).load (data , ** kwargs )
147+ return cast (ma .Schema , super ()).load (_cast_data ( data ) , ** kwargs )
140148 finally :
141149 self .instance = None
142150
143151 def validate (
144152 self ,
145- data : Mapping [ str , Any ] | Iterable [ Mapping [ str , Any ]] ,
153+ data : _LoadDataT ,
146154 * ,
147155 session : Session | None = None ,
148156 ** kwargs ,
@@ -151,7 +159,7 @@ def validate(
151159 self ._session = session or self ._session
152160 if not (self .transient or self .session ):
153161 raise ValueError ("Validation requires a session" )
154- return cast (ma .Schema , super ()).validate (data , ** kwargs )
162+ return cast (ma .Schema , super ()).validate (_cast_data ( data ) , ** kwargs )
155163
156164 def _split_model_kwargs_association (self , data ):
157165 """Split serialized attrs to ensure association proxies are passed separately.
0 commit comments