77
88import torch
99from lightning import LightningModule
10+ from lightning_utilities .core .rank_zero import rank_zero_info
1011
1112from chebai .models import ChebaiBaseNet
1213from chebai .result .classification import print_metrics
@@ -37,7 +38,7 @@ def __init__(
3738 data_processed_dir_main (str): Path to the processed data directory.
3839 **kwargs: Additional arguments for initialization.
3940 """
40- if kwargs .get ("_validate_configs" , False ):
41+ if bool ( kwargs .get ("_validate_configs" , True ) ):
4142 self ._validate_model_configs (model_configs )
4243
4344 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -113,6 +114,8 @@ def _load_data_module_labels(self):
113114 FileNotFoundError: If the classes.txt file does not exist.
114115 """
115116 classes_txt_file = os .path .join (self .data_processed_dir_main , "classes.txt" )
117+ rank_zero_info (f"Loading { classes_txt_file } ...." )
118+
116119 if not os .path .exists (classes_txt_file ):
117120 raise FileNotFoundError (f"{ classes_txt_file } does not exist" )
118121 else :
@@ -128,19 +131,25 @@ def run_ensemble(self):
128131 false_scores = torch .zeros (batch_size , self .num_of_labels , device = self .device )
129132
130133 while self ._model_queue :
131- model , model_props = self ._load_model_and_its_props (
132- self ._model_queue .popleft ()
133- )
134+ model_name = self ._model_queue .popleft ()
135+ rank_zero_info (f"Processing model: { model_name } " )
136+ model , model_props = self ._load_model_and_its_props (model_name )
137+
138+ rank_zero_info ("\t Passing model to controller to generate predictions..." )
134139 pred_conf_dict = self ._controller (model , model_props )
135140 del model # Model can be huge to keep it in memory, delete as no longer needed
136141
142+ rank_zero_info ("\t Passing predictions to consolidator to aggregation" )
137143 self ._consolidator (
138144 pred_conf_dict ,
139145 model_props ,
140146 true_scores = true_scores ,
141147 false_scores = false_scores ,
142148 )
143149
150+ rank_zero_info (
151+ f"Consolidate predictions of the ensemble: { self .__class__ .__name__ } "
152+ )
144153 final_preds = self ._consolidate_on_finish (
145154 true_scores = true_scores , false_scores = false_scores
146155 )
@@ -172,19 +181,21 @@ def _load_model_and_its_props(self, model_name):
172181 lightning_cls , ChebaiBaseNet
173182 ), f"{ class_name } must inherit from ChebaiBaseNet"
174183
175- model = lightning_cls .load_from_checkpoint (
176- model_ckpt_path , input_dim = self .input_dim
177- )
178- model .eval ()
179- model .freeze ()
180-
181- model_label_props = self ._generate_model_label_props (
182- model_name , model_labels_path
183- )
184+ try :
185+ model = lightning_cls .load_from_checkpoint (
186+ model_ckpt_path , input_dim = self .input_dim
187+ )
188+ model .eval ()
189+ model .freeze ()
190+ model_label_props = self ._generate_model_label_props (model_labels_path )
191+ except Exception as e :
192+ raise RuntimeError (
193+ f"For model { model_name } following exception as occurred \n Error: { e } "
194+ )
184195
185196 return model , model_label_props
186197
187- def _generate_model_label_props (self , model_name : str , labels_path : str ):
198+ def _generate_model_label_props (self , labels_path : str ):
188199 """
189200 Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values
190201 as tensors.
@@ -193,6 +204,7 @@ def _generate_model_label_props(self, model_name: str, labels_path: str):
193204 FileNotFoundError: If the labels path does not exist.
194205 ValueError: If label values are empty for any model.
195206 """
207+ rank_zero_info ("\t Generating mask model's labels and other properties" )
196208 labels_dict = self ._load_model_labels (labels_path )
197209
198210 model_label_indices , tpv_label_values , fpv_label_values = [], [], []
@@ -208,7 +220,7 @@ def _generate_model_label_props(self, model_name: str, labels_path: str):
208220 fpv_label_values .append (labels_dict [label ]["FPV" ])
209221
210222 if not all ([model_label_indices , tpv_label_values , fpv_label_values ]):
211- raise ValueError (f"Values are empty for labels of model { model_name } " )
223+ raise ValueError (f"Values are empty for labels of the model " )
212224
213225 # Create masks to apply predictions only to known classes
214226 mask = torch .zeros (self .num_of_labels , device = self .device , dtype = torch .bool )
0 commit comments