Skip to content

Commit df68ecb

Browse files
committed
init smoother at init to avoid re-initialising it for every prediction-call
1 parent 2bead4a commit df68ecb

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
3737
os.path.join("data", "disjoint_additional.csv"),
3838
]
3939

40+
self.smoother = PredictionSmoother(
41+
self.chebi_dataset,
42+
label_names=None,
43+
disjoint_files=self.disjoint_files,
44+
)
45+
4046
def gather_predictions(self, smiles_list):
4147
# get predictions from all models for the SMILES list
4248
# order them by alphabetically by label class
@@ -164,13 +170,8 @@ def predict_smiles_list(
164170
# Smooth predictions
165171
start_time = time.perf_counter()
166172
class_names = list(predicted_classes.keys())
167-
# initialise new smoother class since we don't know the labels beforehand (#todo this could be more efficient)
168-
new_smoother = PredictionSmoother(
169-
self.chebi_dataset,
170-
label_names=class_names,
171-
disjoint_files=self.disjoint_files,
172-
)
173-
class_decisions = new_smoother(class_decisions)
173+
self.smoother.set_label_names(class_names)
174+
class_decisions = self.smoother(class_decisions)
174175
end_time = time.perf_counter()
175176
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
176177

0 commit comments

Comments
 (0)