Skip to content

Commit 2875385

Browse files
committed
ensemble: add MLP layer on top ensemble models
1 parent 9513fea commit 2875385

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

chebai/models/ensemble.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ class _EnsembleBase(ChebaiBaseNet, ABC):
1414
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
1515
super().__init__(**kwargs)
1616

17-
self._validate_model_configs(model_configs)
18-
1917
self.models: Dict[str, ChebaiBaseNet] = {}
2018
self.model_configs: Dict[str, ModelConfig] = model_configs
2119

@@ -41,6 +39,23 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
4139
# else:
4240
# self.threshold = int(kwargs["threshold"])
4341

42+
@abstractmethod
43+
def _get_prediction_and_labels(
44+
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
45+
) -> (torch.Tensor, torch.Tensor):
46+
pass
47+
48+
49+
class ChebiEnsemble(_EnsembleBase):
50+
51+
NAME = "ChebiEnsemble"
52+
53+
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
54+
self._validate_model_configs(model_configs)
55+
super().__init__(model_configs, **kwargs)
56+
# Add a dummy trainable parameter
57+
self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True))
58+
4459
@classmethod
4560
def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
4661
path_set = set()
@@ -80,22 +95,6 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
8095
f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}."
8196
)
8297

83-
@abstractmethod
84-
def _get_prediction_and_labels(
85-
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
86-
) -> (torch.Tensor, torch.Tensor):
87-
pass
88-
89-
90-
class ChebiEnsemble(_EnsembleBase):
91-
92-
NAME = "ChebiEnsemble"
93-
94-
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
95-
super().__init__(model_configs, **kwargs)
96-
# Add a dummy trainable parameter
97-
self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True))
98-
9998
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
10099
predictions = {}
101100
confidences = {}
@@ -255,30 +254,35 @@ class ChebiEnsembleLearning(_EnsembleBase):
255254

256255
NAME = "ChebiEnsembleLearning"
257256

258-
def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
257+
def __init__(self, model_configs: Dict[str, Dict], **kwargs):
259258
super().__init__(model_configs, **kwargs)
260-
self.ensemble_classifier = torch.nn.Linear(
261-
in_features=len(self.models) * self.out_dim, out_features=self.out_dim
262-
)
263259

264-
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
265-
predictions = {}
266-
confidences = {}
260+
from chebai.models.ffn import FFN
267261

262+
ffn_kwargs = kwargs.copy()
263+
ffn_kwargs["input_size"] = len(self.model_configs) * int(kwargs["out_dim"])
264+
self.ffn: FFN = FFN(**ffn_kwargs)
265+
266+
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
267+
logits_list = []
268268
for name, model in self.models.items():
269-
output = model(data["features"])
270-
confidence = torch.sigmoid(output) # Assuming confidence scores
271-
predictions[name] = output.argmax(dim=1) # Convert logits to class
272-
confidences[name] = confidence.max(dim=1).values # Max confidence
269+
output = model(data)
270+
logits_list.append(output["logits"])
273271

274-
# Aggregate predictions using weighted voting
275-
final_preds = self.aggregate_predictions(predictions, confidences)
276-
return final_preds
272+
return self.ffn({"features": torch.cat(logits_list, dim=1)})
277273

278274
def _get_prediction_and_labels(
279275
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
280276
) -> (torch.Tensor, torch.Tensor):
281-
pass
277+
return self.ffn._get_prediction_and_labels(data, labels, output)
278+
279+
def _process_for_loss(
280+
self,
281+
model_output: Dict[str, torch.Tensor],
282+
labels: torch.Tensor,
283+
loss_kwargs: Dict[str, Any],
284+
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
285+
return self.ffn._process_for_loss(model_output, labels, loss_kwargs)
282286

283287

284288
if __name__ == "__main__":
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class_path: chebai.models.ensemble.ChebiEnsembleLearning
2+
init_args:
3+
optimizer_kwargs:
4+
lr: 1e-3

0 commit comments

Comments
 (0)