@@ -14,34 +14,47 @@ def predict_using_closest_tanimoto(
1414 (simplified version of old MS2Query)
1515 """
1616 inchikeys_of_best_match = []
17- single_highest_score = []
17+ highest_scores = []
1818 for spectrum_idx in range (len (query_spectra .spectra )):
1919 inchikey_of_best_match , score = predict_using_closest_tanimoto_single_spectrum (
2020 library_spectra , query_spectra .subset_spectra ([spectrum_idx ]), nr_of_closest_inchikeys_to_select )
2121 inchikeys_of_best_match .append (inchikey_of_best_match )
22- single_highest_score .append (score )
23- return inchikeys_of_best_match , single_highest_score
22+ highest_scores .append (score )
23+ return inchikeys_of_best_match , highest_scores
2424
2525
2626def predict_using_closest_tanimoto_single_spectrum (spectra_with_embeddings , single_spectrum_with_embeddings ,
27- nr_of_closest_inchikeys_to_select ) -> Tuple [str , float ]:
27+ nr_of_closest_inchikeys_to_select ,
28+ nr_of_inchikeys_with_highest_ms2deepscore_to_select ) -> Tuple [str , float ]:
2829 if len (single_spectrum_with_embeddings .spectra ) != 1 :
2930 raise ValueError ("expected a single spectrum" )
3031 ms2deepscores = cosine_similarity_matrix (single_spectrum_with_embeddings .embeddings ,
3132 spectra_with_embeddings .embeddings )[0 ]
33+ top_inchikeys = select_inchikeys_with_highest_ms2deepscore (spectra_with_embeddings , ms2deepscores ,
34+ nr_of_inchikeys_with_highest_ms2deepscore_to_select )
3235 average_predicted_scores = {}
33- for inchikey , spectrum_indexes in spectra_with_embeddings .spectrum_indexes_per_inchikey .items ():
34- all_ms2deepscores_for_inchikey = ms2deepscores [spectrum_indexes ]
35- if max (all_ms2deepscores_for_inchikey ) > 0.7 :
36- top_k_inchikeys , _ = get_inchikey_and_tanimoto_scores_for_top_k (
37- spectra_with_embeddings , inchikey , nr_of_closest_inchikeys_to_select )
38- average_predicted_score = get_average_predictions_for_closely_related_metabolites (
39- spectra_with_embeddings , top_k_inchikeys , ms2deepscores )
40- average_predicted_scores [inchikey ] = average_predicted_score
36+ for inchikey in top_inchikeys :
37+ top_k_inchikeys , _ = get_inchikey_and_tanimoto_scores_for_top_k (
38+ spectra_with_embeddings , inchikey , nr_of_closest_inchikeys_to_select )
39+ average_predicted_score = get_average_predictions_for_closely_related_metabolites (
40+ spectra_with_embeddings , top_k_inchikeys , ms2deepscores )
41+ average_predicted_scores [inchikey ] = average_predicted_score
4142
4243 inchikey_with_highest_average_prediction , score = max (average_predicted_scores .items (), key = lambda item : item [1 ])
4344 return inchikey_with_highest_average_prediction , score
4445
46+ def select_inchikeys_with_highest_ms2deepscore (spectra_with_embeddings , ms2deepscores , nr_of_inchikeys_to_select = 10 ):
47+ highest_score_for_inchikey = []
48+ for inchikey , spectrum_indexes in spectra_with_embeddings .spectrum_indexes_per_inchikey .items ():
49+ all_ms2deepscores_for_inchikey = ms2deepscores [spectrum_indexes ]
50+ highest_score_for_inchikey .append (max (all_ms2deepscores_for_inchikey ))
51+ inchikey_indexes_with_highest_ms2deepscore = np .argpartition (
52+ np .array (highest_score_for_inchikey ), - nr_of_inchikeys_to_select )[- nr_of_inchikeys_to_select :]
53+
54+ all_inchikeys = list (spectra_with_embeddings .inchikey_fingerprint_pairs .keys ())
55+ top_inchikeys = [all_inchikeys [inchikey_index ] for inchikey_index in inchikey_indexes_with_highest_ms2deepscore ]
56+ return top_inchikeys
57+
4558def get_average_predictions_for_closely_related_metabolites (spectra_with_embeddings , top_k_inchikeys ,
4659 all_ms2deepscores ):
4760 """Calculates the average ms2deepscore predictions for top k closest inchikeys"""
0 commit comments