1212
1313class BaseEnsemble :
1414
15- def __init__ (self , model_configs : dict , chebi_version : int = 241 , resolve_inconsistencies : bool = True ):
15+ def __init__ (
16+ self ,
17+ model_configs : dict ,
18+ chebi_version : int = 241 ,
19+ resolve_inconsistencies : bool = True ,
20+ ):
1621 # Deferred Import: To avoid circular import error
1722 from chebifier .model_registry import MODEL_TYPES
1823
@@ -28,33 +33,43 @@ def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_incons
2833 if os .path .isfile (file ):
2934 self .disjoint_files .append (file )
3035 else :
31- print (f"Disjoint axiom file { file } not found. Loading from huggingface instead..." )
36+ print (
37+ f"Disjoint axiom file { file } not found. Loading from huggingface instead..."
38+ )
3239 from chebifier .hugging_face import download_model_files
33- self .disjoint_files .append (download_model_files ({
34- "repo_id" : "chebai/chebifier" ,
35- "repo_type" : "dataset" ,
36- "files" : {"disjoint_file" : os .path .basename (file )},
37- })["disjoint_file" ])
40+
41+ self .disjoint_files .append (
42+ download_model_files (
43+ {
44+ "repo_id" : "chebai/chebifier" ,
45+ "repo_type" : "dataset" ,
46+ "files" : {"disjoint_file" : os .path .basename (file )},
47+ }
48+ )["disjoint_file" ]
49+ )
3850
3951 self .models = []
4052 self .positive_prediction_threshold = 0.5
4153 for model_name , model_config in model_configs .items ():
4254 model_cls = MODEL_TYPES [model_config ["type" ]]
4355 if "hugging_face" in model_config :
4456 from chebifier .hugging_face import download_model_files
57+
4558 hugging_face_kwargs = download_model_files (model_config ["hugging_face" ])
4659 else :
4760 hugging_face_kwargs = {}
4861 if "package_name" in model_config :
4962 check_package_installed (model_config ["package_name" ])
5063
5164 model_instance = model_cls (
52- model_name , ** model_config , ** hugging_face_kwargs , chebi_graph = self .chebi_graph
65+ model_name ,
66+ ** model_config ,
67+ ** hugging_face_kwargs ,
68+ chebi_graph = self .chebi_graph ,
5369 )
5470 assert isinstance (model_instance , BasePredictor )
5571 self .models .append (model_instance )
5672
57-
5873 if resolve_inconsistencies :
5974 self .smoother = PredictionSmoother (
6075 self .chebi_dataset ,
@@ -96,7 +111,9 @@ def gather_predictions(self, smiles_list):
96111
97112 return ordered_logits , predicted_classes
98113
99- def consolidate_predictions (self , predictions , classwise_weights , predicted_classes , ** kwargs ):
114+ def consolidate_predictions (
115+ self , predictions , classwise_weights , predicted_classes , ** kwargs
116+ ):
100117 """
101118 Aggregates predictions from multiple models using weighted majority voting.
102119 Optimized version using tensor operations instead of for loops.
@@ -152,9 +169,13 @@ def consolidate_predictions(self, predictions, classwise_weights, predicted_clas
152169 if self .smoother is not None :
153170 self .smoother .set_label_names (class_names )
154171 smooth_net_score = self .smoother (net_score )
155- class_decisions = (smooth_net_score > 0.5 ) & has_valid_predictions # Shape: (num_smiles, num_classes)
172+ class_decisions = (
173+ smooth_net_score > 0.5
174+ ) & has_valid_predictions # Shape: (num_smiles, num_classes)
156175 else :
157- class_decisions = (net_score > 0 ) & has_valid_predictions # Shape: (num_smiles, num_classes)
176+ class_decisions = (
177+ net_score > 0
178+ ) & has_valid_predictions # Shape: (num_smiles, num_classes)
158179 end_time = time .perf_counter ()
159180 print (f"Prediction smoothing took { end_time - start_time :.2f} seconds" )
160181
@@ -178,7 +199,9 @@ def predict_smiles_list(
178199 smiles_list
179200 )
180201 if len (predicted_classes ) == 0 :
181- print (f"Warning: No classes have been predicted for the given SMILES list." )
202+ print (
203+ f"Warning: No classes have been predicted for the given SMILES list."
204+ )
182205 # save predictions
183206 torch .save (ordered_predictions , preds_file )
184207 with open (predicted_classes_file , "w" ) as f :
@@ -203,7 +226,14 @@ def predict_smiles_list(
203226 class_names = list (predicted_classes .keys ())
204227 class_indices = {predicted_classes [cls ]: cls for cls in class_names }
205228 result = [
206- [class_indices [idx .item ()] for idx in torch .nonzero (i , as_tuple = True )[0 ]] if not failure else None
229+ (
230+ [
231+ class_indices [idx .item ()]
232+ for idx in torch .nonzero (i , as_tuple = True )[0 ]
233+ ]
234+ if not failure
235+ else None
236+ )
207237 for i , failure in zip (class_decisions , is_failure )
208238 ]
209239
@@ -240,7 +270,11 @@ def predict_smiles_list(
240270 }
241271 )
242272 r = ensemble .predict_smiles_list (
243- ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" , "C[C@H](N)C(=O)NCC(O)=O#" , "" ],
273+ [
274+ "[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" ,
275+ "C[C@H](N)C(=O)NCC(O)=O#" ,
276+ "" ,
277+ ],
244278 load_preds_if_possible = False ,
245279 )
246280 print (len (r ), r [0 ])
0 commit comments