11import os
22from abc import ABC
3+
34import torch
45import tqdm
5- from rdkit import Chem
66
77from chebifier .prediction_models .base_predictor import BasePredictor
88from chebifier .prediction_models .chemlog_predictor import ChemLogPredictor
1212MODEL_TYPES = {
1313 "electra" : ElectraPredictor ,
1414 "resgated" : ResGatedPredictor ,
15- "chemlog" : ChemLogPredictor
15+ "chemlog" : ChemLogPredictor ,
1616}
1717
18- class BaseEnsemble (ABC ):
1918
19+ class BaseEnsemble (ABC ):
2020 def __init__ (self , model_configs : dict ):
2121 self .models = []
2222 self .positive_prediction_threshold = 0.5
@@ -37,22 +37,30 @@ def gather_predictions(self, smiles_list):
3737 if logits_for_smiles is not None :
3838 for cls in logits_for_smiles :
3939 predicted_classes .add (cls )
40- print (f "Sorting predictions..." )
40+ print ("Sorting predictions..." )
4141 predicted_classes = sorted (list (predicted_classes ))
4242 predicted_classes_dict = {cls : i for i , cls in enumerate (predicted_classes )}
43- ordered_logits = torch .zeros (len (smiles_list ), len (predicted_classes ), len (self .models )) * torch .nan
43+ ordered_logits = (
44+ torch .zeros (len (smiles_list ), len (predicted_classes ), len (self .models ))
45+ * torch .nan
46+ )
4447 for i , model_prediction in enumerate (model_predictions ):
45- for j , logits_for_smiles in tqdm .tqdm (enumerate (model_prediction ),
46- total = len (model_prediction ),
47- desc = f"Sorting predictions for { self .models [i ].model_name } " ):
48+ for j , logits_for_smiles in tqdm .tqdm (
49+ enumerate (model_prediction ),
50+ total = len (model_prediction ),
51+ desc = f"Sorting predictions for { self .models [i ].model_name } " ,
52+ ):
4853 if logits_for_smiles is not None :
4954 for cls in logits_for_smiles :
50- ordered_logits [j , predicted_classes_dict [cls ], i ] = logits_for_smiles [cls ]
55+ ordered_logits [j , predicted_classes_dict [cls ], i ] = (
56+ logits_for_smiles [cls ]
57+ )
5158
52- return ordered_logits , predicted_classes
59+ return ordered_logits , predicted_classes_dict
5360
54-
55- def consolidate_predictions (self , predictions , predicted_classes , classwise_weights , ** kwargs ):
61+ def consolidate_predictions (
62+ self , predictions , predicted_classes , classwise_weights , ** kwargs
63+ ):
5664 """
5765 Aggregates predictions from multiple models using weighted majority voting.
5866 Optimized version using tensor operations instead of for loops.
@@ -74,7 +82,9 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
7482 positive_mask = (predictions > 0.5 ) & valid_predictions
7583 negative_mask = (predictions < 0.5 ) & valid_predictions
7684
77- confidence = 2 * torch .abs (predictions .nan_to_num () - self .positive_prediction_threshold )
85+ confidence = 2 * torch .abs (
86+ predictions .nan_to_num () - self .positive_prediction_threshold
87+ )
7888
7989 # Extract positive and negative weights
8090 pos_weights = classwise_weights [0 ] # Shape: (num_classes, num_models)
@@ -83,26 +93,34 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
8393 # Calculate weighted predictions using broadcasting
8494 # predictions shape: (num_smiles, num_classes, num_models)
8595 # weights shape: (num_classes, num_models)
86- positive_weighted = positive_mask .float () * confidence * pos_weights .unsqueeze (0 )
87- negative_weighted = negative_mask .float () * confidence * neg_weights .unsqueeze (0 )
96+ positive_weighted = (
97+ positive_mask .float () * confidence * pos_weights .unsqueeze (0 )
98+ )
99+ negative_weighted = (
100+ negative_mask .float () * confidence * neg_weights .unsqueeze (0 )
101+ )
88102
89103 # Sum over models dimension
90104 positive_sum = positive_weighted .sum (dim = 2 ) # Shape: (num_smiles, num_classes)
91105 negative_sum = negative_weighted .sum (dim = 2 ) # Shape: (num_smiles, num_classes)
92106
93107 # Determine which classes to include for each SMILES
94108 net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
95- class_decisions = (net_score > 0 ) & has_valid_predictions # Shape: (num_smiles, num_classes)
109+ class_decisions = (
110+ net_score > 0
111+ ) & has_valid_predictions # Shape: (num_smiles, num_classes)
96112
97113 # Convert tensor decisions to result list using list comprehension for efficiency
98114 result = [
99- [class_indices [idx .item ()] for idx in torch .nonzero (class_decisions [i ], as_tuple = True )[0 ]]
115+ [
116+ class_indices [idx .item ()]
117+ for idx in torch .nonzero (class_decisions [i ], as_tuple = True )[0 ]
118+ ]
100119 for i in range (num_smiles )
101120 ]
102121
103122 return result
104123
105-
106124 def calculate_classwise_weights (self , predicted_classes ):
107125 """No weights, simple majority voting"""
108126 positive_weights = torch .ones (len (predicted_classes ), len (self .models ))
@@ -114,18 +132,26 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
114132 preds_file = f"predictions_by_model_{ '_' .join (model .model_name for model in self .models )} .pt"
115133 predicted_classes_file = f"predicted_classes_{ '_' .join (model .model_name for model in self .models )} .txt"
116134 if not load_preds_if_possible or not os .path .isfile (preds_file ):
117- ordered_predictions , predicted_classes = self .gather_predictions (smiles_list )
135+ ordered_predictions , predicted_classes = self .gather_predictions (
136+ smiles_list
137+ )
118138 # save predictions
119139 torch .save (ordered_predictions , preds_file )
120140 with open (predicted_classes_file , "w" ) as f :
121141 for cls in predicted_classes :
122142 f .write (f"{ cls } \n " )
123143 else :
124- print (f"Loading predictions from { preds_file } and label indexes from { predicted_classes_file } " )
144+ print (
145+ f"Loading predictions from { preds_file } and label indexes from { predicted_classes_file } "
146+ )
125147 ordered_predictions = torch .load (preds_file )
126148 with open (predicted_classes_file , "r" ) as f :
127- predicted_classes = {line .strip (): i for i , line in enumerate (f .readlines ())}
149+ predicted_classes = {
150+ line .strip (): i for i , line in enumerate (f .readlines ())
151+ }
128152
129153 classwise_weights = self .calculate_classwise_weights (predicted_classes )
130- aggregated_predictions = self .consolidate_predictions (ordered_predictions , predicted_classes , classwise_weights )
154+ aggregated_predictions = self .consolidate_predictions (
155+ ordered_predictions , predicted_classes , classwise_weights
156+ )
131157 return aggregated_predictions
0 commit comments