@@ -19,33 +19,37 @@ class BaseEnsemble(ABC):
1919
2020 def __init__ (self , model_configs : dict ):
2121 self .models = []
22+ self .positive_prediction_threshold = 0.5
2223 for model_name , model_config in model_configs .items ():
2324 model_cls = MODEL_TYPES [model_config ["type" ]]
2425 model_instance = model_cls (** model_config )
2526 assert isinstance (model_instance , BasePredictor )
2627 self .models .append (model_instance )
2728
2829 def gather_predictions (self , smiles_list ):
30+ # get predictions from all models for the SMILES list
31+ # order them by alphabetically by label class
2932 model_predictions = []
3033 predicted_classes = set ()
3134 for model in self .models :
3235 model_predictions .append (model .predict_smiles_list (smiles_list ))
33- for predicted_labels_for_smiles in model_predictions [- 1 ]:
34- if predicted_labels_for_smiles is not None :
35- for cls in predicted_labels_for_smiles :
36+ for logits_for_smiles in model_predictions [- 1 ]:
37+ if logits_for_smiles is not None :
38+ for cls in logits_for_smiles :
3639 predicted_classes .add (cls )
3740 print (f"Sorting predictions..." )
3841 predicted_classes = sorted (list (predicted_classes ))
3942 predicted_classes = {cls : i for i , cls in enumerate (predicted_classes )}
40- ordered_predictions = torch .zeros (len (smiles_list ), len (predicted_classes ), len (self .models )) * torch .nan
43+ ordered_logits = torch .zeros (len (smiles_list ), len (predicted_classes ), len (self .models )) * torch .nan
4144 for i , model_prediction in enumerate (model_predictions ):
42- for j , predicted_labels_for_smiles in tqdm .tqdm (enumerate (model_prediction ),
45+ for j , logits_for_smiles in tqdm .tqdm (enumerate (model_prediction ),
4346 total = len (model_prediction ),
4447 desc = f"Sorting predictions for { self .models [i ].model_name } " ):
45- if predicted_labels_for_smiles is not None :
46- for cls in predicted_labels_for_smiles :
47- ordered_predictions [j , predicted_classes [cls ], i ] = predicted_labels_for_smiles [cls ]
48- return ordered_predictions , predicted_classes
48+ if logits_for_smiles is not None :
49+ for cls in logits_for_smiles :
50+ ordered_logits [j , predicted_classes [cls ], i ] = logits_for_smiles [cls ]
51+
52+ return ordered_logits , predicted_classes
4953
5054
5155 def consolidate_predictions (self , predictions , predicted_classes , classwise_weights , ** kwargs ):
@@ -70,15 +74,17 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
7074 positive_mask = (predictions > 0.5 ) & valid_predictions
7175 negative_mask = (predictions < 0.5 ) & valid_predictions
7276
77+ confidence = 2 * torch .abs (predictions .nan_to_num () - self .positive_prediction_threshold )
78+
7379 # Extract positive and negative weights
7480 pos_weights = classwise_weights [0 ] # Shape: (num_classes, num_models)
7581 neg_weights = classwise_weights [1 ] # Shape: (num_classes, num_models)
7682
7783 # Calculate weighted predictions using broadcasting
7884 # predictions shape: (num_smiles, num_classes, num_models)
7985 # weights shape: (num_classes, num_models)
80- positive_weighted = positive_mask .float () * ( predictions . nan_to_num () - 0.5 ) * pos_weights .unsqueeze (0 )
81- negative_weighted = negative_mask .float () * ( 0.5 - predictions . nan_to_num ()) * neg_weights .unsqueeze (0 )
86+ positive_weighted = positive_mask .float () * confidence * pos_weights .unsqueeze (0 )
87+ negative_weighted = negative_mask .float () * confidence * neg_weights .unsqueeze (0 )
8288
8389 # Sum over models dimension
8490 positive_sum = positive_weighted .sum (dim = 2 ) # Shape: (num_smiles, num_classes)
@@ -96,18 +102,6 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
96102
97103 return result
98104
99- def normalize_smiles_list (self , smiles_list ):
100- new = []
101- print (f"Normalizing SMILES strings..." )
102- for smiles in tqdm .tqdm (smiles_list ):
103- try :
104- mol = Chem .MolFromSmiles (smiles )
105- canonical_smiles = Chem .MolToSmiles (mol )
106- except Exception as e :
107- print (f"Failed to parse SMILES '{ smiles } ': { e } " )
108- canonical_smiles = None
109- new .append (canonical_smiles )
110- return new
111105
112106 def calculate_classwise_weights (self , predicted_classes ):
113107 """No weights, simple majority voting"""
@@ -120,8 +114,7 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
120114 preds_file = f"predictions_by_model_{ '_' .join (model .model_name for model in self .models )} .pt"
121115 predicted_classes_file = f"predicted_classes_{ '_' .join (model .model_name for model in self .models )} .txt"
122116 if not load_preds_if_possible or not os .path .isfile (preds_file ):
123- #smiles_list = self.normalize_smiles_list(smiles_list)
124- ordered_predictions , predicted_classes = self .gather_predictions (smiles_list )
117+ ordered_predictions = predicted_classes = self .gather_predictions (smiles_list )
125118 # save predictions
126119 torch .save (ordered_predictions , preds_file )
127120 with open (predicted_classes_file , "w" ) as f :
0 commit comments