Skip to content

Commit f6b74c1

Browse files
authored
Merge pull request #52 from Point72/pit/misspelling
Add warning for common misspellings of _target_
2 parents 83d6ea2 + f7072f6 commit f6b74c1

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

ccflow/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import pathlib
88
import platform
9+
import warnings
910
from types import GenericAlias, MappingProxyType
1011
from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin
1112

@@ -286,12 +287,23 @@ def _get_registry_dependencies(value, types: Optional[Tuple[Type]]) -> List[List
286287
return deps
287288

288289

290+
def _is_config_model(value: Dict):
291+
"""Test whether a config value is a model, i.e. it is a dict which contains a _target_ key."""
292+
if "_target_" in value:
293+
return True
294+
for key in value:
295+
# Catch potential misspellings (which will cause the config to be ignored and treated as a dict)
296+
if key.lower().strip("_") in ("target", "arget", "trget", "taget", "taret", "targt", "targe", "tagret"):
297+
warnings.warn(f"Found config value containing `{key}`, are you sure you didn't mean '_target_'?", SyntaxWarning)
298+
return False
299+
300+
289301
def _is_config_subregistry(value):
290302
"""Test whether a config value is a subregistry, i.e. it is a dict which either
291303
contains a _target_ key, or recursively contains a dict that has a _target_ key.
292304
"""
293305
if isinstance(value, (dict, DictConfig)):
294-
if "_target_" in value:
306+
if _is_config_model(value):
295307
return True
296308
else:
297309
for v in value.values():
@@ -579,13 +591,14 @@ def _make_subregistries(self, cfg, registries: List[ModelRegistry]) -> List[Tupl
579591
# Skip config "variables", i.e. strings, etc that could be re-used by reference across the
580592
# object configs
581593
continue
582-
elif "_target_" in v:
594+
elif _is_config_model(v):
583595
models_to_register.append((registries, k, v))
584596
elif _is_config_subregistry(v):
585597
# Config value represents a sub-registry
586598
subregistry = ModelRegistry(name=k)
587599
registry.add(k, subregistry, overwrite=self._overwrite)
588600
models_to_register.extend(self._make_subregistries(v, registries + [subregistry]))
601+
589602
return models_to_register
590603

591604
def load_config(self, cfg: DictConfig, registry: ModelRegistry, skip_exceptions: bool = False) -> ModelRegistry:

ccflow/tests/test_base_registry.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,21 @@ def test_validation_error(self):
529529
with self.assertRaises(InstantiationException, msg=msg):
530530
r.load_config(cfg)
531531

532+
def test_misspelling_warning(self):
533+
cfg = OmegaConf.create(
534+
{
535+
"foo": {
536+
"_target": "ccflow.tests.test_base_registry.MyTestModel",
537+
"a": "test",
538+
"b": "string_that_should_be_a_float",
539+
},
540+
}
541+
)
542+
r = ModelRegistry(name="test")
543+
msg = "Found config value containing `_target`, are you sure you didn't mean '_target_'?"
544+
with self.assertWarnsRegex(SyntaxWarning, msg):
545+
r.load_config(cfg)
546+
532547

533548
class TestRegistryLookupContext(TestCase):
534549
def setUp(self) -> None:

0 commit comments

Comments
 (0)