99from torch import Tensor
1010
1111from chebai .models import ChebaiBaseNet
12+ from chebai .models .ffn import FFN
1213from chebai .preprocessing .structures import XYData
1314
1415
1516class _EnsembleBase (ChebaiBaseNet , ABC ):
17+ """
18+ Base class for ensemble models in the Chebai framework.
19+
20+ Inherits from ChebaiBaseNet and provides functionality to load multiple models,
21+ validate configuration, and manage predictions.
22+
23+ Attributes:
24+ data_processed_dir_main (str): Directory where the processed data is stored.
25+ models (Dict[str, LightningModule]): A dictionary of loaded models.
26+ model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble.
27+ dm_labels (Dict[str, int]): Mapping of label names to integer indices.
28+ """
29+
1630 def __init__ (
1731 self , model_configs : Dict [str , Dict ], data_processed_dir_main : str , ** kwargs
1832 ):
33+ """
34+ Initializes the ensemble model and loads configuration, models, and labels.
35+
36+ Args:
37+ model_configs (Dict[str, Dict]): Dictionary of model configurations.
38+ data_processed_dir_main (str): Path to the processed data directory.
39+ **kwargs: Additional arguments for initialization.
40+ """
1941 super ().__init__ (** kwargs )
2042 self ._validate_model_configs (model_configs )
2143
@@ -27,16 +49,18 @@ def __init__(
2749 self ._load_data_module_labels ()
2850 self ._load_ensemble_models ()
2951
30- # TODO: Later discuss whether this threshold should be independent of metric threshold or not ?
31- # if kwargs.get("threshold") is None:
32- # first_metric_key = next(iter(self.train_metrics)) # Get the first key
33- # first_metric = self.train_metrics[first_metric_key] # Get the metric object
34- # self.threshold = int(first_metric.threshold) # Access threshold
35- # else:
36- # self.threshold = int(kwargs["threshold"])
37-
3852 @classmethod
3953 def _validate_model_configs (cls , model_configs : Dict [str , Dict ]):
54+ """
55+ Validates the model configurations to ensure required keys are present.
56+
57+ Args:
58+ model_configs (Dict[str, Dict]): Dictionary of model configurations.
59+
60+ Raises:
61+ AttributeError: If required keys are missing in the configuration.
62+ ValueError: If there are duplicate model paths or class paths.
63+ """
4064 path_set = set ()
4165 class_set = set ()
4266 labels_set = set ()
@@ -55,15 +79,15 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
5579 model_path = config ["ckpt_path" ]
5680 class_path = config ["class_path" ]
5781
58- # if model_path in path_set:
59- # raise ValueError(
60- # f"Duplicate model path detected: '{model_path}'. Each model must have a unique path."
61- # )
82+ if model_path in path_set :
83+ raise ValueError (
84+ f"Duplicate model path detected: '{ model_path } '. Each model must have a unique path."
85+ )
6286
63- # if class_path not in class_set:
64- # raise ValueError(
65- # f"Duplicate class path detected: '{class_path}'. Each model must have a unique path."
66- # )
87+ if class_path not in class_set :
88+ raise ValueError (
89+ f"Duplicate class path detected: '{ class_path } '. Each model must have a unique path."
90+ )
6791
6892 path_set .add (model_path )
6993 class_set .add (class_path )
@@ -74,9 +98,27 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
7498 def _extra_validation (
7599 cls , model_name : str , config : Dict [str , Any ], sets_ : Dict [str , set ]
76100 ):
101+ """
102+ Perform extra validation on the model configuration, if necessary.
103+
104+ This method can be extended by subclasses to add additional validation logic.
105+
106+ Args:
107+ model_name (str): The name of the model.
108+ config (Dict[str, Any]): The configuration dictionary for the model.
109+ sets_ (Dict[str, set]): A dictionary of sets to track model paths, class paths, and labels.
110+ """
77111 pass
78112
79113 def _load_ensemble_models (self ):
114+ """
115+ Loads the models specified in the configuration and initializes them.
116+
117+ Raises:
118+ FileNotFoundError: If the model checkpoint path does not exist.
119+ ModuleNotFoundError: If the module containing the model class is not found.
120+ AttributeError: If the specified class is not found within the module.
121+ """
80122 for model_name in self .model_configs :
81123 model_ckpt_path = self .model_configs [model_name ]["ckpt_path" ]
82124 model_class_path = self .model_configs [model_name ]["class_path" ]
@@ -108,6 +150,12 @@ def _load_ensemble_models(self):
108150 )
109151
110152 def _load_data_module_labels (self ):
153+ """
154+ Loads the label mapping from the classes.txt file for loaded data.
155+
156+ Raises:
157+ FileNotFoundError: If the classes.txt file does not exist.
158+ """
111159 classes_txt_file = os .path .join (self .data_processed_dir_main , "classes.txt" )
112160 if not os .path .exists (classes_txt_file ):
113161 raise FileNotFoundError (f"{ classes_txt_file } does not exist" )
@@ -120,15 +168,43 @@ def _load_data_module_labels(self):
120168 @abstractmethod
121169 def _get_prediction_and_labels (
122170 self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
123- ) -> (torch .Tensor , torch .Tensor ):
171+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
172+ """
173+ Abstract method for obtaining predictions and labels.
174+
175+ Args:
176+ data (Dict[str, Any]): The input data.
177+ labels (torch.Tensor): The target labels.
178+ output (torch.Tensor): The model output.
179+
180+ Returns:
181+ Tuple[torch.Tensor, torch.Tensor]: The predicted labels and the ground truth labels.
182+ """
124183 pass
125184
126185
127186class ChebiEnsemble (_EnsembleBase ):
187+ """
188+ Ensemble model that aggregates predictions from multiple models for the Chebai task.
189+
190+ This model combines the outputs of several individual models and aggregates their predictions
191+ using a weighted voting strategy based on trustworthiness (TPV and FPV). This strategy can modified by overriding
192+ `aggregate_predictions` method by subclasses, as per needs.
193+
194+ There is are relevant trainable parameters for this ensemble model, hence trainer.max_epochs should be set to 1.
195+ `_dummy_param` exists for only lighting module completeness and compatability purpose.
196+ """
128197
129198 NAME = "ChebiEnsemble"
130199
131200 def __init__ (self , model_configs : Dict [str , Dict ], ** kwargs ):
201+ """
202+ Initializes the ensemble model and computes the model-label mask.
203+
204+ Args:
205+ model_configs (Dict[str, Dict]): Dictionary of model configurations.
206+ **kwargs: Additional arguments for initialization.
207+ """
132208 super ().__init__ (model_configs , ** kwargs )
133209
134210 # Add a dummy trainable parameter
@@ -140,15 +216,27 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
140216 def _extra_validation (
141217 cls , model_name : str , config : Dict [str , Any ], sets_ : Dict [str , set ]
142218 ):
219+ """
220+ Additional validation for the ensemble model configuration.
143221
222+ Args:
223+ model_name (str): The model name.
224+ config (Dict[str, Any]): The configuration dictionary.
225+ sets_ (Dict[str, set]): The set of paths for labels.
226+
227+ Raises:
228+ AttributeError: If the 'labels_path' key is missing.
229+ ValueError: If the 'labels_path' contains duplicate entries or certain are not convertible to float.
230+ """
144231 if "labels_path" not in config :
145232 raise AttributeError ("Missing 'labels_path' key in config!" )
146233
147234 labels_path = config ["labels_path" ]
148- # if labels_path not in sets_["labels"]:
149- # raise ValueError(
150- # f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path."
151- # )
235+ if labels_path not in sets_ ["labels" ]:
236+ raise ValueError (
237+ f"Duplicate labels path detected: '{ labels_path } '. Each model must have a unique path."
238+ )
239+
152240 sets_ ["labels" ].add (labels_path )
153241
154242 with open (labels_path , "r" ) as f :
@@ -175,6 +263,14 @@ def _extra_validation(
175263 )
176264
177265 def _generate_model_label_mask (self ):
266+ """
267+ Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values
268+ as tensors.
269+
270+ Raises:
271+ FileNotFoundError: If the labels path does not exist.
272+ ValueError: If label values are empty for any model.
273+ """
178274 num_models_per_label = torch .zeros (1 , self .out_dim , device = self .device )
179275
180276 for model_name , model_config in self .model_configs .items ():
@@ -223,6 +319,16 @@ def _generate_model_label_mask(self):
223319 self ._num_models_per_label = num_models_per_label
224320
225321 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
322+ """
323+ Forward pass through the ensemble model, aggregating predictions from all models.
324+
325+ Args:
326+ data (Dict[str, Tensor]): Input data including features and labels.
327+ **kwargs: Additional arguments for the forward pass.
328+
329+ Returns:
330+ Dict[str, Any]: The aggregated logits, predictions, and confidences.
331+ """
226332 predictions = {}
227333 confidences = {}
228334
@@ -257,6 +363,17 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
257363 }
258364
259365 def _get_prediction_and_labels (self , data , labels , model_output ):
366+ """
367+ Gets predictions and labels from the model output.
368+
369+ Args:
370+ data (Dict[str, Any]): The input data.
371+ labels (torch.Tensor): The target labels.
372+ model_output (Dict[str, Tensor]): The model's output.
373+
374+ Returns:
375+ Tuple[torch.Tensor, torch.Tensor]: The predictions and the ground truth labels.
376+ """
260377 d = model_output ["logits" ]
261378 # Aggregate predictions using weighted voting
262379 metrics_preds = self .aggregate_predictions (
@@ -348,8 +465,30 @@ def _execute(
348465 self ._log_metrics (prefix , metrics , len (batch ))
349466 return d
350467
351- def aggregate_predictions (self , predictions , confidences ):
352- """Implements weighted voting based on trustworthiness."""
468+ def aggregate_predictions (
469+ self , predictions : Dict [str , torch .Tensor ], confidences : Dict [str , torch .Tensor ]
470+ ) -> torch .Tensor :
471+ """
472+ Implements weighted voting based on trustworthiness.
473+
474+ This method aggregates predictions from multiple models using a weighted voting mechanism.
475+ The weight of each model's prediction is determined by its True Positive Value (TPV) and
476+ False Positive Value (FPV), scaled by the confidence score.
477+
478+ Args:
479+ predictions (Dict[str, torch.Tensor]):
480+ A dictionary mapping model names to their respective binary class predictions
481+ (shape: `[batch_size, num_classes]`).
482+ confidences (Dict[str, torch.Tensor]):
483+ A dictionary mapping model names to their respective confidence scores
484+ (shape: `[batch_size, num_classes]`).
485+
486+ Returns:
487+ torch.Tensor:
488+ A tensor of final aggregated predictions based on weighted voting
489+ (shape: `[batch_size, num_classes]`), where values are `True` for positive class
490+ and `False` otherwise.
491+ """
353492 batch_size , num_classes = list (predictions .values ())[0 ].shape
354493 true_scores = torch .zeros (batch_size , num_classes , device = self .device )
355494 false_scores = torch .zeros (batch_size , num_classes , device = self .device )
@@ -369,7 +508,7 @@ def aggregate_predictions(self, predictions, confidences):
369508 # Avoid division by zero: Set valid_counts to 1 where it's zero
370509 valid_counts = self ._num_models_per_label .clamp (min = 1 )
371510
372- # Normalize by valid contributions to prevent bias, this step can be optional depending upon scenario
511+ # Normalize by valid contributions to prevent bias
373512 final_preds = (true_scores / valid_counts ) > (false_scores / valid_counts )
374513
375514 return final_preds
@@ -398,33 +537,76 @@ def _process_for_loss(
398537
399538
400539class ChebiEnsembleLearning (_EnsembleBase ):
540+ """
541+ A specialized ensemble learning class for ChEBI classification.
542+
543+ This ensemble combines multiple models by concatenating their logits and
544+ passing them through a feedforward neural network (FFN) for final predictions.
545+ """
401546
402547 NAME = "ChebiEnsembleLearning"
403548
404- def __init__ (self , model_configs : Dict [str , Dict ], ** kwargs ):
405- super ().__init__ (model_configs , ** kwargs )
549+ def __init__ (self , model_configs : Dict [str , Dict ], ** kwargs : Any ):
550+ """
551+ Initializes the ChebiEnsembleLearning class.
406552
407- from chebai .models .ffn import FFN
553+ Args:
554+ model_configs (Dict[str, Dict]): Configuration dictionary for each model.
555+ **kwargs (Any): Additional keyword arguments for configuring the FFN.
556+ """
557+ super ().__init__ (model_configs , ** kwargs )
408558
409559 ffn_kwargs = kwargs .copy ()
410560 ffn_kwargs ["input_size" ] = len (self .model_configs ) * int (kwargs ["out_dim" ])
411561 self .ffn : FFN = FFN (** ffn_kwargs )
412562
413563 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
564+ """
565+ Performs a forward pass through the ensemble model.
566+
567+ Args:
568+ data (Dict[str, Tensor]): Input data dictionary for the models.
569+ **kwargs (Any): Additional keyword arguments.
570+
571+ Returns:
572+ Dict[str, Any]: Output from the FFN model.
573+ """
414574 logits_list = [model (data )["logits" ] for model in self .models .values ()]
415575 return self .ffn ({"features" : torch .cat (logits_list , dim = 1 )})
416576
417577 def _get_prediction_and_labels (
418- self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
419- ) -> (torch .Tensor , torch .Tensor ):
578+ self , data : Dict [str , Any ], labels : Tensor , output : Tensor
579+ ) -> Tuple [Tensor , Tensor ]:
580+ """
581+ Extracts predictions and labels for evaluation.
582+
583+ Args:
584+ data (Dict[str, Any]): Input data dictionary.
585+ labels (Tensor): Ground truth labels.
586+ output (Tensor): Model output.
587+
588+ Returns:
589+ Tuple[Tensor, Tensor]: Processed predictions and labels.
590+ """
420591 return self .ffn ._get_prediction_and_labels (data , labels , output )
421592
422593 def _process_for_loss (
423594 self ,
424- model_output : Dict [str , torch . Tensor ],
425- labels : torch . Tensor ,
595+ model_output : Dict [str , Tensor ],
596+ labels : Tensor ,
426597 loss_kwargs : Dict [str , Any ],
427- ) -> (torch .Tensor , torch .Tensor , Dict [str , Any ]):
598+ ) -> Tuple [Tensor , Tensor , Dict [str , Any ]]:
599+ """
600+ Processes model output and labels for computing the loss.
601+
602+ Args:
603+ model_output (Dict[str, Tensor]): Output dictionary from the model.
604+ labels (Tensor): Ground truth labels.
605+ loss_kwargs (Dict[str, Any]): Additional arguments for loss computation.
606+
607+ Returns:
608+ Tuple[Tensor, Tensor, Dict[str, Any]]: Loss, processed predictions, and additional info.
609+ """
428610 return self .ffn ._process_for_loss (model_output , labels , loss_kwargs )
429611
430612
0 commit comments