77import dataclasses
88from typing import Union
99from abc import ABC , abstractmethod
10+ import inspect
1011
12+ from .utils import ForwardRef
1113from .common import CHECK_TYPES
1214from .validation import TypeMismatchError , ensure_isa as default_ensure_isa
13- from .pytypes import cast_to_type , SumType , NoneType
15+ from .pytypes import TypeCaster , type_caster , SumType , NoneType
1416
1517Required = object ()
1618MAX_SAMPLE_SIZE = 16
1719
20+ class NopTypeCaster :
21+ cache = {}
22+ def to_canon (self , t ):
23+ return t
24+
1825class Configuration (ABC ):
1926 """Generic configuration template for dataclass. Mainly for type-checking.
2027
@@ -48,15 +55,15 @@ class Form:
4855
4956 """
5057
51- def canonize_type (self , t ):
52- """Given a type, return its canonical form .
58+ def on_default (self , default ):
59+ """Called whenever a dataclass member is assigned a default value .
5360 """
54- return t
61+ return default
5562
56- def on_default (self , t , default ):
57- """Called whenever a dataclass member is assigned a default value.
63+ def make_type_caster (self , frame ):
64+ """Return a type caster, as defined in pytypes.TypeCaster
5865 """
59- return t , default
66+ return NopTypeCaster ()
6067
6168 @abstractmethod
6269 def ensure_isa (self , a , b , sampler = None ):
@@ -79,40 +86,48 @@ class PythonConfiguration(Configuration):
7986
8087 This is the default class given to the ``dataclass()`` function.
8188 """
82- canonize_type = staticmethod ( cast_to_type )
89+ make_type_caster = TypeCaster
8390 ensure_isa = staticmethod (default_ensure_isa )
8491
8592 def cast (self , obj , to_type ):
8693 return to_type .cast_from (obj )
8794
88- def on_default (self , type_ , default ):
89- if default is None :
90- type_ = SumType ([type_ , NoneType ])
91- elif isinstance (default , (list , dict , set )):
95+ def on_default (self , default ):
96+ if isinstance (default , (list , dict , set )):
9297 def f (_ = default ):
9398 return copy (_ )
94- default = dataclasses .field (default_factory = f )
95- return type_ , default
96-
99+ return dataclasses .field (default_factory = f )
100+ return default
97101
98102
99- def _post_init (self , config , should_cast , sampler ):
103+ def _post_init (self , config , should_cast , sampler , type_caster ):
100104 for name , field in getattr (self , '__dataclass_fields__' , {}).items ():
101105 value = getattr (self , name )
102106
103107 if value is Required :
104108 raise TypeError (f"Field { name } requires a value" )
105109
110+ try :
111+ type_ = type_caster .cache [id (field )]
112+ except KeyError :
113+ type_ = field .type
114+ if isinstance (type_ , str ):
115+ type_ = ForwardRef (type_ )
116+ type_ = type_caster .to_canon (type_ )
117+ if field .default is None :
118+ type_ = SumType ([type_ , NoneType ])
119+ type_caster .cache [id (field )] = type_
120+
106121 try :
107122 if should_cast : # Basic cast
108123 assert not sampler
109- value = config .cast (value , field . type )
124+ value = config .cast (value , type_ )
110125 object .__setattr__ (self , name , value )
111126 else :
112- config .ensure_isa (value , field . type , sampler )
127+ config .ensure_isa (value , type_ , sampler )
113128 except TypeMismatchError as e :
114129 item_value , item_type = e .args
115- msg = f"[{ type (self ).__name__ } ] Attribute '{ name } ' expected value of type { field . type } ."
130+ msg = f"[{ type (self ).__name__ } ] Attribute '{ name } ' expected value of type ' { type_ } ' ."
116131 msg += f" Instead got { value !r} "
117132 if item_value is not value :
118133 msg += f'\n \n Failed on item: { item_value !r} , expected type { item_type } '
@@ -197,9 +212,9 @@ def _sample(seq, max_sample_size=MAX_SAMPLE_SIZE):
197212 return seq
198213 return random .sample (seq , max_sample_size )
199214
200- def _process_class (cls , config , check_types , ** kw ):
215+ def _process_class (cls , config , check_types , context_frame , ** kw ):
201216 for name , type_ in getattr (cls , '__annotations__' , {}).items ():
202- type_ = config .canonize_type (type_ )
217+ # type_ = config.type_to_canon (type_) if not isinstance(type_, str) else type_
203218
204219 # If default not specified, assign Required, for a later check
205220 # We don't assign MISSING; we want to bypass dataclass which is too strict for this
@@ -211,7 +226,7 @@ def _process_class(cls, config, check_types, **kw):
211226 if default .default is dataclasses .MISSING and default .default_factory is dataclasses .MISSING :
212227 default .default = Required
213228
214- type_ , new_default = config .on_default (type_ , default )
229+ new_default = config .on_default (default )
215230 if new_default is not default :
216231 setattr (cls , name , new_default )
217232
@@ -222,9 +237,12 @@ def _process_class(cls, config, check_types, **kw):
222237
223238 orig_post_init = getattr (cls , '__post_init__' , None )
224239 sampler = _sample if check_types == 'sample' else None
240+ # eval_type_string = EvalInContext(context_frame)
241+ type_caster = config .make_type_caster (context_frame )
225242
226243 def __post_init__ (self ):
227- _post_init (self , config = config , should_cast = check_types == 'cast' , sampler = sampler )
244+ # Only now context_frame has complete information
245+ _post_init (self , config = config , should_cast = check_types == 'cast' , sampler = sampler , type_caster = type_caster )
228246 if orig_post_init is not None :
229247 orig_post_init (self )
230248
@@ -340,8 +358,9 @@ def dataclass(cls=None, *, check_types: Union[bool, str] = CHECK_TYPES,
340358 """
341359 assert isinstance (config , Configuration )
342360
361+ context_frame = inspect .currentframe ().f_back # Get parent frame, to resolve forward-references
343362 def wrap (cls ):
344- return _process_class (cls , config , check_types ,
363+ return _process_class (cls , config , check_types , context_frame ,
345364 init = init , repr = repr , eq = eq , order = order , unsafe_hash = unsafe_hash , frozen = frozen , slots = slots )
346365
347366 # See if we're being called as @dataclass or @dataclass().
0 commit comments