Skip to content

Commit 2f0e63c

Browse files
committed
Make predict_using_closest_tanimoto use TopKTanimotoScores
1 parent c809095 commit 2f0e63c

File tree

2 files changed

+24
-88
lines changed

2 files changed

+24
-88
lines changed

ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from tqdm import tqdm
55
from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet
66
from ms2query.benchmarking.Fingerprints import Fingerprints
7+
from ms2query.benchmarking.reference_methods.predict_top_ms2deepscores import select_inchikeys_with_highest_ms2deepscore
8+
from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores
79
from ms2query.metrics import generalized_tanimoto_similarity_matrix
810

911

@@ -17,66 +19,34 @@ def predict_using_closest_tanimoto(
1719
"""Predict best inchikey, by taking the average score over all spectra for the 10 closest related library inchikeys.
1820
(simplified version of old MS2Query)
1921
"""
22+
top_k_tanimoto_scores = TopKTanimotoScores.calculate_from_fingerprints(
23+
library_fingerprints,
24+
library_fingerprints,
25+
k=nr_of_closest_inchikeys_to_select,
26+
)
27+
ms2deepscores = cosine_similarity_matrix(query_spectra.embeddings.embeddings, library_spectra.embeddings.embeddings)
28+
inchikeys_with_highest_ms2deepscores = select_inchikeys_with_highest_ms2deepscore(
29+
query_spectra, library_spectra, nr_of_inchikeys_with_highest_ms2deepscore_to_select, ms2deepscores=ms2deepscores
30+
)
31+
2032
inchikeys_of_best_match = []
2133
highest_scores = []
2234
for spectrum_idx in tqdm(range(len(query_spectra.spectra)), "Predicting using closest tanimoto"):
23-
inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum(
24-
library_spectra,
25-
query_spectra.subset_spectra([spectrum_idx]),
26-
nr_of_closest_inchikeys_to_select,
27-
nr_of_inchikeys_with_highest_ms2deepscore_to_select,
28-
library_fingerprints,
29-
)
30-
inchikeys_of_best_match.append(inchikey_of_best_match)
31-
highest_scores.append(score)
32-
return inchikeys_of_best_match, highest_scores
35+
average_predicted_scores = {}
36+
for inchikey in inchikeys_with_highest_ms2deepscores[spectrum_idx]:
37+
top_k_inchikeys = top_k_tanimoto_scores.select_top_k_inchikeys(inchikey)
3338

39+
average_predicted_score = get_average_predictions_for_closely_related_metabolites(
40+
library_spectra, top_k_inchikeys, ms2deepscores[spectrum_idx]
41+
)
42+
average_predicted_scores[inchikey] = average_predicted_score
3443

35-
def predict_using_closest_tanimoto_single_spectrum(
36-
spectra_with_embeddings: AnnotatedSpectrumSet,
37-
single_spectrum_with_embeddings: AnnotatedSpectrumSet,
38-
nr_of_closest_inchikeys_to_select,
39-
nr_of_inchikeys_with_highest_ms2deepscore_to_select,
40-
fingerprints,
41-
) -> Tuple[str, float]:
42-
if len(single_spectrum_with_embeddings.spectra) != 1:
43-
raise ValueError("expected a single spectrum")
44-
ms2deepscores = cosine_similarity_matrix(
45-
single_spectrum_with_embeddings.embeddings.embeddings, spectra_with_embeddings.embeddings.embeddings
46-
)[0]
47-
top_inchikeys = select_inchikeys_with_highest_ms2deepscore(
48-
spectra_with_embeddings, ms2deepscores, nr_of_inchikeys_with_highest_ms2deepscore_to_select
49-
)
50-
average_predicted_scores = {}
51-
for inchikey in top_inchikeys:
52-
top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k(
53-
fingerprints, inchikey, nr_of_closest_inchikeys_to_select
44+
inchikey_with_highest_average_prediction, score = max(
45+
average_predicted_scores.items(), key=lambda item: item[1]
5446
)
55-
average_predicted_score = get_average_predictions_for_closely_related_metabolites(
56-
spectra_with_embeddings, top_k_inchikeys, ms2deepscores
57-
)
58-
average_predicted_scores[inchikey] = average_predicted_score
59-
60-
inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1])
61-
return inchikey_with_highest_average_prediction, score
62-
63-
64-
def select_inchikeys_with_highest_ms2deepscore(
65-
spectra_with_embeddings: AnnotatedSpectrumSet, ms2deepscores, nr_of_inchikeys_to_select=10
66-
):
67-
highest_score_for_inchikey = []
68-
for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indices_per_inchikey.items():
69-
all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes,]
70-
highest_score_for_inchikey.append(max(all_ms2deepscores_for_inchikey))
71-
inchikey_indexes_with_highest_ms2deepscore = np.argpartition(
72-
np.array(highest_score_for_inchikey), -nr_of_inchikeys_to_select
73-
)[-nr_of_inchikeys_to_select:]
74-
75-
top_inchikeys = [
76-
spectra_with_embeddings.inchikeys[inchikey_index]
77-
for inchikey_index in inchikey_indexes_with_highest_ms2deepscore
78-
]
79-
return top_inchikeys
47+
inchikeys_of_best_match.append(inchikey_with_highest_average_prediction)
48+
highest_scores.append(score)
49+
return inchikeys_of_best_match, highest_scores
8050

8151

8252
def get_average_predictions_for_closely_related_metabolites(

tests/test_benchmarking/test_predict_using_closest_tanimoto.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
get_average_predictions_for_closely_related_metabolites,
77
get_inchikey_and_tanimoto_scores_for_top_k,
88
predict_using_closest_tanimoto,
9-
predict_using_closest_tanimoto_single_spectrum,
10-
select_inchikeys_with_highest_ms2deepscore,
119
)
1210
from tests.helper_functions import create_test_spectra, ms2deepscore_model
1311

@@ -28,38 +26,6 @@ def test_predict_using_closest_tanimoto():
2826
assert len(scores) == 3
2927

3028

31-
def test_predict_using_closest_tanimoto_single_spectrum():
32-
"""Only very basic test that the function runs and that the output is the right format"""
33-
model = ms2deepscore_model()
34-
library_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(nr_of_inchikeys=7))
35-
test_spectra = AnnotatedSpectrumSet.create_spectrum_set(create_test_spectra(1, nr_of_inchikeys=1))
36-
library_spectra.add_embeddings(model)
37-
test_spectra.add_embeddings(model)
38-
fingerprints = Fingerprints.from_spectrum_set(library_spectra, "daylight", 2048)
39-
40-
predicted_inchikey, score = predict_using_closest_tanimoto_single_spectrum(
41-
library_spectra, test_spectra, 3, 3, fingerprints
42-
)
43-
44-
assert isinstance(predicted_inchikey, str)
45-
assert len(predicted_inchikey) == 14
46-
assert isinstance(score, float)
47-
48-
49-
def test_select_inchikeys_with_highest_ms2deepscore():
50-
test_spectra = create_test_spectra(nr_of_inchikeys=7)
51-
spectra = AnnotatedSpectrumSet.create_spectrum_set(test_spectra)
52-
53-
ms2deepscores = np.zeros(len(test_spectra))
54-
ms2deepscores[2] = 0.4
55-
ms2deepscores[5] = 0.9
56-
ms2deepscores[7] = 0.6
57-
inchikeys_with_highest_ms2deepscore = select_inchikeys_with_highest_ms2deepscore(spectra, ms2deepscores, 3)
58-
expected_inchikeys = list(spectra.spectrum_indices_per_inchikey.keys())[:3]
59-
assert set(expected_inchikeys) == set(inchikeys_with_highest_ms2deepscore)
60-
print(inchikeys_with_highest_ms2deepscore)
61-
62-
6329
def test_get_average_predictions_for_closely_related_metabolites():
6430
test_spectra = create_test_spectra(nr_of_inchikeys=7)
6531
# Select different number per inchikey (only one for the first) to check that it is correctly weighted.

0 commit comments

Comments
 (0)