Skip to content

Commit 4d6856d

Browse files
committed
add rank_zero_info printing
1 parent b9dbd97 commit 4d6856d

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

chebai/ensemble/base.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from lightning import LightningModule
10+
from lightning_utilities.core.rank_zero import rank_zero_info
1011

1112
from chebai.models import ChebaiBaseNet
1213
from chebai.result.classification import print_metrics
@@ -37,7 +38,7 @@ def __init__(
3738
data_processed_dir_main (str): Path to the processed data directory.
3839
**kwargs: Additional arguments for initialization.
3940
"""
40-
if kwargs.get("_validate_configs", False):
41+
if bool(kwargs.get("_validate_configs", True)):
4142
self._validate_model_configs(model_configs)
4243

4344
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -113,6 +114,8 @@ def _load_data_module_labels(self):
113114
FileNotFoundError: If the classes.txt file does not exist.
114115
"""
115116
classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt")
117+
rank_zero_info(f"Loading {classes_txt_file} ....")
118+
116119
if not os.path.exists(classes_txt_file):
117120
raise FileNotFoundError(f"{classes_txt_file} does not exist")
118121
else:
@@ -128,19 +131,25 @@ def run_ensemble(self):
128131
false_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device)
129132

130133
while self._model_queue:
131-
model, model_props = self._load_model_and_its_props(
132-
self._model_queue.popleft()
133-
)
134+
model_name = self._model_queue.popleft()
135+
rank_zero_info(f"Processing model: {model_name}")
136+
model, model_props = self._load_model_and_its_props(model_name)
137+
138+
rank_zero_info("\t Passing model to controller to generate predictions...")
134139
pred_conf_dict = self._controller(model, model_props)
135140
del model # Model can be huge to keep it in memory, delete as no longer needed
136141

142+
rank_zero_info("\t Passing predictions to consolidator to aggregation")
137143
self._consolidator(
138144
pred_conf_dict,
139145
model_props,
140146
true_scores=true_scores,
141147
false_scores=false_scores,
142148
)
143149

150+
rank_zero_info(
151+
f"Consolidate predictions of the ensemble: {self.__class__.__name__}"
152+
)
144153
final_preds = self._consolidate_on_finish(
145154
true_scores=true_scores, false_scores=false_scores
146155
)
@@ -172,19 +181,21 @@ def _load_model_and_its_props(self, model_name):
172181
lightning_cls, ChebaiBaseNet
173182
), f"{class_name} must inherit from ChebaiBaseNet"
174183

175-
model = lightning_cls.load_from_checkpoint(
176-
model_ckpt_path, input_dim=self.input_dim
177-
)
178-
model.eval()
179-
model.freeze()
180-
181-
model_label_props = self._generate_model_label_props(
182-
model_name, model_labels_path
183-
)
184+
try:
185+
model = lightning_cls.load_from_checkpoint(
186+
model_ckpt_path, input_dim=self.input_dim
187+
)
188+
model.eval()
189+
model.freeze()
190+
model_label_props = self._generate_model_label_props(model_labels_path)
191+
except Exception as e:
192+
raise RuntimeError(
193+
f"For model {model_name} following exception as occurred \n Error: {e}"
194+
)
184195

185196
return model, model_label_props
186197

187-
def _generate_model_label_props(self, model_name: str, labels_path: str):
198+
def _generate_model_label_props(self, labels_path: str):
188199
"""
189200
Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values
190201
as tensors.
@@ -193,6 +204,7 @@ def _generate_model_label_props(self, model_name: str, labels_path: str):
193204
FileNotFoundError: If the labels path does not exist.
194205
ValueError: If label values are empty for any model.
195206
"""
207+
rank_zero_info("\t Generating mask model's labels and other properties")
196208
labels_dict = self._load_model_labels(labels_path)
197209

198210
model_label_indices, tpv_label_values, fpv_label_values = [], [], []
@@ -208,7 +220,7 @@ def _generate_model_label_props(self, model_name: str, labels_path: str):
208220
fpv_label_values.append(labels_dict[label]["FPV"])
209221

210222
if not all([model_label_indices, tpv_label_values, fpv_label_values]):
211-
raise ValueError(f"Values are empty for labels of model {model_name}")
223+
raise ValueError(f"Values are empty for labels of the model")
212224

213225
# Create masks to apply predictions only to known classes
214226
mask = torch.zeros(self.num_of_labels, device=self.device, dtype=torch.bool)

0 commit comments

Comments
 (0)