|
6 | 6 | import logging |
7 | 7 | import pathlib |
8 | 8 | import platform |
| 9 | +import warnings |
9 | 10 | from types import GenericAlias, MappingProxyType |
10 | 11 | from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin |
11 | 12 |
|
@@ -286,12 +287,23 @@ def _get_registry_dependencies(value, types: Optional[Tuple[Type]]) -> List[List |
286 | 287 | return deps |
287 | 288 |
|
288 | 289 |
|
| 290 | +def _is_config_model(value: Dict): |
| 291 | + """Test whether a config value is a model, i.e. it is a dict which contains a _target_ key.""" |
| 292 | + if "_target_" in value: |
| 293 | + return True |
| 294 | + for key in value: |
| 295 | + # Catch potential misspellings (which will cause the config to be ignored and treated as a dict) |
| 296 | + if key.lower().strip("_") in ("target", "arget", "trget", "taget", "taret", "targt", "targe", "tagret"): |
| 297 | + warnings.warn(f"Found config value containing `{key}`, are you sure you didn't mean '_target_'?", SyntaxWarning) |
| 298 | + return False |
| 299 | + |
| 300 | + |
289 | 301 | def _is_config_subregistry(value): |
290 | 302 | """Test whether a config value is a subregistry, i.e. it is a dict which either |
291 | 303 | contains a _target_ key, or recursively contains a dict that has a _target_ key. |
292 | 304 | """ |
293 | 305 | if isinstance(value, (dict, DictConfig)): |
294 | | - if "_target_" in value: |
| 306 | + if _is_config_model(value): |
295 | 307 | return True |
296 | 308 | else: |
297 | 309 | for v in value.values(): |
@@ -579,13 +591,14 @@ def _make_subregistries(self, cfg, registries: List[ModelRegistry]) -> List[Tupl |
579 | 591 | # Skip config "variables", i.e. strings, etc that could be re-used by reference across the |
580 | 592 | # object configs |
581 | 593 | continue |
582 | | - elif "_target_" in v: |
| 594 | + elif _is_config_model(v): |
583 | 595 | models_to_register.append((registries, k, v)) |
584 | 596 | elif _is_config_subregistry(v): |
585 | 597 | # Config value represents a sub-registry |
586 | 598 | subregistry = ModelRegistry(name=k) |
587 | 599 | registry.add(k, subregistry, overwrite=self._overwrite) |
588 | 600 | models_to_register.extend(self._make_subregistries(v, registries + [subregistry])) |
| 601 | + |
589 | 602 | return models_to_register |
590 | 603 |
|
591 | 604 | def load_config(self, cfg: DictConfig, registry: ModelRegistry, skip_exceptions: bool = False) -> ModelRegistry: |
|
0 commit comments