|
1 | 1 | import collections.abc |
| 2 | +import dataclasses |
2 | 3 | import typing |
3 | | -from typing import cast, Any, Callable, Dict, Tuple, Type, TypeVar, Optional, Union |
| 4 | +from typing import cast, Any, Callable, Dict, Set, Tuple, Type, TypeVar, Optional, Union # noqa |
4 | 5 | from typing_extensions import Protocol |
5 | 6 |
|
6 | | -CACHED_TYPES: Dict[type, Optional[Dict[str, type]]] = {} |
| 7 | +# As of Python 3.7, the mypy Field definition is different from Python's version. |
| 8 | +# Use an old-fashioned comment until this situation is fixed. |
| 9 | +CACHED_TYPES = {} # type: Dict[type, Optional[Dict[str, dataclasses.Field[Any]]]] |
7 | 10 |
|
8 | 11 |
|
9 | 12 | class HasAnnotations(Protocol): |
@@ -75,6 +78,9 @@ def inner(ty: type, plural: bool, level: int) -> str: |
75 | 78 | if ty is type(None): # noqa |
76 | 79 | return 'nothing' |
77 | 80 |
|
| 81 | + if ty is object or ty is Any: |
| 82 | + return 'anything' |
| 83 | + |
78 | 84 | level += 1 |
79 | 85 | if level > 4: |
80 | 86 | # Making nested English clauses understandable is hard. Give up. |
@@ -171,22 +177,34 @@ def check_type(ty: Type[C], data: object, ty_module: str = '') -> C: |
171 | 177 | if isinstance(data, dict) and ty in CACHED_TYPES: |
172 | 178 | annotations = CACHED_TYPES[ty] |
173 | 179 | if annotations is None: |
174 | | - annotations = typing.get_type_hints(ty) |
| 180 | + annotations = {field.name: field for field in dataclasses.fields(ty)} |
175 | 181 | CACHED_TYPES[ty] = annotations |
176 | 182 | result: Dict[str, object] = {} |
177 | 183 |
|
178 | 184 | # Assign missing fields None |
| 185 | + missing: Set[str] = set() |
179 | 186 | for key in annotations: |
180 | 187 | if key not in data: |
181 | 188 | data[key] = None |
| 189 | + missing.add(key) |
182 | 190 |
|
183 | 191 | # Check field types |
184 | 192 | for key, value in data.items(): |
185 | 193 | if key not in annotations: |
186 | 194 | raise LoadUnknownField(ty, data, key) |
187 | 195 |
|
188 | | - expected_type = annotations[key] |
189 | | - result[key] = check_type(expected_type, value, ty.__module__) |
| 196 | + field = annotations[key] |
| 197 | + have_value = False |
| 198 | + if key in missing: |
| 199 | + # Use the field's default_factory if it's defined |
| 200 | + try: |
| 201 | + result[key] = field.default_factory() # type: ignore |
| 202 | + have_value = True |
| 203 | + except TypeError: |
| 204 | + pass |
| 205 | + |
| 206 | + if not have_value: |
| 207 | + result[key] = check_type(field.type, value, ty.__module__) |
190 | 208 |
|
191 | 209 | output = ty(**result) |
192 | 210 | start_line = getattr(data, '_start_line', None) |
|
0 commit comments