Skip to content

Commit 62eef51

Browse files
committed
cleaning and additional test
1 parent c3d7b62 commit 62eef51

File tree

1 file changed

+67
-12
lines changed

1 file changed

+67
-12
lines changed

tests/test_ann_vector_index.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import numpy as np
55
import pytest
66
import scipy.sparse as sp
7-
from ms2query.database.ann_vector_index import EmbeddingIndex, csr_row_from_tuple, l1_norms_csr, tuples_to_csr
7+
from ms2query.database.ann_vector_index import (
8+
EmbeddingIndex,
9+
FingerprintSparseIndex, # <-- added
10+
csr_row_from_tuple,
11+
l1_norms_csr,
12+
tuples_to_csr,
13+
)
814

915

1016
def _mk_unit_vecs(*rows):
@@ -15,8 +21,8 @@ def _mk_unit_vecs(*rows):
1521

1622

1723
def test_build_index_and_query_dense():
18-
X = _mk_unit_vecs([1,0,0], [0,1,0], [0,0,1])
19-
ids = ["a","b","c"]
24+
X = _mk_unit_vecs([1, 0, 0], [0, 1, 0], [0, 0, 1])
25+
ids = ["a", "b", "c"]
2026
idx = EmbeddingIndex(dim=3)
2127
idx.build_index(X, ids)
2228
# Query close to [1,0,0]
@@ -27,11 +33,11 @@ def test_build_index_and_query_dense():
2733

2834

2935
def test_build_index_normalizes_when_requested():
30-
X = np.array([[2.0,0,0],[0,2.0,0]], dtype=np.float32)
31-
ids = ["x","y"]
36+
X = np.array([[2.0, 0, 0], [0, 2.0, 0]], dtype=np.float32)
37+
ids = ["x", "y"]
3238
idx = EmbeddingIndex(dim=3)
3339
idx.build_index(X, ids)
34-
q = np.array([1.0,0,0], dtype=np.float32)
40+
q = np.array([1.0, 0, 0], dtype=np.float32)
3541
out = idx.query(q, k=1)
3642
assert out[0][0] == "x"
3743
assert 0.99 <= out[0][1] <= 1.0
@@ -41,14 +47,14 @@ def test_query_errors_and_dim_check():
4147
idx = EmbeddingIndex(dim=3)
4248
with pytest.raises(RuntimeError):
4349
idx.query(np.zeros(3, np.float32))
44-
idx.build_index(np.eye(3, dtype=np.float32), ["a","b","c"])
50+
idx.build_index(np.eye(3, dtype=np.float32), ["a", "b", "c"])
4551
with pytest.raises(ValueError, match="dim=3"):
4652
idx.query(np.zeros(4, np.float32))
4753

4854

4955
def test_save_and_load_roundtrip_dense(tmp_path):
50-
X = _mk_unit_vecs([1,0,0],[0,1,0],[0,0,1])
51-
ids = ["a","b","c"]
56+
X = _mk_unit_vecs([1, 0, 0], [0, 1, 0], [0, 0, 1])
57+
ids = ["a", "b", "c"]
5258
idx = EmbeddingIndex(dim=3)
5359
idx.build_index(X, ids)
5460
prefix = os.path.join(tmp_path, "emb")
@@ -59,7 +65,7 @@ def test_save_and_load_roundtrip_dense(tmp_path):
5965
idx2.load_index(prefix)
6066

6167
# Query should still work
62-
res = idx2.query(np.array([1.0,0,0], dtype=np.float32), k=1)
68+
res = idx2.query(np.array([1.0, 0, 0], dtype=np.float32), k=1)
6369
assert res[0][0] == "a"
6470
# meta persisted
6571
with open(prefix + ".meta.json") as f:
@@ -72,7 +78,8 @@ def test_build_index_from_sqlite_streams_and_orders(batch_rows):
7278
conn = sqlite3.connect(":memory:")
7379
conn.execute("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)")
7480
# Add 3 vectors of dim 3
75-
conn.executemany( "INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)",
81+
conn.executemany(
82+
"INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)",
7683
[
7784
("id_1", np.array([1.0, 0.0, 0.0], np.float32).tobytes(), 3),
7885
("id_2", np.array([1.0, 1.0, 0.0], np.float32).tobytes(), 3),
@@ -106,7 +113,7 @@ def test_build_index_from_sqlite_errors():
106113

107114
# Empty table should error
108115
conn2 = sqlite3.connect(":memory:")
109-
conn2.execute("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)")
116+
conn2.execute("CREATE TABLE embeddings(spec_id TEXT, vec, d INTEGER)")
110117
with pytest.raises(ValueError, match="No rows"):
111118
idx.build_index_from_sqlite(conn2, embeddings_table="embeddings")
112119

@@ -158,3 +165,51 @@ def test_l1_norms_csr():
158165
norms = l1_norms_csr(X)
159166
np.testing.assert_allclose(norms, [3.0, 1.0, 4.0])
160167
assert norms.dtype == np.float64
168+
169+
170+
def test_save_and_load_roundtrip_fingerprint_sparse(tmp_path):
171+
# Build a tiny sparse fingerprint matrix with 3 compounds, dim=5
172+
tuples = [
173+
(np.array([0, 3], dtype=np.int32), np.array([1.0, 0.5], dtype=np.float32)),
174+
(np.array([1], dtype=np.int32), np.array([1.0], dtype=np.float32)),
175+
(np.array([2, 4], dtype=np.int32), np.array([0.2, 2.0], dtype=np.float32)),
176+
]
177+
csr = tuples_to_csr(tuples, dim=5)
178+
comp_ids = np.array([10, 11, 12], dtype=int)
179+
180+
idx = FingerprintSparseIndex(dim=5)
181+
idx.build_index(
182+
csr,
183+
comp_ids,
184+
keep_csr_for_rerank=True,
185+
compute_l1_for_rerank=True,
186+
)
187+
188+
# Basic sanity: query with the first fingerprint should hit comp_id=10 first
189+
q = tuples[0]
190+
res = idx.query(q, k=1)
191+
assert res[0][0] == 10
192+
193+
# Save to disk
194+
prefix = os.path.join(tmp_path, "fp")
195+
idx.save_index(prefix)
196+
197+
# New instance loads back
198+
idx2 = FingerprintSparseIndex()
199+
idx2.load_index(prefix)
200+
201+
# Query should still work and return the same top compound
202+
res2 = idx2.query(q, k=1)
203+
assert res2[0][0] == 10
204+
205+
# CSR and L1 data should have been persisted
206+
assert idx2._csr is not None
207+
assert idx2._l1 is not None
208+
assert idx2._csr.shape == csr.shape
209+
assert idx2._l1.shape[0] == csr.shape[0]
210+
211+
# Meta persisted and type is correct
212+
with open(prefix + ".meta.json") as f:
213+
meta = json.load(f)
214+
assert meta["type"] == "FingerprintSparseIndex"
215+
assert meta["space"] == "cosinesimil_sparse"

0 commit comments

Comments
 (0)