@@ -22,18 +22,19 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
2222 hugging_face_kwargs = download_model_files (model_config ["hugging_face" ])
2323 else :
2424 hugging_face_kwargs = {}
25- model_instance = model_cls (model_name , ** model_config , ** hugging_face_kwargs )
25+ model_instance = model_cls (
26+ model_name , ** model_config , ** hugging_face_kwargs
27+ )
2628 assert isinstance (model_instance , BasePredictor )
2729 self .models .append (model_instance )
2830
2931 self .chebi_dataset = ChEBIOver50 (chebi_version = chebi_version )
3032 self .chebi_dataset ._download_required_data () # download chebi if not already downloaded
31- self .disjoint_files = [
33+ self .disjoint_files = [
3234 os .path .join ("data" , "disjoint_chebi.csv" ),
33- os .path .join ("data" , "disjoint_additional.csv" )
35+ os .path .join ("data" , "disjoint_additional.csv" ),
3436 ]
3537
36-
3738 def gather_predictions (self , smiles_list ):
3839 # get predictions from all models for the SMILES list
3940 # order them by alphabetically by label class
@@ -60,11 +61,12 @@ def gather_predictions(self, smiles_list):
6061 ):
6162 if logits_for_smiles is not None :
6263 for cls in logits_for_smiles :
63- ordered_logits [j , predicted_classes_dict [cls ], i ] = logits_for_smiles [cls ]
64+ ordered_logits [j , predicted_classes_dict [cls ], i ] = (
65+ logits_for_smiles [cls ]
66+ )
6467
6568 return ordered_logits , predicted_classes
6669
67-
6870 def consolidate_predictions (self , predictions , classwise_weights , ** kwargs ):
6971 """
7072 Aggregates predictions from multiple models using weighted majority voting.
@@ -80,11 +82,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
8082 has_valid_predictions = valid_counts > 0
8183
8284 # Calculate positive and negative predictions for all classes at once
83- positive_mask = (predictions > self .positive_prediction_threshold ) & valid_predictions
84- negative_mask = (predictions < self .positive_prediction_threshold ) & valid_predictions
85+ positive_mask = (
86+ predictions > self .positive_prediction_threshold
87+ ) & valid_predictions
88+ negative_mask = (
89+ predictions < self .positive_prediction_threshold
90+ ) & valid_predictions
8591
8692 if "use_confidence" in kwargs and kwargs ["use_confidence" ]:
87- confidence = 2 * torch .abs (predictions .nan_to_num () - self .positive_prediction_threshold )
93+ confidence = 2 * torch .abs (
94+ predictions .nan_to_num () - self .positive_prediction_threshold
95+ )
8896 else :
8997 confidence = torch .ones_like (predictions )
9098
@@ -95,18 +103,22 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
95103 # Calculate weighted predictions using broadcasting
96104 # predictions shape: (num_smiles, num_classes, num_models)
97105 # weights shape: (num_classes, num_models)
98- positive_weighted = positive_mask .float () * confidence * pos_weights .unsqueeze (0 )
99- negative_weighted = negative_mask .float () * confidence * neg_weights .unsqueeze (0 )
106+ positive_weighted = (
107+ positive_mask .float () * confidence * pos_weights .unsqueeze (0 )
108+ )
109+ negative_weighted = (
110+ negative_mask .float () * confidence * neg_weights .unsqueeze (0 )
111+ )
100112
101113 # Sum over models dimension
102114 positive_sum = positive_weighted .sum (dim = 2 ) # Shape: (num_smiles, num_classes)
103115 negative_sum = negative_weighted .sum (dim = 2 ) # Shape: (num_smiles, num_classes)
104116
105117 # Determine which classes to include for each SMILES
106118 net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
107- class_decisions = (net_score > 0 ) & has_valid_predictions # Shape: (num_smiles, num_classes)
108-
109-
119+ class_decisions = (
120+ net_score > 0
121+ ) & has_valid_predictions # Shape: (num_smiles, num_classes)
110122
111123 return class_decisions
112124
@@ -117,29 +129,43 @@ def calculate_classwise_weights(self, predicted_classes):
117129
118130 return positive_weights , negative_weights
119131
120- def predict_smiles_list (self , smiles_list , load_preds_if_possible = True , ** kwargs ) -> list :
132+ def predict_smiles_list (
133+ self , smiles_list , load_preds_if_possible = True , ** kwargs
134+ ) -> list :
121135 preds_file = f"predictions_by_model_{ '_' .join (model .model_name for model in self .models )} .pt"
122136 predicted_classes_file = f"predicted_classes_{ '_' .join (model .model_name for model in self .models )} .txt"
123137 if not load_preds_if_possible or not os .path .isfile (preds_file ):
124- ordered_predictions , predicted_classes = self .gather_predictions (smiles_list )
138+ ordered_predictions , predicted_classes = self .gather_predictions (
139+ smiles_list
140+ )
125141 # save predictions
126142 torch .save (ordered_predictions , preds_file )
127143 with open (predicted_classes_file , "w" ) as f :
128144 for cls in predicted_classes :
129145 f .write (f"{ cls } \n " )
130146 predicted_classes = {cls : i for i , cls in enumerate (predicted_classes )}
131147 else :
132- print (f"Loading predictions from { preds_file } and label indexes from { predicted_classes_file } " )
148+ print (
149+ f"Loading predictions from { preds_file } and label indexes from { predicted_classes_file } "
150+ )
133151 ordered_predictions = torch .load (preds_file )
134152 with open (predicted_classes_file , "r" ) as f :
135- predicted_classes = {line .strip (): i for i , line in enumerate (f .readlines ())}
153+ predicted_classes = {
154+ line .strip (): i for i , line in enumerate (f .readlines ())
155+ }
136156
137157 classwise_weights = self .calculate_classwise_weights (predicted_classes )
138- class_decisions = self .consolidate_predictions (ordered_predictions , classwise_weights , ** kwargs )
158+ class_decisions = self .consolidate_predictions (
159+ ordered_predictions , classwise_weights , ** kwargs
160+ )
139161 # Smooth predictions
140162 class_names = list (predicted_classes .keys ())
141163 # initialise new smoother class since we don't know the labels beforehand (this could be more efficient)
142- new_smoother = PredictionSmoother (self .chebi_dataset , label_names = class_names , disjoint_files = self .disjoint_files )
164+ new_smoother = PredictionSmoother (
165+ self .chebi_dataset ,
166+ label_names = class_names ,
167+ disjoint_files = self .disjoint_files ,
168+ )
143169 class_decisions = new_smoother (class_decisions )
144170
145171 class_names = list (predicted_classes .keys ())
@@ -153,31 +179,36 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs
153179
154180
155181if __name__ == "__main__" :
156- ensemble = BaseEnsemble ({"resgated_0ps1g189" :{
157- "type" : "resgated" ,
158- "ckpt_path" : "data/0ps1g189/epoch=122.ckpt" ,
159- "target_labels_path" : "data/chebi_v241/ChEBI50/processed/classes.txt" ,
160- "molecular_properties" : [
161- "chebai_graph.preprocessing.properties.AtomType" ,
162- "chebai_graph.preprocessing.properties.NumAtomBonds" ,
163- "chebai_graph.preprocessing.properties.AtomCharge" ,
164- "chebai_graph.preprocessing.properties.AtomAromaticity" ,
165- "chebai_graph.preprocessing.properties.AtomHybridization" ,
166- "chebai_graph.preprocessing.properties.AtomNumHs" ,
167- "chebai_graph.preprocessing.properties.BondType" ,
168- "chebai_graph.preprocessing.properties.BondInRing" ,
169- "chebai_graph.preprocessing.properties.BondAromaticity" ,
170- "chebai_graph.preprocessing.properties.RDKit2DNormalized" ,
171- ],
172- #"classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
173- },
174-
175- "electra_14ko0zcf" : {
176- "type" : "electra" ,
177- "ckpt_path" : "data/14ko0zcf/epoch=193.ckpt" ,
178- "target_labels_path" : "data/chebi_v241/ChEBI50/processed/classes.txt" ,
179- #"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
180- }
181- })
182- r = ensemble .predict_smiles_list (["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" ], load_preds_if_possible = False )
182+ ensemble = BaseEnsemble (
183+ {
184+ "resgated_0ps1g189" : {
185+ "type" : "resgated" ,
186+ "ckpt_path" : "data/0ps1g189/epoch=122.ckpt" ,
187+ "target_labels_path" : "data/chebi_v241/ChEBI50/processed/classes.txt" ,
188+ "molecular_properties" : [
189+ "chebai_graph.preprocessing.properties.AtomType" ,
190+ "chebai_graph.preprocessing.properties.NumAtomBonds" ,
191+ "chebai_graph.preprocessing.properties.AtomCharge" ,
192+ "chebai_graph.preprocessing.properties.AtomAromaticity" ,
193+ "chebai_graph.preprocessing.properties.AtomHybridization" ,
194+ "chebai_graph.preprocessing.properties.AtomNumHs" ,
195+ "chebai_graph.preprocessing.properties.BondType" ,
196+ "chebai_graph.preprocessing.properties.BondInRing" ,
197+ "chebai_graph.preprocessing.properties.BondAromaticity" ,
198+ "chebai_graph.preprocessing.properties.RDKit2DNormalized" ,
199+ ],
200+ # "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
201+ },
202+ "electra_14ko0zcf" : {
203+ "type" : "electra" ,
204+ "ckpt_path" : "data/14ko0zcf/epoch=193.ckpt" ,
205+ "target_labels_path" : "data/chebi_v241/ChEBI50/processed/classes.txt" ,
206+ # "classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
207+ },
208+ }
209+ )
210+ r = ensemble .predict_smiles_list (
211+ ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" ],
212+ load_preds_if_possible = False ,
213+ )
183214 print (len (r ), r [0 ])
0 commit comments