Skip to content

Commit 1f9aba8

Browse files
committed
Revert "Revert #b50eee8 (explosion#67)"
This reverts commit 281a42a.
1 parent 6f9b1bc commit 1f9aba8

File tree

1 file changed

+82
-12
lines changed

1 file changed

+82
-12
lines changed

confection/_registry.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,20 @@
66
Callable,
77
Dict,
88
Generic,
9+
Generator,
910
List,
1011
Optional,
1112
Sequence,
1213
Tuple,
1314
Type,
1415
TypeVar,
1516
Union,
17+
get_args,
18+
get_origin,
19+
get_type_hints,
20+
ForwardRef
1621
)
22+
from types import GeneratorType
1723

1824
import catalogue
1925
from pydantic import BaseModel, Field, ValidationError, create_model
@@ -28,6 +34,7 @@
2834
)
2935
from ._errors import ConfigValidationError
3036
from .util import is_promise
37+
from . import util
3138

3239
_PromisedType = TypeVar("_PromisedType")
3340

@@ -65,17 +72,24 @@ def validate(self) -> Any:
6572
def resolve(self, validate: bool = True) -> Any:
6673
if isinstance(self.getter, catalogue.RegistryError):
6774
raise self.getter
75+
assert self.schema is not None
6876
kwargs = _recursive_resolve(self.kwargs, validate=validate)
77+
assert isinstance(kwargs, dict)
6978
args = _recursive_resolve(self.var_args, validate=validate)
7079
args = list(args.values()) if isinstance(args, dict) else args
7180
if validate:
7281
schema_args = dict(kwargs)
7382
if args:
7483
schema_args[ARGS_FIELD] = args
84+
#schema_args = _replace_generators(schema_args)
7585
try:
76-
_ = self.schema.model_validate(schema_args)
86+
kwargs = self.schema.model_validate(schema_args).model_dump()
7787
except ValidationError as e:
7888
raise ConfigValidationError(config=kwargs, errors=e.errors()) from None
89+
if args:
90+
# Do type coercion
91+
args = kwargs.pop(ARGS_FIELD)
92+
kwargs = {RESERVED_FIELDS_REVERSE.get(k, k): v for k, v in kwargs.items()}
7993
return self.getter(*args, **kwargs) # type: ignore
8094

8195
@classmethod
@@ -145,6 +159,7 @@ def resolve(
145159
overrides: Dict[str, Any] = {},
146160
validate: bool = True,
147161
) -> Dict[str, Any]:
162+
schema = fix_forward_refs(schema)
148163
config = cls.fill(
149164
config,
150165
schema=schema,
@@ -278,13 +293,18 @@ def _make_unresolved_schema(
278293
)
279294
elif isinstance(config[name], dict):
280295
fields[name] = cls._make_unresolved_schema(
281-
_make_dummy_schema(config[name]), config
296+
_make_dummy_schema(config[name]), config[name]
282297
)
298+
elif isinstance(field.annotation, str) or field.annotation == ForwardRef:
299+
fields[name] = (Any, Field(field.default))
283300
else:
284-
fields[name] = (field.annotation, Field(...))
285-
return create_model(
286-
"UnresolvedConfig", __config__={"extra": "forbid"}, **fields
301+
fields[name] = (Any, Field(field.default))
302+
303+
model = create_model(
304+
f"{schema.__name__}_UnresolvedConfig", __config__=schema.model_config, **fields
287305
)
306+
model.model_rebuild(raise_errors=True)
307+
return model
288308

289309
@classmethod
290310
def _make_unresolved_promise_schema(cls, obj: Dict[str, Any]) -> Type[BaseModel]:
@@ -363,7 +383,7 @@ def validate_resolved(config, schema: Type[BaseModel]):
363383
# If value is a generator we can't validate type without
364384
# consuming it (which doesn't work if it's infinite – see
365385
# schedule for examples). So we skip it.
366-
config = dict(config)
386+
config = _replace_generators(config)
367387
try:
368388
_ = schema.model_validate(config)
369389
except ValidationError as e:
@@ -471,12 +491,22 @@ def fix_positionals(config):
471491
return config
472492

473493

494+
def fix_forward_refs(schema: Type[BaseModel]) -> Type[BaseModel]:
495+
fields = {}
496+
for name, field_info in schema.model_fields.items():
497+
if isinstance(field_info.annotation, str) or field_info.annotation == ForwardRef:
498+
fields[name] = (Any, field_info)
499+
else:
500+
fields[name] = (field_info.annotation, field_info)
501+
return create_model(schema.__name__, __config__=schema.model_config, **fields)
502+
503+
474504
def apply_overrides(
475505
config: Dict[str, Dict[str, Any]],
476506
overrides: Dict[str, Dict[str, Any]],
477507
) -> Dict[str, Dict[str, Any]]:
478508
"""Build first representation of the config:"""
479-
output = copy.deepcopy(config)
509+
output = _shallow_copy(config)
480510
for key, value in overrides.items():
481511
path = key.split(".")
482512
err_title = "Error parsing config overrides"
@@ -493,33 +523,73 @@ def apply_overrides(
493523
return output
494524

495525

526+
def _shallow_copy(obj):
527+
"""Ensure dict values in the config are new dicts, allowing assignment, without copying
528+
leaf objects.
529+
"""
530+
if isinstance(obj, dict):
531+
return {k: _shallow_copy(v) for k, v in obj.items()}
532+
elif isinstance(obj, list):
533+
return [_shallow_copy(v) for v in obj]
534+
else:
535+
return obj
536+
537+
496538
def make_func_schema(func) -> Type[BaseModel]:
497539
fields = get_func_fields(func)
498540
model_config = {
499541
"extra": "forbid",
500542
"arbitrary_types_allowed": True,
501543
"alias_generator": alias_generator,
502544
}
503-
return create_model("ArgModel", __config__=model_config, **fields) # type: ignore
545+
return create_model(f"{func.__name__}_ArgModel", __config__=model_config, **fields) # type: ignore
504546

505547

506548
def get_func_fields(func) -> Dict[str, Tuple[Type, FieldInfo]]:
507549
# Read the argument annotations and defaults from the function signature
508550
sig_args = {}
509-
for param in inspect.signature(func).parameters.values():
551+
for name, param in inspect.signature(func).parameters.items():
510552
# If no annotation is specified assume it's anything
511553
annotation = param.annotation if param.annotation != param.empty else Any
554+
annotation = _replace_forward_refs(annotation)
512555
# If no default value is specified assume that it's required
513556
default = param.default if param.default != param.empty else ...
514557
# Handle spread arguments and use their annotation as Sequence[whatever]
515558
if param.kind == param.VAR_POSITIONAL:
516559
spread_annot = Sequence[annotation] # type: ignore
517-
sig_args[ARGS_FIELD_ALIAS] = (spread_annot, Field(default))
560+
sig_args[ARGS_FIELD_ALIAS] = (spread_annot, Field(default, ))
518561
else:
519562
name = RESERVED_FIELDS.get(param.name, param.name)
520563
sig_args[name] = (annotation, Field(default))
521564
return sig_args
522565

523566

524-
def _is_model(type_):
525-
return issubclass(type_, BaseModel)
567+
def _replace_forward_refs(annot):
568+
if isinstance(annot, str) or annot == ForwardRef:
569+
return Any
570+
elif isinstance(annot, list):
571+
return [_replace_forward_refs(x) for x in annot]
572+
args = get_args(annot)
573+
if not args:
574+
return annot
575+
else:
576+
origin = get_origin(annot)
577+
if origin == Literal:
578+
return annot
579+
args = [_replace_forward_refs(a) for a in args]
580+
return origin[*args]
581+
582+
583+
def _replace_generators(data):
584+
if isinstance(data, BaseModel):
585+
return {k: _replace_generators(v) for k, v in data.model_dump().items()}
586+
elif isinstance(data, dict):
587+
return {k: _replace_generators(v) for k, v in data.items()}
588+
elif isinstance(data, GeneratorType):
589+
return []
590+
elif isinstance(data, list):
591+
return [_replace_generators(v) for v in data]
592+
elif isinstance(data, tuple):
593+
return tuple([_replace_generators(v) for v in data])
594+
else:
595+
return data

0 commit comments

Comments
 (0)