11import os
2+ import time
3+
24import torch
35import tqdm
46from chebai .preprocessing .datasets .chebi import ChEBIOver50
57from chebai .result .analyse_sem import PredictionSmoother
68
7- from api .hugging_face import download_model_files
89from chebifier .prediction_models .base_predictor import BasePredictor
910
1011
@@ -19,6 +20,7 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
1920 for model_name , model_config in model_configs .items ():
2021 model_cls = MODEL_TYPES [model_config ["type" ]]
2122 if "hugging_face" in model_config :
23+ from api .hugging_face import download_model_files
2224 hugging_face_kwargs = download_model_files (model_config ["hugging_face" ])
2325 else :
2426 hugging_face_kwargs = {}
@@ -118,9 +120,10 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
118120 net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
119121 class_decisions = (
120122 net_score > 0
121- ) & has_valid_predictions # Shape: (num_smiles, num_classes)
123+ ) & has_valid_predictions # Shape: (num_smiles, num_classes)
122124
123- return class_decisions
125+ complete_failure = torch .all (~ has_valid_predictions , dim = 1 )
126+ return class_decisions , complete_failure
124127
125128 def calculate_classwise_weights (self , predicted_classes ):
126129 """No weights, simple majority voting"""
@@ -155,24 +158,27 @@ def predict_smiles_list(
155158 }
156159
157160 classwise_weights = self .calculate_classwise_weights (predicted_classes )
158- class_decisions = self .consolidate_predictions (
161+ class_decisions , is_failure = self .consolidate_predictions (
159162 ordered_predictions , classwise_weights , ** kwargs
160163 )
161164 # Smooth predictions
165+ start_time = time .perf_counter ()
162166 class_names = list (predicted_classes .keys ())
163- # initialise new smoother class since we don't know the labels beforehand (this could be more efficient)
167+ # initialise new smoother class since we don't know the labels beforehand (#todo this could be more efficient)
164168 new_smoother = PredictionSmoother (
165169 self .chebi_dataset ,
166170 label_names = class_names ,
167171 disjoint_files = self .disjoint_files ,
168172 )
169173 class_decisions = new_smoother (class_decisions )
174+ end_time = time .perf_counter ()
175+ print (f"Prediction smoothing took { end_time - start_time :.2f} seconds" )
170176
171177 class_names = list (predicted_classes .keys ())
172178 class_indices = {predicted_classes [cls ]: cls for cls in class_names }
173179 result = [
174- [class_indices [idx .item ()] for idx in torch .nonzero (i , as_tuple = True )[0 ]]
175- for i in class_decisions
180+ [class_indices [idx .item ()] for idx in torch .nonzero (i , as_tuple = True )[0 ]] if not failure else None
181+ for i , failure in zip ( class_decisions , is_failure )
176182 ]
177183
178184 return result
@@ -208,7 +214,7 @@ def predict_smiles_list(
208214 }
209215 )
210216 r = ensemble .predict_smiles_list (
211- ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" ],
217+ ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O" , "C[C@H](N)C(=O)NCC(O)=O#" , "" ],
212218 load_preds_if_possible = False ,
213219 )
214220 print (len (r ), r [0 ])
0 commit comments