22from abc import ABC
33import torch
44import tqdm
5+ from chebai .preprocessing .datasets .chebi import ChEBIOver50
6+ from chebai .result .analyse_sem import PredictionSmoother
57from rdkit import Chem
68
79from chebifier .prediction_models .base_predictor import BasePredictor
1719
1820class BaseEnsemble (ABC ):
1921
20- def __init__ (self , model_configs : dict ):
22+ def __init__ (self , model_configs : dict , chebi_version : int = 241 ):
2123 self .models = []
2224 self .positive_prediction_threshold = 0.5
2325 for model_name , model_config in model_configs .items ():
@@ -26,6 +28,12 @@ def __init__(self, model_configs: dict):
2628 assert isinstance (model_instance , BasePredictor )
2729 self .models .append (model_instance )
2830
31+ self .smoother = PredictionSmoother (ChEBIOver50 (chebi_version = chebi_version ), disjoint_files = [
32+ os .path .join ("data" , "disjoint_chebi.csv" ),
33+ os .path .join ("data" , "disjoint_additional.csv" )
34+ ])
35+
36+
2937 def gather_predictions (self , smiles_list ):
3038 # get predictions from all models for the SMILES list
3139 # order them by alphabetically by label class
@@ -52,17 +60,13 @@ def gather_predictions(self, smiles_list):
5260 return ordered_logits , predicted_classes
5361
5462
55- def consolidate_predictions (self , predictions , predicted_classes , classwise_weights , ** kwargs ):
63+ def consolidate_predictions (self , predictions , classwise_weights , ** kwargs ):
5664 """
5765 Aggregates predictions from multiple models using weighted majority voting.
5866 Optimized version using tensor operations instead of for loops.
5967 """
6068 num_smiles , num_classes , num_models = predictions .shape
6169
62- # Create a mapping from class indices to class names for faster lookup
63- class_names = list (predicted_classes .keys ())
64- class_indices = {predicted_classes [cls ]: cls for cls in class_names }
65-
6670 # Get predictions for all classes
6771 valid_predictions = ~ torch .isnan (predictions )
6872 valid_counts = valid_predictions .sum (dim = 2 ) # Sum over models dimension
@@ -94,14 +98,9 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig
9498 net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
9599 class_decisions = (net_score > 0 ) & has_valid_predictions # Shape: (num_smiles, num_classes)
96100
97- # Convert tensor decisions to result list using list comprehension for efficiency
98- result = [
99- [class_indices [idx .item ()] for idx in torch .nonzero (class_decisions [i ], as_tuple = True )[0 ]]
100- for i in range (num_smiles )
101- ]
102101
103- return result
104102
103+ return class_decisions
105104
106105 def calculate_classwise_weights (self , predicted_classes ):
107106 """No weights, simple majority voting"""
@@ -128,14 +127,26 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
128127 predicted_classes = {line .strip (): i for i , line in enumerate (f .readlines ())}
129128
130129 classwise_weights = self .calculate_classwise_weights (predicted_classes )
131- aggregated_predictions = self .consolidate_predictions (ordered_predictions , predicted_classes , classwise_weights )
132- return aggregated_predictions
130+ class_decisions = self .consolidate_predictions (ordered_predictions , classwise_weights )
131+ # Smooth predictions
132+ class_names = list (predicted_classes .keys ())
133+ self .smoother .label_names = class_names
134+ class_decisions = self .smoother (class_decisions )
135+
136+ class_names = list (predicted_classes .keys ())
137+ class_indices = {predicted_classes [cls ]: cls for cls in class_names }
138+ result = [
139+ [class_indices [idx .item ()] for idx in torch .nonzero (i , as_tuple = True )[0 ]]
140+ for i in class_decisions
141+ ]
142+
143+ return result
133144
134145if __name__ == "__main__" :
135146 ensemble = BaseEnsemble ({"resgated_0ps1g189" :{
136147 "type" : "resgated" ,
137- "ckpt_path" : "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/resgated_80-10-10_0ps1g189_epoch =122.ckpt" ,
138- "target_labels_path" : "../python-chebai/ data/chebi_v241/ChEBI50/processed/classes.txt" ,
148+ "ckpt_path" : "data/0ps1g189/epoch =122.ckpt" ,
149+ "target_labels_path" : "data/chebi_v241/ChEBI50/processed/classes.txt" ,
139150 "molecular_properties" : [
140151 "chebai_graph.preprocessing.properties.AtomType" ,
141152 "chebai_graph.preprocessing.properties.NumAtomBonds" ,
@@ -148,14 +159,14 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
148159 "chebai_graph.preprocessing.properties.BondAromaticity" ,
149160 "chebai_graph.preprocessing.properties.RDKit2DNormalized" ,
150161 ],
151- "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
162+ # "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
152163 },
153164
154165"electra_14ko0zcf" : {
155166 "type" : "electra" ,
156- "ckpt_path" : "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/electra_80-10-10_14ko0zcf_epoch =193.ckpt" ,
157- "target_labels_path" : "../python-chebai/ data/chebi_v241/ChEBI50/processed/classes.txt" ,
158- "classwise_weights_path" : "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json" ,
167+ "ckpt_path" : "data/14ko0zcf/epoch =193.ckpt" ,
168+ "target_labels_path" : "data/chebi_v241/ChEBI50/processed/classes.txt" ,
169+ # "classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
159170}
160171 })
161172 r = ensemble .predict_smiles_list (["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" ], load_preds_if_possible = False )
0 commit comments