|
| 1 | +import numpy as np |
| 2 | +from ms2deepscore.vector_operations import cosine_similarity_matrix |
| 3 | +from typing import Tuple, List |
| 4 | + |
| 5 | +from ms2query.benchmarking.SpectrumDataSet import SpectraWithMS2DeepScoreEmbeddings |
| 6 | +from ms2query.metrics import generalized_tanimoto_similarity_matrix |
| 7 | + |
| 8 | + |
| 9 | +def predict_using_closest_tanimoto( |
| 10 | + library_spectra: SpectraWithMS2DeepScoreEmbeddings, query_spectra: SpectraWithMS2DeepScoreEmbeddings, |
| 11 | + nr_of_closest_inchikeys_to_select=10 |
| 12 | +) -> Tuple[List[str], List[float]]: |
| 13 | + """Predict best inchikey, by taking the average score over all spectra for the 10 closest related library inchikeys. |
| 14 | + (simplified version of old MS2Query) |
| 15 | + """ |
| 16 | + inchikeys_of_best_match = [] |
| 17 | + single_highest_score = [] |
| 18 | + for spectrum_idx in range(len(query_spectra.spectra)): |
| 19 | + inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum( |
| 20 | + library_spectra, query_spectra.subset_spectra([spectrum_idx]), nr_of_closest_inchikeys_to_select) |
| 21 | + inchikeys_of_best_match.append(inchikey_of_best_match) |
| 22 | + single_highest_score.append(score) |
| 23 | + return inchikeys_of_best_match, single_highest_score |
| 24 | + |
| 25 | + |
| 26 | +def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, single_spectrum_with_embeddings, |
| 27 | + nr_of_closest_inchikeys_to_select) -> Tuple[str, float]: |
| 28 | + if len(single_spectrum_with_embeddings.spectra) != 1: |
| 29 | + raise ValueError("expected a single spectrum") |
| 30 | + ms2deepscores = cosine_similarity_matrix(single_spectrum_with_embeddings.embeddings, |
| 31 | + spectra_with_embeddings.embeddings)[0] |
| 32 | + average_predicted_scores = {} |
| 33 | + for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items(): |
| 34 | + all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes] |
| 35 | + if max(all_ms2deepscores_for_inchikey) > 0.7: |
| 36 | + average_predicted_score = get_average_predictions_for_closely_related_metabolites( |
| 37 | + spectra_with_embeddings, inchikey, ms2deepscores, nr_of_closest_inchikeys_to_select) |
| 38 | + average_predicted_scores[inchikey] = average_predicted_score |
| 39 | + |
| 40 | + inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1]) |
| 41 | + return inchikey_with_highest_average_prediction, score |
| 42 | + |
| 43 | +def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, inchikey, |
| 44 | + all_ms2deepscores, nr_of_closest_inchikeys_to_select): |
| 45 | + """Calculates the average ms2deepscore predictions for top k closest inchikeys""" |
| 46 | + top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k( |
| 47 | + spectra_with_embeddings, inchikey,nr_of_closest_inchikeys_to_select) |
| 48 | + |
| 49 | + average_predicted_scores = [] |
| 50 | + for top_inchikey in top_k_inchikeys: |
| 51 | + matching_spectrum_indexes = spectra_with_embeddings.spectrum_indexes_per_inchikey[top_inchikey] |
| 52 | + predicted_scores = all_ms2deepscores[matching_spectrum_indexes] |
| 53 | + average_predicted_scores.append(predicted_scores.mean()) |
| 54 | + average_predicted_score = sum(average_predicted_scores) / len(average_predicted_scores) |
| 55 | + return average_predicted_score |
| 56 | + |
| 57 | +def get_inchikey_and_tanimoto_scores_for_top_k(spectra: SpectraWithMS2DeepScoreEmbeddings, inchikey, k) -> tuple[list[str], np.ndarray]: |
| 58 | + """For an inchikey in a library the top k highest tanimoto scores in the library are predicted (including itself)""" |
| 59 | + library_fingerprints = np.vstack(list(spectra.inchikey_fingerprint_pairs.values())) |
| 60 | + fingerprint_single_inchikey = np.vstack(list([spectra.inchikey_fingerprint_pairs[inchikey]])) |
| 61 | + similarity_scores = generalized_tanimoto_similarity_matrix(fingerprint_single_inchikey, library_fingerprints)[0] |
| 62 | + inchikey_indexes_of_top_k = np.argpartition(similarity_scores, -k)[-k:] |
| 63 | + tanimoto_scores_for_top_k = similarity_scores[inchikey_indexes_of_top_k] |
| 64 | + all_inchikeys = list(spectra.inchikey_fingerprint_pairs.keys()) |
| 65 | + top_inchikeys = [all_inchikeys[inchikey_index] for inchikey_index in inchikey_indexes_of_top_k] |
| 66 | + return top_inchikeys, tanimoto_scores_for_top_k |
0 commit comments