77 Callable ,
88 Dict ,
99 Generic ,
10- Generator ,
1110 List ,
1211 Literal ,
1312 Optional ,
1615 Type ,
1716 TypeVar ,
1817 Union ,
19- get_args ,
20- get_origin ,
21- get_type_hints ,
22- ForwardRef
2318)
24- from types import GeneratorType
2519
2620import catalogue
2721from pydantic import BaseModel , ConfigDict , Field , ValidationError , create_model
3630)
3731from ._errors import ConfigValidationError
3832from .util import is_promise
39- from . import util
4033
4134_PromisedType = TypeVar ("_PromisedType" )
4235
@@ -74,24 +67,17 @@ def validate(self) -> Any:
7467 def resolve (self , validate : bool = True ) -> Any :
7568 if isinstance (self .getter , catalogue .RegistryError ):
7669 raise self .getter
77- assert self .schema is not None
7870 kwargs = _recursive_resolve (self .kwargs , validate = validate )
79- assert isinstance (kwargs , dict )
8071 args = _recursive_resolve (self .var_args , validate = validate )
8172 args = list (args .values ()) if isinstance (args , dict ) else args
8273 if validate :
8374 schema_args = dict (kwargs )
8475 if args :
8576 schema_args [ARGS_FIELD ] = args
86- #schema_args = _replace_generators(schema_args)
8777 try :
88- kwargs = self .schema .model_validate (schema_args ). model_dump ( )
78+ _ = self .schema .model_validate (schema_args )
8979 except ValidationError as e :
9080 raise ConfigValidationError (config = kwargs , errors = e .errors ()) from None
91- if args :
92- # Do type coercion
93- args = kwargs .pop (ARGS_FIELD )
94- kwargs = {RESERVED_FIELDS_REVERSE .get (k , k ): v for k , v in kwargs .items ()}
9581 return self .getter (* args , ** kwargs ) # type: ignore
9682
9783 @classmethod
@@ -161,7 +147,6 @@ def resolve(
161147 overrides : Dict [str , Any ] = {},
162148 validate : bool = True ,
163149 ) -> Dict [str , Any ]:
164- schema = fix_forward_refs (schema )
165150 config = cls .fill (
166151 config ,
167152 schema = schema ,
@@ -289,18 +274,13 @@ def _make_unresolved_schema(
289274 )
290275 elif isinstance (config [name ], dict ):
291276 fields [name ] = cls ._make_unresolved_schema (
292- _make_dummy_schema (config [name ]), config [ name ]
277+ _make_dummy_schema (config [name ]), config
293278 )
294- elif isinstance (field .annotation , str ) or field .annotation == ForwardRef :
295- fields [name ] = (Any , Field (field .default ))
296279 else :
297- fields [name ] = (Any , Field (field .default ))
298-
299- model = create_model (
300- f"{ schema .__name__ } _UnresolvedConfig" , __config__ = schema .model_config , ** fields
280+ fields [name ] = (field .annotation , Field (...))
281+ return create_model (
282+ "UnresolvedConfig" , __config__ = {"extra" : "forbid" }, ** fields
301283 )
302- model .model_rebuild (raise_errors = True )
303- return model
304284
305285 @classmethod
306286 def _make_unresolved_promise_schema (cls , obj : Dict [str , Any ]) -> Type [BaseModel ]:
@@ -377,7 +357,7 @@ def validate_resolved(config, schema: Type[BaseModel]):
377357 # If value is a generator we can't validate type without
378358 # consuming it (which doesn't work if it's infinite – see
379359 # schedule for examples). So we skip it.
380- config = _replace_generators (config )
360+ config = dict (config )
381361 try :
382362 _ = schema .model_validate (config )
383363 except ValidationError as e :
@@ -485,22 +465,12 @@ def fix_positionals(config):
485465 return config
486466
487467
488- def fix_forward_refs (schema : Type [BaseModel ]) -> Type [BaseModel ]:
489- fields = {}
490- for name , field_info in schema .model_fields .items ():
491- if isinstance (field_info .annotation , str ) or field_info .annotation == ForwardRef :
492- fields [name ] = (Any , field_info )
493- else :
494- fields [name ] = (field_info .annotation , field_info )
495- return create_model (schema .__name__ , __config__ = schema .model_config , ** fields )
496-
497-
498468def apply_overrides (
499469 config : Dict [str , Dict [str , Any ]],
500470 overrides : Dict [str , Dict [str , Any ]],
501471) -> Dict [str , Dict [str , Any ]]:
502472 """Build first representation of the config:"""
503- output = _shallow_copy (config )
473+ output = copy . deepcopy (config )
504474 for key , value in overrides .items ():
505475 path = key .split ("." )
506476 err_title = "Error parsing config overrides"
@@ -517,73 +487,33 @@ def apply_overrides(
517487 return output
518488
519489
520- def _shallow_copy (obj ):
521- """Ensure dict values in the config are new dicts, allowing assignment, without copying
522- leaf objects.
523- """
524- if isinstance (obj , dict ):
525- return {k : _shallow_copy (v ) for k , v in obj .items ()}
526- elif isinstance (obj , list ):
527- return [_shallow_copy (v ) for v in obj ]
528- else :
529- return obj
530-
531-
532490def make_func_schema (func ) -> Type [BaseModel ]:
533491 fields = get_func_fields (func )
534492 model_config = {
535493 "extra" : "forbid" ,
536494 "arbitrary_types_allowed" : True ,
537495 "alias_generator" : alias_generator ,
538496 }
539- return create_model (f" { func . __name__ } _ArgModel " , __config__ = model_config , ** fields ) # type: ignore
497+ return create_model ("ArgModel " , __config__ = model_config , ** fields ) # type: ignore
540498
541499
542500def get_func_fields (func ) -> Dict [str , Tuple [Type , FieldInfo ]]:
543501 # Read the argument annotations and defaults from the function signature
544502 sig_args = {}
545- for name , param in inspect .signature (func ).parameters .items ():
503+ for param in inspect .signature (func ).parameters .values ():
546504 # If no annotation is specified assume it's anything
547505 annotation = param .annotation if param .annotation != param .empty else Any
548- annotation = _replace_forward_refs (annotation )
549506 # If no default value is specified assume that it's required
550507 default = param .default if param .default != param .empty else ...
551508 # Handle spread arguments and use their annotation as Sequence[whatever]
552509 if param .kind == param .VAR_POSITIONAL :
553510 spread_annot = Sequence [annotation ] # type: ignore
554- sig_args [ARGS_FIELD_ALIAS ] = (spread_annot , Field (default , ))
511+ sig_args [ARGS_FIELD_ALIAS ] = (spread_annot , Field (default ))
555512 else :
556513 name = RESERVED_FIELDS .get (param .name , param .name )
557514 sig_args [name ] = (annotation , Field (default ))
558515 return sig_args
559516
560517
561- def _replace_forward_refs (annot ):
562- if isinstance (annot , str ) or annot == ForwardRef :
563- return Any
564- elif isinstance (annot , list ):
565- return [_replace_forward_refs (x ) for x in annot ]
566- args = get_args (annot )
567- if not args :
568- return annot
569- else :
570- origin = get_origin (annot )
571- if origin == Literal :
572- return annot
573- args = [_replace_forward_refs (a ) for a in args ]
574- return origin [* args ]
575-
576-
577- def _replace_generators (data ):
578- if isinstance (data , BaseModel ):
579- return {k : _replace_generators (v ) for k , v in data .model_dump ().items ()}
580- elif isinstance (data , dict ):
581- return {k : _replace_generators (v ) for k , v in data .items ()}
582- elif isinstance (data , GeneratorType ):
583- return []
584- elif isinstance (data , list ):
585- return [_replace_generators (v ) for v in data ]
586- elif isinstance (data , tuple ):
587- return tuple ([_replace_generators (v ) for v in data ])
588- else :
589- return data
518+ def _is_model (type_ ):
519+ return issubclass (type_ , BaseModel )
0 commit comments