@@ -39,15 +39,15 @@ def gather_predictions(self, smiles_list):
3939 predicted_classes .add (cls )
4040 print (f"Sorting predictions..." )
4141 predicted_classes = sorted (list (predicted_classes ))
42- predicted_classes = {cls : i for i , cls in enumerate (predicted_classes )}
42+ predicted_classes_dict = {cls : i for i , cls in enumerate (predicted_classes )}
4343 ordered_logits = torch .zeros (len (smiles_list ), len (predicted_classes ), len (self .models )) * torch .nan
4444 for i , model_prediction in enumerate (model_predictions ):
4545 for j , logits_for_smiles in tqdm .tqdm (enumerate (model_prediction ),
4646 total = len (model_prediction ),
4747 desc = f"Sorting predictions for { self .models [i ].model_name } " ):
4848 if logits_for_smiles is not None :
4949 for cls in logits_for_smiles :
50- ordered_logits [j , predicted_classes [cls ], i ] = logits_for_smiles [cls ]
50+ ordered_logits [j , predicted_classes_dict [cls ], i ] = logits_for_smiles [cls ]
5151
5252 return ordered_logits , predicted_classes
5353
@@ -114,7 +114,7 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
114114 preds_file = f"predictions_by_model_{ '_' .join (model .model_name for model in self .models )} .pt"
115115 predicted_classes_file = f"predicted_classes_{ '_' .join (model .model_name for model in self .models )} .txt"
116116 if not load_preds_if_possible or not os .path .isfile (preds_file ):
117- ordered_predictions = predicted_classes = self .gather_predictions (smiles_list )
117+ ordered_predictions , predicted_classes = self .gather_predictions (smiles_list )
118118 # save predictions
119119 torch .save (ordered_predictions , preds_file )
120120 with open (predicted_classes_file , "w" ) as f :
0 commit comments