Skip to content

Commit 59aca8e

Browse files
committed
Split select_inchikeys_with_highest_ms2deepscores to make more modular and to select top k highest ms2deepscores
1 parent fd2bf89 commit 59aca8e

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

ms2query/benchmarking/reference_methods/predict_using_closest_tanimoto.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,34 +14,47 @@ def predict_using_closest_tanimoto(
1414
(simplified version of old MS2Query)
1515
"""
1616
inchikeys_of_best_match = []
17-
single_highest_score = []
17+
highest_scores = []
1818
for spectrum_idx in range(len(query_spectra.spectra)):
1919
inchikey_of_best_match, score = predict_using_closest_tanimoto_single_spectrum(
2020
library_spectra, query_spectra.subset_spectra([spectrum_idx]), nr_of_closest_inchikeys_to_select)
2121
inchikeys_of_best_match.append(inchikey_of_best_match)
22-
single_highest_score.append(score)
23-
return inchikeys_of_best_match, single_highest_score
22+
highest_scores.append(score)
23+
return inchikeys_of_best_match, highest_scores
2424

2525

2626
def predict_using_closest_tanimoto_single_spectrum(spectra_with_embeddings, single_spectrum_with_embeddings,
27-
nr_of_closest_inchikeys_to_select) -> Tuple[str, float]:
27+
nr_of_closest_inchikeys_to_select,
28+
nr_of_inchikeys_with_highest_ms2deepscore_to_select) -> Tuple[str, float]:
2829
if len(single_spectrum_with_embeddings.spectra) != 1:
2930
raise ValueError("expected a single spectrum")
3031
ms2deepscores = cosine_similarity_matrix(single_spectrum_with_embeddings.embeddings,
3132
spectra_with_embeddings.embeddings)[0]
33+
top_inchikeys = select_inchikeys_with_highest_ms2deepscore(spectra_with_embeddings, ms2deepscores,
34+
nr_of_inchikeys_with_highest_ms2deepscore_to_select)
3235
average_predicted_scores = {}
33-
for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items():
34-
all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes]
35-
if max(all_ms2deepscores_for_inchikey) > 0.7:
36-
top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k(
37-
spectra_with_embeddings, inchikey, nr_of_closest_inchikeys_to_select)
38-
average_predicted_score = get_average_predictions_for_closely_related_metabolites(
39-
spectra_with_embeddings, top_k_inchikeys, ms2deepscores)
40-
average_predicted_scores[inchikey] = average_predicted_score
36+
for inchikey in top_inchikeys:
37+
top_k_inchikeys, _ = get_inchikey_and_tanimoto_scores_for_top_k(
38+
spectra_with_embeddings, inchikey, nr_of_closest_inchikeys_to_select)
39+
average_predicted_score = get_average_predictions_for_closely_related_metabolites(
40+
spectra_with_embeddings, top_k_inchikeys, ms2deepscores)
41+
average_predicted_scores[inchikey] = average_predicted_score
4142

4243
inchikey_with_highest_average_prediction, score = max(average_predicted_scores.items(), key=lambda item: item[1])
4344
return inchikey_with_highest_average_prediction, score
4445

46+
def select_inchikeys_with_highest_ms2deepscore(spectra_with_embeddings, ms2deepscores, nr_of_inchikeys_to_select=10):
47+
highest_score_for_inchikey = []
48+
for inchikey, spectrum_indexes in spectra_with_embeddings.spectrum_indexes_per_inchikey.items():
49+
all_ms2deepscores_for_inchikey = ms2deepscores[spectrum_indexes]
50+
highest_score_for_inchikey.append(max(all_ms2deepscores_for_inchikey))
51+
inchikey_indexes_with_highest_ms2deepscore = np.argpartition(
52+
np.array(highest_score_for_inchikey), -nr_of_inchikeys_to_select)[-nr_of_inchikeys_to_select:]
53+
54+
all_inchikeys = list(spectra_with_embeddings.inchikey_fingerprint_pairs.keys())
55+
top_inchikeys = [all_inchikeys[inchikey_index] for inchikey_index in inchikey_indexes_with_highest_ms2deepscore]
56+
return top_inchikeys
57+
4558
def get_average_predictions_for_closely_related_metabolites(spectra_with_embeddings, top_k_inchikeys,
4659
all_ms2deepscores):
4760
"""Calculates the average ms2deepscore predictions for top k closest inchikeys"""

0 commit comments

Comments
 (0)