Skip to content

Commit fe56e0d

Browse files
authored
Merge pull request #10 from erezsh/forward_ref
Added support for forward-references in dataclasses
2 parents a98a5a8 + 533c507 commit fe56e0d

File tree

6 files changed

+207
-112
lines changed

6 files changed

+207
-112
lines changed

runtype/dataclass.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,21 @@
77
import dataclasses
88
from typing import Union
99
from abc import ABC, abstractmethod
10+
import inspect
1011

12+
from .utils import ForwardRef
1113
from .common import CHECK_TYPES
1214
from .validation import TypeMismatchError, ensure_isa as default_ensure_isa
13-
from .pytypes import cast_to_type, SumType, NoneType
15+
from .pytypes import TypeCaster, type_caster, SumType, NoneType
1416

1517
Required = object()
1618
MAX_SAMPLE_SIZE = 16
1719

20+
class NopTypeCaster:
21+
cache = {}
22+
def to_canon(self, t):
23+
return t
24+
1825
class Configuration(ABC):
1926
"""Generic configuration template for dataclass. Mainly for type-checking.
2027
@@ -48,15 +55,15 @@ class Form:
4855
4956
"""
5057

51-
def canonize_type(self, t):
52-
"""Given a type, return its canonical form.
58+
def on_default(self, default):
59+
"""Called whenever a dataclass member is assigned a default value.
5360
"""
54-
return t
61+
return default
5562

56-
def on_default(self, t, default):
57-
"""Called whenever a dataclass member is assigned a default value.
63+
def make_type_caster(self, frame):
64+
"""Return a type caster, as defined in pytypes.TypeCaster
5865
"""
59-
return t, default
66+
return NopTypeCaster()
6067

6168
@abstractmethod
6269
def ensure_isa(self, a, b, sampler=None):
@@ -79,40 +86,48 @@ class PythonConfiguration(Configuration):
7986
8087
This is the default class given to the ``dataclass()`` function.
8188
"""
82-
canonize_type = staticmethod(cast_to_type)
89+
make_type_caster = TypeCaster
8390
ensure_isa = staticmethod(default_ensure_isa)
8491

8592
def cast(self, obj, to_type):
8693
return to_type.cast_from(obj)
8794

88-
def on_default(self, type_, default):
89-
if default is None:
90-
type_ = SumType([type_, NoneType])
91-
elif isinstance(default, (list, dict, set)):
95+
def on_default(self, default):
96+
if isinstance(default, (list, dict, set)):
9297
def f(_=default):
9398
return copy(_)
94-
default = dataclasses.field(default_factory=f)
95-
return type_, default
96-
99+
return dataclasses.field(default_factory=f)
100+
return default
97101

98102

99-
def _post_init(self, config, should_cast, sampler):
103+
def _post_init(self, config, should_cast, sampler, type_caster):
100104
for name, field in getattr(self, '__dataclass_fields__', {}).items():
101105
value = getattr(self, name)
102106

103107
if value is Required:
104108
raise TypeError(f"Field {name} requires a value")
105109

110+
try:
111+
type_ = type_caster.cache[id(field)]
112+
except KeyError:
113+
type_ = field.type
114+
if isinstance(type_, str):
115+
type_ = ForwardRef(type_)
116+
type_ = type_caster.to_canon(type_)
117+
if field.default is None:
118+
type_ = SumType([type_, NoneType])
119+
type_caster.cache[id(field)] = type_
120+
106121
try:
107122
if should_cast: # Basic cast
108123
assert not sampler
109-
value = config.cast(value, field.type)
124+
value = config.cast(value, type_)
110125
object.__setattr__(self, name, value)
111126
else:
112-
config.ensure_isa(value, field.type, sampler)
127+
config.ensure_isa(value, type_, sampler)
113128
except TypeMismatchError as e:
114129
item_value, item_type = e.args
115-
msg = f"[{type(self).__name__}] Attribute '{name}' expected value of type {field.type}."
130+
msg = f"[{type(self).__name__}] Attribute '{name}' expected value of type '{type_}'."
116131
msg += f" Instead got {value!r}"
117132
if item_value is not value:
118133
msg += f'\n\n Failed on item: {item_value!r}, expected type {item_type}'
@@ -197,9 +212,9 @@ def _sample(seq, max_sample_size=MAX_SAMPLE_SIZE):
197212
return seq
198213
return random.sample(seq, max_sample_size)
199214

200-
def _process_class(cls, config, check_types, **kw):
215+
def _process_class(cls, config, check_types, context_frame, **kw):
201216
for name, type_ in getattr(cls, '__annotations__', {}).items():
202-
type_ = config.canonize_type(type_)
217+
# type_ = config.type_to_canon(type_) if not isinstance(type_, str) else type_
203218

204219
# If default not specified, assign Required, for a later check
205220
# We don't assign MISSING; we want to bypass dataclass which is too strict for this
@@ -211,7 +226,7 @@ def _process_class(cls, config, check_types, **kw):
211226
if default.default is dataclasses.MISSING and default.default_factory is dataclasses.MISSING:
212227
default.default = Required
213228

214-
type_, new_default = config.on_default(type_, default)
229+
new_default = config.on_default(default)
215230
if new_default is not default:
216231
setattr(cls, name, new_default)
217232

@@ -222,9 +237,12 @@ def _process_class(cls, config, check_types, **kw):
222237

223238
orig_post_init = getattr(cls, '__post_init__', None)
224239
sampler = _sample if check_types=='sample' else None
240+
# eval_type_string = EvalInContext(context_frame)
241+
type_caster = config.make_type_caster(context_frame)
225242

226243
def __post_init__(self):
227-
_post_init(self, config=config, should_cast=check_types == 'cast', sampler=sampler)
244+
# Only now context_frame has complete information
245+
_post_init(self, config=config, should_cast=check_types == 'cast', sampler=sampler, type_caster=type_caster)
228246
if orig_post_init is not None:
229247
orig_post_init(self)
230248

@@ -340,8 +358,9 @@ def dataclass(cls=None, *, check_types: Union[bool, str] = CHECK_TYPES,
340358
"""
341359
assert isinstance(config, Configuration)
342360

361+
context_frame = inspect.currentframe().f_back # Get parent frame, to resolve forward-references
343362
def wrap(cls):
344-
return _process_class(cls, config, check_types,
363+
return _process_class(cls, config, check_types, context_frame,
345364
init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, slots=slots)
346365

347366
# See if we're being called as @dataclass or @dataclass().

runtype/pytypes.py

Lines changed: 96 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import typing
1010
from datetime import datetime
1111

12+
from .utils import ForwardRef
1213
from .base_types import DataType, Validator, TypeMismatchError
1314
from . import base_types
1415
from . import datetime_parse
@@ -27,7 +28,7 @@ class PythonType(base_types.Type, Validator):
2728

2829
class Constraint(base_types.Constraint):
2930
def __init__(self, for_type, predicates):
30-
super().__init__(cast_to_type(for_type), predicates)
31+
super().__init__(type_caster.to_canon(for_type), predicates)
3132

3233
def cast_from(self, obj):
3334
obj = self.type.cast_from(obj)
@@ -194,7 +195,7 @@ def __init__(self, base, item=Any*Any):
194195
super().__init__(base)
195196
if isinstance(item, tuple):
196197
assert len(item) == 2
197-
item = ProductType([cast_to_type(x) for x in item])
198+
item = ProductType([type_caster.to_canon(x) for x in item])
198199
self.item = item
199200

200201
def validate_instance(self, obj, sampler=None):
@@ -329,85 +330,99 @@ def cast_from(self, obj):
329330
origin_frozenset = typing.FrozenSet
330331

331332

332-
def _cast_to_type(t):
333-
if isinstance(t, Validator):
334-
return t
333+
class TypeCaster:
334+
def __init__(self, frame=None):
335+
self.cache = {}
336+
self.frame = frame
335337

336-
if isinstance(t, tuple):
337-
return SumType([cast_to_type(x) for x in t])
338+
def _to_canon(self, t):
339+
to_canon = self.to_canon
340+
341+
if isinstance(t, (base_types.Type, Validator)):
342+
return t
343+
344+
if isinstance(t, ForwardRef):
345+
t = t._evaluate(self.frame.f_globals, self.frame.f_locals, set())
346+
347+
if isinstance(t, tuple):
348+
return SumType([to_canon(x) for x in t])
349+
350+
try:
351+
t.__origin__
352+
except AttributeError:
353+
pass
354+
else:
355+
if getattr(t, '__args__', None) is None:
356+
if t is typing.List:
357+
return List
358+
elif t is typing.Dict:
359+
return Dict
360+
elif t is typing.Set:
361+
return Set
362+
elif t is typing.FrozenSet:
363+
return FrozenSet
364+
elif t is typing.Tuple:
365+
return Tuple
366+
elif t is typing.Mapping: # 3.6
367+
return Mapping
368+
elif t is typing.Sequence:
369+
return Sequence
370+
371+
if t.__origin__ is origin_list:
372+
x ,= t.__args__
373+
return List[to_canon(x)]
374+
elif t.__origin__ is origin_set:
375+
x ,= t.__args__
376+
return Set[to_canon(x)]
377+
elif t.__origin__ is origin_frozenset:
378+
x ,= t.__args__
379+
return FrozenSet[to_canon(x)]
380+
elif t.__origin__ is origin_dict:
381+
k, v = t.__args__
382+
return Dict[to_canon(k), to_canon(v)]
383+
elif t.__origin__ is origin_tuple:
384+
if Ellipsis in t.__args__:
385+
if len(t.__args__) != 2 or t.__args__[0] == Ellipsis:
386+
raise ValueError("Tuple with '...'' expected to be of the exact form: tuple[t, ...].")
387+
return TupleEllipsis[to_canon(t.__args__[0])]
388+
389+
return ProductType([to_canon(x) for x in t.__args__])
390+
391+
elif t.__origin__ is typing.Union:
392+
return SumType([to_canon(x) for x in t.__args__])
393+
elif t.__origin__ is abc.Callable or t is typing.Callable:
394+
# return Callable[ProductType(to_canon(x) for x in t.__args__)]
395+
return Callable # TODO
396+
elif py38 and t.__origin__ is typing.Literal:
397+
return OneOf(t.__args__)
398+
elif t.__origin__ is abc.Mapping or t.__origin__ is typing.Mapping:
399+
k, v = t.__args__
400+
return Mapping[to_canon(k), to_canon(v)]
401+
elif t.__origin__ is abc.Sequence or t.__origin__ is typing.Sequence:
402+
x ,= t.__args__
403+
return Sequence[to_canon(x)]
404+
405+
elif t.__origin__ is type or t.__origin__ is typing.Type:
406+
# TODO test issubclass on t.__args__
407+
return PythonDataType(type)
408+
409+
raise NotImplementedError("No support for type:", t)
410+
411+
if isinstance(t, typing.TypeVar):
412+
return Any # XXX is this correct?
413+
414+
return PythonDataType(t)
415+
416+
def to_canon(self, t):
417+
try:
418+
return self.cache[t]
419+
except KeyError:
420+
try:
421+
res = _type_cast_mapping[t]
422+
except KeyError:
423+
res = self._to_canon(t)
424+
self.cache[t] = res # memoize
425+
return res
338426

339-
try:
340-
t.__origin__
341-
except AttributeError:
342-
pass
343-
else:
344-
if getattr(t, '__args__', None) is None:
345-
if t is typing.List:
346-
return List
347-
elif t is typing.Dict:
348-
return Dict
349-
elif t is typing.Set:
350-
return Set
351-
elif t is typing.FrozenSet:
352-
return FrozenSet
353-
elif t is typing.Tuple:
354-
return Tuple
355-
elif t is typing.Mapping: # 3.6
356-
return Mapping
357-
elif t is typing.Sequence:
358-
return Sequence
359-
360-
if t.__origin__ is origin_list:
361-
x ,= t.__args__
362-
return List[cast_to_type(x)]
363-
elif t.__origin__ is origin_set:
364-
x ,= t.__args__
365-
return Set[cast_to_type(x)]
366-
elif t.__origin__ is origin_frozenset:
367-
x ,= t.__args__
368-
return FrozenSet[cast_to_type(x)]
369-
elif t.__origin__ is origin_dict:
370-
k, v = t.__args__
371-
return Dict[cast_to_type(k), cast_to_type(v)]
372-
elif t.__origin__ is origin_tuple:
373-
if Ellipsis in t.__args__:
374-
if len(t.__args__) != 2 or t.__args__[0] == Ellipsis:
375-
raise ValueError("Tuple with '...'' expected to be of the exact form: tuple[t, ...].")
376-
return TupleEllipsis[cast_to_type(t.__args__[0])]
377-
378-
return ProductType([cast_to_type(x) for x in t.__args__])
379-
380-
elif t.__origin__ is typing.Union:
381-
return SumType([cast_to_type(x) for x in t.__args__])
382-
elif t.__origin__ is abc.Callable or t is typing.Callable:
383-
# return Callable[ProductType(cast_to_type(x) for x in t.__args__)]
384-
return Callable # TODO
385-
elif py38 and t.__origin__ is typing.Literal:
386-
return OneOf(t.__args__)
387-
elif t.__origin__ is abc.Mapping or t.__origin__ is typing.Mapping:
388-
k, v = t.__args__
389-
return Mapping[cast_to_type(k), cast_to_type(v)]
390-
elif t.__origin__ is abc.Sequence or t.__origin__ is typing.Sequence:
391-
x ,= t.__args__
392-
return Sequence[_cast_to_type(x)]
393-
394-
elif t.__origin__ is type or t.__origin__ is typing.Type:
395-
# TODO test issubclass on t.__args__
396-
return PythonDataType(type)
397-
398-
raise NotImplementedError("No support for type:", t)
399-
400-
if isinstance(t, typing.TypeVar):
401-
return Any # XXX is this correct?
402-
403-
return PythonDataType(t)
404-
405-
406-
def cast_to_type(t):
407-
try:
408-
return _type_cast_mapping[t]
409-
except KeyError:
410-
res = _cast_to_type(t)
411-
_type_cast_mapping[t] = res # memoize
412-
return res
413427

428+
type_caster = TypeCaster()

runtype/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
11
import inspect
2+
import sys
3+
4+
if sys.version_info < (3, 7):
5+
# python 3.6
6+
from typing import _ForwardRef as ForwardRef
7+
_orig_eval = ForwardRef._eval_type
8+
elif sys.version_info < (3, 9):
9+
from typing import ForwardRef
10+
_orig_eval = ForwardRef._evaluate
11+
else:
12+
from typing import ForwardRef
13+
14+
if sys.version_info < (3, 9):
15+
def _evaluate(self, g, l, _):
16+
return _orig_eval(self, g, l)
17+
ForwardRef._evaluate = _evaluate
18+
19+
220

321
def get_func_signatures(typesystem, f):
422
sig = inspect.signature(f)

0 commit comments

Comments
 (0)