Skip to content

Commit 1867637

Browse files
committed
add first set of tests
1 parent e1ad847 commit 1867637

File tree

5 files changed

+924
-0
lines changed

5 files changed

+924
-0
lines changed

tests/test_ann_index.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import io
2+
import json
3+
import sqlite3
4+
from typing import List, Tuple
5+
6+
import numpy as np
7+
import pytest
8+
from matchms import Spectrum
9+
10+
from ms2query.database import ANNIndex
11+
from ms2query.database.spectra_merging import ensure_merged_tables
12+
13+
# --- small helpers for array <-> BLOB used in tests (mirrors the production helpers) ---
14+
15+
def _ndarray_to_blob(arr: np.ndarray) -> bytes:
16+
with io.BytesIO() as f:
17+
np.save(f, arr, allow_pickle=False)
18+
return f.getvalue()
19+
20+
21+
@pytest.fixture()
22+
def conn() -> sqlite3.Connection:
23+
# In-memory DB for tests
24+
return sqlite3.connect(":memory:")
25+
26+
27+
@pytest.fixture()
28+
def ann(conn) -> ANNIndex:
29+
# Instantiate with dummy model path; we’ll monkeypatch load_model.
30+
return ANNIndex(
31+
conn=conn,
32+
model_path="dummy_model.pt",
33+
faiss_metric="ip",
34+
faiss_factory=None,
35+
normalize_embeddings=True,
36+
)
37+
38+
39+
def test_ensure_schema_creates_tables(ann: ANNIndex):
40+
"""Schema should be created with all required columns."""
41+
ann.ensure_schema()
42+
cur = ann.conn.cursor()
43+
# Check both tables exist
44+
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='merged_spectra';")
45+
assert cur.fetchone() is not None
46+
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='merged_embeddings';")
47+
assert cur.fetchone() is not None
48+
49+
# Check a few critical columns exist
50+
cur.execute("PRAGMA table_info('merged_spectra');")
51+
cols = {row[1] for row in cur.fetchall()}
52+
for required in ("merged_id", "comp_id", "precursor_mz", "mz", "intensities", "num_merged"):
53+
assert required in cols
54+
55+
56+
def _insert_synthetic_merged_rows(conn: sqlite3.Connection) -> Tuple[int, int]:
57+
"""
58+
Insert two tiny merged_spectra rows with minimal viable metadata.
59+
Returns their merged_ids (sqlite autoincrement).
60+
"""
61+
cur = conn.cursor()
62+
ensure_merged_tables(conn)
63+
64+
# Synthetic peaks
65+
mz1 = np.array([100.0, 150.0, 200.0], dtype=np.float64)
66+
it1 = np.array([0.2, 0.3, 0.5], dtype=np.float32)
67+
68+
mz2 = np.array([101.0, 151.0, 201.0], dtype=np.float64)
69+
it2 = np.array([0.4, 0.1, 0.5], dtype=np.float32)
70+
71+
base_cols = (
72+
"comp_id, ionmode, charge, precursor_mz, smiles, inchikey, inchi, name, "
73+
"instrument_type, adduct, collision_energy, num_merged, source_spec_ids, mz, intensities"
74+
)
75+
q = f"INSERT INTO merged_spectra ({base_cols}) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);"
76+
77+
cur.execute(
78+
q,
79+
(
80+
"C1", "positive", 1, 300.123, "C(CO)O", "AAAA-BBBB-CCCC", "InChI=1S/...", "Compound A",
81+
"QTOF", "[M+H]+", "NCE 20", 3, json.dumps([11, 12, 13]),
82+
sqlite3.Binary(_ndarray_to_blob(mz1)), sqlite3.Binary(_ndarray_to_blob(it1))
83+
),
84+
)
85+
id1 = cur.lastrowid
86+
87+
cur.execute(
88+
q,
89+
(
90+
"C2", "positive", 1, 450.5, "CCN(CC)CC", "XXXX-YYYY-ZZZZ", "InChI=1S/...", "Compound B",
91+
"Orbitrap", "[M+H]+", "NCE 25", 2, json.dumps([21, 22]),
92+
sqlite3.Binary(_ndarray_to_blob(mz2)), sqlite3.Binary(_ndarray_to_blob(it2))
93+
),
94+
)
95+
id2 = cur.lastrowid
96+
97+
conn.commit()
98+
return id1, id2
99+
100+
101+
def test_compute_embeddings_inserts_rows(ann: ANNIndex, monkeypatch):
102+
"""Embeddings should be computed and written; rerun with only_missing yields 0 new rows."""
103+
id1, id2 = _insert_synthetic_merged_rows(ann.conn)
104+
105+
# Monkeypatch model loading and embedding to be deterministic & light.
106+
class _DummyModel:
107+
def eval(self):
108+
pass
109+
110+
def fake_load_model(_path):
111+
return _DummyModel()
112+
113+
def fake_compute_embedding_array(model, specs: List[Spectrum]) -> np.ndarray:
114+
# simple deterministic embedding: [precursor_mz, charge, sum(intens), len(peaks)]
115+
out = []
116+
for s in specs:
117+
pmz = float(s.metadata["precursor_mz"])
118+
charge = float(s.metadata.get("charge") or 0)
119+
intens_sum = float(np.sum(s.peaks.intensities))
120+
n_peaks = float(len(s.peaks.mz))
121+
out.append([pmz, charge, intens_sum, n_peaks])
122+
return np.asarray(out, dtype=np.float32)
123+
124+
monkeypatch.setattr("ms2query.database.ann_index.load_model", fake_load_model)
125+
monkeypatch.setattr("ms2query.database.ann_index.compute_embedding_array", fake_compute_embedding_array)
126+
127+
inserted = ann.compute_embeddings_to_sqlite(batch_rows=64, only_missing=True)
128+
assert inserted == 2
129+
130+
# Confirm rows exist
131+
cur = ann.conn.cursor()
132+
cur.execute("SELECT COUNT(1) FROM merged_embeddings;")
133+
assert cur.fetchone()[0] == 2
134+
135+
# Re-run with only_missing: should insert 0
136+
inserted2 = ann.compute_embeddings_to_sqlite(batch_rows=64, only_missing=True)
137+
assert inserted2 == 0
138+
139+
140+
def test_build_index_and_query(ann: ANNIndex, monkeypatch):
141+
"""Build index from stored embeddings and query it; top-1 should be the intended nearest."""
142+
id1, id2 = _insert_synthetic_merged_rows(ann.conn)
143+
144+
# Same monkeypatch as previous test (model + embeddings)
145+
class _DummyModel:
146+
def eval(self):
147+
pass
148+
149+
def fake_load_model(_path):
150+
return _DummyModel()
151+
152+
def fake_compute_embedding_array(model, specs: List[Spectrum]) -> np.ndarray:
153+
# Embedding consistent with test_compute_embeddings
154+
out = []
155+
for s in specs:
156+
pmz = float(s.metadata["precursor_mz"])
157+
charge = float(s.metadata.get("charge") or 0)
158+
intens_sum = float(np.sum(s.peaks.intensities))
159+
n_peaks = float(len(s.peaks.mz))
160+
out.append([pmz, charge, intens_sum, n_peaks])
161+
return np.asarray(out, dtype=np.float32)
162+
163+
monkeypatch.setattr("ms2query.database.ann_index.load_model", fake_load_model)
164+
monkeypatch.setattr("ms2query.database.ann_index.compute_embedding_array", fake_compute_embedding_array)
165+
166+
# Compute embeddings
167+
ann.compute_embeddings_to_sqlite(batch_rows=64, only_missing=False)
168+
169+
# Build FAISS index
170+
index = ann.build_index()
171+
assert index.ntotal == 2
172+
173+
# Prepare a query spectrum that should be closest to the 2nd row (precursor_mz=450.5)
174+
q_mz = np.array([100.0, 200.0], dtype=np.float32)
175+
q_it = np.array([0.5, 0.5], dtype=np.float32)
176+
q_spec = Spectrum(mz=q_mz, intensities=q_it, metadata={"precursor_mz": 450.5, "ionmode": "positive", "charge": 1})
177+
178+
# Query
179+
results = ann.query(q_spec, k=2, include_metadata=True, as_dataframe=True)
180+
assert isinstance(results, list) and len(results) == 1
181+
df = results[0]
182+
assert {"rank", "merged_id", "score", "distance", "comp_id", "name"}.issubset(df.columns)
183+
184+
# Top-1 hit should be the row with precursor_mz=450.5 (id2)
185+
top1 = df.iloc[0]
186+
assert int(top1["merged_id"]) == id2
187+
assert top1["comp_id"] == "C2"
188+
assert top1["name"] == "Compound B"
189+
190+
# Scores should be non-increasing by rank
191+
assert np.all(df["score"].values[:-1] >= df["score"].values[1:])

0 commit comments

Comments
 (0)