Skip to content

Commit 4032847

Browse files
committed
fixes and additional tests
1 parent 077b1af commit 4032847

File tree

2 files changed

+189
-3
lines changed

2 files changed

+189
-3
lines changed

ms2query/ms2query_library.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def query_spectra_by_spectra(
146146
self,
147147
spectra: list[Spectrum],
148148
*,
149-
k_spectra: int = 100,
149+
k_spectra: int = 10,
150150
ef: Optional[int] = None,
151151
):
152152
"""
@@ -193,7 +193,7 @@ def query_compounds_by_spectra(
193193
raise ValueError("k_compounds cannot be larger than k_spectra")
194194

195195
# Step1: Query spectral embeddings
196-
results = self.query_spectra_by_spectra(spectra, k=k_spectra, ef=ef)
196+
results = self.query_spectra_by_spectra(spectra, k_spectra=k_spectra, ef=ef)
197197

198198
# Pick k_compounds top compounds from the k_spectra hits (if possible)
199199
spec_ids = results.spec_id.values
@@ -202,7 +202,7 @@ def query_compounds_by_spectra(
202202
compounds = compounds.merge(results, on="spec_id").sort_values(["query_ix", "rank"])
203203

204204
# Pick no more than k_compounds per query_ix
205-
idx = compounds.groupby(['query_ix', 'comp_id'])['score'].idxmax()
205+
idx = compounds.groupby(['query_ix', 'rank'])['score'].idxmax()
206206
best_per_pair = compounds.loc[idx]
207207

208208
# Within each query_ix, keep the top-k by score

tests/test_ms2query_library.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import json
2+
from pathlib import Path
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
from matchms.importing import load_spectra
7+
from ms2query import MS2QueryDatabase, MS2QueryLibrary
8+
from ms2query.library_io import create_new_library, load_created_library
9+
10+
11+
TEST_COMP_ID = "ZBSGKPYXQINNGF" # known from your snippet
12+
EXPECTED_METADATA_SHAPE = (5, 11)
13+
EXPECTED_METADATA_FIELDS = [
14+
"precursor_mz", "ionmode", "smiles", "inchikey", "inchi", "name",
15+
"charge", "instrument_type", "adduct", "collision_energy",
16+
]
17+
EMB_PREFIX = "embedding_index"
18+
SQLITE_NAME = "ms2query_library.sqlite"
19+
MANIFEST_NAME = "ms2query_manifest.json"
20+
21+
22+
def _data_dir() -> Path:
23+
return Path(__file__).parent / "test_data"
24+
25+
26+
def _paths():
27+
data_dir = _data_dir()
28+
spectra_file = data_dir / "10_spectra.mgf"
29+
model_path = data_dir / "ms2deepscore_testmodel_v1.pt"
30+
assert spectra_file.exists(), f"Missing test spectra file: {spectra_file}"
31+
assert model_path.exists(), f"Missing test model file: {model_path}"
32+
return spectra_file, model_path
33+
34+
35+
def create_lib_from_test_files(tmp_path: Path) -> MS2QueryLibrary:
36+
"""Create a small library from test files and return the loaded MS2QueryLibrary."""
37+
spectra_file, model_path = _paths()
38+
outdir = tmp_path / "ms2query_out"
39+
outdir.mkdir(parents=True, exist_ok=True)
40+
41+
# Build the library (keep HNSW params small for speed)
42+
lib_created = create_new_library(
43+
spectra_files=[str(spectra_file)],
44+
annotation_files=[],
45+
output_folder=str(outdir),
46+
model_path=str(model_path),
47+
build_embedding_index=True,
48+
embedding_index_params={"M": 8, "ef_construction": 50, "post_init_ef": 50, "batch_rows": 100_000},
49+
compute_embeddings_batch_rows=256,
50+
)
51+
# Sanity
52+
assert isinstance(lib_created, MS2QueryLibrary)
53+
assert isinstance(lib_created.db, MS2QueryDatabase)
54+
55+
# Load from disk to mirror real workflow
56+
lib_loaded = load_created_library(str(outdir))
57+
assert isinstance(lib_loaded, MS2QueryLibrary)
58+
return lib_loaded
59+
60+
61+
# --------------------------------------------------------------------
62+
# End-to-end smoke test (also re-used by the unit tests below)
63+
# --------------------------------------------------------------------
64+
65+
@pytest.mark.filterwarnings("ignore::UserWarning")
66+
def test_create_and_load_smoke(tmp_path: Path):
67+
lib = create_lib_from_test_files(tmp_path)
68+
69+
outdir = tmp_path / "ms2query_out"
70+
db_path = outdir / SQLITE_NAME
71+
assert db_path.exists(), "MS2Query database file was not created."
72+
73+
# Manifest present & contains basic keys
74+
manifest_path = outdir / MANIFEST_NAME
75+
assert manifest_path.exists()
76+
with open(manifest_path, "r", encoding="utf-8") as f:
77+
manifest = json.load(f)
78+
assert manifest.get("sqlite_path") == SQLITE_NAME
79+
80+
# DB content checks
81+
ms2query_db = lib.db
82+
meta_df = ms2query_db.metadata_by_comp_id(TEST_COMP_ID)
83+
assert tuple(meta_df.shape) == EXPECTED_METADATA_SHAPE
84+
for field in EXPECTED_METADATA_FIELDS:
85+
assert field in ms2query_db.metadata_fields
86+
assert field in meta_df.columns
87+
88+
# ANN artifacts (nmslib base + .dat OR legacy .nmslib)
89+
emb_base = outdir / EMB_PREFIX
90+
two_file_ok = emb_base.exists() and (emb_base.with_suffix(".dat")).exists()
91+
legacy_ok = (emb_base.with_suffix(".nmslib")).exists()
92+
assert two_file_ok or legacy_ok, "Embedding index files missing."
93+
94+
95+
# --------------------------------------------------------------------
96+
# Unit tests for MS2QueryLibrary methods
97+
# --------------------------------------------------------------------
98+
99+
@pytest.mark.filterwarnings("ignore::UserWarning")
100+
def test_process_and_compute_embeddings(tmp_path: Path):
101+
lib = create_lib_from_test_files(tmp_path)
102+
spectra_path, _ = _paths()
103+
spectra = list(load_spectra(spectra_path))
104+
assert len(spectra) > 0
105+
106+
# process_spectra passthrough for now
107+
processed = lib.process_spectra(spectra)
108+
assert isinstance(processed, list)
109+
assert len(processed) == len(spectra)
110+
111+
# compute_embeddings returns (n, d) float32, d > 0
112+
E = lib.compute_embeddings(spectra[:3]) # small batch
113+
assert isinstance(E, np.ndarray)
114+
assert E.dtype == np.float32
115+
assert E.ndim == 2 and E.shape[0] == 3 and E.shape[1] > 0
116+
117+
# L2 normalization sanity (norm ~ 1)
118+
norms = np.linalg.norm(E, axis=1)
119+
assert np.allclose(norms, 1.0, atol=1e-4)
120+
121+
122+
@pytest.mark.filterwarnings("ignore::UserWarning")
123+
def test_query_embedding_index_returns_dataframe(tmp_path: Path):
124+
lib = create_lib_from_test_files(tmp_path)
125+
spectra_path, _ = _paths()
126+
spectra = list(load_spectra(spectra_path))
127+
q = spectra[0]
128+
129+
df = lib.query_embedding_index(q, k=3, ef=40, return_dataframe=True)
130+
assert isinstance(df, pd.DataFrame)
131+
assert set(["query_ix", "rank", "spec_id", "score"]).issubset(df.columns)
132+
assert df.shape[0] >= 1
133+
assert isinstance(df["spec_id"].iloc[0], (str, np.str_))
134+
135+
136+
@pytest.mark.filterwarnings("ignore::UserWarning")
137+
def test_query_spectra_by_spectra_and_compounds(tmp_path: Path):
138+
lib = create_lib_from_test_files(tmp_path)
139+
spectra_path, _ = _paths()
140+
spectra = list(load_spectra(spectra_path))
141+
142+
# spectra-by-spectra (DataFrame)
143+
df_s = lib.query_spectra_by_spectra(spectra[:2], k_spectra=5, ef=40)
144+
assert isinstance(df_s, pd.DataFrame)
145+
assert set(["query_ix", "rank", "spec_id", "score"]).issubset(df_s.columns)
146+
assert df_s["query_ix"].nunique() == 2
147+
148+
# compounds-by-spectra (top-k compounds per query)
149+
df_c = lib.query_compounds_by_spectra(spectra[:3], k_spectra=20, k_compounds=5, ef=40)
150+
# Expect columns from metadata + query_ix/rank/score present after merge
151+
required_cols = set(["query_ix", "spec_id", "score", "inchikey"]).union(EXPECTED_METADATA_FIELDS)
152+
assert required_cols.issubset(df_c.columns)
153+
# per query, at most k_compounds rows
154+
assert (df_c.groupby("query_ix").size() <= 5).all()
155+
156+
157+
@pytest.mark.filterwarnings("ignore::UserWarning")
158+
def test_query_by_spec_ids_uses_db_embeddings(tmp_path: Path):
159+
lib = create_lib_from_test_files(tmp_path)
160+
# grab a few ids from the DB
161+
spec_ids = lib.db.ref_sdb.ids()[:2]
162+
assert len(spec_ids) >= 1
163+
164+
df = lib.query_by_spec_ids(spec_ids, k=4, ef=40, return_dataframe=True)
165+
assert isinstance(df, pd.DataFrame)
166+
assert set(["query_ix", "rank", "spec_id", "score"]).issubset(df.columns)
167+
# same number of query_ix values as requested ids
168+
assert df["query_ix"].nunique() == len(spec_ids)
169+
170+
171+
@pytest.mark.filterwarnings("ignore::UserWarning")
172+
def test_error_paths_missing_model_or_index(tmp_path: Path):
173+
# Fresh DB only (no index attached)
174+
lib = create_lib_from_test_files(tmp_path)
175+
176+
# Remove index to test guard rails
177+
lib.embedding_index = None
178+
spectra_path, _ = _paths()
179+
spectra = list(load_spectra(spectra_path))
180+
with pytest.raises(RuntimeError, match="EmbeddingIndex is not set"):
181+
lib.query_embedding_index(spectra[0], k=3)
182+
183+
# New instance without model_path → compute_embeddings should raise when invoked
184+
lib2 = MS2QueryLibrary(db=lib.db, embedding_index=None, model_path=None)
185+
with pytest.raises(RuntimeError, match="model_path is not set"):
186+
lib2.compute_embeddings(spectra[:1])

0 commit comments

Comments
 (0)