Skip to content

Commit 0541ed2

Browse files
committed
ensemble: add docstrings and typehints
1 parent 9b851c5 commit 0541ed2

File tree

1 file changed

+214
-32
lines changed

1 file changed

+214
-32
lines changed

chebai/models/ensemble.py

Lines changed: 214 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,35 @@
99
from torch import Tensor
1010

1111
from chebai.models import ChebaiBaseNet
12+
from chebai.models.ffn import FFN
1213
from chebai.preprocessing.structures import XYData
1314

1415

1516
class _EnsembleBase(ChebaiBaseNet, ABC):
17+
"""
18+
Base class for ensemble models in the Chebai framework.
19+
20+
Inherits from ChebaiBaseNet and provides functionality to load multiple models,
21+
validate configuration, and manage predictions.
22+
23+
Attributes:
24+
data_processed_dir_main (str): Directory where the processed data is stored.
25+
models (Dict[str, LightningModule]): A dictionary of loaded models.
26+
model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble.
27+
dm_labels (Dict[str, int]): Mapping of label names to integer indices.
28+
"""
29+
1630
def __init__(
1731
self, model_configs: Dict[str, Dict], data_processed_dir_main: str, **kwargs
1832
):
33+
"""
34+
Initializes the ensemble model and loads configuration, models, and labels.
35+
36+
Args:
37+
model_configs (Dict[str, Dict]): Dictionary of model configurations.
38+
data_processed_dir_main (str): Path to the processed data directory.
39+
**kwargs: Additional arguments for initialization.
40+
"""
1941
super().__init__(**kwargs)
2042
self._validate_model_configs(model_configs)
2143

@@ -27,16 +49,18 @@ def __init__(
2749
self._load_data_module_labels()
2850
self._load_ensemble_models()
2951

30-
# TODO: Later discuss whether this threshold should be independent of metric threshold or not ?
31-
# if kwargs.get("threshold") is None:
32-
# first_metric_key = next(iter(self.train_metrics)) # Get the first key
33-
# first_metric = self.train_metrics[first_metric_key] # Get the metric object
34-
# self.threshold = int(first_metric.threshold) # Access threshold
35-
# else:
36-
# self.threshold = int(kwargs["threshold"])
37-
3852
@classmethod
3953
def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
54+
"""
55+
Validates the model configurations to ensure required keys are present.
56+
57+
Args:
58+
model_configs (Dict[str, Dict]): Dictionary of model configurations.
59+
60+
Raises:
61+
AttributeError: If required keys are missing in the configuration.
62+
ValueError: If there are duplicate model paths or class paths.
63+
"""
4064
path_set = set()
4165
class_set = set()
4266
labels_set = set()
@@ -55,15 +79,15 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
5579
model_path = config["ckpt_path"]
5680
class_path = config["class_path"]
5781

58-
# if model_path in path_set:
59-
# raise ValueError(
60-
# f"Duplicate model path detected: '{model_path}'. Each model must have a unique path."
61-
# )
82+
if model_path in path_set:
83+
raise ValueError(
84+
f"Duplicate model path detected: '{model_path}'. Each model must have a unique path."
85+
)
6286

63-
# if class_path not in class_set:
64-
# raise ValueError(
65-
# f"Duplicate class path detected: '{class_path}'. Each model must have a unique path."
66-
# )
87+
if class_path not in class_set:
88+
raise ValueError(
89+
f"Duplicate class path detected: '{class_path}'. Each model must have a unique path."
90+
)
6791

6892
path_set.add(model_path)
6993
class_set.add(class_path)
@@ -74,9 +98,27 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
7498
def _extra_validation(
7599
cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set]
76100
):
101+
"""
102+
Perform extra validation on the model configuration, if necessary.
103+
104+
This method can be extended by subclasses to add additional validation logic.
105+
106+
Args:
107+
model_name (str): The name of the model.
108+
config (Dict[str, Any]): The configuration dictionary for the model.
109+
sets_ (Dict[str, set]): A dictionary of sets to track model paths, class paths, and labels.
110+
"""
77111
pass
78112

79113
def _load_ensemble_models(self):
114+
"""
115+
Loads the models specified in the configuration and initializes them.
116+
117+
Raises:
118+
FileNotFoundError: If the model checkpoint path does not exist.
119+
ModuleNotFoundError: If the module containing the model class is not found.
120+
AttributeError: If the specified class is not found within the module.
121+
"""
80122
for model_name in self.model_configs:
81123
model_ckpt_path = self.model_configs[model_name]["ckpt_path"]
82124
model_class_path = self.model_configs[model_name]["class_path"]
@@ -108,6 +150,12 @@ def _load_ensemble_models(self):
108150
)
109151

110152
def _load_data_module_labels(self):
153+
"""
154+
Loads the label mapping from the classes.txt file for loaded data.
155+
156+
Raises:
157+
FileNotFoundError: If the classes.txt file does not exist.
158+
"""
111159
classes_txt_file = os.path.join(self.data_processed_dir_main, "classes.txt")
112160
if not os.path.exists(classes_txt_file):
113161
raise FileNotFoundError(f"{classes_txt_file} does not exist")
@@ -120,15 +168,43 @@ def _load_data_module_labels(self):
120168
@abstractmethod
121169
def _get_prediction_and_labels(
122170
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
123-
) -> (torch.Tensor, torch.Tensor):
171+
) -> Tuple[torch.Tensor, torch.Tensor]:
172+
"""
173+
Abstract method for obtaining predictions and labels.
174+
175+
Args:
176+
data (Dict[str, Any]): The input data.
177+
labels (torch.Tensor): The target labels.
178+
output (torch.Tensor): The model output.
179+
180+
Returns:
181+
Tuple[torch.Tensor, torch.Tensor]: The predicted labels and the ground truth labels.
182+
"""
124183
pass
125184

126185

127186
class ChebiEnsemble(_EnsembleBase):
187+
"""
188+
Ensemble model that aggregates predictions from multiple models for the Chebai task.
189+
190+
This model combines the outputs of several individual models and aggregates their predictions
191+
using a weighted voting strategy based on trustworthiness (TPV and FPV). This strategy can modified by overriding
192+
`aggregate_predictions` method by subclasses, as per needs.
193+
194+
There is are relevant trainable parameters for this ensemble model, hence trainer.max_epochs should be set to 1.
195+
`_dummy_param` exists for only lighting module completeness and compatability purpose.
196+
"""
128197

129198
NAME = "ChebiEnsemble"
130199

131200
def __init__(self, model_configs: Dict[str, Dict], **kwargs):
201+
"""
202+
Initializes the ensemble model and computes the model-label mask.
203+
204+
Args:
205+
model_configs (Dict[str, Dict]): Dictionary of model configurations.
206+
**kwargs: Additional arguments for initialization.
207+
"""
132208
super().__init__(model_configs, **kwargs)
133209

134210
# Add a dummy trainable parameter
@@ -140,15 +216,27 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
140216
def _extra_validation(
141217
cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set]
142218
):
219+
"""
220+
Additional validation for the ensemble model configuration.
143221
222+
Args:
223+
model_name (str): The model name.
224+
config (Dict[str, Any]): The configuration dictionary.
225+
sets_ (Dict[str, set]): The set of paths for labels.
226+
227+
Raises:
228+
AttributeError: If the 'labels_path' key is missing.
229+
ValueError: If the 'labels_path' contains duplicate entries or certain are not convertible to float.
230+
"""
144231
if "labels_path" not in config:
145232
raise AttributeError("Missing 'labels_path' key in config!")
146233

147234
labels_path = config["labels_path"]
148-
# if labels_path not in sets_["labels"]:
149-
# raise ValueError(
150-
# f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path."
151-
# )
235+
if labels_path not in sets_["labels"]:
236+
raise ValueError(
237+
f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path."
238+
)
239+
152240
sets_["labels"].add(labels_path)
153241

154242
with open(labels_path, "r") as f:
@@ -175,6 +263,14 @@ def _extra_validation(
175263
)
176264

177265
def _generate_model_label_mask(self):
266+
"""
267+
Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values
268+
as tensors.
269+
270+
Raises:
271+
FileNotFoundError: If the labels path does not exist.
272+
ValueError: If label values are empty for any model.
273+
"""
178274
num_models_per_label = torch.zeros(1, self.out_dim, device=self.device)
179275

180276
for model_name, model_config in self.model_configs.items():
@@ -223,6 +319,16 @@ def _generate_model_label_mask(self):
223319
self._num_models_per_label = num_models_per_label
224320

225321
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
322+
"""
323+
Forward pass through the ensemble model, aggregating predictions from all models.
324+
325+
Args:
326+
data (Dict[str, Tensor]): Input data including features and labels.
327+
**kwargs: Additional arguments for the forward pass.
328+
329+
Returns:
330+
Dict[str, Any]: The aggregated logits, predictions, and confidences.
331+
"""
226332
predictions = {}
227333
confidences = {}
228334

@@ -257,6 +363,17 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
257363
}
258364

259365
def _get_prediction_and_labels(self, data, labels, model_output):
366+
"""
367+
Gets predictions and labels from the model output.
368+
369+
Args:
370+
data (Dict[str, Any]): The input data.
371+
labels (torch.Tensor): The target labels.
372+
model_output (Dict[str, Tensor]): The model's output.
373+
374+
Returns:
375+
Tuple[torch.Tensor, torch.Tensor]: The predictions and the ground truth labels.
376+
"""
260377
d = model_output["logits"]
261378
# Aggregate predictions using weighted voting
262379
metrics_preds = self.aggregate_predictions(
@@ -348,8 +465,30 @@ def _execute(
348465
self._log_metrics(prefix, metrics, len(batch))
349466
return d
350467

351-
def aggregate_predictions(self, predictions, confidences):
352-
"""Implements weighted voting based on trustworthiness."""
468+
def aggregate_predictions(
469+
self, predictions: Dict[str, torch.Tensor], confidences: Dict[str, torch.Tensor]
470+
) -> torch.Tensor:
471+
"""
472+
Implements weighted voting based on trustworthiness.
473+
474+
This method aggregates predictions from multiple models using a weighted voting mechanism.
475+
The weight of each model's prediction is determined by its True Positive Value (TPV) and
476+
False Positive Value (FPV), scaled by the confidence score.
477+
478+
Args:
479+
predictions (Dict[str, torch.Tensor]):
480+
A dictionary mapping model names to their respective binary class predictions
481+
(shape: `[batch_size, num_classes]`).
482+
confidences (Dict[str, torch.Tensor]):
483+
A dictionary mapping model names to their respective confidence scores
484+
(shape: `[batch_size, num_classes]`).
485+
486+
Returns:
487+
torch.Tensor:
488+
A tensor of final aggregated predictions based on weighted voting
489+
(shape: `[batch_size, num_classes]`), where values are `True` for positive class
490+
and `False` otherwise.
491+
"""
353492
batch_size, num_classes = list(predictions.values())[0].shape
354493
true_scores = torch.zeros(batch_size, num_classes, device=self.device)
355494
false_scores = torch.zeros(batch_size, num_classes, device=self.device)
@@ -369,7 +508,7 @@ def aggregate_predictions(self, predictions, confidences):
369508
# Avoid division by zero: Set valid_counts to 1 where it's zero
370509
valid_counts = self._num_models_per_label.clamp(min=1)
371510

372-
# Normalize by valid contributions to prevent bias, this step can be optional depending upon scenario
511+
# Normalize by valid contributions to prevent bias
373512
final_preds = (true_scores / valid_counts) > (false_scores / valid_counts)
374513

375514
return final_preds
@@ -398,33 +537,76 @@ def _process_for_loss(
398537

399538

400539
class ChebiEnsembleLearning(_EnsembleBase):
540+
"""
541+
A specialized ensemble learning class for ChEBI classification.
542+
543+
This ensemble combines multiple models by concatenating their logits and
544+
passing them through a feedforward neural network (FFN) for final predictions.
545+
"""
401546

402547
NAME = "ChebiEnsembleLearning"
403548

404-
def __init__(self, model_configs: Dict[str, Dict], **kwargs):
405-
super().__init__(model_configs, **kwargs)
549+
def __init__(self, model_configs: Dict[str, Dict], **kwargs: Any):
550+
"""
551+
Initializes the ChebiEnsembleLearning class.
406552
407-
from chebai.models.ffn import FFN
553+
Args:
554+
model_configs (Dict[str, Dict]): Configuration dictionary for each model.
555+
**kwargs (Any): Additional keyword arguments for configuring the FFN.
556+
"""
557+
super().__init__(model_configs, **kwargs)
408558

409559
ffn_kwargs = kwargs.copy()
410560
ffn_kwargs["input_size"] = len(self.model_configs) * int(kwargs["out_dim"])
411561
self.ffn: FFN = FFN(**ffn_kwargs)
412562

413563
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
564+
"""
565+
Performs a forward pass through the ensemble model.
566+
567+
Args:
568+
data (Dict[str, Tensor]): Input data dictionary for the models.
569+
**kwargs (Any): Additional keyword arguments.
570+
571+
Returns:
572+
Dict[str, Any]: Output from the FFN model.
573+
"""
414574
logits_list = [model(data)["logits"] for model in self.models.values()]
415575
return self.ffn({"features": torch.cat(logits_list, dim=1)})
416576

417577
def _get_prediction_and_labels(
418-
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
419-
) -> (torch.Tensor, torch.Tensor):
578+
self, data: Dict[str, Any], labels: Tensor, output: Tensor
579+
) -> Tuple[Tensor, Tensor]:
580+
"""
581+
Extracts predictions and labels for evaluation.
582+
583+
Args:
584+
data (Dict[str, Any]): Input data dictionary.
585+
labels (Tensor): Ground truth labels.
586+
output (Tensor): Model output.
587+
588+
Returns:
589+
Tuple[Tensor, Tensor]: Processed predictions and labels.
590+
"""
420591
return self.ffn._get_prediction_and_labels(data, labels, output)
421592

422593
def _process_for_loss(
423594
self,
424-
model_output: Dict[str, torch.Tensor],
425-
labels: torch.Tensor,
595+
model_output: Dict[str, Tensor],
596+
labels: Tensor,
426597
loss_kwargs: Dict[str, Any],
427-
) -> (torch.Tensor, torch.Tensor, Dict[str, Any]):
598+
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
599+
"""
600+
Processes model output and labels for computing the loss.
601+
602+
Args:
603+
model_output (Dict[str, Tensor]): Output dictionary from the model.
604+
labels (Tensor): Ground truth labels.
605+
loss_kwargs (Dict[str, Any]): Additional arguments for loss computation.
606+
607+
Returns:
608+
Tuple[Tensor, Tensor, Dict[str, Any]]: Loss, processed predictions, and additional info.
609+
"""
428610
return self.ffn._process_for_loss(model_output, labels, loss_kwargs)
429611

430612

0 commit comments

Comments
 (0)