99from torch import Tensor
1010
1111from chebai .models import ChebaiBaseNet
12- from chebai .models .ffn import FFN
1312from chebai .preprocessing .structures import XYData
1413
1514
@@ -39,7 +38,8 @@ def __init__(
3938 **kwargs: Additional arguments for initialization.
4039 """
4140 super ().__init__ (** kwargs )
42- self ._validate_model_configs (model_configs )
41+ if kwargs .get ("_validate_configs" , True ):
42+ self ._validate_model_configs (model_configs )
4343
4444 self .data_processed_dir_main = data_processed_dir_main
4545 self .models : Dict [str , LightningModule ] = {}
@@ -79,7 +79,8 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
7979
8080 if model_path in path_set :
8181 raise ValueError (
82- f"Duplicate model path detected: '{ model_path } '. Each model must have a unique model-checkpoint path."
82+ f"Duplicate model path detected: '{ model_path } '. "
83+ f"Each model must have a unique model-checkpoint path."
8384 )
8485
8586 if class_path in class_set :
@@ -94,16 +95,11 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
9495
9596 path_set .add (model_path )
9697 class_set .add (class_path )
97- labels_path .add (labels_path )
98+ labels_set .add (labels_path )
9899
99100 def _load_ensemble_models (self ):
100101 """
101102 Loads the models specified in the configuration and initializes them.
102-
103- Raises:
104- FileNotFoundError: If the model checkpoint path does not exist.
105- ModuleNotFoundError: If the module containing the model class is not found.
106- AttributeError: If the specified class is not found within the module.
107103 """
108104 for model_name in self .model_configs :
109105 model_ckpt_path = self .model_configs [model_name ]["ckpt_path" ]
@@ -116,33 +112,38 @@ def _load_ensemble_models(self):
116112
117113 class_name = model_class_path .split ("." )[- 1 ]
118114 module_path = "." .join (model_class_path .split ("." )[:- 1 ])
115+ module = importlib .import_module (module_path )
116+ lightning_cls : LightningModule = getattr (module , class_name )
119117
120- try :
121- module = importlib .import_module (module_path )
122- lightning_cls : LightningModule = getattr (module , class_name )
118+ model = lightning_cls .load_from_checkpoint (
119+ model_ckpt_path , input_dim = self .input_dim
120+ )
121+ model .eval ()
122+ model .freeze ()
123123
124- model = lightning_cls .load_from_checkpoint (
125- model_ckpt_path , input_dim = self .input_dim
126- )
127- model .eval ()
128- model .freeze ()
129- self .models [model_name ] = model
130- self .models_configs [model_name ]["labels" ] = self ._load_model_labels (
131- model_labels_path
132- )
124+ self .models [model_name ] = model
125+ self .model_configs [model_name ]["labels" ] = self ._load_model_labels (
126+ model_labels_path , model_name
127+ )
133128
134- except ModuleNotFoundError :
135- print (f"Module '{ module_path } ' not found!" )
136- except AttributeError :
137- print (f"Class '{ class_name } ' not found in '{ module_path } '!" )
129+ def _load_data_module_labels (self ):
130+ """
131+ Loads the label mapping from the classes.txt file for loaded data.
138132
139- except Exception as e :
140- raise RuntimeError (
141- f"Failed to load model '{ model_name } ' from { model_ckpt_path } : \n { e } "
142- )
133+ Raises:
134+ FileNotFoundError: If the classes.txt file does not exist.
135+ """
136+ classes_txt_file = os .path .join (self .data_processed_dir_main , "classes.txt" )
137+ if not os .path .exists (classes_txt_file ):
138+ raise FileNotFoundError (f"{ classes_txt_file } does not exist" )
139+ else :
140+ with open (classes_txt_file , "r" ) as f :
141+ for line in f :
142+ if line .strip () not in self .dm_labels :
143+ self .dm_labels [line .strip ()] = len (self .dm_labels )
143144
144145 @staticmethod
145- def _load_model_labels (labels_path : str ) -> Dict [str , float ]:
146+ def _load_model_labels (labels_path : str , model_name : str ) -> Dict [str , float ]:
146147 if not os .path .exists (labels_path ):
147148 raise FileNotFoundError (f"{ labels_path } does not exist." )
148149
@@ -154,7 +155,7 @@ def _load_model_labels(labels_path: str) -> Dict[str, float]:
154155
155156 labels_dict = {}
156157 for label , label_dict in model_labels .items ():
157- msg = f"for label { label } "
158+ msg = f"for model { model_name } for label { label } "
158159 if "TPV" not in label_dict .keys () or "FPV" not in label_dict .keys ():
159160 raise AttributeError (f"Missing keys 'TPV' and/or 'FPV' { msg } " )
160161
@@ -170,7 +171,7 @@ def _load_model_labels(labels_path: str) -> Dict[str, float]:
170171 raise ValueError (
171172 f"'{ key } ' must be a float or convertible to float, but got { label_dict [key ]} { msg } "
172173 )
173- labels_dict [ label ] [key ] = value
174+ labels_dict . setdefault ( label , {}) [key ] = value
174175 return labels_dict
175176
176177 @abstractmethod
@@ -193,7 +194,9 @@ def _get_prediction_and_labels(
193194 def controller (self ):
194195 pass
195196
196- def consolidator (self ):
197+ def consolidator (
198+ self ,
199+ ):
197200 pass
198201
199202
@@ -238,19 +241,14 @@ def _generate_model_label_mask(self):
238241 num_models_per_label = torch .zeros (1 , self .out_dim , device = self .device )
239242
240243 for model_name , model_config in self .model_configs .items ():
241- labels_path = model_config ["labels_path" ]
242- if not os .path .exists (labels_path ):
243- raise FileNotFoundError (f"Labels path '{ labels_path } ' does not exist." )
244-
245- with open (labels_path , "r" ) as f :
246- labels_dict = json .load (f )
244+ labels_dict = model_config ["labels" ]
247245
248246 model_label_indices , tpv_label_values , fpv_label_values = [], [], []
249247 for label in labels_dict .keys ():
250248 if label in self .dm_labels :
251249 model_label_indices .append (self .dm_labels [label ])
252- tpv_label_values .append (float ( labels_dict [label ]["TPV" ]) )
253- fpv_label_values .append (float ( labels_dict [label ]["FPV" ]) )
250+ tpv_label_values .append (labels_dict [label ]["TPV" ])
251+ fpv_label_values .append (labels_dict [label ]["FPV" ])
254252
255253 if not all ([model_label_indices , tpv_label_values , fpv_label_values ]):
256254 raise ValueError (f"Values are empty for labels of model { model_name } " )
@@ -318,7 +316,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
318316 confidences [name ] = confidence
319317 total_logits += output [
320318 "logits"
321- ] # Don 't play a role here, just for lightning flow completeness
319+ ] # This doesn 't play a role here, just for lightning flow completeness
322320
323321 return {
324322 "logits" : total_logits ,
0 commit comments