@@ -16,7 +16,7 @@ def __init__(self, model_configs: dict):
1616 self .positive_prediction_threshold = 0.5
1717 for model_name , model_config in model_configs .items ():
1818 model_cls = MODEL_TYPES [model_config ["type" ]]
19- model_instance = model_cls (** model_config )
19+ model_instance = model_cls (model_name , ** model_config )
2020 assert isinstance (model_instance , BasePredictor )
2121 self .models .append (model_instance )
2222
@@ -73,8 +73,12 @@ def consolidate_predictions(
7373 has_valid_predictions = valid_counts > 0
7474
7575 # Calculate positive and negative predictions for all classes at once
76- positive_mask = (predictions > 0.5 ) & valid_predictions
77- negative_mask = (predictions < 0.5 ) & valid_predictions
76+ positive_mask = (
77+ predictions > self .positive_prediction_threshold
78+ ) & valid_predictions
79+ negative_mask = (
80+ predictions < self .positive_prediction_threshold
81+ ) & valid_predictions
7882
7983 confidence = 2 * torch .abs (
8084 predictions .nan_to_num () - self .positive_prediction_threshold
@@ -134,6 +138,7 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
134138 with open (predicted_classes_file , "w" ) as f :
135139 for cls in predicted_classes :
136140 f .write (f"{ cls } \n " )
141+ predicted_classes = {cls : i for i , cls in enumerate (predicted_classes )}
137142 else :
138143 print (
139144 f"Loading predictions from { preds_file } and label indexes from { predicted_classes_file } "
@@ -149,3 +154,39 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
149154 ordered_predictions , predicted_classes , classwise_weights
150155 )
151156 return aggregated_predictions
157+
158+
159+ if __name__ == "__main__" :
160+ ensemble = BaseEnsemble (
161+ {
162+ "resgated_0ps1g189" : {
163+ "type" : "resgated" ,
164+ "ckpt_path" : "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/resgated_80-10-10_0ps1g189_epoch=122.ckpt" ,
165+ "target_labels_path" : "../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt" ,
166+ "molecular_properties" : [
167+ "chebai_graph.preprocessing.properties.AtomType" ,
168+ "chebai_graph.preprocessing.properties.NumAtomBonds" ,
169+ "chebai_graph.preprocessing.properties.AtomCharge" ,
170+ "chebai_graph.preprocessing.properties.AtomAromaticity" ,
171+ "chebai_graph.preprocessing.properties.AtomHybridization" ,
172+ "chebai_graph.preprocessing.properties.AtomNumHs" ,
173+ "chebai_graph.preprocessing.properties.BondType" ,
174+ "chebai_graph.preprocessing.properties.BondInRing" ,
175+ "chebai_graph.preprocessing.properties.BondAromaticity" ,
176+ "chebai_graph.preprocessing.properties.RDKit2DNormalized" ,
177+ ],
178+ "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json" ,
179+ },
180+ "electra_14ko0zcf" : {
181+ "type" : "electra" ,
182+ "ckpt_path" : "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/electra_80-10-10_14ko0zcf_epoch=193.ckpt" ,
183+ "target_labels_path" : "../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt" ,
184+ "classwise_weights_path" : "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json" ,
185+ },
186+ }
187+ )
188+ r = ensemble .predict_smiles_list (
189+ ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" ],
190+ load_preds_if_possible = False ,
191+ )
192+ print (len (r ), r [0 ])
0 commit comments