66 Callable ,
77 Dict ,
88 Generic ,
9+ Generator ,
910 List ,
1011 Optional ,
1112 Sequence ,
1213 Tuple ,
1314 Type ,
1415 TypeVar ,
1516 Union ,
17+ get_args ,
18+ get_origin ,
19+ get_type_hints ,
20+ ForwardRef
1621)
22+ from types import GeneratorType
1723
1824import catalogue
1925from pydantic import BaseModel , Field , ValidationError , create_model
2834)
2935from ._errors import ConfigValidationError
3036from .util import is_promise
37+ from . import util
3138
3239_PromisedType = TypeVar ("_PromisedType" )
3340
@@ -65,17 +72,24 @@ def validate(self) -> Any:
6572 def resolve (self , validate : bool = True ) -> Any :
6673 if isinstance (self .getter , catalogue .RegistryError ):
6774 raise self .getter
75+ assert self .schema is not None
6876 kwargs = _recursive_resolve (self .kwargs , validate = validate )
77+ assert isinstance (kwargs , dict )
6978 args = _recursive_resolve (self .var_args , validate = validate )
7079 args = list (args .values ()) if isinstance (args , dict ) else args
7180 if validate :
7281 schema_args = dict (kwargs )
7382 if args :
7483 schema_args [ARGS_FIELD ] = args
84+ #schema_args = _replace_generators(schema_args)
7585 try :
76- _ = self .schema .model_validate (schema_args )
86+ kwargs = self .schema .model_validate (schema_args ). model_dump ( )
7787 except ValidationError as e :
7888 raise ConfigValidationError (config = kwargs , errors = e .errors ()) from None
89+ if args :
90+ # Do type coercion
91+ args = kwargs .pop (ARGS_FIELD )
92+ kwargs = {RESERVED_FIELDS_REVERSE .get (k , k ): v for k , v in kwargs .items ()}
7993 return self .getter (* args , ** kwargs ) # type: ignore
8094
8195 @classmethod
@@ -145,6 +159,7 @@ def resolve(
145159 overrides : Dict [str , Any ] = {},
146160 validate : bool = True ,
147161 ) -> Dict [str , Any ]:
162+ schema = fix_forward_refs (schema )
148163 config = cls .fill (
149164 config ,
150165 schema = schema ,
@@ -278,13 +293,18 @@ def _make_unresolved_schema(
278293 )
279294 elif isinstance (config [name ], dict ):
280295 fields [name ] = cls ._make_unresolved_schema (
281- _make_dummy_schema (config [name ]), config
296+ _make_dummy_schema (config [name ]), config [ name ]
282297 )
298+ elif isinstance (field .annotation , str ) or field .annotation == ForwardRef :
299+ fields [name ] = (Any , Field (field .default ))
283300 else :
284- fields [name ] = (field .annotation , Field (...))
285- return create_model (
286- "UnresolvedConfig" , __config__ = {"extra" : "forbid" }, ** fields
301+ fields [name ] = (Any , Field (field .default ))
302+
303+ model = create_model (
304+ f"{ schema .__name__ } _UnresolvedConfig" , __config__ = schema .model_config , ** fields
287305 )
306+ model .model_rebuild (raise_errors = True )
307+ return model
288308
289309 @classmethod
290310 def _make_unresolved_promise_schema (cls , obj : Dict [str , Any ]) -> Type [BaseModel ]:
@@ -363,7 +383,7 @@ def validate_resolved(config, schema: Type[BaseModel]):
363383 # If value is a generator we can't validate type without
364384 # consuming it (which doesn't work if it's infinite – see
365385 # schedule for examples). So we skip it.
366- config = dict (config )
386+ config = _replace_generators (config )
367387 try :
368388 _ = schema .model_validate (config )
369389 except ValidationError as e :
@@ -471,12 +491,22 @@ def fix_positionals(config):
471491 return config
472492
473493
494+ def fix_forward_refs (schema : Type [BaseModel ]) -> Type [BaseModel ]:
495+ fields = {}
496+ for name , field_info in schema .model_fields .items ():
497+ if isinstance (field_info .annotation , str ) or field_info .annotation == ForwardRef :
498+ fields [name ] = (Any , field_info )
499+ else :
500+ fields [name ] = (field_info .annotation , field_info )
501+ return create_model (schema .__name__ , __config__ = schema .model_config , ** fields )
502+
503+
474504def apply_overrides (
475505 config : Dict [str , Dict [str , Any ]],
476506 overrides : Dict [str , Dict [str , Any ]],
477507) -> Dict [str , Dict [str , Any ]]:
478508 """Build first representation of the config:"""
479- output = copy . deepcopy (config )
509+ output = _shallow_copy (config )
480510 for key , value in overrides .items ():
481511 path = key .split ("." )
482512 err_title = "Error parsing config overrides"
@@ -493,33 +523,73 @@ def apply_overrides(
493523 return output
494524
495525
526+ def _shallow_copy (obj ):
527+ """Ensure dict values in the config are new dicts, allowing assignment, without copying
528+ leaf objects.
529+ """
530+ if isinstance (obj , dict ):
531+ return {k : _shallow_copy (v ) for k , v in obj .items ()}
532+ elif isinstance (obj , list ):
533+ return [_shallow_copy (v ) for v in obj ]
534+ else :
535+ return obj
536+
537+
496538def make_func_schema (func ) -> Type [BaseModel ]:
497539 fields = get_func_fields (func )
498540 model_config = {
499541 "extra" : "forbid" ,
500542 "arbitrary_types_allowed" : True ,
501543 "alias_generator" : alias_generator ,
502544 }
503- return create_model ("ArgModel " , __config__ = model_config , ** fields ) # type: ignore
545+ return create_model (f" { func . __name__ } _ArgModel " , __config__ = model_config , ** fields ) # type: ignore
504546
505547
506548def get_func_fields (func ) -> Dict [str , Tuple [Type , FieldInfo ]]:
507549 # Read the argument annotations and defaults from the function signature
508550 sig_args = {}
509- for param in inspect .signature (func ).parameters .values ():
551+ for name , param in inspect .signature (func ).parameters .items ():
510552 # If no annotation is specified assume it's anything
511553 annotation = param .annotation if param .annotation != param .empty else Any
554+ annotation = _replace_forward_refs (annotation )
512555 # If no default value is specified assume that it's required
513556 default = param .default if param .default != param .empty else ...
514557 # Handle spread arguments and use their annotation as Sequence[whatever]
515558 if param .kind == param .VAR_POSITIONAL :
516559 spread_annot = Sequence [annotation ] # type: ignore
517- sig_args [ARGS_FIELD_ALIAS ] = (spread_annot , Field (default ))
560+ sig_args [ARGS_FIELD_ALIAS ] = (spread_annot , Field (default , ))
518561 else :
519562 name = RESERVED_FIELDS .get (param .name , param .name )
520563 sig_args [name ] = (annotation , Field (default ))
521564 return sig_args
522565
523566
524- def _is_model (type_ ):
525- return issubclass (type_ , BaseModel )
567+ def _replace_forward_refs (annot ):
568+ if isinstance (annot , str ) or annot == ForwardRef :
569+ return Any
570+ elif isinstance (annot , list ):
571+ return [_replace_forward_refs (x ) for x in annot ]
572+ args = get_args (annot )
573+ if not args :
574+ return annot
575+ else :
576+ origin = get_origin (annot )
577+ if origin == Literal :
578+ return annot
579+ args = [_replace_forward_refs (a ) for a in args ]
580+ return origin [* args ]
581+
582+
583+ def _replace_generators (data ):
584+ if isinstance (data , BaseModel ):
585+ return {k : _replace_generators (v ) for k , v in data .model_dump ().items ()}
586+ elif isinstance (data , dict ):
587+ return {k : _replace_generators (v ) for k , v in data .items ()}
588+ elif isinstance (data , GeneratorType ):
589+ return []
590+ elif isinstance (data , list ):
591+ return [_replace_generators (v ) for v in data ]
592+ elif isinstance (data , tuple ):
593+ return tuple ([_replace_generators (v ) for v in data ])
594+ else :
595+ return data
0 commit comments