|
| 1 | +import os.path |
| 2 | +from abc import ABC, abstractmethod |
| 3 | +from typing import Any, Dict, Optional, Union |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import Tensor |
| 7 | + |
| 8 | +from chebai.custom_typehints import ModelConfig |
| 9 | +from chebai.models import ChebaiBaseNet, Electra |
| 10 | +from chebai.preprocessing.structures import XYData |
| 11 | + |
| 12 | + |
| 13 | +class _EnsembleBase(ChebaiBaseNet, ABC): |
| 14 | + def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): |
| 15 | + super().__init__(**kwargs) |
| 16 | + |
| 17 | + self._validate_model_configs(model_configs) |
| 18 | + |
| 19 | + self.models: Dict[str, ChebaiBaseNet] = {} |
| 20 | + self.model_configs: Dict[str, ModelConfig] = model_configs |
| 21 | + |
| 22 | + for model_name in self.model_configs: |
| 23 | + model_path = self.model_configs[model_name]["path"] |
| 24 | + if os.path.exists(model_path): |
| 25 | + self.models[model_name] = Electra.load_from_checkpoint( |
| 26 | + model_path, map_location="cpu" |
| 27 | + ) |
| 28 | + else: |
| 29 | + raise FileNotFoundError( |
| 30 | + f"Model {model_name} does not exist in the given path {model_path}" |
| 31 | + ) |
| 32 | + |
| 33 | + for model in self.models.values(): |
| 34 | + model.freeze() |
| 35 | + |
| 36 | + # TODO: Later discuss whether this threshold should be independent of metric threshold or not ? |
| 37 | + # if kwargs.get("threshold") is None: |
| 38 | + # first_metric_key = next(iter(self.train_metrics)) # Get the first key |
| 39 | + # first_metric = self.train_metrics[first_metric_key] # Get the metric object |
| 40 | + # self.threshold = int(first_metric.threshold) # Access threshold |
| 41 | + # else: |
| 42 | + # self.threshold = int(kwargs["threshold"]) |
| 43 | + |
| 44 | + @classmethod |
| 45 | + def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]): |
| 46 | + path_set = set() |
| 47 | + required_keys = {"path", "TPV", "FPV"} |
| 48 | + |
| 49 | + for model_name, config in model_configs.items(): |
| 50 | + missing_keys = required_keys - config.keys() |
| 51 | + |
| 52 | + if missing_keys: |
| 53 | + raise AttributeError( |
| 54 | + f"Missing keys {missing_keys} in model '{model_name}' configuration." |
| 55 | + ) |
| 56 | + |
| 57 | + model_path = config["path"] |
| 58 | + if not os.path.exists(model_path): |
| 59 | + raise FileNotFoundError( |
| 60 | + f"Model path '{model_path}' for '{model_name}' does not exist." |
| 61 | + ) |
| 62 | + |
| 63 | + # if model_path in path_set: |
| 64 | + # raise ValueError( |
| 65 | + # f"Duplicate model path detected: '{model_path}'. Each model must have a unique path." |
| 66 | + # ) |
| 67 | + |
| 68 | + path_set.add(model_path) |
| 69 | + |
| 70 | + # Validate 'tpv' and 'fpv' are either floats or convertible to float |
| 71 | + for key in ["TPV", "FPV"]: |
| 72 | + try: |
| 73 | + value = float(config[key]) |
| 74 | + if value < 0: |
| 75 | + raise ValueError( |
| 76 | + f"'{key}' in model '{model_name}' must be non-negative, but got {value}." |
| 77 | + ) |
| 78 | + except (TypeError, ValueError): |
| 79 | + raise ValueError( |
| 80 | + f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}." |
| 81 | + ) |
| 82 | + |
| 83 | + @abstractmethod |
| 84 | + def _get_prediction_and_labels( |
| 85 | + self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor |
| 86 | + ) -> (torch.Tensor, torch.Tensor): |
| 87 | + pass |
| 88 | + |
| 89 | + |
| 90 | +class ChebiEnsemble(_EnsembleBase): |
| 91 | + |
| 92 | + NAME = "ChebiEnsemble" |
| 93 | + |
| 94 | + def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): |
| 95 | + super().__init__(model_configs, **kwargs) |
| 96 | + # Add a dummy trainable parameter |
| 97 | + self.dummy_param = torch.nn.Parameter(torch.randn(1)) |
| 98 | + |
| 99 | + def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: |
| 100 | + predictions = {} |
| 101 | + confidences = {} |
| 102 | + total_logits = torch.zeros( |
| 103 | + data["labels"].shape[0], data["labels"].shape[1], device=self.device |
| 104 | + ).to(self.device) |
| 105 | + |
| 106 | + print(data["features"].shape) # Debugging |
| 107 | + |
| 108 | + for name, model in self.models.items(): |
| 109 | + output = model(data) |
| 110 | + confidences[name] = torch.sigmoid(output["logits"]) |
| 111 | + predictions[name] = ( |
| 112 | + torch.sigmoid(output["logits"]) > 0.5 |
| 113 | + ).long() # Multi-label classification |
| 114 | + total_logits += output["logits"] |
| 115 | + |
| 116 | + return { |
| 117 | + "logits": total_logits, |
| 118 | + "pred_dict": predictions, |
| 119 | + "conf_dict": confidences, |
| 120 | + } |
| 121 | + |
| 122 | + def _get_prediction_and_labels(self, data, labels, model_output): |
| 123 | + d = model_output["logits"] |
| 124 | + # Aggregate predictions using weighted voting |
| 125 | + metrics_preds = self.aggregate_predictions( |
| 126 | + model_output["pred_dict"], model_output["conf_dict"] |
| 127 | + ) |
| 128 | + loss_kwargs = data.get("loss_kwargs", dict()) |
| 129 | + if "non_null_labels" in loss_kwargs: |
| 130 | + n = loss_kwargs["non_null_labels"] |
| 131 | + d = d[n] |
| 132 | + metrics_preds = metrics_preds[n] |
| 133 | + return ( |
| 134 | + torch.sigmoid(d), |
| 135 | + labels.int() if labels is not None else None, |
| 136 | + metrics_preds, |
| 137 | + ) |
| 138 | + |
| 139 | + def _execute( |
| 140 | + self, |
| 141 | + batch: XYData, |
| 142 | + batch_idx: int, |
| 143 | + metrics: Optional[torch.nn.Module] = None, |
| 144 | + prefix: Optional[str] = "", |
| 145 | + log: Optional[bool] = True, |
| 146 | + sync_dist: Optional[bool] = False, |
| 147 | + ) -> Dict[str, Union[torch.Tensor, Any]]: |
| 148 | + """ |
| 149 | + Executes the model on a batch of data and returns the model output and predictions. |
| 150 | +
|
| 151 | + Args: |
| 152 | + batch (XYData): The input batch of data. |
| 153 | + batch_idx (int): The index of the current batch. |
| 154 | + metrics (torch.nn.Module): A dictionary of metrics to track. |
| 155 | + prefix (str, optional): A prefix to add to the metric names. Defaults to "". |
| 156 | + log (bool, optional): Whether to log the metrics. Defaults to True. |
| 157 | + sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False. |
| 158 | +
|
| 159 | + Returns: |
| 160 | + Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output, |
| 161 | + predictions, and loss (if applicable). |
| 162 | + """ |
| 163 | + assert isinstance(batch, XYData) |
| 164 | + batch = batch.to(self.device) |
| 165 | + data = self._process_batch(batch, batch_idx) |
| 166 | + labels = data["labels"] |
| 167 | + model_output = self(data, **data.get("model_kwargs", dict())) |
| 168 | + pr, tar, metrics_preds = self._get_prediction_and_labels( |
| 169 | + data, labels, model_output |
| 170 | + ) |
| 171 | + d = dict(data=data, labels=labels, output=model_output, preds=pr) |
| 172 | + if log: |
| 173 | + if self.criterion is not None: |
| 174 | + loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss( |
| 175 | + model_output, labels, data.get("loss_kwargs", dict()) |
| 176 | + ) |
| 177 | + loss_kwargs = dict() |
| 178 | + if self.pass_loss_kwargs: |
| 179 | + loss_kwargs = loss_kwargs_candidates |
| 180 | + loss = self.criterion(loss_data, loss_labels, **loss_kwargs) |
| 181 | + if isinstance(loss, tuple): |
| 182 | + loss_additional = loss[1:] |
| 183 | + for i, loss_add in enumerate(loss_additional): |
| 184 | + self.log( |
| 185 | + f"{prefix}loss_{i}", |
| 186 | + loss_add if isinstance(loss_add, int) else loss_add.item(), |
| 187 | + batch_size=len(batch), |
| 188 | + on_step=True, |
| 189 | + on_epoch=False, |
| 190 | + prog_bar=False, |
| 191 | + logger=True, |
| 192 | + sync_dist=sync_dist, |
| 193 | + ) |
| 194 | + loss = loss[0] |
| 195 | + |
| 196 | + d["loss"] = loss |
| 197 | + self.log( |
| 198 | + f"{prefix}loss", |
| 199 | + loss.item(), |
| 200 | + batch_size=len(batch), |
| 201 | + on_step=True, |
| 202 | + on_epoch=True, |
| 203 | + prog_bar=True, |
| 204 | + logger=True, |
| 205 | + sync_dist=sync_dist, |
| 206 | + ) |
| 207 | + if metrics and labels is not None: |
| 208 | + for metric_name, metric in metrics.items(): |
| 209 | + metric.update(metrics_preds, tar) |
| 210 | + self._log_metrics(prefix, metrics, len(batch)) |
| 211 | + return d |
| 212 | + |
| 213 | + def aggregate_predictions(self, predictions, confidences): |
| 214 | + """Implements weighted voting based on trustworthiness.""" |
| 215 | + batch_size, num_classes = list(predictions.values())[0].shape |
| 216 | + |
| 217 | + true_scores = torch.zeros(batch_size, num_classes, device=self.device) |
| 218 | + false_scores = torch.zeros(batch_size, num_classes, device=self.device) |
| 219 | + |
| 220 | + for model, preds in predictions.items(): |
| 221 | + tpv = float(self.model_configs[model]["TPV"]) |
| 222 | + npv = float(self.model_configs[model]["FPV"]) |
| 223 | + |
| 224 | + confidence = confidences[model] |
| 225 | + weight = confidence * (tpv * preds + npv * (1 - preds)) |
| 226 | + |
| 227 | + true_scores += weight * preds |
| 228 | + false_scores += weight * (1 - preds) |
| 229 | + |
| 230 | + return (true_scores > false_scores).long() # Final class decision |
| 231 | + |
| 232 | + |
| 233 | +class ChebiEnsembleLearning(_EnsembleBase): |
| 234 | + |
| 235 | + NAME = "ChebiEnsembleLearning" |
| 236 | + |
| 237 | + def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs): |
| 238 | + super().__init__(model_configs, **kwargs) |
| 239 | + self.ensemble_classifier = torch.nn.Linear( |
| 240 | + in_features=len(self.models) * self.out_dim, out_features=self.out_dim |
| 241 | + ) |
| 242 | + |
| 243 | + def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: |
| 244 | + predictions = {} |
| 245 | + confidences = {} |
| 246 | + |
| 247 | + for name, model in self.models.items(): |
| 248 | + output = model(data["features"]) |
| 249 | + confidence = torch.sigmoid(output) # Assuming confidence scores |
| 250 | + predictions[name] = output.argmax(dim=1) # Convert logits to class |
| 251 | + confidences[name] = confidence.max(dim=1).values # Max confidence |
| 252 | + |
| 253 | + # Aggregate predictions using weighted voting |
| 254 | + final_preds = self.aggregate_predictions(predictions, confidences) |
| 255 | + return final_preds |
| 256 | + |
| 257 | + def _get_prediction_and_labels( |
| 258 | + self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor |
| 259 | + ) -> (torch.Tensor, torch.Tensor): |
| 260 | + pass |
| 261 | + |
| 262 | + |
| 263 | +if __name__ == "__main__": |
| 264 | + pass |
0 commit comments