Skip to content

Commit 0ec03b1

Browse files
committed
remove ensemble learning class
1 parent 405026e commit 0ec03b1

File tree

1 file changed

+0
-78
lines changed

1 file changed

+0
-78
lines changed

chebai/models/ensemble.py

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -498,81 +498,3 @@ def _process_for_loss(
498498
if labels is not None:
499499
labels = labels.float()
500500
return model_output["logits"], labels, kwargs_copy
501-
502-
503-
class ChebiEnsembleLearning(_EnsembleBase):
504-
"""
505-
A specialized ensemble learning class for ChEBI classification.
506-
507-
This ensemble combines multiple models by concatenating their logits and
508-
passing them through a feedforward neural network (FFN) for final predictions.
509-
"""
510-
511-
NAME = "ChebiEnsembleLearning"
512-
513-
def __init__(self, model_configs: Dict[str, Dict], **kwargs: Any):
514-
"""
515-
Initializes the ChebiEnsembleLearning class.
516-
517-
Args:
518-
model_configs (Dict[str, Dict]): Configuration dictionary for each model.
519-
**kwargs (Any): Additional keyword arguments for configuring the FFN.
520-
"""
521-
super().__init__(model_configs, **kwargs)
522-
523-
ffn_kwargs = kwargs.copy()
524-
ffn_kwargs["input_size"] = len(self.model_configs) * int(kwargs["out_dim"])
525-
self.ffn: FFN = FFN(**ffn_kwargs)
526-
527-
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
528-
"""
529-
Performs a forward pass through the ensemble model.
530-
531-
Args:
532-
data (Dict[str, Tensor]): Input data dictionary for the models.
533-
**kwargs (Any): Additional keyword arguments.
534-
535-
Returns:
536-
Dict[str, Any]: Output from the FFN model.
537-
"""
538-
logits_list = [model(data)["logits"] for model in self.models.values()]
539-
return self.ffn({"features": torch.cat(logits_list, dim=1)})
540-
541-
def _get_prediction_and_labels(
542-
self, data: Dict[str, Any], labels: Tensor, output: Tensor
543-
) -> Tuple[Tensor, Tensor]:
544-
"""
545-
Extracts predictions and labels for evaluation.
546-
547-
Args:
548-
data (Dict[str, Any]): Input data dictionary.
549-
labels (Tensor): Ground truth labels.
550-
output (Tensor): Model output.
551-
552-
Returns:
553-
Tuple[Tensor, Tensor]: Processed predictions and labels.
554-
"""
555-
return self.ffn._get_prediction_and_labels(data, labels, output)
556-
557-
def _process_for_loss(
558-
self,
559-
model_output: Dict[str, Tensor],
560-
labels: Tensor,
561-
loss_kwargs: Dict[str, Any],
562-
) -> Tuple[Tensor, Tensor, Dict[str, Any]]:
563-
"""
564-
Processes model output and labels for computing the loss.
565-
566-
Args:
567-
model_output (Dict[str, Tensor]): Output dictionary from the model.
568-
labels (Tensor): Ground truth labels.
569-
loss_kwargs (Dict[str, Any]): Additional arguments for loss computation.
570-
571-
Returns:
572-
Tuple[Tensor, Tensor, Dict[str, Any]]: Loss, processed predictions, and additional info.
573-
"""
574-
return self.ffn._process_for_loss(model_output, labels, loss_kwargs)
575-
576-
577-
if __name__ == "__main__":
578-
pass

0 commit comments

Comments
 (0)