@@ -14,8 +14,6 @@ class _EnsembleBase(ChebaiBaseNet, ABC):
1414 def __init__ (self , model_configs : Dict [str , ModelConfig ], ** kwargs ):
1515 super ().__init__ (** kwargs )
1616
17- self ._validate_model_configs (model_configs )
18-
1917 self .models : Dict [str , ChebaiBaseNet ] = {}
2018 self .model_configs : Dict [str , ModelConfig ] = model_configs
2119
@@ -41,6 +39,23 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
4139 # else:
4240 # self.threshold = int(kwargs["threshold"])
4341
42+ @abstractmethod
43+ def _get_prediction_and_labels (
44+ self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
45+ ) -> (torch .Tensor , torch .Tensor ):
46+ pass
47+
48+
49+ class ChebiEnsemble (_EnsembleBase ):
50+
51+ NAME = "ChebiEnsemble"
52+
53+ def __init__ (self , model_configs : Dict [str , ModelConfig ], ** kwargs ):
54+ self ._validate_model_configs (model_configs )
55+ super ().__init__ (model_configs , ** kwargs )
56+ # Add a dummy trainable parameter
57+ self .dummy_param = torch .nn .Parameter (torch .randn (1 , requires_grad = True ))
58+
4459 @classmethod
4560 def _validate_model_configs (cls , model_configs : Dict [str , ModelConfig ]):
4661 path_set = set ()
@@ -80,22 +95,6 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
8095 f"'{ key } ' in model '{ model_name } ' must be a float or convertible to float, but got { config [key ]} ."
8196 )
8297
83- @abstractmethod
84- def _get_prediction_and_labels (
85- self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
86- ) -> (torch .Tensor , torch .Tensor ):
87- pass
88-
89-
90- class ChebiEnsemble (_EnsembleBase ):
91-
92- NAME = "ChebiEnsemble"
93-
94- def __init__ (self , model_configs : Dict [str , ModelConfig ], ** kwargs ):
95- super ().__init__ (model_configs , ** kwargs )
96- # Add a dummy trainable parameter
97- self .dummy_param = torch .nn .Parameter (torch .randn (1 , requires_grad = True ))
98-
9998 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
10099 predictions = {}
101100 confidences = {}
@@ -255,30 +254,35 @@ class ChebiEnsembleLearning(_EnsembleBase):
255254
256255 NAME = "ChebiEnsembleLearning"
257256
258- def __init__ (self , model_configs : Dict [str , ModelConfig ], ** kwargs ):
257+ def __init__ (self , model_configs : Dict [str , Dict ], ** kwargs ):
259258 super ().__init__ (model_configs , ** kwargs )
260- self .ensemble_classifier = torch .nn .Linear (
261- in_features = len (self .models ) * self .out_dim , out_features = self .out_dim
262- )
263259
264- def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
265- predictions = {}
266- confidences = {}
260+ from chebai .models .ffn import FFN
267261
262+ ffn_kwargs = kwargs .copy ()
263+ ffn_kwargs ["input_size" ] = len (self .model_configs ) * int (kwargs ["out_dim" ])
264+ self .ffn : FFN = FFN (** ffn_kwargs )
265+
266+ def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
267+ logits_list = []
268268 for name , model in self .models .items ():
269- output = model (data ["features" ])
270- confidence = torch .sigmoid (output ) # Assuming confidence scores
271- predictions [name ] = output .argmax (dim = 1 ) # Convert logits to class
272- confidences [name ] = confidence .max (dim = 1 ).values # Max confidence
269+ output = model (data )
270+ logits_list .append (output ["logits" ])
273271
274- # Aggregate predictions using weighted voting
275- final_preds = self .aggregate_predictions (predictions , confidences )
276- return final_preds
272+ return self .ffn ({"features" : torch .cat (logits_list , dim = 1 )})
277273
278274 def _get_prediction_and_labels (
279275 self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
280276 ) -> (torch .Tensor , torch .Tensor ):
281- pass
277+ return self .ffn ._get_prediction_and_labels (data , labels , output )
278+
279+ def _process_for_loss (
280+ self ,
281+ model_output : Dict [str , torch .Tensor ],
282+ labels : torch .Tensor ,
283+ loss_kwargs : Dict [str , Any ],
284+ ) -> (torch .Tensor , torch .Tensor , Dict [str , Any ]):
285+ return self .ffn ._process_for_loss (model_output , labels , loss_kwargs )
282286
283287
284288if __name__ == "__main__" :
0 commit comments