44from tqdm import tqdm
55from ms2query .benchmarking .AnnotatedSpectrumSet import AnnotatedSpectrumSet
66from ms2query .benchmarking .Fingerprints import Fingerprints
7+ from ms2query .benchmarking .reference_methods .predict_top_ms2deepscores import select_inchikeys_with_highest_ms2deepscore
8+ from ms2query .benchmarking .TopKTanimotoScores import TopKTanimotoScores
79from ms2query .metrics import generalized_tanimoto_similarity_matrix
810
911
@@ -17,66 +19,34 @@ def predict_using_closest_tanimoto(
1719 """Predict best inchikey, by taking the average score over all spectra for the 10 closest related library inchikeys.
1820 (simplified version of old MS2Query)
1921 """
22+ top_k_tanimoto_scores = TopKTanimotoScores .calculate_from_fingerprints (
23+ library_fingerprints ,
24+ library_fingerprints ,
25+ k = nr_of_closest_inchikeys_to_select ,
26+ )
27+ ms2deepscores = cosine_similarity_matrix (query_spectra .embeddings .embeddings , library_spectra .embeddings .embeddings )
28+ inchikeys_with_highest_ms2deepscores = select_inchikeys_with_highest_ms2deepscore (
29+ query_spectra , library_spectra , nr_of_inchikeys_with_highest_ms2deepscore_to_select , ms2deepscores = ms2deepscores
30+ )
31+
2032 inchikeys_of_best_match = []
2133 highest_scores = []
2234 for spectrum_idx in tqdm (range (len (query_spectra .spectra )), "Predicting using closest tanimoto" ):
23- inchikey_of_best_match , score = predict_using_closest_tanimoto_single_spectrum (
24- library_spectra ,
25- query_spectra .subset_spectra ([spectrum_idx ]),
26- nr_of_closest_inchikeys_to_select ,
27- nr_of_inchikeys_with_highest_ms2deepscore_to_select ,
28- library_fingerprints ,
29- )
30- inchikeys_of_best_match .append (inchikey_of_best_match )
31- highest_scores .append (score )
32- return inchikeys_of_best_match , highest_scores
35+ average_predicted_scores = {}
36+ for inchikey in inchikeys_with_highest_ms2deepscores [spectrum_idx ]:
37+ top_k_inchikeys = top_k_tanimoto_scores .select_top_k_inchikeys (inchikey )
3338
39+ average_predicted_score = get_average_predictions_for_closely_related_metabolites (
40+ library_spectra , top_k_inchikeys , ms2deepscores [spectrum_idx ]
41+ )
42+ average_predicted_scores [inchikey ] = average_predicted_score
3443
35- def predict_using_closest_tanimoto_single_spectrum (
36- spectra_with_embeddings : AnnotatedSpectrumSet ,
37- single_spectrum_with_embeddings : AnnotatedSpectrumSet ,
38- nr_of_closest_inchikeys_to_select ,
39- nr_of_inchikeys_with_highest_ms2deepscore_to_select ,
40- fingerprints ,
41- ) -> Tuple [str , float ]:
42- if len (single_spectrum_with_embeddings .spectra ) != 1 :
43- raise ValueError ("expected a single spectrum" )
44- ms2deepscores = cosine_similarity_matrix (
45- single_spectrum_with_embeddings .embeddings .embeddings , spectra_with_embeddings .embeddings .embeddings
46- )[0 ]
47- top_inchikeys = select_inchikeys_with_highest_ms2deepscore (
48- spectra_with_embeddings , ms2deepscores , nr_of_inchikeys_with_highest_ms2deepscore_to_select
49- )
50- average_predicted_scores = {}
51- for inchikey in top_inchikeys :
52- top_k_inchikeys , _ = get_inchikey_and_tanimoto_scores_for_top_k (
53- fingerprints , inchikey , nr_of_closest_inchikeys_to_select
44+ inchikey_with_highest_average_prediction , score = max (
45+ average_predicted_scores .items (), key = lambda item : item [1 ]
5446 )
55- average_predicted_score = get_average_predictions_for_closely_related_metabolites (
56- spectra_with_embeddings , top_k_inchikeys , ms2deepscores
57- )
58- average_predicted_scores [inchikey ] = average_predicted_score
59-
60- inchikey_with_highest_average_prediction , score = max (average_predicted_scores .items (), key = lambda item : item [1 ])
61- return inchikey_with_highest_average_prediction , score
62-
63-
64- def select_inchikeys_with_highest_ms2deepscore (
65- spectra_with_embeddings : AnnotatedSpectrumSet , ms2deepscores , nr_of_inchikeys_to_select = 10
66- ):
67- highest_score_for_inchikey = []
68- for inchikey , spectrum_indexes in spectra_with_embeddings .spectrum_indices_per_inchikey .items ():
69- all_ms2deepscores_for_inchikey = ms2deepscores [spectrum_indexes ,]
70- highest_score_for_inchikey .append (max (all_ms2deepscores_for_inchikey ))
71- inchikey_indexes_with_highest_ms2deepscore = np .argpartition (
72- np .array (highest_score_for_inchikey ), - nr_of_inchikeys_to_select
73- )[- nr_of_inchikeys_to_select :]
74-
75- top_inchikeys = [
76- spectra_with_embeddings .inchikeys [inchikey_index ]
77- for inchikey_index in inchikey_indexes_with_highest_ms2deepscore
78- ]
79- return top_inchikeys
47+ inchikeys_of_best_match .append (inchikey_with_highest_average_prediction )
48+ highest_scores .append (score )
49+ return inchikeys_of_best_match , highest_scores
8050
8151
8252def get_average_predictions_for_closely_related_metabolites (
0 commit comments