Skip to content

Commit 7f7c6a0

Browse files
committed
ensemble: abstract code
1 parent 4fda565 commit 7f7c6a0

File tree

1 file changed

+264
-0
lines changed

1 file changed

+264
-0
lines changed

chebai/models/ensemble.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import os.path
2+
from abc import ABC, abstractmethod
3+
from typing import Any, Dict, Optional, Union
4+
5+
import torch
6+
from torch import Tensor
7+
8+
from chebai.custom_typehints import ModelConfig
9+
from chebai.models import ChebaiBaseNet, Electra
10+
from chebai.preprocessing.structures import XYData
11+
12+
13+
class _EnsembleBase(ChebaiBaseNet, ABC):
14+
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
15+
super().__init__(**kwargs)
16+
17+
self._validate_model_configs(model_configs)
18+
19+
self.models: Dict[str, ChebaiBaseNet] = {}
20+
self.model_configs: Dict[str, ModelConfig] = model_configs
21+
22+
for model_name in self.model_configs:
23+
model_path = self.model_configs[model_name]["path"]
24+
if os.path.exists(model_path):
25+
self.models[model_name] = Electra.load_from_checkpoint(
26+
model_path, map_location="cpu"
27+
)
28+
else:
29+
raise FileNotFoundError(
30+
f"Model {model_name} does not exist in the given path {model_path}"
31+
)
32+
33+
for model in self.models.values():
34+
model.freeze()
35+
36+
# TODO: Later discuss whether this threshold should be independent of metric threshold or not ?
37+
# if kwargs.get("threshold") is None:
38+
# first_metric_key = next(iter(self.train_metrics)) # Get the first key
39+
# first_metric = self.train_metrics[first_metric_key] # Get the metric object
40+
# self.threshold = int(first_metric.threshold) # Access threshold
41+
# else:
42+
# self.threshold = int(kwargs["threshold"])
43+
44+
@classmethod
45+
def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
46+
path_set = set()
47+
required_keys = {"path", "TPV", "FPV"}
48+
49+
for model_name, config in model_configs.items():
50+
missing_keys = required_keys - config.keys()
51+
52+
if missing_keys:
53+
raise AttributeError(
54+
f"Missing keys {missing_keys} in model '{model_name}' configuration."
55+
)
56+
57+
model_path = config["path"]
58+
if not os.path.exists(model_path):
59+
raise FileNotFoundError(
60+
f"Model path '{model_path}' for '{model_name}' does not exist."
61+
)
62+
63+
# if model_path in path_set:
64+
# raise ValueError(
65+
# f"Duplicate model path detected: '{model_path}'. Each model must have a unique path."
66+
# )
67+
68+
path_set.add(model_path)
69+
70+
# Validate 'tpv' and 'fpv' are either floats or convertible to float
71+
for key in ["TPV", "FPV"]:
72+
try:
73+
value = float(config[key])
74+
if value < 0:
75+
raise ValueError(
76+
f"'{key}' in model '{model_name}' must be non-negative, but got {value}."
77+
)
78+
except (TypeError, ValueError):
79+
raise ValueError(
80+
f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}."
81+
)
82+
83+
@abstractmethod
84+
def _get_prediction_and_labels(
85+
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
86+
) -> (torch.Tensor, torch.Tensor):
87+
pass
88+
89+
90+
class ChebiEnsemble(_EnsembleBase):
91+
92+
NAME = "ChebiEnsemble"
93+
94+
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
95+
super().__init__(model_configs, **kwargs)
96+
# Add a dummy trainable parameter
97+
self.dummy_param = torch.nn.Parameter(torch.randn(1))
98+
99+
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
100+
predictions = {}
101+
confidences = {}
102+
total_logits = torch.zeros(
103+
data["labels"].shape[0], data["labels"].shape[1], device=self.device
104+
).to(self.device)
105+
106+
print(data["features"].shape) # Debugging
107+
108+
for name, model in self.models.items():
109+
output = model(data)
110+
confidences[name] = torch.sigmoid(output["logits"])
111+
predictions[name] = (
112+
torch.sigmoid(output["logits"]) > 0.5
113+
).long() # Multi-label classification
114+
total_logits += output["logits"]
115+
116+
return {
117+
"logits": total_logits,
118+
"pred_dict": predictions,
119+
"conf_dict": confidences,
120+
}
121+
122+
def _get_prediction_and_labels(self, data, labels, model_output):
123+
d = model_output["logits"]
124+
# Aggregate predictions using weighted voting
125+
metrics_preds = self.aggregate_predictions(
126+
model_output["pred_dict"], model_output["conf_dict"]
127+
)
128+
loss_kwargs = data.get("loss_kwargs", dict())
129+
if "non_null_labels" in loss_kwargs:
130+
n = loss_kwargs["non_null_labels"]
131+
d = d[n]
132+
metrics_preds = metrics_preds[n]
133+
return (
134+
torch.sigmoid(d),
135+
labels.int() if labels is not None else None,
136+
metrics_preds,
137+
)
138+
139+
def _execute(
140+
self,
141+
batch: XYData,
142+
batch_idx: int,
143+
metrics: Optional[torch.nn.Module] = None,
144+
prefix: Optional[str] = "",
145+
log: Optional[bool] = True,
146+
sync_dist: Optional[bool] = False,
147+
) -> Dict[str, Union[torch.Tensor, Any]]:
148+
"""
149+
Executes the model on a batch of data and returns the model output and predictions.
150+
151+
Args:
152+
batch (XYData): The input batch of data.
153+
batch_idx (int): The index of the current batch.
154+
metrics (torch.nn.Module): A dictionary of metrics to track.
155+
prefix (str, optional): A prefix to add to the metric names. Defaults to "".
156+
log (bool, optional): Whether to log the metrics. Defaults to True.
157+
sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False.
158+
159+
Returns:
160+
Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output,
161+
predictions, and loss (if applicable).
162+
"""
163+
assert isinstance(batch, XYData)
164+
batch = batch.to(self.device)
165+
data = self._process_batch(batch, batch_idx)
166+
labels = data["labels"]
167+
model_output = self(data, **data.get("model_kwargs", dict()))
168+
pr, tar, metrics_preds = self._get_prediction_and_labels(
169+
data, labels, model_output
170+
)
171+
d = dict(data=data, labels=labels, output=model_output, preds=pr)
172+
if log:
173+
if self.criterion is not None:
174+
loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss(
175+
model_output, labels, data.get("loss_kwargs", dict())
176+
)
177+
loss_kwargs = dict()
178+
if self.pass_loss_kwargs:
179+
loss_kwargs = loss_kwargs_candidates
180+
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
181+
if isinstance(loss, tuple):
182+
loss_additional = loss[1:]
183+
for i, loss_add in enumerate(loss_additional):
184+
self.log(
185+
f"{prefix}loss_{i}",
186+
loss_add if isinstance(loss_add, int) else loss_add.item(),
187+
batch_size=len(batch),
188+
on_step=True,
189+
on_epoch=False,
190+
prog_bar=False,
191+
logger=True,
192+
sync_dist=sync_dist,
193+
)
194+
loss = loss[0]
195+
196+
d["loss"] = loss
197+
self.log(
198+
f"{prefix}loss",
199+
loss.item(),
200+
batch_size=len(batch),
201+
on_step=True,
202+
on_epoch=True,
203+
prog_bar=True,
204+
logger=True,
205+
sync_dist=sync_dist,
206+
)
207+
if metrics and labels is not None:
208+
for metric_name, metric in metrics.items():
209+
metric.update(metrics_preds, tar)
210+
self._log_metrics(prefix, metrics, len(batch))
211+
return d
212+
213+
def aggregate_predictions(self, predictions, confidences):
214+
"""Implements weighted voting based on trustworthiness."""
215+
batch_size, num_classes = list(predictions.values())[0].shape
216+
217+
true_scores = torch.zeros(batch_size, num_classes, device=self.device)
218+
false_scores = torch.zeros(batch_size, num_classes, device=self.device)
219+
220+
for model, preds in predictions.items():
221+
tpv = float(self.model_configs[model]["TPV"])
222+
npv = float(self.model_configs[model]["FPV"])
223+
224+
confidence = confidences[model]
225+
weight = confidence * (tpv * preds + npv * (1 - preds))
226+
227+
true_scores += weight * preds
228+
false_scores += weight * (1 - preds)
229+
230+
return (true_scores > false_scores).long() # Final class decision
231+
232+
233+
class ChebiEnsembleLearning(_EnsembleBase):
234+
235+
NAME = "ChebiEnsembleLearning"
236+
237+
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
238+
super().__init__(model_configs, **kwargs)
239+
self.ensemble_classifier = torch.nn.Linear(
240+
in_features=len(self.models) * self.out_dim, out_features=self.out_dim
241+
)
242+
243+
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
244+
predictions = {}
245+
confidences = {}
246+
247+
for name, model in self.models.items():
248+
output = model(data["features"])
249+
confidence = torch.sigmoid(output) # Assuming confidence scores
250+
predictions[name] = output.argmax(dim=1) # Convert logits to class
251+
confidences[name] = confidence.max(dim=1).values # Max confidence
252+
253+
# Aggregate predictions using weighted voting
254+
final_preds = self.aggregate_predictions(predictions, confidences)
255+
return final_preds
256+
257+
def _get_prediction_and_labels(
258+
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
259+
) -> (torch.Tensor, torch.Tensor):
260+
pass
261+
262+
263+
if __name__ == "__main__":
264+
pass

0 commit comments

Comments
 (0)