1010
1111import importlib .metadata
1212from collections .abc import Iterable , Mapping , Sequence
13- from typing import Any , Generic , TypeVar , Union , cast
13+ from typing import TYPE_CHECKING , Any , Generic , TypeVar , Union , cast
1414
1515import marshmallow as ma
1616from sqlalchemy .ext .declarative import DeclarativeMeta
17- from sqlalchemy .orm import Session
1817from sqlalchemy .orm .exc import ObjectDeletedError
1918
2019from .fields import get_primary_keys
2120
21+ if TYPE_CHECKING :
22+ from sqlalchemy .orm import Session
23+
2224_LoadDataV3 = Union [Mapping [str , Any ], Iterable [Mapping [str , Any ]]]
2325_LoadDataV4 = Union [Mapping [str , Any ], Sequence [Mapping [str , Any ]]]
2426_LoadDataT = TypeVar ("_LoadDataT" , _LoadDataV3 , _LoadDataV4 )
2729
2830def _cast_data (data ):
2931 if int (importlib .metadata .version ("marshmallow" )[0 ]) >= 4 :
30- return cast (_LoadDataV4 , data )
31- return cast (_LoadDataV3 , data )
32+ return cast (" _LoadDataV4" , data )
33+ return cast (" _LoadDataV3" , data )
3234
3335
3436class LoadInstanceMixin :
@@ -87,12 +89,12 @@ def get_instance(self, data) -> _ModelType | None:
8789 """
8890 if self .transient :
8991 return None
90- model = cast (type [_ModelType ], self .opts .model )
92+ model = cast (" type[_ModelType]" , self .opts .model )
9193 props = get_primary_keys (model )
9294 filters = {prop .key : data .get (prop .key ) for prop in props }
9395 if None not in filters .values ():
9496 try :
95- return cast (Session , self .session ).get (model , filters )
97+ return cast (" Session" , self .session ).get (model , filters )
9698 except ObjectDeletedError :
9799 return None
98100 return None
@@ -114,7 +116,7 @@ def make_instance(self, data, **kwargs) -> _ModelType:
114116 setattr (instance , key , value )
115117 return instance
116118 kwargs , association_attrs = self ._split_model_kwargs_association (data )
117- ModelClass = cast (DeclarativeMeta , self .opts .model )
119+ ModelClass = cast (" DeclarativeMeta" , self .opts .model )
118120 instance = ModelClass (** kwargs )
119121 for attr , value in association_attrs .items ():
120122 setattr (instance , attr , value )
@@ -144,7 +146,7 @@ def load(
144146 raise ValueError ("Deserialization requires a session" )
145147 self .instance = instance or self .instance
146148 try :
147- return cast (ma .Schema , super ()).load (_cast_data (data ), ** kwargs )
149+ return cast (" ma.Schema" , super ()).load (_cast_data (data ), ** kwargs )
148150 finally :
149151 self .instance = None
150152
@@ -159,7 +161,7 @@ def validate(
159161 self ._session = session or self ._session
160162 if not (self .transient or self .session ):
161163 raise ValueError ("Validation requires a session" )
162- return cast (ma .Schema , super ()).validate (_cast_data (data ), ** kwargs )
164+ return cast (" ma.Schema" , super ()).validate (_cast_data (data ), ** kwargs )
163165
164166 def _split_model_kwargs_association (self , data ):
165167 """Split serialized attrs to ensure association proxies are passed separately.
0 commit comments