@@ -22,13 +22,17 @@ class EnsembleBase(ABC):
2222
2323 Attributes:
2424 data_processed_dir_main (str): Directory where the processed data is stored.
25- models (Dict[str, LightningModule]): A dictionary of loaded models.
25+ _models (Dict[str, LightningModule]): A dictionary of loaded models.
2626 model_configs (Dict[str, Dict]): Configuration dictionary for models in the ensemble.
27- dm_labels (Dict[str, int]): Mapping of label names to integer indices.
27+ _dm_labels (Dict[str, int]): Mapping of label names to integer indices.
2828 """
2929
3030 def __init__ (
31- self , model_configs : Dict [str , Dict ], data_processed_dir_main : str , ** kwargs
31+ self ,
32+ model_configs : Dict [str , Dict ],
33+ data_processed_dir_main : str ,
34+ reader_dir_name : str = "smiles_token" ,
35+ ** kwargs ,
3236 ):
3337 """
3438 Initializes the ensemble model and loads configuration, models, and labels.
@@ -41,22 +45,25 @@ def __init__(
4145 if bool (kwargs .get ("_validate_configs" , True )):
4246 self ._validate_model_configs (model_configs )
4347
44- self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
48+ self .model_configs = model_configs
49+ self .data_processed_dir_main = data_processed_dir_main
50+ self .reader_dir_name = reader_dir_name
4551 self .input_dim = kwargs .get ("input_dim" , None )
46- self .num_of_labels : Optional [int ] = (
52+
53+ self ._device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
54+ self ._num_of_labels : Optional [int ] = (
4755 None # will be set by `_load_data_module_labels` method
4856 )
49- self .data_processed_dir_main = data_processed_dir_main
50- self .models : Dict [str , LightningModule ] = {}
51- self .model_configs = model_configs
52- self .dm_labels : Dict [str , int ] = {}
57+ self ._models : Dict [str , LightningModule ] = {}
58+ self ._dm_labels : Dict [str , int ] = {}
5359
5460 self ._load_data_module_labels ()
5561 self ._num_models_per_label : torch .Tensor = torch .zeros (
56- 1 , self .num_of_labels , device = self .device
62+ 1 , self ._num_of_labels , device = self ._device
5763 )
5864 self ._model_queue : Deque = deque ()
5965 self ._collated_data = None
66+ self ._total_data_size : Optional [int ] = None
6067
6168 @classmethod
6269 def _validate_model_configs (cls , model_configs : Dict [str , Dict ]):
@@ -121,14 +128,17 @@ def _load_data_module_labels(self):
121128 else :
122129 with open (classes_txt_file , "r" ) as f :
123130 for line in f :
124- if line .strip () not in self .dm_labels :
125- self .dm_labels [line .strip ()] = len (self .dm_labels )
126- self .num_of_labels = len (self .dm_labels )
131+ if line .strip () not in self ._dm_labels :
132+ self ._dm_labels [line .strip ()] = len (self ._dm_labels )
133+ self ._num_of_labels = len (self ._dm_labels )
127134
128135 def run_ensemble (self ):
129- batch_size = 10
130- true_scores = torch .zeros (batch_size , self .num_of_labels , device = self .device )
131- false_scores = torch .zeros (batch_size , self .num_of_labels , device = self .device )
136+ true_scores = torch .zeros (
137+ self ._total_data_size , self ._num_of_labels , device = self ._device
138+ )
139+ false_scores = torch .zeros (
140+ self ._total_data_size , self ._num_of_labels , device = self ._device
141+ )
132142
133143 while self ._model_queue :
134144 model_name = self ._model_queue .popleft ()
@@ -156,8 +166,8 @@ def run_ensemble(self):
156166 print_metrics (
157167 final_preds ,
158168 self ._collated_data .y ,
159- self .device ,
160- classes = list (self .dm_labels .keys ()),
169+ self ._device ,
170+ classes = list (self ._dm_labels .keys ()),
161171 )
162172
163173 def _load_model_and_its_props (self , model_name ):
@@ -209,33 +219,33 @@ def _generate_model_label_props(self, labels_path: str):
209219
210220 model_label_indices , tpv_label_values , fpv_label_values = [], [], []
211221 for label in labels_dict .keys ():
212- if label in self .dm_labels :
222+ if label in self ._dm_labels :
213223 try :
214224 self ._validate_model_labels_json_element (labels_dict [label ])
215225 except Exception as e :
216226 raise Exception (f"Label '{ label } ' has an unexpected error: { e } " )
217227
218- model_label_indices .append (self .dm_labels [label ])
228+ model_label_indices .append (self ._dm_labels [label ])
219229 tpv_label_values .append (labels_dict [label ]["TPV" ])
220230 fpv_label_values .append (labels_dict [label ]["FPV" ])
221231
222232 if not all ([model_label_indices , tpv_label_values , fpv_label_values ]):
223233 raise ValueError (f"Values are empty for labels of the model" )
224234
225235 # Create masks to apply predictions only to known classes
226- mask = torch .zeros (self .num_of_labels , device = self .device , dtype = torch .bool )
227- mask [torch . tensor ( model_label_indices , dtype = torch . int , device = self . device )] = (
228- True
229- )
236+ mask = torch .zeros (self ._num_of_labels , device = self ._device , dtype = torch .bool )
237+ mask [
238+ torch . tensor ( model_label_indices , dtype = torch . int , device = self . _device )
239+ ] = True
230240
231- tpv_tensor = torch .full_like (mask , - 1 , dtype = torch .float , device = self .device )
232- fpv_tensor = torch .full_like (mask , - 1 , dtype = torch .float , device = self .device )
241+ tpv_tensor = torch .full_like (mask , - 1 , dtype = torch .float , device = self ._device )
242+ fpv_tensor = torch .full_like (mask , - 1 , dtype = torch .float , device = self ._device )
233243
234244 tpv_tensor [mask ] = torch .tensor (
235- tpv_label_values , dtype = torch .float , device = self .device
245+ tpv_label_values , dtype = torch .float , device = self ._device
236246 )
237247 fpv_tensor [mask ] = torch .tensor (
238- fpv_label_values , dtype = torch .float , device = self .device
248+ fpv_label_values , dtype = torch .float , device = self ._device
239249 )
240250 self ._num_models_per_label += mask
241251 return {"mask" : mask , "tpv_tensor" : tpv_tensor , "fpv_tensor" : fpv_tensor }
0 commit comments