|
| 1 | +import os.path |
| 2 | +from abc import ABC |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from chebai.ensemble.base import EnsembleBase |
| 7 | +from chebai.models import ChebaiBaseNet |
| 8 | +from chebai.preprocessing.collate import RaggedCollator |
| 9 | + |
| 10 | + |
| 11 | +class _Controller(EnsembleBase, ABC): |
| 12 | + def __init__(self, **kwargs): |
| 13 | + super().__init__(**kwargs) |
| 14 | + self._collator = RaggedCollator() |
| 15 | + |
| 16 | + self._collated_data = self._load_and_collate_data() |
| 17 | + self.total_data_size: int = len(self._collated_data) |
| 18 | + |
| 19 | + def _load_and_collate_data(self): |
| 20 | + data = torch.load( |
| 21 | + os.path.join(self.data_processed_dir_main, "data.pt"), |
| 22 | + weights_only=False, |
| 23 | + map_location=self.device, |
| 24 | + ) |
| 25 | + collated_data = self._collator(data) |
| 26 | + collated_data.x = collated_data.to_x(self.device) |
| 27 | + if collated_data.y is not None: |
| 28 | + collated_data.y = collated_data.to_y(self.device) |
| 29 | + return collated_data |
| 30 | + |
| 31 | + def _forward_pass(self, model: ChebaiBaseNet): |
| 32 | + processable_data = model._process_batch(self._collated_data, 0) |
| 33 | + del processable_data["loss_kwargs"] |
| 34 | + model_output = model(processable_data, **processable_data["model_kwargs"]) |
| 35 | + return model_output |
| 36 | + |
| 37 | + def _get_pred_conf_from_model_output(self, model_output, model_label_mask): |
| 38 | + # Consider logits and confidence only for valid classes |
| 39 | + sigmoid_logits = torch.sigmoid(model_output["logits"]) |
| 40 | + prediction = torch.full( |
| 41 | + (self.total_data_size, self.num_of_labels), -1, dtype=torch.bool |
| 42 | + ) |
| 43 | + confidence = torch.full( |
| 44 | + (self.total_data_size, self.num_of_labels), -1, dtype=torch.float |
| 45 | + ) |
| 46 | + prediction[:, model_label_mask] = sigmoid_logits > 0.5 |
| 47 | + confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5) |
| 48 | + return {"prediction": prediction, "confidence": confidence} |
| 49 | + |
| 50 | + |
| 51 | +class SimpleController(_Controller): |
| 52 | + def __init__(self, **kwargs): |
| 53 | + super().__init__(**kwargs) |
| 54 | + self._model_queue = list(self.model_configs.keys()) |
| 55 | + |
| 56 | + def _controller(self, model, model_props, **kwargs): |
| 57 | + model_output = self._forward_pass(model) |
| 58 | + return self._get_pred_conf_from_model_output(model_output, model_props["mask"]) |
0 commit comments