Skip to content

Commit 3bdc0a7

Browse files
committed
sync flutter to support enums and improve behavior
1 parent 982a3fe commit 3bdc0a7

File tree

1 file changed

+88
-29
lines changed

1 file changed

+88
-29
lines changed

snooty/flutter.py

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import collections.abc
22
import dataclasses
3+
import enum
34
import typing
4-
from typing import cast, Any, Callable, Dict, Set, Tuple, Type, TypeVar, Optional, Union # noqa
5+
from dataclasses import MISSING
6+
from typing import cast, Any, Callable, Dict, Set, Tuple, Type, TypeVar, Iterator, Optional, Union
57
from typing_extensions import Protocol
68

7-
# As of Python 3.7, the mypy Field definition is different from Python's version.
8-
# Use an old-fashioned comment until this situation is fixed.
9-
CACHED_TYPES = {} # type: Dict[type, Optional[Dict[str, dataclasses.Field[Any]]]]
10-
119

1210
class HasAnnotations(Protocol):
1311
__annotations__: Dict[str, type]
@@ -17,15 +15,66 @@ class Constructable(Protocol):
1715
def __init__(self, **kwargs: object) -> None: ...
1816

1917

18+
_A = TypeVar('_A', bound=HasAnnotations)
19+
_C = TypeVar('_C', bound=Constructable)
20+
21+
2022
class mapping_dict(Dict[str, Any]):
2123
"""A dictionary that also contains source line information."""
2224
__slots__ = ('_start_line', '_end_line')
2325
_start_line: int
2426
_end_line: int
2527

2628

27-
T = TypeVar('T', bound=HasAnnotations)
28-
C = TypeVar('C', bound=Constructable)
29+
@dataclasses.dataclass
30+
class _Field:
31+
"""A single field in a _TypeThunk."""
32+
__slots__ = ('default_factory', 'type')
33+
34+
default_factory: Optional[Callable[[], Any]]
35+
type: Type[Any]
36+
37+
38+
class _TypeThunk:
39+
"""Type hints cannot be fully resolved at module runtime due to ForwardRefs. Instead,
40+
store the type here, and resolve type hints only when needed. By that time, hopefully all
41+
types have been declared."""
42+
__slots__ = ('type', '_fields')
43+
44+
def __init__(self, klass: Type[Any]) -> None:
45+
self.type = klass
46+
self._fields: Optional[Dict[str, _Field]] = None
47+
48+
def __getitem__(self, field_name: str) -> _Field:
49+
return self.fields[field_name]
50+
51+
def __iter__(self) -> Iterator[str]:
52+
return iter(self.fields.keys())
53+
54+
@property
55+
def fields(self) -> Dict[str, _Field]:
56+
def make_factory(value: object) -> Callable[[], Any]:
57+
return lambda: value
58+
59+
if self._fields is None:
60+
hints = typing.get_type_hints(self.type)
61+
# This is gnarly. Sorry. For each field, store its default_factory if present; otherwise
62+
# create a factory returning its default if present; otherwise None. Default parameter
63+
# in the lambda is a ~~hack~~ to avoid messing up the variable binding.
64+
fields: Dict[str, _Field] = {
65+
field.name: _Field(
66+
field.default_factory if field.default_factory is not MISSING # type: ignore
67+
else ((make_factory(field.default)) if field.default is not MISSING
68+
else None),
69+
hints[field.name]) for field in dataclasses.fields(self.type)
70+
}
71+
72+
self._fields = fields
73+
74+
return self._fields
75+
76+
77+
CACHED_TYPES: Dict[type, _TypeThunk] = {}
2978

3079

3180
def _add_indefinite_article(s: str) -> str:
@@ -131,9 +180,9 @@ def inner(ty: type, plural: bool, level: int) -> str:
131180
return inner(ty, False, 0), hints
132181

133182

134-
def checked(klass: Type[T]) -> Type[T]:
183+
def checked(klass: Type[_A]) -> Type[_A]:
135184
"""Marks a dataclass as being deserializable."""
136-
CACHED_TYPES[klass] = None
185+
CACHED_TYPES[klass] = _TypeThunk(klass)
137186
return klass
138187

139188

@@ -167,19 +216,28 @@ def __init__(self, ty: type, bad_data: object, bad_field: str) -> None:
167216
self.bad_field = bad_field
168217

169218

170-
def check_type(ty: Type[C], data: object, ty_module: str = '') -> C:
219+
def check_type(ty: Type[_C], data: object) -> _C:
171220
# Check for a primitive type
172221
if isinstance(ty, type) and issubclass(ty, (str, int, float, bool, type(None))):
173222
if not isinstance(data, ty):
174223
raise LoadWrongType(ty, data)
175-
return cast(C, data)
224+
return cast(_C, data)
225+
226+
if isinstance(ty, enum.EnumMeta):
227+
try:
228+
if isinstance(data, str):
229+
return ty[data]
230+
if isinstance(data, int):
231+
return ty(data)
232+
except (KeyError, ValueError) as err:
233+
raise LoadWrongType(ty, data) from err
234+
235+
# Check if the given type is a known flutter-annotated type
236+
if ty in CACHED_TYPES:
237+
if not isinstance(data, dict):
238+
raise LoadWrongType(ty, data)
176239

177-
# Check for an object
178-
if isinstance(data, dict) and ty in CACHED_TYPES:
179240
annotations = CACHED_TYPES[ty]
180-
if annotations is None:
181-
annotations = {field.name: field for field in dataclasses.fields(ty)}
182-
CACHED_TYPES[ty] = annotations
183241
result: Dict[str, object] = {}
184242

185243
# Assign missing fields None
@@ -198,14 +256,12 @@ def check_type(ty: Type[C], data: object, ty_module: str = '') -> C:
198256
have_value = False
199257
if key in missing:
200258
# Use the field's default_factory if it's defined
201-
try:
202-
result[key] = field.default_factory() # type: ignore
259+
if field.default_factory is not None:
260+
result[key] = field.default_factory()
203261
have_value = True
204-
except TypeError:
205-
pass
206262

207263
if not have_value:
208-
result[key] = check_type(field.type, value, ty.__module__)
264+
result[key] = check_type(field.type, value)
209265

210266
output = ty(**result)
211267
start_line = getattr(data, '_start_line', None)
@@ -220,23 +276,26 @@ def check_type(ty: Type[C], data: object, ty_module: str = '') -> C:
220276
if origin is list:
221277
if not isinstance(data, list):
222278
raise LoadWrongType(ty, data)
223-
return cast(C, [check_type(args[0], x, ty_module) for x in data])
279+
return cast(_C, [check_type(args[0], x) for x in data])
224280
elif origin is dict:
225281
if not isinstance(data, dict):
226282
raise LoadWrongType(ty, data)
227283
key_type, value_type = args
228-
return cast(C, {
229-
check_type(key_type, k, ty_module): check_type(value_type, v, ty_module)
284+
return cast(_C, {
285+
check_type(key_type, k): check_type(value_type, v)
230286
for k, v in data.items()})
231-
elif origin is tuple and isinstance(data, collections.abc.Collection):
287+
elif origin is tuple:
288+
if not isinstance(data, collections.abc.Collection):
289+
raise LoadWrongType(ty, data)
290+
232291
if not len(data) == len(args):
233292
raise LoadWrongArity(ty, data)
234-
return cast(C, tuple(
235-
check_type(tuple_ty, x, ty_module) for x, tuple_ty in zip(data, args)))
293+
return cast(_C, tuple(
294+
check_type(tuple_ty, x) for x, tuple_ty in zip(data, args)))
236295
elif origin is Union:
237296
for candidate_ty in args:
238297
try:
239-
return cast(C, check_type(candidate_ty, data, ty_module))
298+
return cast(_C, check_type(candidate_ty, data))
240299
except LoadError:
241300
pass
242301

@@ -245,6 +304,6 @@ def check_type(ty: Type[C], data: object, ty_module: str = '') -> C:
245304
raise LoadError('Unsupported PEP-484 type', ty, data)
246305

247306
if ty is object or ty is Any or isinstance(data, ty):
248-
return cast(C, data)
307+
return cast(_C, data)
249308

250309
raise LoadError('Unloadable type', ty, data)

0 commit comments

Comments
 (0)