|
| 1 | +import json |
| 2 | +from pathlib import Path |
| 3 | +from typing import Sequence, Tuple |
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
| 6 | +from matchms import Spectrum |
| 7 | +from matchms.importing import load_spectra |
| 8 | +from ms2deepscore.models import load_model |
| 9 | +from ms2deepscore.vector_operations import cosine_similarity_matrix |
| 10 | +from tqdm import tqdm |
| 11 | +from ms2query.benchmarking.AnnotatedSpectrumSet import AnnotatedSpectrumSet |
| 12 | +from ms2query.benchmarking.Embeddings import Embeddings |
| 13 | +from ms2query.benchmarking.Fingerprints import Fingerprints |
| 14 | +from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores |
| 15 | + |
| 16 | + |
| 17 | +def run_ms2query( |
| 18 | + query_embeddings: Embeddings, |
| 19 | + library_embeddings: Embeddings, |
| 20 | + library_metadata: pd.DataFrame, |
| 21 | + spectrum_indices_per_inchikey: dict[str, Tuple[int, ...]], |
| 22 | + top_k_tanimoto_scores: TopKTanimotoScores, |
| 23 | + batch_size: int = 1000, |
| 24 | +): |
| 25 | + num_of_query_embeddings = query_embeddings.embeddings.shape[0] |
| 26 | + |
| 27 | + library_index_highest_ms2deepscore = np.zeros((num_of_query_embeddings), dtype=int) |
| 28 | + ms2query_scores = [] |
| 29 | + for start_idx in tqdm( |
| 30 | + range(0, num_of_query_embeddings, batch_size), |
| 31 | + desc="Predicting highest ms2deepscore per batch of " |
| 32 | + + str(min(batch_size, num_of_query_embeddings)) |
| 33 | + + " embeddings", |
| 34 | + ): |
| 35 | + # Do MS2DeepScore predictions for batch |
| 36 | + end_idx = min(start_idx + batch_size, num_of_query_embeddings) |
| 37 | + selected_query_embeddings = query_embeddings.embeddings[start_idx:end_idx] |
| 38 | + score_matrix = cosine_similarity_matrix(selected_query_embeddings, library_embeddings.embeddings) |
| 39 | + highest_score_idx = np.argmax(score_matrix, axis=1) |
| 40 | + library_index_highest_ms2deepscore[start_idx:end_idx] = highest_score_idx |
| 41 | + |
| 42 | + # get predicted inchikeys |
| 43 | + predicted_inchikeys = library_metadata.iloc[highest_score_idx]["inchikey"] |
| 44 | + # Compute MS2Query reliability score |
| 45 | + ms2query_scores.extend( |
| 46 | + get_ms2query_reliability_prediction( |
| 47 | + predicted_inchikeys, spectrum_indices_per_inchikey, top_k_tanimoto_scores, score_matrix |
| 48 | + ) |
| 49 | + ) |
| 50 | + |
| 51 | + # construct results df |
| 52 | + results = library_metadata.iloc[library_index_highest_ms2deepscore] |
| 53 | + results["ms2query_reliability_prediction"] = ms2query_scores |
| 54 | + return results |
| 55 | + |
| 56 | + |
| 57 | +def get_ms2query_reliability_prediction( |
| 58 | + predicted_inchikeys: list[str], |
| 59 | + spectrum_indices_per_inchikey, |
| 60 | + top_k_tanimoto_scores: TopKTanimotoScores, |
| 61 | + ms2deepscore_score_matrix, |
| 62 | +) -> list[float]: |
| 63 | + ms2query_scores = [] |
| 64 | + for query_spectrum_index, library_inchikey in enumerate(predicted_inchikeys): |
| 65 | + top_k_inchikeys = top_k_tanimoto_scores.select_top_k_inchikeys(library_inchikey[:14]) |
| 66 | + maximum_ms2deepscores = np.zeros(top_k_tanimoto_scores.k, dtype=float) |
| 67 | + for i, inchikey in enumerate(top_k_inchikeys): |
| 68 | + spectrum_indexes = spectrum_indices_per_inchikey[inchikey] |
| 69 | + highest_ms2deepscore = np.max(ms2deepscore_score_matrix[query_spectrum_index, spectrum_indexes]) |
| 70 | + maximum_ms2deepscores[i] = highest_ms2deepscore |
| 71 | + ms2query_scores.append(np.mean(maximum_ms2deepscores)) |
| 72 | + # todo get the spectrum hashes instead of the indexes for lookup later. |
| 73 | + return ms2query_scores |
0 commit comments