Skip to content

Commit 00bd478

Browse files
committed
nn validate model config
1 parent 8d8a748 commit 00bd478

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

chebai/ensemble/_wrappers/_neural_network.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ class NNWrapper(BaseWrapper):
1515

1616
def __init__(
1717
self,
18-
reader_cls: Type[DataReader],
19-
reader_kwargs: Optional[dict] = None,
2018
**kwargs,
2119
):
20+
self._validate_model_configs(**kwargs)
2221
super().__init__(**kwargs)
2322

2423
self._model_ckpt_path = self._model_config[MODEL_CKPT_PATH]
@@ -29,10 +28,37 @@ def __init__(
2928
else dict()
3029
)
3130

32-
self._reader = self._load_class(self._reader_class_path)(**self._reader_kwargs)
31+
reader_cls: Type[DataReader] = self._load_class(self._reader_class_path)
32+
assert issubclass(reader_cls, DataReader), ""
33+
self._reader = reader_cls(**self._reader_kwargs)
3334
self._collator = reader_cls.COLLATOR()
3435
self._model: ChebaiBaseNet = self._load_model_()
3536

37+
@classmethod
38+
def _validate_model_configs(
39+
cls, model_config: dict[str, str], model_name: str
40+
) -> None:
41+
"""
42+
Validates model configuration dictionary for required keys and uniqueness.
43+
44+
Args:
45+
model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary.
46+
47+
Raises:
48+
AttributeError: If any model config is missing required keys.
49+
ValueError: If duplicate paths are found for model checkpoint, class, or labels.
50+
"""
51+
required_keys = {
52+
MODEL_CKPT_PATH,
53+
READER_CLS_PATH,
54+
}
55+
56+
missing_keys = required_keys - model_config.keys()
57+
if missing_keys:
58+
raise AttributeError(
59+
f"Missing keys {missing_keys} in model '{model_name}' configuration."
60+
)
61+
3662
def _load_model_(self) -> ChebaiBaseNet:
3763
"""
3864
Loads a model checkpoint and its label-related properties.

0 commit comments

Comments
 (0)