|
1 | 1 | import importlib |
| 2 | +import json |
2 | 3 | import os.path |
3 | 4 | from abc import ABC, abstractmethod |
4 | 5 | from typing import Any, Dict, Optional, Tuple, Union |
|
12 | 13 |
|
13 | 14 |
|
14 | 15 | class _EnsembleBase(ChebaiBaseNet, ABC): |
15 | | - def __init__(self, model_configs: Dict[str, Dict], **kwargs): |
| 16 | + def __init__( |
| 17 | + self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs |
| 18 | + ): |
16 | 19 | super().__init__(**kwargs) |
17 | 20 | self._validate_model_configs(model_configs) |
18 | 21 |
|
| 22 | + self.data_processed_dir_main = data_processed_dir_main |
19 | 23 | self.models: Dict[str, LightningModule] = {} |
20 | 24 | self.model_configs = model_configs |
| 25 | + self.dm_labels: Dict[str, int] = {} |
21 | 26 |
|
22 | | - for model_name in self.model_configs: |
23 | | - model_ckpt_path = self.model_configs[model_name]["ckpt_path"] |
24 | | - model_class_path = self.model_configs[model_name]["class_path"] |
25 | | - if not os.path.exists(model_ckpt_path): |
26 | | - raise FileNotFoundError( |
27 | | - f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." |
28 | | - ) |
29 | | - |
30 | | - class_name = model_class_path.split(".")[-1] |
31 | | - module_path = ".".join(model_class_path.split(".")[:-1]) |
32 | | - |
33 | | - try: |
34 | | - module = importlib.import_module(module_path) |
35 | | - lightning_cls: LightningModule = getattr(module, class_name) |
36 | | - |
37 | | - model = lightning_cls.load_from_checkpoint(model_ckpt_path) |
38 | | - model.eval() |
39 | | - model.freeze() |
40 | | - self.models[model_name] = model |
41 | | - |
42 | | - except ModuleNotFoundError: |
43 | | - print(f"Module '{module_path}' not found!") |
44 | | - except AttributeError: |
45 | | - print(f"Class '{class_name}' not found in '{module_path}'!") |
46 | | - |
47 | | - except Exception as e: |
48 | | - raise RuntimeError( |
49 | | - f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}" |
50 | | - ) |
| 27 | + self._load_data_module_labels() |
| 28 | + self._load_ensemble_models() |
51 | 29 |
|
52 | 30 | # TODO: Later discuss whether this threshold should be independent of metric threshold or not ? |
53 | 31 | # if kwargs.get("threshold") is None: |
@@ -98,6 +76,47 @@ def _extra_validation( |
98 | 76 | ): |
99 | 77 | pass |
100 | 78 |
|
| 79 | + def _load_ensemble_models(self): |
| 80 | + for model_name in self.model_configs: |
| 81 | + model_ckpt_path = self.model_configs[model_name]["ckpt_path"] |
| 82 | + model_class_path = self.model_configs[model_name]["class_path"] |
| 83 | + if not os.path.exists(model_ckpt_path): |
| 84 | + raise FileNotFoundError( |
| 85 | + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." |
| 86 | + ) |
| 87 | + |
| 88 | + class_name = model_class_path.split(".")[-1] |
| 89 | + module_path = ".".join(model_class_path.split(".")[:-1]) |
| 90 | + |
| 91 | + try: |
| 92 | + module = importlib.import_module(module_path) |
| 93 | + lightning_cls: LightningModule = getattr(module, class_name) |
| 94 | + |
| 95 | + model = lightning_cls.load_from_checkpoint(model_ckpt_path) |
| 96 | + model.eval() |
| 97 | + model.freeze() |
| 98 | + self.models[model_name] = model |
| 99 | + |
| 100 | + except ModuleNotFoundError: |
| 101 | + print(f"Module '{module_path}' not found!") |
| 102 | + except AttributeError: |
| 103 | + print(f"Class '{class_name}' not found in '{module_path}'!") |
| 104 | + |
| 105 | + except Exception as e: |
| 106 | + raise RuntimeError( |
| 107 | + f"Failed to load model '{model_name}' from {model_ckpt_path}: \n {e}" |
| 108 | + ) |
| 109 | + |
| 110 | + def _load_data_module_labels(self): |
| 111 | + classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") |
| 112 | + if not os.path.exists(classes_txt_file): |
| 113 | + raise FileNotFoundError(f"{classes_txt_file} does not exist") |
| 114 | + else: |
| 115 | + with open(classes_txt_file, "r") as f: |
| 116 | + for line in f: |
| 117 | + if line.strip() not in self.dm_labels: |
| 118 | + self.dm_labels[line.strip()] = len(self.dm_labels) |
| 119 | + |
101 | 120 | @abstractmethod |
102 | 121 | def _get_prediction_and_labels( |
103 | 122 | self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor |
@@ -132,49 +151,73 @@ def _extra_validation( |
132 | 151 | # ) |
133 | 152 | sets_["labels"].add(labels_path) |
134 | 153 |
|
135 | | - if "TPV" not in config.keys() or "FPV" not in config.keys(): |
136 | | - raise AttributeError( |
137 | | - f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration." |
138 | | - ) |
| 154 | + with open(labels_path, "r") as f: |
| 155 | + model_labels = json.load(f) |
139 | 156 |
|
140 | | - # Validate 'tpv' and 'fpv' are either floats or convertible to float |
141 | | - for key in ["TPV", "FPV"]: |
142 | | - try: |
143 | | - value = float(config[key]) |
144 | | - if value < 0: |
| 157 | + for label, label_dict in model_labels.items(): |
| 158 | + |
| 159 | + if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): |
| 160 | + raise AttributeError( |
| 161 | + f"Missing keys 'TPV' and/or 'FPV' in model '{model_name}' configuration." |
| 162 | + ) |
| 163 | + |
| 164 | + # Validate 'tpv' and 'fpv' are either floats or convertible to float |
| 165 | + for key in ["TPV", "FPV"]: |
| 166 | + try: |
| 167 | + value = float(label_dict[key]) |
| 168 | + if value < 0: |
| 169 | + raise ValueError( |
| 170 | + f"'{key}' in model '{model_name}' and label '{label}' must be non-negative, but got {value}." |
| 171 | + ) |
| 172 | + except (TypeError, ValueError): |
145 | 173 | raise ValueError( |
146 | | - f"'{key}' in model '{model_name}' must be non-negative, but got {value}." |
| 174 | + f"'{key}' in model '{model_name}' and label '{label}' must be a float or convertible to float, but got {label_dict[key]}." |
147 | 175 | ) |
148 | | - except (TypeError, ValueError): |
149 | | - raise ValueError( |
150 | | - f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." |
151 | | - ) |
152 | 176 |
|
153 | 177 | def _generate_model_label_mask(self): |
154 | | - labels_dict = {} |
155 | 178 | num_models_per_label = torch.zeros(1, self.out_dim, device=self.device) |
| 179 | + |
156 | 180 | for model_name, model_config in self.model_configs.items(): |
157 | 181 | labels_path = model_config["labels_path"] |
158 | 182 | if not os.path.exists(labels_path): |
159 | 183 | raise FileNotFoundError(f"Labels path '{labels_path}' does not exist.") |
160 | 184 |
|
161 | 185 | with open(labels_path, "r") as f: |
162 | | - labels_list = [int(line.strip()) for line in f] |
| 186 | + labels_dict = json.load(f) |
163 | 187 |
|
164 | | - model_label_indices = [] |
165 | | - for label in labels_list: |
166 | | - if label not in labels_dict: |
167 | | - labels_dict[label] = len(labels_dict) |
| 188 | + model_label_indices, tpv_label_values, fpv_label_values = [], [], [] |
| 189 | + for label in labels_dict.keys(): |
| 190 | + if label in self.dm_labels: |
| 191 | + model_label_indices.append(self.dm_labels[label]) |
| 192 | + tpv_label_values.append(float(labels_dict[label]["TPV"])) |
| 193 | + fpv_label_values.append(float(labels_dict[label]["FPV"])) |
168 | 194 |
|
169 | | - model_label_indices.append(labels_dict[label]) |
| 195 | + if not all([model_label_indices, tpv_label_values, fpv_label_values]): |
| 196 | + raise ValueError(f"Values are empty for labels of model {model_name}") |
170 | 197 |
|
171 | 198 | # Create masks to apply predictions only to known classes |
172 | 199 | mask = torch.zeros(self.out_dim, device=self.device, dtype=torch.bool) |
173 | 200 | mask[ |
174 | 201 | torch.tensor(model_label_indices, dtype=torch.int, device=self.device) |
175 | 202 | ] = True |
176 | 203 |
|
| 204 | + tpv_tensor = torch.full_like( |
| 205 | + mask, -1, dtype=torch.float, device=self.device |
| 206 | + ) |
| 207 | + fpv_tensor = torch.full_like( |
| 208 | + mask, -1, dtype=torch.float, device=self.device |
| 209 | + ) |
| 210 | + |
| 211 | + tpv_tensor[mask] = torch.tensor( |
| 212 | + tpv_label_values, dtype=torch.float, device=self.device |
| 213 | + ) |
| 214 | + fpv_tensor[mask] = torch.tensor( |
| 215 | + fpv_label_values, dtype=torch.float, device=self.device |
| 216 | + ) |
| 217 | + |
177 | 218 | self.model_configs[model_name]["labels_mask"] = mask |
| 219 | + self.model_configs[model_name]["tpv_tensor"] = tpv_tensor |
| 220 | + self.model_configs[model_name]["fpv_tensor"] = fpv_tensor |
178 | 221 | num_models_per_label += mask |
179 | 222 |
|
180 | 223 | self._num_models_per_label = num_models_per_label |
@@ -312,8 +355,8 @@ def aggregate_predictions(self, predictions, confidences): |
312 | 355 | false_scores = torch.zeros(batch_size, num_classes, device=self.device) |
313 | 356 |
|
314 | 357 | for model, conf in confidences.items(): |
315 | | - tpv = float(self.model_configs[model]["TPV"]) |
316 | | - npv = float(self.model_configs[model]["FPV"]) |
| 358 | + tpv = self.model_configs[model]["tpv_tensor"] |
| 359 | + npv = self.model_configs[model]["fpv_tensor"] |
317 | 360 |
|
318 | 361 | # Determine which classes the model provides predictions for |
319 | 362 | mask = self.model_configs[model]["labels_mask"] |
|
0 commit comments