11import collections .abc
22import dataclasses
3+ import enum
34import typing
4- from typing import cast , Any , Callable , Dict , Set , Tuple , Type , TypeVar , Optional , Union # noqa
5+ from dataclasses import MISSING
6+ from typing import cast , Any , Callable , Dict , Set , Tuple , Type , TypeVar , Iterator , Optional , Union
57from typing_extensions import Protocol
68
7- # As of Python 3.7, the mypy Field definition is different from Python's version.
8- # Use an old-fashioned comment until this situation is fixed.
9- CACHED_TYPES = {} # type: Dict[type, Optional[Dict[str, dataclasses.Field[Any]]]]
10-
119
1210class HasAnnotations (Protocol ):
1311 __annotations__ : Dict [str , type ]
@@ -17,15 +15,66 @@ class Constructable(Protocol):
1715 def __init__ (self , ** kwargs : object ) -> None : ...
1816
1917
18+ _A = TypeVar ('_A' , bound = HasAnnotations )
19+ _C = TypeVar ('_C' , bound = Constructable )
20+
21+
2022class mapping_dict (Dict [str , Any ]):
2123 """A dictionary that also contains source line information."""
2224 __slots__ = ('_start_line' , '_end_line' )
2325 _start_line : int
2426 _end_line : int
2527
2628
27- T = TypeVar ('T' , bound = HasAnnotations )
28- C = TypeVar ('C' , bound = Constructable )
29+ @dataclasses .dataclass
30+ class _Field :
31+ """A single field in a _TypeThunk."""
32+ __slots__ = ('default_factory' , 'type' )
33+
34+ default_factory : Optional [Callable [[], Any ]]
35+ type : Type [Any ]
36+
37+
38+ class _TypeThunk :
39+ """Type hints cannot be fully resolved at module runtime due to ForwardRefs. Instead,
40+ store the type here, and resolve type hints only when needed. By that time, hopefully all
41+ types have been declared."""
42+ __slots__ = ('type' , '_fields' )
43+
44+ def __init__ (self , klass : Type [Any ]) -> None :
45+ self .type = klass
46+ self ._fields : Optional [Dict [str , _Field ]] = None
47+
48+ def __getitem__ (self , field_name : str ) -> _Field :
49+ return self .fields [field_name ]
50+
51+ def __iter__ (self ) -> Iterator [str ]:
52+ return iter (self .fields .keys ())
53+
54+ @property
55+ def fields (self ) -> Dict [str , _Field ]:
56+ def make_factory (value : object ) -> Callable [[], Any ]:
57+ return lambda : value
58+
59+ if self ._fields is None :
60+ hints = typing .get_type_hints (self .type )
61+ # This is gnarly. Sorry. For each field, store its default_factory if present; otherwise
62+ # create a factory returning its default if present; otherwise None. Default parameter
63+ # in the lambda is a ~~hack~~ to avoid messing up the variable binding.
64+ fields : Dict [str , _Field ] = {
65+ field .name : _Field (
66+ field .default_factory if field .default_factory is not MISSING # type: ignore
67+ else ((make_factory (field .default )) if field .default is not MISSING
68+ else None ),
69+ hints [field .name ]) for field in dataclasses .fields (self .type )
70+ }
71+
72+ self ._fields = fields
73+
74+ return self ._fields
75+
76+
77+ CACHED_TYPES : Dict [type , _TypeThunk ] = {}
2978
3079
3180def _add_indefinite_article (s : str ) -> str :
@@ -131,9 +180,9 @@ def inner(ty: type, plural: bool, level: int) -> str:
131180 return inner (ty , False , 0 ), hints
132181
133182
134- def checked (klass : Type [T ]) -> Type [T ]:
183+ def checked (klass : Type [_A ]) -> Type [_A ]:
135184 """Marks a dataclass as being deserializable."""
136- CACHED_TYPES [klass ] = None
185+ CACHED_TYPES [klass ] = _TypeThunk ( klass )
137186 return klass
138187
139188
@@ -167,19 +216,28 @@ def __init__(self, ty: type, bad_data: object, bad_field: str) -> None:
167216 self .bad_field = bad_field
168217
169218
170- def check_type (ty : Type [C ], data : object , ty_module : str = '' ) -> C :
219+ def check_type (ty : Type [_C ], data : object ) -> _C :
171220 # Check for a primitive type
172221 if isinstance (ty , type ) and issubclass (ty , (str , int , float , bool , type (None ))):
173222 if not isinstance (data , ty ):
174223 raise LoadWrongType (ty , data )
175- return cast (C , data )
224+ return cast (_C , data )
225+
226+ if isinstance (ty , enum .EnumMeta ):
227+ try :
228+ if isinstance (data , str ):
229+ return ty [data ]
230+ if isinstance (data , int ):
231+ return ty (data )
232+ except (KeyError , ValueError ) as err :
233+ raise LoadWrongType (ty , data ) from err
234+
235+ # Check if the given type is a known flutter-annotated type
236+ if ty in CACHED_TYPES :
237+ if not isinstance (data , dict ):
238+ raise LoadWrongType (ty , data )
176239
177- # Check for an object
178- if isinstance (data , dict ) and ty in CACHED_TYPES :
179240 annotations = CACHED_TYPES [ty ]
180- if annotations is None :
181- annotations = {field .name : field for field in dataclasses .fields (ty )}
182- CACHED_TYPES [ty ] = annotations
183241 result : Dict [str , object ] = {}
184242
185243 # Assign missing fields None
@@ -198,14 +256,12 @@ def check_type(ty: Type[C], data: object, ty_module: str = '') -> C:
198256 have_value = False
199257 if key in missing :
200258 # Use the field's default_factory if it's defined
201- try :
202- result [key ] = field .default_factory () # type: ignore
259+ if field . default_factory is not None :
260+ result [key ] = field .default_factory ()
203261 have_value = True
204- except TypeError :
205- pass
206262
207263 if not have_value :
208- result [key ] = check_type (field .type , value , ty . __module__ )
264+ result [key ] = check_type (field .type , value )
209265
210266 output = ty (** result )
211267 start_line = getattr (data , '_start_line' , None )
@@ -220,23 +276,26 @@ def check_type(ty: Type[C], data: object, ty_module: str = '') -> C:
220276 if origin is list :
221277 if not isinstance (data , list ):
222278 raise LoadWrongType (ty , data )
223- return cast (C , [check_type (args [0 ], x , ty_module ) for x in data ])
279+ return cast (_C , [check_type (args [0 ], x ) for x in data ])
224280 elif origin is dict :
225281 if not isinstance (data , dict ):
226282 raise LoadWrongType (ty , data )
227283 key_type , value_type = args
228- return cast (C , {
229- check_type (key_type , k , ty_module ): check_type (value_type , v , ty_module )
284+ return cast (_C , {
285+ check_type (key_type , k ): check_type (value_type , v )
230286 for k , v in data .items ()})
231- elif origin is tuple and isinstance (data , collections .abc .Collection ):
287+ elif origin is tuple :
288+ if not isinstance (data , collections .abc .Collection ):
289+ raise LoadWrongType (ty , data )
290+
232291 if not len (data ) == len (args ):
233292 raise LoadWrongArity (ty , data )
234- return cast (C , tuple (
235- check_type (tuple_ty , x , ty_module ) for x , tuple_ty in zip (data , args )))
293+ return cast (_C , tuple (
294+ check_type (tuple_ty , x ) for x , tuple_ty in zip (data , args )))
236295 elif origin is Union :
237296 for candidate_ty in args :
238297 try :
239- return cast (C , check_type (candidate_ty , data , ty_module ))
298+ return cast (_C , check_type (candidate_ty , data ))
240299 except LoadError :
241300 pass
242301
@@ -245,6 +304,6 @@ def check_type(ty: Type[C], data: object, ty_module: str = '') -> C:
245304 raise LoadError ('Unsupported PEP-484 type' , ty , data )
246305
247306 if ty is object or ty is Any or isinstance (data , ty ):
248- return cast (C , data )
307+ return cast (_C , data )
249308
250309 raise LoadError ('Unloadable type' , ty , data )
0 commit comments