Skip to content

Commit 2bead4a

Browse files
committed
use None values to mark samples where all methods failed (usually due to a faulty SMILES string)
1 parent 90aedd4 commit 2bead4a

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
2+
import time
3+
24
import torch
35
import tqdm
46
from chebai.preprocessing.datasets.chebi import ChEBIOver50
57
from chebai.result.analyse_sem import PredictionSmoother
68

7-
from api.hugging_face import download_model_files
89
from chebifier.prediction_models.base_predictor import BasePredictor
910

1011

@@ -19,6 +20,7 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
1920
for model_name, model_config in model_configs.items():
2021
model_cls = MODEL_TYPES[model_config["type"]]
2122
if "hugging_face" in model_config:
23+
from api.hugging_face import download_model_files
2224
hugging_face_kwargs = download_model_files(model_config["hugging_face"])
2325
else:
2426
hugging_face_kwargs = {}
@@ -118,9 +120,10 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
118120
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
119121
class_decisions = (
120122
net_score > 0
121-
) & has_valid_predictions # Shape: (num_smiles, num_classes)
123+
) & has_valid_predictions # Shape: (num_smiles, num_classes)
122124

123-
return class_decisions
125+
complete_failure = torch.all(~has_valid_predictions, dim=1)
126+
return class_decisions, complete_failure
124127

125128
def calculate_classwise_weights(self, predicted_classes):
126129
"""No weights, simple majority voting"""
@@ -155,24 +158,27 @@ def predict_smiles_list(
155158
}
156159

157160
classwise_weights = self.calculate_classwise_weights(predicted_classes)
158-
class_decisions = self.consolidate_predictions(
161+
class_decisions, is_failure = self.consolidate_predictions(
159162
ordered_predictions, classwise_weights, **kwargs
160163
)
161164
# Smooth predictions
165+
start_time = time.perf_counter()
162166
class_names = list(predicted_classes.keys())
163-
# initialise new smoother class since we don't know the labels beforehand (this could be more efficient)
167+
# initialise new smoother class since we don't know the labels beforehand (#todo this could be more efficient)
164168
new_smoother = PredictionSmoother(
165169
self.chebi_dataset,
166170
label_names=class_names,
167171
disjoint_files=self.disjoint_files,
168172
)
169173
class_decisions = new_smoother(class_decisions)
174+
end_time = time.perf_counter()
175+
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
170176

171177
class_names = list(predicted_classes.keys())
172178
class_indices = {predicted_classes[cls]: cls for cls in class_names}
173179
result = [
174-
[class_indices[idx.item()] for idx in torch.nonzero(i, as_tuple=True)[0]]
175-
for i in class_decisions
180+
[class_indices[idx.item()] for idx in torch.nonzero(i, as_tuple=True)[0]] if not failure else None
181+
for i, failure in zip(class_decisions, is_failure)
176182
]
177183

178184
return result
@@ -208,7 +214,7 @@ def predict_smiles_list(
208214
}
209215
)
210216
r = ensemble.predict_smiles_list(
211-
["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"],
217+
["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O", "C[C@H](N)C(=O)NCC(O)=O#", ""],
212218
load_preds_if_possible=False,
213219
)
214220
print(len(r), r[0])

0 commit comments

Comments
 (0)