1+ import os
12from abc import ABC
23import torch
34import tqdm
45from rdkit import Chem
56
67from chebifier .prediction_models .base_predictor import BasePredictor
8+ from chebifier .prediction_models .chemlog_predictor import ChemLogPredictor
79from chebifier .prediction_models .electra_predictor import ElectraPredictor
10+ from chebifier .prediction_models .gnn_predictor import ResGatedPredictor
811
912MODEL_TYPES = {
1013 "electra" : ElectraPredictor ,
11- # todo add other model types here
14+ "resgated" : ResGatedPredictor ,
15+ "chemlog" : ChemLogPredictor
1216}
1317
1418class BaseEnsemble (ABC ):
@@ -22,70 +26,73 @@ def __init__(self, model_configs: dict):
2226 self .models .append (model_instance )
2327
2428 def gather_predictions (self , smiles_list ):
25- """
26-
27- :param smiles_list: list of SMILES strings to predict
28- :return:
29- ordered_predictions: torch.Tensor of shape (num_smiles, num_classes, num_models)
30- predicted_classes: list of ChEBI IDs predicted by the models
31- """
3229 model_predictions = []
3330 predicted_classes = set ()
3431 for model in self .models :
3532 model_predictions .append (model .predict_smiles_list (smiles_list ))
36- for predicted_smiles in model_predictions [- 1 ]:
37- if predicted_smiles is not None :
38- for cls in predicted_smiles :
33+ for predicted_labels_for_smiles in model_predictions [- 1 ]:
34+ if predicted_labels_for_smiles is not None :
35+ for cls in predicted_labels_for_smiles :
3936 predicted_classes .add (cls )
4037 print (f"Sorting predictions..." )
4138 predicted_classes = sorted (list (predicted_classes ))
39+ predicted_classes = {cls : i for i , cls in enumerate (predicted_classes )}
4240 ordered_predictions = torch .zeros (len (smiles_list ), len (predicted_classes ), len (self .models )) * torch .nan
4341 for i , model_prediction in enumerate (model_predictions ):
44- for j , predicted_smiles in tqdm .tqdm (enumerate (model_prediction ),
42+ for j , predicted_labels_for_smiles in tqdm .tqdm (enumerate (model_prediction ),
4543 total = len (model_prediction ),
4644 desc = f"Sorting predictions for { self .models [i ].model_name } " ):
47- if predicted_smiles is not None :
48- for cls in predicted_smiles :
49- ordered_predictions [j , predicted_classes . index ( cls ) , i ] = predicted_smiles [cls ]
45+ if predicted_labels_for_smiles is not None :
46+ for cls in predicted_labels_for_smiles :
47+ ordered_predictions [j , predicted_classes [ cls ] , i ] = predicted_labels_for_smiles [cls ]
5048 return ordered_predictions , predicted_classes
5149
5250
53- def aggregate_predictions (self , predictions , predicted_classes , ** kwargs ):
51+ def consolidate_predictions (self , predictions , predicted_classes , classwise_weights , ** kwargs ):
5452 """
55- Aggregates predictions from multiple models using majority voting.
56-
57- :param predictions: torch.Tensor of shape (num_smiles, num_classes, num_models)
58- :param predicted_classes: list of ChEBI IDs predicted by the models
59- :param kwargs: Additional arguments
60- :return: list of lists, where each inner list contains the class IDs that received
61- positive predictions from the majority of models for a given SMILES
53+ Aggregates predictions from multiple models using weighted majority voting.
54+ Optimized version using tensor operations instead of for loops.
6255 """
6356 num_smiles , num_classes , num_models = predictions .shape
64- result = []
6557
66- for i in tqdm .tqdm (range (num_smiles ), total = num_smiles , desc = "Aggregating predictions" ):
67- smiles_result = []
68- for j in range (num_classes ):
69- # Get predictions for this SMILES and class across all models
70- class_predictions = predictions [i , j , :]
58+ # Create a mapping from class indices to class names for faster lookup
59+ class_names = list (predicted_classes .keys ())
60+ class_indices = {predicted_classes [cls ]: cls for cls in class_names }
61+
62+ # Get predictions for all classes
63+ valid_predictions = ~ torch .isnan (predictions )
64+ valid_counts = valid_predictions .sum (dim = 2 ) # Sum over models dimension
65+
66+ # Skip classes with no valid predictions
67+ has_valid_predictions = valid_counts > 0
68+
69+ # Calculate positive and negative predictions for all classes at once
70+ positive_mask = (predictions > 0.5 ) & valid_predictions
71+ negative_mask = (predictions < 0.5 ) & valid_predictions
7172
72- # Count models that made a prediction (not NaN)
73- valid_predictions = ~ torch . isnan ( class_predictions )
74- num_valid_predictions = valid_predictions . sum (). item ( )
73+ # Extract positive and negative weights
74+ pos_weights = classwise_weights [ 0 ] # Shape: (num_classes, num_models )
75+ neg_weights = classwise_weights [ 1 ] # Shape: (num_classes, num_models )
7576
76- # If no valid predictions, skip this class
77- if num_valid_predictions == 0 :
78- continue
77+ # Calculate weighted predictions using broadcasting
78+ # predictions shape: (num_smiles, num_classes, num_models)
79+ # weights shape: (num_classes, num_models)
80+ positive_weighted = positive_mask .float () * (predictions .nan_to_num () - 0.5 ) * pos_weights .unsqueeze (0 )
81+ negative_weighted = negative_mask .float () * (0.5 - predictions .nan_to_num ()) * neg_weights .unsqueeze (0 )
7982
80- # Count positive predictions (assuming positive is > 0)
81- positive_predictions = class_predictions > 0
82- num_positive = ( positive_predictions & valid_predictions ) .sum (). item ( )
83+ # Sum over models dimension
84+ positive_sum = positive_weighted . sum ( dim = 2 ) # Shape: (num_smiles, num_classes)
85+ negative_sum = negative_weighted .sum (dim = 2 ) # Shape: (num_smiles, num_classes )
8386
84- # If majority of models that made a prediction are positive, add this class
85- if num_positive > num_valid_predictions / 2 :
86- smiles_result . append ( predicted_classes [ j ] )
87+ # Determine which classes to include for each SMILES
88+ net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
89+ class_decisions = ( net_score > 0 ) & has_valid_predictions # Shape: (num_smiles, num_classes )
8790
88- result .append (smiles_result )
91+ # Convert tensor decisions to result list using list comprehension for efficiency
92+ result = [
93+ [class_indices [idx .item ()] for idx in torch .nonzero (class_decisions [i ], as_tuple = True )[0 ]]
94+ for i in range (num_smiles )
95+ ]
8996
9097 return result
9198
@@ -102,8 +109,30 @@ def normalize_smiles_list(self, smiles_list):
102109 new .append (canonical_smiles )
103110 return new
104111
105- def predict_smiles_list (self , smiles_list ) -> list :
106- #smiles_list = self.normalize_smiles_list(smiles_list)
107- ordered_predictions , predicted_classes = self .gather_predictions (smiles_list )
108- aggregated_predictions = self .aggregate_predictions (ordered_predictions , predicted_classes )
112+ def calculate_classwise_weights (self , predicted_classes ):
113+ """No weights, simple majority voting"""
114+ positive_weights = torch .ones (len (predicted_classes ), len (self .models ))
115+ negative_weights = torch .ones (len (predicted_classes ), len (self .models ))
116+
117+ return positive_weights , negative_weights
118+
119+ def predict_smiles_list (self , smiles_list , load_preds_if_possible = True ) -> list :
120+ preds_file = f"predictions_by_model_{ '_' .join (model .model_name for model in self .models )} .pt"
121+ predicted_classes_file = f"predicted_classes_{ '_' .join (model .model_name for model in self .models )} .txt"
122+ if not load_preds_if_possible or not os .path .isfile (preds_file ):
123+ #smiles_list = self.normalize_smiles_list(smiles_list)
124+ ordered_predictions , predicted_classes = self .gather_predictions (smiles_list )
125+ # save predictions
126+ torch .save (ordered_predictions , preds_file )
127+ with open (predicted_classes_file , "w" ) as f :
128+ for cls in predicted_classes :
129+ f .write (f"{ cls } \n " )
130+ else :
131+ print (f"Loading predictions from { preds_file } and label indexes from { predicted_classes_file } " )
132+ ordered_predictions = torch .load (preds_file )
133+ with open (predicted_classes_file , "r" ) as f :
134+ predicted_classes = {line .strip (): i for i , line in enumerate (f .readlines ())}
135+
136+ classwise_weights = self .calculate_classwise_weights (predicted_classes )
137+ aggregated_predictions = self .consolidate_predictions (ordered_predictions , predicted_classes , classwise_weights )
109138 return aggregated_predictions
0 commit comments