Skip to content

Commit 281a42a

Browse files
authored
Revert #b50eee8 (#67)
1 parent 1ea74f4 commit 281a42a

File tree

2 files changed

+13
-83
lines changed

2 files changed

+13
-83
lines changed

confection/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
try_load_json,
1818
)
1919
from ._errors import ConfigValidationError
20-
from ._registry import registry, Promise
20+
from ._registry import Promise, registry
2121
from .util import SimpleFrozenDict, SimpleFrozenList # noqa: F401
2222

2323

confection/_registry.py

Lines changed: 12 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Callable,
88
Dict,
99
Generic,
10-
Generator,
1110
List,
1211
Literal,
1312
Optional,
@@ -16,12 +15,7 @@
1615
Type,
1716
TypeVar,
1817
Union,
19-
get_args,
20-
get_origin,
21-
get_type_hints,
22-
ForwardRef
2318
)
24-
from types import GeneratorType
2519

2620
import catalogue
2721
from pydantic import BaseModel, ConfigDict, Field, ValidationError, create_model
@@ -36,7 +30,6 @@
3630
)
3731
from ._errors import ConfigValidationError
3832
from .util import is_promise
39-
from . import util
4033

4134
_PromisedType = TypeVar("_PromisedType")
4235

@@ -74,24 +67,17 @@ def validate(self) -> Any:
7467
def resolve(self, validate: bool = True) -> Any:
7568
if isinstance(self.getter, catalogue.RegistryError):
7669
raise self.getter
77-
assert self.schema is not None
7870
kwargs = _recursive_resolve(self.kwargs, validate=validate)
79-
assert isinstance(kwargs, dict)
8071
args = _recursive_resolve(self.var_args, validate=validate)
8172
args = list(args.values()) if isinstance(args, dict) else args
8273
if validate:
8374
schema_args = dict(kwargs)
8475
if args:
8576
schema_args[ARGS_FIELD] = args
86-
#schema_args = _replace_generators(schema_args)
8777
try:
88-
kwargs = self.schema.model_validate(schema_args).model_dump()
78+
_ = self.schema.model_validate(schema_args)
8979
except ValidationError as e:
9080
raise ConfigValidationError(config=kwargs, errors=e.errors()) from None
91-
if args:
92-
# Do type coercion
93-
args = kwargs.pop(ARGS_FIELD)
94-
kwargs = {RESERVED_FIELDS_REVERSE.get(k, k): v for k, v in kwargs.items()}
9581
return self.getter(*args, **kwargs) # type: ignore
9682

9783
@classmethod
@@ -161,7 +147,6 @@ def resolve(
161147
overrides: Dict[str, Any] = {},
162148
validate: bool = True,
163149
) -> Dict[str, Any]:
164-
schema = fix_forward_refs(schema)
165150
config = cls.fill(
166151
config,
167152
schema=schema,
@@ -289,18 +274,13 @@ def _make_unresolved_schema(
289274
)
290275
elif isinstance(config[name], dict):
291276
fields[name] = cls._make_unresolved_schema(
292-
_make_dummy_schema(config[name]), config[name]
277+
_make_dummy_schema(config[name]), config
293278
)
294-
elif isinstance(field.annotation, str) or field.annotation == ForwardRef:
295-
fields[name] = (Any, Field(field.default))
296279
else:
297-
fields[name] = (Any, Field(field.default))
298-
299-
model = create_model(
300-
f"{schema.__name__}_UnresolvedConfig", __config__=schema.model_config, **fields
280+
fields[name] = (field.annotation, Field(...))
281+
return create_model(
282+
"UnresolvedConfig", __config__={"extra": "forbid"}, **fields
301283
)
302-
model.model_rebuild(raise_errors=True)
303-
return model
304284

305285
@classmethod
306286
def _make_unresolved_promise_schema(cls, obj: Dict[str, Any]) -> Type[BaseModel]:
@@ -377,7 +357,7 @@ def validate_resolved(config, schema: Type[BaseModel]):
377357
# If value is a generator we can't validate type without
378358
# consuming it (which doesn't work if it's infinite – see
379359
# schedule for examples). So we skip it.
380-
config = _replace_generators(config)
360+
config = dict(config)
381361
try:
382362
_ = schema.model_validate(config)
383363
except ValidationError as e:
@@ -485,22 +465,12 @@ def fix_positionals(config):
485465
return config
486466

487467

488-
def fix_forward_refs(schema: Type[BaseModel]) -> Type[BaseModel]:
489-
fields = {}
490-
for name, field_info in schema.model_fields.items():
491-
if isinstance(field_info.annotation, str) or field_info.annotation == ForwardRef:
492-
fields[name] = (Any, field_info)
493-
else:
494-
fields[name] = (field_info.annotation, field_info)
495-
return create_model(schema.__name__, __config__=schema.model_config, **fields)
496-
497-
498468
def apply_overrides(
499469
config: Dict[str, Dict[str, Any]],
500470
overrides: Dict[str, Dict[str, Any]],
501471
) -> Dict[str, Dict[str, Any]]:
502472
"""Build first representation of the config:"""
503-
output = _shallow_copy(config)
473+
output = copy.deepcopy(config)
504474
for key, value in overrides.items():
505475
path = key.split(".")
506476
err_title = "Error parsing config overrides"
@@ -517,73 +487,33 @@ def apply_overrides(
517487
return output
518488

519489

520-
def _shallow_copy(obj):
521-
"""Ensure dict values in the config are new dicts, allowing assignment, without copying
522-
leaf objects.
523-
"""
524-
if isinstance(obj, dict):
525-
return {k: _shallow_copy(v) for k, v in obj.items()}
526-
elif isinstance(obj, list):
527-
return [_shallow_copy(v) for v in obj]
528-
else:
529-
return obj
530-
531-
532490
def make_func_schema(func) -> Type[BaseModel]:
533491
fields = get_func_fields(func)
534492
model_config = {
535493
"extra": "forbid",
536494
"arbitrary_types_allowed": True,
537495
"alias_generator": alias_generator,
538496
}
539-
return create_model(f"{func.__name__}_ArgModel", __config__=model_config, **fields) # type: ignore
497+
return create_model("ArgModel", __config__=model_config, **fields) # type: ignore
540498

541499

542500
def get_func_fields(func) -> Dict[str, Tuple[Type, FieldInfo]]:
543501
# Read the argument annotations and defaults from the function signature
544502
sig_args = {}
545-
for name, param in inspect.signature(func).parameters.items():
503+
for param in inspect.signature(func).parameters.values():
546504
# If no annotation is specified assume it's anything
547505
annotation = param.annotation if param.annotation != param.empty else Any
548-
annotation = _replace_forward_refs(annotation)
549506
# If no default value is specified assume that it's required
550507
default = param.default if param.default != param.empty else ...
551508
# Handle spread arguments and use their annotation as Sequence[whatever]
552509
if param.kind == param.VAR_POSITIONAL:
553510
spread_annot = Sequence[annotation] # type: ignore
554-
sig_args[ARGS_FIELD_ALIAS] = (spread_annot, Field(default, ))
511+
sig_args[ARGS_FIELD_ALIAS] = (spread_annot, Field(default))
555512
else:
556513
name = RESERVED_FIELDS.get(param.name, param.name)
557514
sig_args[name] = (annotation, Field(default))
558515
return sig_args
559516

560517

561-
def _replace_forward_refs(annot):
562-
if isinstance(annot, str) or annot == ForwardRef:
563-
return Any
564-
elif isinstance(annot, list):
565-
return [_replace_forward_refs(x) for x in annot]
566-
args = get_args(annot)
567-
if not args:
568-
return annot
569-
else:
570-
origin = get_origin(annot)
571-
if origin == Literal:
572-
return annot
573-
args = [_replace_forward_refs(a) for a in args]
574-
return origin[*args]
575-
576-
577-
def _replace_generators(data):
578-
if isinstance(data, BaseModel):
579-
return {k: _replace_generators(v) for k, v in data.model_dump().items()}
580-
elif isinstance(data, dict):
581-
return {k: _replace_generators(v) for k, v in data.items()}
582-
elif isinstance(data, GeneratorType):
583-
return []
584-
elif isinstance(data, list):
585-
return [_replace_generators(v) for v in data]
586-
elif isinstance(data, tuple):
587-
return tuple([_replace_generators(v) for v in data])
588-
else:
589-
return data
518+
def _is_model(type_):
519+
return issubclass(type_, BaseModel)

0 commit comments

Comments
 (0)