44import torch
55import tqdm
66from chebai .preprocessing .datasets .chebi import ChEBIOver50
7- from chebai .result .analyse_sem import PredictionSmoother
7+ from chebai .result .analyse_sem import PredictionSmoother , get_chebi_graph
88
99from chebifier .prediction_models .base_predictor import BasePredictor
1010
@@ -15,6 +15,14 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
1515 # Deferred Import: To avoid circular import error
1616 from chebifier .model_registry import MODEL_TYPES
1717
18+ self .chebi_dataset = ChEBIOver50 (chebi_version = chebi_version )
19+ self .chebi_dataset ._download_required_data () # download chebi if not already downloaded
20+ self .chebi_graph = get_chebi_graph (self .chebi_dataset , None )
21+ self .disjoint_files = [
22+ os .path .join ("data" , "disjoint_chebi.csv" ),
23+ os .path .join ("data" , "disjoint_additional.csv" ),
24+ ]
25+
1826 self .models = []
1927 self .positive_prediction_threshold = 0.5
2028 for model_name , model_config in model_configs .items ():
@@ -25,17 +33,12 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
2533 else :
2634 hugging_face_kwargs = {}
2735 model_instance = model_cls (
28- model_name , ** model_config , ** hugging_face_kwargs
36+ model_name , ** model_config , ** hugging_face_kwargs , chebi_graph = self . chebi_graph
2937 )
3038 assert isinstance (model_instance , BasePredictor )
3139 self .models .append (model_instance )
3240
33- self .chebi_dataset = ChEBIOver50 (chebi_version = chebi_version )
34- self .chebi_dataset ._download_required_data () # download chebi if not already downloaded
35- self .disjoint_files = [
36- os .path .join ("data" , "disjoint_chebi.csv" ),
37- os .path .join ("data" , "disjoint_additional.csv" ),
38- ]
41+
3942
4043 self .smoother = PredictionSmoother (
4144 self .chebi_dataset ,
@@ -54,7 +57,7 @@ def gather_predictions(self, smiles_list):
5457 if logits_for_smiles is not None :
5558 for cls in logits_for_smiles :
5659 predicted_classes .add (cls )
57- print ("Sorting predictions..." )
60+ print (f "Sorting predictions from { len ( model_predictions ) } models ..." )
5861 predicted_classes = sorted (list (predicted_classes ))
5962 predicted_classes_dict = {cls : i for i , cls in enumerate (predicted_classes )}
6063 ordered_logits = (
@@ -75,7 +78,7 @@ def gather_predictions(self, smiles_list):
7578
7679 return ordered_logits , predicted_classes
7780
78- def consolidate_predictions (self , predictions , classwise_weights , ** kwargs ):
81+ def consolidate_predictions (self , predictions , classwise_weights , predicted_classes , ** kwargs ):
7982 """
8083 Aggregates predictions from multiple models using weighted majority voting.
8184 Optimized version using tensor operations instead of for loops.
@@ -124,8 +127,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
124127
125128 # Determine which classes to include for each SMILES
126129 net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
130+
131+ # Smooth predictions
132+ start_time = time .perf_counter ()
133+ class_names = list (predicted_classes .keys ())
134+ self .smoother .set_label_names (class_names )
135+ smooth_net_score = self .smoother (net_score )
136+ end_time = time .perf_counter ()
137+ print (f"Prediction smoothing took { end_time - start_time :.2f} seconds" )
138+
127139 class_decisions = (
128- net_score > 0
140+ smooth_net_score > 0.5
129141 ) & has_valid_predictions # Shape: (num_smiles, num_classes)
130142
131143 complete_failure = torch .all (~ has_valid_predictions , dim = 1 )
@@ -139,14 +151,16 @@ def calculate_classwise_weights(self, predicted_classes):
139151 return positive_weights , negative_weights
140152
141153 def predict_smiles_list (
142- self , smiles_list , load_preds_if_possible = True , ** kwargs
154+ self , smiles_list , load_preds_if_possible = False , ** kwargs
143155 ) -> list :
144156 preds_file = f"predictions_by_model_{ '_' .join (model .model_name for model in self .models )} .pt"
145157 predicted_classes_file = f"predicted_classes_{ '_' .join (model .model_name for model in self .models )} .txt"
146158 if not load_preds_if_possible or not os .path .isfile (preds_file ):
147159 ordered_predictions , predicted_classes = self .gather_predictions (
148160 smiles_list
149161 )
162+ if len (predicted_classes ) == 0 :
163+ print (f"Warning: No classes have been predicted for the given SMILES list." )
150164 # save predictions
151165 torch .save (ordered_predictions , preds_file )
152166 with open (predicted_classes_file , "w" ) as f :
@@ -165,15 +179,8 @@ def predict_smiles_list(
165179
166180 classwise_weights = self .calculate_classwise_weights (predicted_classes )
167181 class_decisions , is_failure = self .consolidate_predictions (
168- ordered_predictions , classwise_weights , ** kwargs
182+ ordered_predictions , classwise_weights , predicted_classes , ** kwargs
169183 )
170- # Smooth predictions
171- start_time = time .perf_counter ()
172- class_names = list (predicted_classes .keys ())
173- self .smoother .set_label_names (class_names )
174- class_decisions = self .smoother (class_decisions )
175- end_time = time .perf_counter ()
176- print (f"Prediction smoothing took { end_time - start_time :.2f} seconds" )
177184
178185 class_names = list (predicted_classes .keys ())
179186 class_indices = {predicted_classes [cls ]: cls for cls in class_names }
0 commit comments