Skip to content

Commit 65a51e0

Browse files
committed
add ensemble controller
1 parent 7db384a commit 65a51e0

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

chebai/ensemble/controller.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os.path
2+
from abc import ABC
3+
4+
import torch
5+
6+
from chebai.ensemble.base import EnsembleBase
7+
from chebai.models import ChebaiBaseNet
8+
from chebai.preprocessing.collate import RaggedCollator
9+
10+
11+
class _Controller(EnsembleBase, ABC):
12+
def __init__(self, **kwargs):
13+
super().__init__(**kwargs)
14+
self._collator = RaggedCollator()
15+
16+
self._collated_data = self._load_and_collate_data()
17+
self.total_data_size: int = len(self._collated_data)
18+
19+
def _load_and_collate_data(self):
20+
data = torch.load(
21+
os.path.join(self.data_processed_dir_main, "data.pt"),
22+
weights_only=False,
23+
map_location=self.device,
24+
)
25+
collated_data = self._collator(data)
26+
collated_data.x = collated_data.to_x(self.device)
27+
if collated_data.y is not None:
28+
collated_data.y = collated_data.to_y(self.device)
29+
return collated_data
30+
31+
def _forward_pass(self, model: ChebaiBaseNet):
32+
processable_data = model._process_batch(self._collated_data, 0)
33+
del processable_data["loss_kwargs"]
34+
model_output = model(processable_data, **processable_data["model_kwargs"])
35+
return model_output
36+
37+
def _get_pred_conf_from_model_output(self, model_output, model_label_mask):
38+
# Consider logits and confidence only for valid classes
39+
sigmoid_logits = torch.sigmoid(model_output["logits"])
40+
prediction = torch.full(
41+
(self.total_data_size, self.num_of_labels), -1, dtype=torch.bool
42+
)
43+
confidence = torch.full(
44+
(self.total_data_size, self.num_of_labels), -1, dtype=torch.float
45+
)
46+
prediction[:, model_label_mask] = sigmoid_logits > 0.5
47+
confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5)
48+
return {"prediction": prediction, "confidence": confidence}
49+
50+
51+
class SimpleController(_Controller):
52+
def __init__(self, **kwargs):
53+
super().__init__(**kwargs)
54+
self._model_queue = list(self.model_configs.keys())
55+
56+
def _controller(self, model, model_props, **kwargs):
57+
model_output = self._forward_pass(model)
58+
return self._get_pred_conf_from_model_output(model_output, model_props["mask"])

0 commit comments

Comments
 (0)