|
1 | 1 | import json |
| 2 | +from collections import defaultdict |
2 | 3 | from pathlib import Path |
3 | 4 | from typing import Sequence, Tuple |
4 | 5 | import numpy as np |
|
14 | 15 | from ms2query.benchmarking.TopKTanimotoScores import TopKTanimotoScores |
15 | 16 |
|
16 | 17 |
|
| 18 | +def run_ms2query_from_files( |
| 19 | + query_spectrum_file, |
| 20 | + ms2deepscore_model_file_name, |
| 21 | + reference_embeddings_file, |
| 22 | + top_k_tanimoto_scores_file, |
| 23 | + reference_metadata_file, |
| 24 | +): |
| 25 | + reference_embeddings = Embeddings.load(reference_embeddings_file) |
| 26 | + top_k_tanimoto_scores = TopKTanimotoScores.load(top_k_tanimoto_scores_file) |
| 27 | + reference_metadata = pd.read_parquet(reference_metadata_file) |
| 28 | + # Get the spectrum_indices_per_inchikey |
| 29 | + spectrum_indices_per_inchikey = defaultdict(list) |
| 30 | + for lib_spec_index, inchikey in enumerate(reference_metadata["inchikey"]): |
| 31 | + spectrum_indices_per_inchikey[inchikey[:14]].append(lib_spec_index) |
| 32 | + |
| 33 | + query_spectra = list(tqdm(load_spectra(query_spectrum_file), desc="loading_in_query_spectra")) |
| 34 | + ms2deepscore_model = load_model(ms2deepscore_model_file_name) |
| 35 | + query_embeddings = Embeddings.create_from_spectra(query_spectra, ms2deepscore_model) |
| 36 | + run_ms2query( |
| 37 | + query_embeddings, reference_embeddings, reference_metadata, spectrum_indices_per_inchikey, top_k_tanimoto_scores |
| 38 | + ) |
| 39 | + |
| 40 | + |
17 | 41 | def run_ms2query( |
18 | 42 | query_embeddings: Embeddings, |
19 | 43 | library_embeddings: Embeddings, |
20 | 44 | library_metadata: pd.DataFrame, |
21 | | - spectrum_indices_per_inchikey: dict[str, Tuple[int, ...]], |
| 45 | + spectrum_indices_per_inchikey: defaultdict[str, list[int]], |
22 | 46 | top_k_tanimoto_scores: TopKTanimotoScores, |
23 | 47 | batch_size: int = 1000, |
24 | 48 | ): |
|
0 commit comments