Skip to content

Commit 7f76997

Browse files
committed
Add method for predicting using top 10 closest library spectra.
1 parent b4561c0 commit 7f76997

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
from ms2deepscore.vector_operations import cosine_similarity_matrix
3+
from typing import Tuple, List
4+
5+
from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings
6+
from ms2query.metrics import generalized_tanimoto_similarity_matrix
7+
8+
9+
def predict_using_closest_tanimoto(
10+
library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings,
11+
nr_of_closest_inchikeys_to_select=10
12+
) -> Tuple[List[str], List[float]]:
13+
"""Predict best inchikey, by taking the average score over all spectra for the 10 closest related library inchikeys.
14+
(simplified version of old MS2Query)
15+
"""
16+
inchikeys_of_best_match = []
17+
single_highest_score = []
18+
for spectrum_idx in range(len(query_spectra.spectra)):
19+
inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum(
20+
library_spectra, query_spectra.subset_spectra([spectrum_idx]), nr_of_closest_inchikeys_to_select)
21+
inchikeys_of_best_match.append(inchikey_of_best_match)
22+
single_highest_score.append(score)
23+
return inchikeys_of_best_match, single_highest_score
24+
25+
26+
def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, single_spectrum_with_embeddings,
27+
nr_of_closest_inchikeys_to_select) -> Tuple[str, float]:
28+
if len(single_spectrum_with_embeddings.spectra) != 1:
29+
raise ValueError("expected a single spectrum")
30+
ms2deepscores = cosine_similarity_matrix(single_spectrum_with_embeddings.embeddings,
31+
spectra_with_embeddings.embeddings)[0]
32+
average_predicted_scores = {}
33+
for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items():
34+
all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes]
35+
if max(all_ms2deepscores_for_inchikey) > 0.7:
36+
average_predicted_score = get_average_predictions_for_closely_related_metabolites(
37+
spectra_with_embeddings, inchikey, ms2deepscores, nr_of_closest_inchikeys_to_select)
38+
average_predicted_scores[inchikey] = average_predicted_score
39+
40+
inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1])
41+
return inchikey_with_highest_average_prediction, score
42+
43+
def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, inchikey,
44+
all_ms2deepscores, nr_of_closest_inchikeys_to_select):
45+
"""Calculates the average ms2deepscore predictions for top k closest inchikeys"""
46+
top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k(
47+
spectra_with_embeddings, inchikey,nr_of_closest_inchikeys_to_select)
48+
49+
average_predicted_scores = []
50+
for top_inchikey in top_k_inchikeys:
51+
matching_spectrum_indexes = spectra_with_embeddings.spectrum_indexes_per_inchikey[top_inchikey]
52+
predicted_scores = all_ms2deepscores[matching_spectrum_indexes]
53+
average_predicted_scores.append(predicted_scores.mean())
54+
average_predicted_score = sum(average_predicted_scores) / len(average_predicted_scores)
55+
return average_predicted_score
56+
57+
def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithMS2DeepScoreEmbeddings, inchikey, k) -> tuple[list[str], np.ndarray]:
58+
"""For an inchikey in a library the top k highest tanimoto scores in the library are predicted (including itself)"""
59+
library_fingerprints = np.vstack(list(spectra.inchikey_fingerprint_pairs.values()))
60+
fingerprint_single_inchikey = np.vstack(list([spectra.inchikey_fingerprint_pairs[inchikey]]))
61+
similarity_scores = generalized_tanimoto_similarity_matrix(fingerprint_single_inchikey, library_fingerprints)[0]
62+
inchikey_indexes_of_top_k = np.argpartition(similarity_scores, -k)[-k:]
63+
tanimoto_scores_for_top_k = similarity_scores[inchikey_indexes_of_top_k]
64+
all_inchikeys = list(spectra.inchikey_fingerprint_pairs.keys())
65+
top_inchikeys = [all_inchikeys[inchikey_index] for inchikey_index in inchikey_indexes_of_top_k]
66+
return top_inchikeys, tanimoto_scores_for_top_k

0 commit comments

Comments
 (0)