@@ -15,10 +15,9 @@ class NNWrapper(BaseWrapper):
1515
1616 def __init__ (
1717 self ,
18- reader_cls : Type [DataReader ],
19- reader_kwargs : Optional [dict ] = None ,
2018 ** kwargs ,
2119 ):
20+ self ._validate_model_configs (** kwargs )
2221 super ().__init__ (** kwargs )
2322
2423 self ._model_ckpt_path = self ._model_config [MODEL_CKPT_PATH ]
@@ -29,10 +28,37 @@ def __init__(
2928 else dict ()
3029 )
3130
32- self ._reader = self ._load_class (self ._reader_class_path )(** self ._reader_kwargs )
31+ reader_cls : Type [DataReader ] = self ._load_class (self ._reader_class_path )
32+ assert issubclass (reader_cls , DataReader ), ""
33+ self ._reader = reader_cls (** self ._reader_kwargs )
3334 self ._collator = reader_cls .COLLATOR ()
3435 self ._model : ChebaiBaseNet = self ._load_model_ ()
3536
37+ @classmethod
38+ def _validate_model_configs (
39+ cls , model_config : dict [str , str ], model_name : str
40+ ) -> None :
41+ """
42+ Validates model configuration dictionary for required keys and uniqueness.
43+
44+ Args:
45+ model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary.
46+
47+ Raises:
48+ AttributeError: If any model config is missing required keys.
49+ ValueError: If duplicate paths are found for model checkpoint, class, or labels.
50+ """
51+ required_keys = {
52+ MODEL_CKPT_PATH ,
53+ READER_CLS_PATH ,
54+ }
55+
56+ missing_keys = required_keys - model_config .keys ()
57+ if missing_keys :
58+ raise AttributeError (
59+ f"Missing keys { missing_keys } in model '{ model_name } ' configuration."
60+ )
61+
3662 def _load_model_ (self ) -> ChebaiBaseNet :
3763 """
3864 Loads a model checkpoint and its label-related properties.
0 commit comments