Skip to content

Commit 2f63e27

Browse files
committed
Added support for forward-references in dataclasses
1 parent a98a5a8 commit 2f63e27

File tree

5 files changed

+167
-113
lines changed

5 files changed

+167
-113
lines changed

runtype/dataclass.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@
55
import random
66
from copy import copy
77
import dataclasses
8-
from typing import Union
8+
from typing import Union, ForwardRef
99
from abc import ABC, abstractmethod
10+
import inspect
1011

1112
from .common import CHECK_TYPES
1213
from .validation import TypeMismatchError, ensure_isa as default_ensure_isa
13-
from .pytypes import cast_to_type, SumType, NoneType
14+
from .pytypes import TypeCaster, type_caster, SumType, NoneType
1415

1516
Required = object()
1617
MAX_SAMPLE_SIZE = 16
1718

19+
class NopTypeCaster:
20+
cache = {}
21+
def to_canon(self, t):
22+
return t
23+
1824
class Configuration(ABC):
1925
"""Generic configuration template for dataclass. Mainly for type-checking.
2026
@@ -48,15 +54,15 @@ class Form:
4854
4955
"""
5056

51-
def canonize_type(self, t):
52-
"""Given a type, return its canonical form.
57+
def on_default(self, default):
58+
"""Called whenever a dataclass member is assigned a default value.
5359
"""
54-
return t
60+
return default
5561

56-
def on_default(self, t, default):
57-
"""Called whenever a dataclass member is assigned a default value.
62+
def make_type_caster(self, frame):
63+
"""Return a type caster, as defined in pytypes.TypeCaster
5864
"""
59-
return t, default
65+
return NopTypeCaster()
6066

6167
@abstractmethod
6268
def ensure_isa(self, a, b, sampler=None):
@@ -79,40 +85,48 @@ class PythonConfiguration(Configuration):
7985
8086
This is the default class given to the ``dataclass()`` function.
8187
"""
82-
canonize_type = staticmethod(cast_to_type)
88+
make_type_caster = TypeCaster
8389
ensure_isa = staticmethod(default_ensure_isa)
8490

8591
def cast(self, obj, to_type):
8692
return to_type.cast_from(obj)
8793

88-
def on_default(self, type_, default):
89-
if default is None:
90-
type_ = SumType([type_, NoneType])
91-
elif isinstance(default, (list, dict, set)):
94+
def on_default(self, default):
95+
if isinstance(default, (list, dict, set)):
9296
def f(_=default):
9397
return copy(_)
94-
default = dataclasses.field(default_factory=f)
95-
return type_, default
96-
98+
return dataclasses.field(default_factory=f)
99+
return default
97100

98101

99-
def _post_init(self, config, should_cast, sampler):
102+
def _post_init(self, config, should_cast, sampler, type_caster):
100103
for name, field in getattr(self, '__dataclass_fields__', {}).items():
101104
value = getattr(self, name)
102105

103106
if value is Required:
104107
raise TypeError(f"Field {name} requires a value")
105108

109+
try:
110+
type_ = type_caster.cache[id(field)]
111+
except KeyError:
112+
type_ = field.type
113+
if isinstance(type_, str):
114+
type_ = ForwardRef(type_)
115+
type_ = type_caster.to_canon(type_)
116+
if field.default is None:
117+
type_ = SumType([type_, NoneType])
118+
type_caster.cache[id(field)] = type_
119+
106120
try:
107121
if should_cast: # Basic cast
108122
assert not sampler
109-
value = config.cast(value, field.type)
123+
value = config.cast(value, type_)
110124
object.__setattr__(self, name, value)
111125
else:
112-
config.ensure_isa(value, field.type, sampler)
126+
config.ensure_isa(value, type_, sampler)
113127
except TypeMismatchError as e:
114128
item_value, item_type = e.args
115-
msg = f"[{type(self).__name__}] Attribute '{name}' expected value of type {field.type}."
129+
msg = f"[{type(self).__name__}] Attribute '{name}' expected value of type '{type_}'."
116130
msg += f" Instead got {value!r}"
117131
if item_value is not value:
118132
msg += f'\n\n Failed on item: {item_value!r}, expected type {item_type}'
@@ -197,9 +211,9 @@ def _sample(seq, max_sample_size=MAX_SAMPLE_SIZE):
197211
return seq
198212
return random.sample(seq, max_sample_size)
199213

200-
def _process_class(cls, config, check_types, **kw):
214+
def _process_class(cls, config, check_types, context_frame, **kw):
201215
for name, type_ in getattr(cls, '__annotations__', {}).items():
202-
type_ = config.canonize_type(type_)
216+
# type_ = config.type_to_canon(type_) if not isinstance(type_, str) else type_
203217

204218
# If default not specified, assign Required, for a later check
205219
# We don't assign MISSING; we want to bypass dataclass which is too strict for this
@@ -211,7 +225,7 @@ def _process_class(cls, config, check_types, **kw):
211225
if default.default is dataclasses.MISSING and default.default_factory is dataclasses.MISSING:
212226
default.default = Required
213227

214-
type_, new_default = config.on_default(type_, default)
228+
new_default = config.on_default(default)
215229
if new_default is not default:
216230
setattr(cls, name, new_default)
217231

@@ -222,9 +236,12 @@ def _process_class(cls, config, check_types, **kw):
222236

223237
orig_post_init = getattr(cls, '__post_init__', None)
224238
sampler = _sample if check_types=='sample' else None
239+
# eval_type_string = EvalInContext(context_frame)
240+
type_caster = config.make_type_caster(context_frame)
225241

226242
def __post_init__(self):
227-
_post_init(self, config=config, should_cast=check_types == 'cast', sampler=sampler)
243+
# Only now context_frame has complete information
244+
_post_init(self, config=config, should_cast=check_types == 'cast', sampler=sampler, type_caster=type_caster)
228245
if orig_post_init is not None:
229246
orig_post_init(self)
230247

@@ -340,8 +357,9 @@ def dataclass(cls=None, *, check_types: Union[bool, str] = CHECK_TYPES,
340357
"""
341358
assert isinstance(config, Configuration)
342359

360+
context_frame = inspect.currentframe().f_back # Get parent frame, to resolve forward-references
343361
def wrap(cls):
344-
return _process_class(cls, config, check_types,
362+
return _process_class(cls, config, check_types, context_frame,
345363
init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, slots=slots)
346364

347365
# See if we're being called as @dataclass or @dataclass().

runtype/pytypes.py

Lines changed: 95 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class PythonType(base_types.Type, Validator):
2727

2828
class Constraint(base_types.Constraint):
2929
def __init__(self, for_type, predicates):
30-
super().__init__(cast_to_type(for_type), predicates)
30+
super().__init__(type_caster.to_canon(for_type), predicates)
3131

3232
def cast_from(self, obj):
3333
obj = self.type.cast_from(obj)
@@ -194,7 +194,7 @@ def __init__(self, base, item=Any*Any):
194194
super().__init__(base)
195195
if isinstance(item, tuple):
196196
assert len(item) == 2
197-
item = ProductType([cast_to_type(x) for x in item])
197+
item = ProductType([type_caster.to_canon(x) for x in item])
198198
self.item = item
199199

200200
def validate_instance(self, obj, sampler=None):
@@ -329,85 +329,99 @@ def cast_from(self, obj):
329329
origin_frozenset = typing.FrozenSet
330330

331331

332-
def _cast_to_type(t):
333-
if isinstance(t, Validator):
334-
return t
332+
class TypeCaster:
333+
def __init__(self, frame=None):
334+
self.cache = {}
335+
self.frame = frame
335336

336-
if isinstance(t, tuple):
337-
return SumType([cast_to_type(x) for x in t])
337+
def _to_canon(self, t):
338+
to_canon = self.to_canon
339+
340+
if isinstance(t, (base_types.Type, Validator)):
341+
return t
342+
343+
if isinstance(t, typing.ForwardRef):
344+
t = t._evaluate(self.frame.f_globals, self.frame.f_locals, set())
345+
346+
if isinstance(t, tuple):
347+
return SumType([to_canon(x) for x in t])
348+
349+
try:
350+
t.__origin__
351+
except AttributeError:
352+
pass
353+
else:
354+
if getattr(t, '__args__', None) is None:
355+
if t is typing.List:
356+
return List
357+
elif t is typing.Dict:
358+
return Dict
359+
elif t is typing.Set:
360+
return Set
361+
elif t is typing.FrozenSet:
362+
return FrozenSet
363+
elif t is typing.Tuple:
364+
return Tuple
365+
elif t is typing.Mapping: # 3.6
366+
return Mapping
367+
elif t is typing.Sequence:
368+
return Sequence
369+
370+
if t.__origin__ is origin_list:
371+
x ,= t.__args__
372+
return List[to_canon(x)]
373+
elif t.__origin__ is origin_set:
374+
x ,= t.__args__
375+
return Set[to_canon(x)]
376+
elif t.__origin__ is origin_frozenset:
377+
x ,= t.__args__
378+
return FrozenSet[to_canon(x)]
379+
elif t.__origin__ is origin_dict:
380+
k, v = t.__args__
381+
return Dict[to_canon(k), to_canon(v)]
382+
elif t.__origin__ is origin_tuple:
383+
if Ellipsis in t.__args__:
384+
if len(t.__args__) != 2 or t.__args__[0] == Ellipsis:
385+
raise ValueError("Tuple with '...'' expected to be of the exact form: tuple[t, ...].")
386+
return TupleEllipsis[to_canon(t.__args__[0])]
387+
388+
return ProductType([to_canon(x) for x in t.__args__])
389+
390+
elif t.__origin__ is typing.Union:
391+
return SumType([to_canon(x) for x in t.__args__])
392+
elif t.__origin__ is abc.Callable or t is typing.Callable:
393+
# return Callable[ProductType(to_canon(x) for x in t.__args__)]
394+
return Callable # TODO
395+
elif py38 and t.__origin__ is typing.Literal:
396+
return OneOf(t.__args__)
397+
elif t.__origin__ is abc.Mapping or t.__origin__ is typing.Mapping:
398+
k, v = t.__args__
399+
return Mapping[to_canon(k), to_canon(v)]
400+
elif t.__origin__ is abc.Sequence or t.__origin__ is typing.Sequence:
401+
x ,= t.__args__
402+
return Sequence[to_canon(x)]
403+
404+
elif t.__origin__ is type or t.__origin__ is typing.Type:
405+
# TODO test issubclass on t.__args__
406+
return PythonDataType(type)
407+
408+
raise NotImplementedError("No support for type:", t)
409+
410+
if isinstance(t, typing.TypeVar):
411+
return Any # XXX is this correct?
412+
413+
return PythonDataType(t)
414+
415+
def to_canon(self, t):
416+
try:
417+
return self.cache[t]
418+
except KeyError:
419+
try:
420+
res = _type_cast_mapping[t]
421+
except KeyError:
422+
res = self._to_canon(t)
423+
self.cache[t] = res # memoize
424+
return res
338425

339-
try:
340-
t.__origin__
341-
except AttributeError:
342-
pass
343-
else:
344-
if getattr(t, '__args__', None) is None:
345-
if t is typing.List:
346-
return List
347-
elif t is typing.Dict:
348-
return Dict
349-
elif t is typing.Set:
350-
return Set
351-
elif t is typing.FrozenSet:
352-
return FrozenSet
353-
elif t is typing.Tuple:
354-
return Tuple
355-
elif t is typing.Mapping: # 3.6
356-
return Mapping
357-
elif t is typing.Sequence:
358-
return Sequence
359-
360-
if t.__origin__ is origin_list:
361-
x ,= t.__args__
362-
return List[cast_to_type(x)]
363-
elif t.__origin__ is origin_set:
364-
x ,= t.__args__
365-
return Set[cast_to_type(x)]
366-
elif t.__origin__ is origin_frozenset:
367-
x ,= t.__args__
368-
return FrozenSet[cast_to_type(x)]
369-
elif t.__origin__ is origin_dict:
370-
k, v = t.__args__
371-
return Dict[cast_to_type(k), cast_to_type(v)]
372-
elif t.__origin__ is origin_tuple:
373-
if Ellipsis in t.__args__:
374-
if len(t.__args__) != 2 or t.__args__[0] == Ellipsis:
375-
raise ValueError("Tuple with '...'' expected to be of the exact form: tuple[t, ...].")
376-
return TupleEllipsis[cast_to_type(t.__args__[0])]
377-
378-
return ProductType([cast_to_type(x) for x in t.__args__])
379-
380-
elif t.__origin__ is typing.Union:
381-
return SumType([cast_to_type(x) for x in t.__args__])
382-
elif t.__origin__ is abc.Callable or t is typing.Callable:
383-
# return Callable[ProductType(cast_to_type(x) for x in t.__args__)]
384-
return Callable # TODO
385-
elif py38 and t.__origin__ is typing.Literal:
386-
return OneOf(t.__args__)
387-
elif t.__origin__ is abc.Mapping or t.__origin__ is typing.Mapping:
388-
k, v = t.__args__
389-
return Mapping[cast_to_type(k), cast_to_type(v)]
390-
elif t.__origin__ is abc.Sequence or t.__origin__ is typing.Sequence:
391-
x ,= t.__args__
392-
return Sequence[_cast_to_type(x)]
393-
394-
elif t.__origin__ is type or t.__origin__ is typing.Type:
395-
# TODO test issubclass on t.__args__
396-
return PythonDataType(type)
397-
398-
raise NotImplementedError("No support for type:", t)
399-
400-
if isinstance(t, typing.TypeVar):
401-
return Any # XXX is this correct?
402-
403-
return PythonDataType(t)
404-
405-
406-
def cast_to_type(t):
407-
try:
408-
return _type_cast_mapping[t]
409-
except KeyError:
410-
res = _cast_to_type(t)
411-
_type_cast_mapping[t] = res # memoize
412-
return res
413426

427+
type_caster = TypeCaster()

runtype/validation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@
55

66
from .common import CHECK_TYPES
77
from .utils import get_func_signatures
8-
from .pytypes import TypeMismatchError, cast_to_type
8+
from .pytypes import TypeMismatchError, type_caster
99
from .typesystem import TypeSystem
1010

1111

1212
def ensure_isa(obj, t, sampler=None):
1313
"""Ensure 'obj' is of type 't'. Otherwise, throws a TypeError
1414
"""
15-
t = cast_to_type(t)
15+
t = type_caster.to_canon(t)
1616
t.validate_instance(obj, sampler)
1717

1818

1919
def is_subtype(t1, t2):
2020
"""Test if t1 is a subtype of t2
2121
"""
2222

23-
t1 = cast_to_type(t1)
24-
t2 = cast_to_type(t2)
23+
t1 = type_caster.to_canon(t1)
24+
t2 = type_caster.to_canon(t2)
2525
return t1 <= t2
2626

2727

0 commit comments

Comments
 (0)