Skip to content

Commit 1933c85

Browse files
committed
Move top_k_selection outside average computation for easier testing
1 parent fb4cbd8 commit 1933c85

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ms2deepscore.vector_operations import cosine_similarity_matrix
33
from typing import Tuple, List
44

5-
from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings
5+
from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings, SpectraWithFingerprints
66
from ms2query.metrics import generalized_tanimoto_similarity_matrix
77

88

@@ -33,19 +33,18 @@ def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, sing
3333
for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items():
3434
all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes]
3535
if max(all_ms2deepscores_for_inchikey) > 0.7:
36+
top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k(
37+
spectra_with_embeddings, inchikey, nr_of_closest_inchikeys_to_select)
3638
average_predicted_score = get_average_predictions_for_closely_related_metabolites(
37-
spectra_with_embeddings, inchikey, ms2deepscores, nr_of_closest_inchikeys_to_select)
39+
spectra_with_embeddings, top_k_inchikeys, ms2deepscores)
3840
average_predicted_scores[inchikey] = average_predicted_score
3941

4042
inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1])
4143
return inchikey_with_highest_average_prediction, score
4244

43-
def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, inchikey,
44-
all_ms2deepscores, nr_of_closest_inchikeys_to_select):
45+
def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, top_k_inchikeys,
46+
all_ms2deepscores):
4547
"""Calculates the average ms2deepscore predictions for top k closest inchikeys"""
46-
top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k(
47-
spectra_with_embeddings, inchikey,nr_of_closest_inchikeys_to_select)
48-
4948
average_predicted_scores = []
5049
for top_inchikey in top_k_inchikeys:
5150
matching_spectrum_indexes = spectra_with_embeddings.spectrum_indexes_per_inchikey[top_inchikey]
@@ -54,7 +53,8 @@ def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddi
5453
average_predicted_score = sum(average_predicted_scores) / len(average_predicted_scores)
5554
return average_predicted_score
5655

57-
def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithMS2DeepScoreEmbeddings, inchikey, k) -> tuple[list[str], np.ndarray]:
56+
def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithFingerprints, inchikey, k
57+
) -> tuple[list[str], np.ndarray]:
5858
"""For an inchikey in a library the top k highest tanimoto scores in the library are predicted (including itself)"""
5959
library_fingerprints = np.vstack(list(spectra.inchikey_fingerprint_pairs.values()))
6060
fingerprint_single_inchikey = np.vstack(list([spectra.inchikey_fingerprint_pairs[inchikey]]))

0 commit comments

Comments
 (0)