Skip to content

Commit dabe5ff

Browse files
committed
update code change
1 parent 0ec03b1 commit dabe5ff

File tree

1 file changed

+40
-42
lines changed

1 file changed

+40
-42
lines changed

chebai/models/ensemble.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch import Tensor
1010

1111
from chebai.models import ChebaiBaseNet
12-
from chebai.models.ffn import FFN
1312
from chebai.preprocessing.structures import XYData
1413

1514

@@ -39,7 +38,8 @@ def __init__(
3938
**kwargs: Additional arguments for initialization.
4039
"""
4140
super().__init__(**kwargs)
42-
self._validate_model_configs(model_configs)
41+
if kwargs.get("_validate_configs", True):
42+
self._validate_model_configs(model_configs)
4343

4444
self.data_processed_dir_main = data_processed_dir_main
4545
self.models: Dict[str, LightningModule] = {}
@@ -79,7 +79,8 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
7979

8080
if model_path in path_set:
8181
raise ValueError(
82-
f"Duplicate model path detected: '{model_path}'. Each model must have a unique model-checkpoint path."
82+
f"Duplicate model path detected: '{model_path}'. "
83+
f"Each model must have a unique model-checkpoint path."
8384
)
8485

8586
if class_path in class_set:
@@ -94,16 +95,11 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
9495

9596
path_set.add(model_path)
9697
class_set.add(class_path)
97-
labels_path.add(labels_path)
98+
labels_set.add(labels_path)
9899

99100
def _load_ensemble_models(self):
100101
"""
101102
Loads the models specified in the configuration and initializes them.
102-
103-
Raises:
104-
FileNotFoundError: If the model checkpoint path does not exist.
105-
ModuleNotFoundError: If the module containing the model class is not found.
106-
AttributeError: If the specified class is not found within the module.
107103
"""
108104
for model_name in self.model_configs:
109105
model_ckpt_path = self.model_configs[model_name]["ckpt_path"]
@@ -116,33 +112,38 @@ def _load_ensemble_models(self):
116112

117113
class_name = model_class_path.split(".")[-1]
118114
module_path = ".".join(model_class_path.split(".")[:-1])
115+
module = importlib.import_module(module_path)
116+
lightning_cls: LightningModule = getattr(module, class_name)
119117

120-
try:
121-
module = importlib.import_module(module_path)
122-
lightning_cls: LightningModule = getattr(module, class_name)
118+
model = lightning_cls.load_from_checkpoint(
119+
model_ckpt_path, input_dim=self.input_dim
120+
)
121+
model.eval()
122+
model.freeze()
123123

124-
model = lightning_cls.load_from_checkpoint(
125-
model_ckpt_path, input_dim=self.input_dim
126-
)
127-
model.eval()
128-
model.freeze()
129-
self.models[model_name] = model
130-
self.models_configs[model_name]["labels"] = self._load_model_labels(
131-
model_labels_path
132-
)
124+
self.models[model_name] = model
125+
self.model_configs[model_name]["labels"] = self._load_model_labels(
126+
model_labels_path, model_name
127+
)
133128

134-
except ModuleNotFoundError:
135-
print(f"Module '{module_path}' not found!")
136-
except AttributeError:
137-
print(f"Class '{class_name}' not found in '{module_path}'!")
129+
def _load_data_module_labels(self):
130+
"""
131+
Loads the label mapping from the classes.txt file for loaded data.
138132
139-
except Exception as e:
140-
raise RuntimeError(
141-
f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}"
142-
)
133+
Raises:
134+
FileNotFoundError: If the classes.txt file does not exist.
135+
"""
136+
classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt")
137+
if not os.path.exists(classes_txt_file):
138+
raise FileNotFoundError(f"{classes_txt_file} does not exist")
139+
else:
140+
with open(classes_txt_file, "r") as f:
141+
for line in f:
142+
if line.strip() not in self.dm_labels:
143+
self.dm_labels[line.strip()] = len(self.dm_labels)
143144

144145
@staticmethod
145-
def _load_model_labels(labels_path: str) -> Dict[str, float]:
146+
def _load_model_labels(labels_path: str, model_name: str) -> Dict[str, float]:
146147
if not os.path.exists(labels_path):
147148
raise FileNotFoundError(f"{labels_path} does not exist.")
148149

@@ -154,7 +155,7 @@ def _load_model_labels(labels_path: str) -> Dict[str, float]:
154155

155156
labels_dict = {}
156157
for label, label_dict in model_labels.items():
157-
msg = f"for label {label}"
158+
msg = f"for model {model_name} for label {label}"
158159
if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys():
159160
raise AttributeError(f"Missing keys 'TPV' and/or 'FPV' {msg}")
160161

@@ -170,7 +171,7 @@ def _load_model_labels(labels_path: str) -> Dict[str, float]:
170171
raise ValueError(
171172
f"'{key}' must be a float or convertible to float, but got {label_dict[key]} {msg}"
172173
)
173-
labels_dict[label][key] = value
174+
labels_dict.setdefault(label, {})[key] = value
174175
return labels_dict
175176

176177
@abstractmethod
@@ -193,7 +194,9 @@ def _get_prediction_and_labels(
193194
def controller(self):
194195
pass
195196

196-
def consolidator(self):
197+
def consolidator(
198+
self,
199+
):
197200
pass
198201

199202

@@ -238,19 +241,14 @@ def _generate_model_label_mask(self):
238241
num_models_per_label = torch.zeros(1, self.out_dim, device=self.device)
239242

240243
for model_name, model_config in self.model_configs.items():
241-
labels_path = model_config["labels_path"]
242-
if not os.path.exists(labels_path):
243-
raise FileNotFoundError(f"Labels path '{labels_path}' does not exist.")
244-
245-
with open(labels_path, "r") as f:
246-
labels_dict = json.load(f)
244+
labels_dict = model_config["labels"]
247245

248246
model_label_indices, tpv_label_values, fpv_label_values = [], [], []
249247
for label in labels_dict.keys():
250248
if label in self.dm_labels:
251249
model_label_indices.append(self.dm_labels[label])
252-
tpv_label_values.append(float(labels_dict[label]["TPV"]))
253-
fpv_label_values.append(float(labels_dict[label]["FPV"]))
250+
tpv_label_values.append(labels_dict[label]["TPV"])
251+
fpv_label_values.append(labels_dict[label]["FPV"])
254252

255253
if not all([model_label_indices, tpv_label_values, fpv_label_values]):
256254
raise ValueError(f"Values are empty for labels of model {model_name}")
@@ -318,7 +316,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
318316
confidences[name] = confidence
319317
total_logits += output[
320318
"logits"
321-
] # Don't play a role here, just for lightning flow completeness
319+
] # This doesn't play a role here, just for lightning flow completeness
322320

323321
return {
324322
"logits": total_logits,

0 commit comments

Comments
 (0)