Skip to content

Commit 75d79ff

Browse files
committed
implement run_ms2query
1 parent 431e802 commit 75d79ff

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

ms2query/run_ms2query.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Sequence, Tuple
4+
import numpy as np
5+
import pandas as pd
6+
from matchms import Spectrum
7+
from matchms.importing import load_spectra
8+
from ms2deepscore.models import load_model
9+
from ms2deepscore.vector_operations import cosine_similarity_matrix
10+
from tqdm import tqdm
11+
from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet
12+
from ms2query.benchmarking.Embeddings import Embeddings
13+
from ms2query.benchmarking.Fingerprints import Fingerprints
14+
from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores
15+
16+
17+
def run_ms2query(
18+
query_embeddings: Embeddings,
19+
library_embeddings: Embeddings,
20+
library_metadata: pd.DataFrame,
21+
spectrum_indices_per_inchikey: dict[str, Tuple[int, ...]],
22+
top_k_tanimoto_scores: TopKTanimotoScores,
23+
batch_size: int = 1000,
24+
):
25+
num_of_query_embeddings = query_embeddings.embeddings.shape[0]
26+
27+
library_index_highest_ms2deepscore = np.zeros((num_of_query_embeddings), dtype=int)
28+
ms2query_scores = []
29+
for start_idx in tqdm(
30+
range(0, num_of_query_embeddings, batch_size),
31+
desc="Predicting highest ms2deepscore per batch of "
32+
+ str(min(batch_size, num_of_query_embeddings))
33+
+ " embeddings",
34+
):
35+
# Do MS2DeepScore predictions for batch
36+
end_idx = min(start_idx + batch_size, num_of_query_embeddings)
37+
selected_query_embeddings = query_embeddings.embeddings[start_idx:end_idx]
38+
score_matrix = cosine_similarity_matrix(selected_query_embeddings, library_embeddings.embeddings)
39+
highest_score_idx = np.argmax(score_matrix, axis=1)
40+
library_index_highest_ms2deepscore[start_idx:end_idx] = highest_score_idx
41+
42+
# get predicted inchikeys
43+
predicted_inchikeys = library_metadata.iloc[highest_score_idx]["inchikey"]
44+
# Compute MS2Query reliability score
45+
ms2query_scores.extend(
46+
get_ms2query_reliability_prediction(
47+
predicted_inchikeys, spectrum_indices_per_inchikey, top_k_tanimoto_scores, score_matrix
48+
)
49+
)
50+
51+
# construct results df
52+
results = library_metadata.iloc[library_index_highest_ms2deepscore]
53+
results["ms2query_reliability_prediction"] = ms2query_scores
54+
return results
55+
56+
57+
def get_ms2query_reliability_prediction(
58+
predicted_inchikeys: list[str],
59+
spectrum_indices_per_inchikey,
60+
top_k_tanimoto_scores: TopKTanimotoScores,
61+
ms2deepscore_score_matrix,
62+
) -> list[float]:
63+
ms2query_scores = []
64+
for query_spectrum_index, library_inchikey in enumerate(predicted_inchikeys):
65+
top_k_inchikeys = top_k_tanimoto_scores.select_top_k_inchikeys(library_inchikey[:14])
66+
maximum_ms2deepscores = np.zeros(top_k_tanimoto_scores.k, dtype=float)
67+
for i, inchikey in enumerate(top_k_inchikeys):
68+
spectrum_indexes = spectrum_indices_per_inchikey[inchikey]
69+
highest_ms2deepscore = np.max(ms2deepscore_score_matrix[query_spectrum_index, spectrum_indexes])
70+
maximum_ms2deepscores[i] = highest_ms2deepscore
71+
ms2query_scores.append(np.mean(maximum_ms2deepscores))
72+
# todo get the spectrum hashes instead of the indexes for lookup later.
73+
return ms2query_scores

0 commit comments

Comments
 (0)