@@ -498,81 +498,3 @@ def _process_for_loss(
498498 if labels is not None :
499499 labels = labels .float ()
500500 return model_output ["logits" ], labels , kwargs_copy
501-
502-
503- class ChebiEnsembleLearning (_EnsembleBase ):
504- """
505- A specialized ensemble learning class for ChEBI classification.
506-
507- This ensemble combines multiple models by concatenating their logits and
508- passing them through a feedforward neural network (FFN) for final predictions.
509- """
510-
511- NAME = "ChebiEnsembleLearning"
512-
513- def __init__ (self , model_configs : Dict [str , Dict ], ** kwargs : Any ):
514- """
515- Initializes the ChebiEnsembleLearning class.
516-
517- Args:
518- model_configs (Dict[str, Dict]): Configuration dictionary for each model.
519- **kwargs (Any): Additional keyword arguments for configuring the FFN.
520- """
521- super ().__init__ (model_configs , ** kwargs )
522-
523- ffn_kwargs = kwargs .copy ()
524- ffn_kwargs ["input_size" ] = len (self .model_configs ) * int (kwargs ["out_dim" ])
525- self .ffn : FFN = FFN (** ffn_kwargs )
526-
527- def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
528- """
529- Performs a forward pass through the ensemble model.
530-
531- Args:
532- data (Dict[str, Tensor]): Input data dictionary for the models.
533- **kwargs (Any): Additional keyword arguments.
534-
535- Returns:
536- Dict[str, Any]: Output from the FFN model.
537- """
538- logits_list = [model (data )["logits" ] for model in self .models .values ()]
539- return self .ffn ({"features" : torch .cat (logits_list , dim = 1 )})
540-
541- def _get_prediction_and_labels (
542- self , data : Dict [str , Any ], labels : Tensor , output : Tensor
543- ) -> Tuple [Tensor , Tensor ]:
544- """
545- Extracts predictions and labels for evaluation.
546-
547- Args:
548- data (Dict[str, Any]): Input data dictionary.
549- labels (Tensor): Ground truth labels.
550- output (Tensor): Model output.
551-
552- Returns:
553- Tuple[Tensor, Tensor]: Processed predictions and labels.
554- """
555- return self .ffn ._get_prediction_and_labels (data , labels , output )
556-
557- def _process_for_loss (
558- self ,
559- model_output : Dict [str , Tensor ],
560- labels : Tensor ,
561- loss_kwargs : Dict [str , Any ],
562- ) -> Tuple [Tensor , Tensor , Dict [str , Any ]]:
563- """
564- Processes model output and labels for computing the loss.
565-
566- Args:
567- model_output (Dict[str, Tensor]): Output dictionary from the model.
568- labels (Tensor): Ground truth labels.
569- loss_kwargs (Dict[str, Any]): Additional arguments for loss computation.
570-
571- Returns:
572- Tuple[Tensor, Tensor, Dict[str, Any]]: Loss, processed predictions, and additional info.
573- """
574- return self .ffn ._process_for_loss (model_output , labels , loss_kwargs )
575-
576-
577- if __name__ == "__main__" :
578- pass
0 commit comments