22from ms2deepscore .vector_operations import cosine_similarity_matrix
33from typing import Tuple , List
44
5- from ms2query .benchmarking .SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings
5+ from ms2query .benchmarking .SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings , SpectraWithFingerprints
66from ms2query .metrics import generalized_tanimoto_similarity_matrix
77
88
@@ -33,19 +33,18 @@ def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, sing
3333 for inchikey , spectrum_indexes in spectra_with_embeddings .spectrum_indexes_per_inchikey .items ():
3434 all_ms2deepscores_for_inchikey = ms2deepscores [spectrum_indexes ]
3535 if max (all_ms2deepscores_for_inchikey ) > 0.7 :
36+ top_k_inchikeys , _ = get_inchikey_and_tanimoto_scores_for_top_k (
37+ spectra_with_embeddings , inchikey , nr_of_closest_inchikeys_to_select )
3638 average_predicted_score = get_average_predictions_for_closely_related_metabolites (
37- spectra_with_embeddings , inchikey , ms2deepscores , nr_of_closest_inchikeys_to_select )
39+ spectra_with_embeddings , top_k_inchikeys , ms2deepscores )
3840 average_predicted_scores [inchikey ] = average_predicted_score
3941
4042 inchikey_with_highest_average_prediction , score = max (average_predicted_scores .items (), key = lambda item : item [1 ])
4143 return inchikey_with_highest_average_prediction , score
4244
43- def get_average_predictions_for_closely_related_metabolites (spectra_with_embeddings , inchikey ,
44- all_ms2deepscores , nr_of_closest_inchikeys_to_select ):
45+ def get_average_predictions_for_closely_related_metabolites (spectra_with_embeddings , top_k_inchikeys ,
46+ all_ms2deepscores ):
4547 """Calculates the average ms2deepscore predictions for top k closest inchikeys"""
46- top_k_inchikeys , _ = get_inchikey_and_tanimoto_scores_for_top_k (
47- spectra_with_embeddings , inchikey ,nr_of_closest_inchikeys_to_select )
48-
4948 average_predicted_scores = []
5049 for top_inchikey in top_k_inchikeys :
5150 matching_spectrum_indexes = spectra_with_embeddings .spectrum_indexes_per_inchikey [top_inchikey ]
@@ -54,7 +53,8 @@ def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddi
5453 average_predicted_score = sum (average_predicted_scores ) / len (average_predicted_scores )
5554 return average_predicted_score
5655
57- def get_inchikey_and_tanimoto_scores_for_top_k (spectra : SpectraWithMS2DeepScoreEmbeddings , inchikey , k ) -> tuple [list [str ], np .ndarray ]:
56+ def get_inchikey_and_tanimoto_scores_for_top_k (spectra : SpectraWithFingerprints , inchikey , k
57+ ) -> tuple [list [str ], np .ndarray ]:
5858 """For an inchikey in a library the top k highest tanimoto scores in the library are predicted (including itself)"""
5959 library_fingerprints = np .vstack (list (spectra .inchikey_fingerprint_pairs .values ()))
6060 fingerprint_single_inchikey = np .vstack (list ([spectra .inchikey_fingerprint_pairs [inchikey ]]))
0 commit comments