55import random
66from copy import copy
77import dataclasses
8- from typing import Union
8+ from typing import Union , ForwardRef
99from abc import ABC , abstractmethod
10+ import inspect
1011
1112from .common import CHECK_TYPES
1213from .validation import TypeMismatchError , ensure_isa as default_ensure_isa
13- from .pytypes import cast_to_type , SumType , NoneType
14+ from .pytypes import TypeCaster , type_caster , SumType , NoneType
1415
1516Required = object ()
1617MAX_SAMPLE_SIZE = 16
1718
19+ class NopTypeCaster :
20+ cache = {}
21+ def to_canon (self , t ):
22+ return t
23+
1824class Configuration (ABC ):
1925 """Generic configuration template for dataclass. Mainly for type-checking.
2026
@@ -48,15 +54,15 @@ class Form:
4854
4955 """
5056
51- def canonize_type (self , t ):
52- """Given a type, return its canonical form .
57+ def on_default (self , default ):
58+ """Called whenever a dataclass member is assigned a default value .
5359 """
54- return t
60+ return default
5561
56- def on_default (self , t , default ):
57- """Called whenever a dataclass member is assigned a default value.
62+ def make_type_caster (self , frame ):
63+ """Return a type caster, as defined in pytypes.TypeCaster
5864 """
59- return t , default
65+ return NopTypeCaster ()
6066
6167 @abstractmethod
6268 def ensure_isa (self , a , b , sampler = None ):
@@ -79,40 +85,48 @@ class PythonConfiguration(Configuration):
7985
8086 This is the default class given to the ``dataclass()`` function.
8187 """
82- canonize_type = staticmethod ( cast_to_type )
88+ make_type_caster = TypeCaster
8389 ensure_isa = staticmethod (default_ensure_isa )
8490
8591 def cast (self , obj , to_type ):
8692 return to_type .cast_from (obj )
8793
88- def on_default (self , type_ , default ):
89- if default is None :
90- type_ = SumType ([type_ , NoneType ])
91- elif isinstance (default , (list , dict , set )):
94+ def on_default (self , default ):
95+ if isinstance (default , (list , dict , set )):
9296 def f (_ = default ):
9397 return copy (_ )
94- default = dataclasses .field (default_factory = f )
95- return type_ , default
96-
98+ return dataclasses .field (default_factory = f )
99+ return default
97100
98101
99- def _post_init (self , config , should_cast , sampler ):
102+ def _post_init (self , config , should_cast , sampler , type_caster ):
100103 for name , field in getattr (self , '__dataclass_fields__' , {}).items ():
101104 value = getattr (self , name )
102105
103106 if value is Required :
104107 raise TypeError (f"Field { name } requires a value" )
105108
109+ try :
110+ type_ = type_caster .cache [id (field )]
111+ except KeyError :
112+ type_ = field .type
113+ if isinstance (type_ , str ):
114+ type_ = ForwardRef (type_ )
115+ type_ = type_caster .to_canon (type_ )
116+ if field .default is None :
117+ type_ = SumType ([type_ , NoneType ])
118+ type_caster .cache [id (field )] = type_
119+
106120 try :
107121 if should_cast : # Basic cast
108122 assert not sampler
109- value = config .cast (value , field . type )
123+ value = config .cast (value , type_ )
110124 object .__setattr__ (self , name , value )
111125 else :
112- config .ensure_isa (value , field . type , sampler )
126+ config .ensure_isa (value , type_ , sampler )
113127 except TypeMismatchError as e :
114128 item_value , item_type = e .args
115- msg = f"[{ type (self ).__name__ } ] Attribute '{ name } ' expected value of type { field . type } ."
129+ msg = f"[{ type (self ).__name__ } ] Attribute '{ name } ' expected value of type ' { type_ } ' ."
116130 msg += f" Instead got { value !r} "
117131 if item_value is not value :
118132 msg += f'\n \n Failed on item: { item_value !r} , expected type { item_type } '
@@ -197,9 +211,9 @@ def _sample(seq, max_sample_size=MAX_SAMPLE_SIZE):
197211 return seq
198212 return random .sample (seq , max_sample_size )
199213
200- def _process_class (cls , config , check_types , ** kw ):
214+ def _process_class (cls , config , check_types , context_frame , ** kw ):
201215 for name , type_ in getattr (cls , '__annotations__' , {}).items ():
202- type_ = config .canonize_type (type_ )
216+ # type_ = config.type_to_canon (type_) if not isinstance(type_, str) else type_
203217
204218 # If default not specified, assign Required, for a later check
205219 # We don't assign MISSING; we want to bypass dataclass which is too strict for this
@@ -211,7 +225,7 @@ def _process_class(cls, config, check_types, **kw):
211225 if default .default is dataclasses .MISSING and default .default_factory is dataclasses .MISSING :
212226 default .default = Required
213227
214- type_ , new_default = config .on_default (type_ , default )
228+ new_default = config .on_default (default )
215229 if new_default is not default :
216230 setattr (cls , name , new_default )
217231
@@ -222,9 +236,12 @@ def _process_class(cls, config, check_types, **kw):
222236
223237 orig_post_init = getattr (cls , '__post_init__' , None )
224238 sampler = _sample if check_types == 'sample' else None
239+ # eval_type_string = EvalInContext(context_frame)
240+ type_caster = config .make_type_caster (context_frame )
225241
226242 def __post_init__ (self ):
227- _post_init (self , config = config , should_cast = check_types == 'cast' , sampler = sampler )
243+ # Only now context_frame has complete information
244+ _post_init (self , config = config , should_cast = check_types == 'cast' , sampler = sampler , type_caster = type_caster )
228245 if orig_post_init is not None :
229246 orig_post_init (self )
230247
@@ -340,8 +357,9 @@ def dataclass(cls=None, *, check_types: Union[bool, str] = CHECK_TYPES,
340357 """
341358 assert isinstance (config , Configuration )
342359
360+ context_frame = inspect .currentframe ().f_back # Get parent frame, to resolve forward-references
343361 def wrap (cls ):
344- return _process_class (cls , config , check_types ,
362+ return _process_class (cls , config , check_types , context_frame ,
345363 init = init , repr = repr , eq = eq , order = order , unsafe_hash = unsafe_hash , frozen = frozen , slots = slots )
346364
347365 # See if we're being called as @dataclass or @dataclass().
0 commit comments