|
| 1 | +import importlib |
| 2 | +import json |
| 3 | +import os |
| 4 | +from abc import ABC, abstractmethod |
| 5 | +from collections import deque |
| 6 | +from typing import Deque, Dict, Optional |
| 7 | + |
| 8 | +import torch |
| 9 | +from lightning import LightningModule |
| 10 | + |
| 11 | +from chebai.models import ChebaiBaseNet |
| 12 | + |
| 13 | + |
| 14 | +class EnsembleBase(ABC): |
| 15 | + """ |
| 16 | + Base class for ensemble models in the Chebai framework. |
| 17 | +
|
| 18 | + Inherits from ChebaiBaseNet and provides functionality to load multiple models, |
| 19 | + validate configuration, and manage predictions. |
| 20 | +
|
| 21 | + Attributes: |
| 22 | + data_processed_dir_main (str): Directory where the processed data is stored. |
| 23 | + models (Dict[str, LightningModule]): A dictionary of loaded models. |
| 24 | + model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble. |
| 25 | + dm_labels (Dict[str, int]): Mapping of label names to integer indices. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs |
| 30 | + ): |
| 31 | + """ |
| 32 | + Initializes the ensemble model and loads configuration, models, and labels. |
| 33 | +
|
| 34 | + Args: |
| 35 | + model_configs (Dict[str, Dict]): Dictionary of model configurations. |
| 36 | + data_processed_dir_main (str): Path to the processed data directory. |
| 37 | + **kwargs: Additional arguments for initialization. |
| 38 | + """ |
| 39 | + if kwargs.get("_validate_configs", False): |
| 40 | + self._validate_model_configs(model_configs) |
| 41 | + |
| 42 | + self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| 43 | + self.input_dim = kwargs.get("input_dim", None) |
| 44 | + self.num_of_labels: Optional[int] = ( |
| 45 | + None # will be set by `_load_data_module_labels` method |
| 46 | + ) |
| 47 | + self.data_processed_dir_main = data_processed_dir_main |
| 48 | + self.models: Dict[str, LightningModule] = {} |
| 49 | + self.model_configs = model_configs |
| 50 | + self.dm_labels: Dict[str, int] = {} |
| 51 | + |
| 52 | + self._load_data_module_labels() |
| 53 | + self._num_models_per_label: torch.Tensor = torch.zeros( |
| 54 | + 1, self.num_of_labels, device=self.device |
| 55 | + ) |
| 56 | + self._model_queue: Deque = deque() |
| 57 | + self._collated_data = None |
| 58 | + |
| 59 | + @classmethod |
| 60 | + def _validate_model_configs(cls, model_configs: Dict[str, Dict]): |
| 61 | + """ |
| 62 | + Validates the model configurations to ensure required keys are present. |
| 63 | +
|
| 64 | + Args: |
| 65 | + model_configs (Dict[str, Dict]): Dictionary of model configurations. |
| 66 | +
|
| 67 | + Raises: |
| 68 | + AttributeError: If required keys are missing in the configuration. |
| 69 | + ValueError: If there are duplicate model paths or class paths. |
| 70 | + """ |
| 71 | + path_set, class_set, labels_set = set(), set(), set() |
| 72 | + |
| 73 | + required_keys = {"class_path", "ckpt_path", "labels_path"} |
| 74 | + |
| 75 | + for model_name, config in model_configs.items(): |
| 76 | + missing_keys = required_keys - config.keys() |
| 77 | + |
| 78 | + if missing_keys: |
| 79 | + raise AttributeError( |
| 80 | + f"Missing keys {missing_keys} in model '{model_name}' configuration." |
| 81 | + ) |
| 82 | + |
| 83 | + model_path = config["ckpt_path"] |
| 84 | + class_path = config["class_path"] |
| 85 | + labels_path = config["labels_path"] |
| 86 | + |
| 87 | + if model_path in path_set: |
| 88 | + raise ValueError( |
| 89 | + f"Duplicate model path detected: '{model_path}'. " |
| 90 | + f"Each model must have a unique model-checkpoint path." |
| 91 | + ) |
| 92 | + |
| 93 | + if class_path in class_set: |
| 94 | + raise ValueError( |
| 95 | + f"Duplicate class path detected: '{class_path}'. Each model must have a unique class path." |
| 96 | + ) |
| 97 | + |
| 98 | + if labels_path in labels_set: |
| 99 | + raise ValueError( |
| 100 | + f"Duplicate labels path: {labels_path}. Each model must have unique labels path." |
| 101 | + ) |
| 102 | + |
| 103 | + path_set.add(model_path) |
| 104 | + class_set.add(class_path) |
| 105 | + labels_set.add(labels_path) |
| 106 | + |
| 107 | + def _load_data_module_labels(self): |
| 108 | + """ |
| 109 | + Loads the label mapping from the classes.txt file for loaded data. |
| 110 | +
|
| 111 | + Raises: |
| 112 | + FileNotFoundError: If the classes.txt file does not exist. |
| 113 | + """ |
| 114 | + classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt") |
| 115 | + if not os.path.exists(classes_txt_file): |
| 116 | + raise FileNotFoundError(f"{classes_txt_file} does not exist") |
| 117 | + else: |
| 118 | + with open(classes_txt_file, "r") as f: |
| 119 | + for line in f: |
| 120 | + if line.strip() not in self.dm_labels: |
| 121 | + self.dm_labels[line.strip()] = len(self.dm_labels) |
| 122 | + self.num_of_labels = len(self.dm_labels) |
| 123 | + |
| 124 | + def run_ensemble(self): |
| 125 | + batch_size = 10 |
| 126 | + true_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device) |
| 127 | + false_scores = torch.zeros(batch_size, self.num_of_labels, device=self.device) |
| 128 | + |
| 129 | + while self._model_queue: |
| 130 | + model, model_props = self._load_model_and_its_props( |
| 131 | + self._model_queue.popleft() |
| 132 | + ) |
| 133 | + pred_conf_dict = self._controller(model, model_props) |
| 134 | + self._consolidator( |
| 135 | + pred_conf_dict, |
| 136 | + model_props, |
| 137 | + true_scores=true_scores, |
| 138 | + false_scores=false_scores, |
| 139 | + ) |
| 140 | + |
| 141 | + self._consolidate_on_finish(true_scores=true_scores, false_scores=false_scores) |
| 142 | + |
| 143 | + def _load_model_and_its_props(self, model_name): |
| 144 | + """ |
| 145 | + Loads the models specified in the configuration and initializes them. |
| 146 | + """ |
| 147 | + model_ckpt_path = self.model_configs[model_name]["ckpt_path"] |
| 148 | + model_class_path = self.model_configs[model_name]["class_path"] |
| 149 | + model_labels_path = self.model_configs[model_name]["labels_path"] |
| 150 | + if not os.path.exists(model_ckpt_path): |
| 151 | + raise FileNotFoundError( |
| 152 | + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." |
| 153 | + ) |
| 154 | + |
| 155 | + class_name = model_class_path.split(".")[-1] |
| 156 | + module_path = ".".join(model_class_path.split(".")[:-1]) |
| 157 | + module = importlib.import_module(module_path) |
| 158 | + lightning_cls: LightningModule = getattr(module, class_name) |
| 159 | + assert isinstance(lightning_cls, type), f"{class_name} is not a class." |
| 160 | + assert issubclass( |
| 161 | + lightning_cls, ChebaiBaseNet |
| 162 | + ), f"{class_name} must inherit from ChebaiBaseNet" |
| 163 | + |
| 164 | + model = lightning_cls.load_from_checkpoint( |
| 165 | + model_ckpt_path, input_dim=self.input_dim |
| 166 | + ) |
| 167 | + model.eval() |
| 168 | + model.freeze() |
| 169 | + |
| 170 | + model_label_props = self._generate_model_label_props( |
| 171 | + model_name, model_labels_path |
| 172 | + ) |
| 173 | + |
| 174 | + return model, model_label_props |
| 175 | + |
| 176 | + def _generate_model_label_props(self, model_name: str, labels_path: str): |
| 177 | + """ |
| 178 | + Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values |
| 179 | + as tensors. |
| 180 | +
|
| 181 | + Raises: |
| 182 | + FileNotFoundError: If the labels path does not exist. |
| 183 | + ValueError: If label values are empty for any model. |
| 184 | + """ |
| 185 | + labels_dict = self._load_model_labels(labels_path) |
| 186 | + |
| 187 | + model_label_indices, tpv_label_values, fpv_label_values = [], [], [] |
| 188 | + for label in labels_dict.keys(): |
| 189 | + if label in self.dm_labels: |
| 190 | + try: |
| 191 | + self._validate_model_labels_json_element(labels_dict[label]) |
| 192 | + except Exception as e: |
| 193 | + raise Exception(f"Label '{label}' has an unexpected error: {e}") |
| 194 | + |
| 195 | + model_label_indices.append(self.dm_labels[label]) |
| 196 | + tpv_label_values.append(labels_dict[label]["TPV"]) |
| 197 | + fpv_label_values.append(labels_dict[label]["FPV"]) |
| 198 | + |
| 199 | + if not all([model_label_indices, tpv_label_values, fpv_label_values]): |
| 200 | + raise ValueError(f"Values are empty for labels of model {model_name}") |
| 201 | + |
| 202 | + # Create masks to apply predictions only to known classes |
| 203 | + mask = torch.zeros(self.num_of_labels, device=self.device, dtype=torch.bool) |
| 204 | + mask[torch.tensor(model_label_indices, dtype=torch.int, device=self.device)] = ( |
| 205 | + True |
| 206 | + ) |
| 207 | + |
| 208 | + tpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device) |
| 209 | + fpv_tensor = torch.full_like(mask, -1, dtype=torch.float, device=self.device) |
| 210 | + |
| 211 | + tpv_tensor[mask] = torch.tensor( |
| 212 | + tpv_label_values, dtype=torch.float, device=self.device |
| 213 | + ) |
| 214 | + fpv_tensor[mask] = torch.tensor( |
| 215 | + fpv_label_values, dtype=torch.float, device=self.device |
| 216 | + ) |
| 217 | + self._num_models_per_label += mask |
| 218 | + return {"mask": mask, "tpv_tensor": tpv_tensor, "fpv_tensor": fpv_tensor} |
| 219 | + |
| 220 | + @staticmethod |
| 221 | + def _load_model_labels(labels_path: str) -> Dict[str, Dict[str, float]]: |
| 222 | + if not os.path.exists(labels_path): |
| 223 | + raise FileNotFoundError(f"{labels_path} does not exist.") |
| 224 | + |
| 225 | + if not labels_path.endswith(".json"): |
| 226 | + raise TypeError(f"{labels_path} is not a JSON file.") |
| 227 | + |
| 228 | + with open(labels_path, "r") as f: |
| 229 | + model_labels = json.load(f) |
| 230 | + return model_labels |
| 231 | + |
| 232 | + @staticmethod |
| 233 | + def _validate_model_labels_json_element(label_dict: Dict[str, float]): |
| 234 | + if "TPV" not in label_dict.keys() or "FPV" not in label_dict.keys(): |
| 235 | + raise AttributeError(f"Missing keys 'TPV' and/or 'FPV'") |
| 236 | + |
| 237 | + # Validate 'tpv' and 'fpv' are either floats or convertible to float |
| 238 | + for key in ["TPV", "FPV"]: |
| 239 | + try: |
| 240 | + value = float(label_dict[key]) |
| 241 | + if value < 0: |
| 242 | + raise ValueError(f"'{key}' must be non-negative but got {value}") |
| 243 | + except (TypeError, ValueError): |
| 244 | + raise ValueError( |
| 245 | + f"'{key}' must be a float or convertible to float, but got {label_dict[key]}" |
| 246 | + ) |
| 247 | + |
| 248 | + @abstractmethod |
| 249 | + def _controller(self, model, model_props, **kwargs): |
| 250 | + pass |
| 251 | + |
| 252 | + @abstractmethod |
| 253 | + def _consolidator( |
| 254 | + self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs |
| 255 | + ): |
| 256 | + pass |
| 257 | + |
| 258 | + @abstractmethod |
| 259 | + def _consolidate_on_finish(self, *, true_scores, false_scores): |
| 260 | + pass |
0 commit comments