44import numpy as np
55import pytest
66import scipy .sparse as sp
7- from ms2query .database .ann_vector_index import EmbeddingIndex , csr_row_from_tuple , l1_norms_csr , tuples_to_csr
7+ from ms2query .database .ann_vector_index import (
8+ EmbeddingIndex ,
9+ FingerprintSparseIndex , # <-- added
10+ csr_row_from_tuple ,
11+ l1_norms_csr ,
12+ tuples_to_csr ,
13+ )
814
915
1016def _mk_unit_vecs (* rows ):
@@ -15,8 +21,8 @@ def _mk_unit_vecs(*rows):
1521
1622
1723def test_build_index_and_query_dense ():
18- X = _mk_unit_vecs ([1 ,0 , 0 ], [0 ,1 , 0 ], [0 ,0 , 1 ])
19- ids = ["a" ,"b" ,"c" ]
24+ X = _mk_unit_vecs ([1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ])
25+ ids = ["a" , "b" , "c" ]
2026 idx = EmbeddingIndex (dim = 3 )
2127 idx .build_index (X , ids )
2228 # Query close to [1,0,0]
@@ -27,11 +33,11 @@ def test_build_index_and_query_dense():
2733
2834
2935def test_build_index_normalizes_when_requested ():
30- X = np .array ([[2.0 ,0 , 0 ],[0 ,2.0 ,0 ]], dtype = np .float32 )
31- ids = ["x" ,"y" ]
36+ X = np .array ([[2.0 , 0 , 0 ], [0 , 2.0 , 0 ]], dtype = np .float32 )
37+ ids = ["x" , "y" ]
3238 idx = EmbeddingIndex (dim = 3 )
3339 idx .build_index (X , ids )
34- q = np .array ([1.0 ,0 , 0 ], dtype = np .float32 )
40+ q = np .array ([1.0 , 0 , 0 ], dtype = np .float32 )
3541 out = idx .query (q , k = 1 )
3642 assert out [0 ][0 ] == "x"
3743 assert 0.99 <= out [0 ][1 ] <= 1.0
@@ -41,14 +47,14 @@ def test_query_errors_and_dim_check():
4147 idx = EmbeddingIndex (dim = 3 )
4248 with pytest .raises (RuntimeError ):
4349 idx .query (np .zeros (3 , np .float32 ))
44- idx .build_index (np .eye (3 , dtype = np .float32 ), ["a" ,"b" ,"c" ])
50+ idx .build_index (np .eye (3 , dtype = np .float32 ), ["a" , "b" , "c" ])
4551 with pytest .raises (ValueError , match = "dim=3" ):
4652 idx .query (np .zeros (4 , np .float32 ))
4753
4854
4955def test_save_and_load_roundtrip_dense (tmp_path ):
50- X = _mk_unit_vecs ([1 ,0 , 0 ],[0 ,1 , 0 ],[0 ,0 , 1 ])
51- ids = ["a" ,"b" ,"c" ]
56+ X = _mk_unit_vecs ([1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ])
57+ ids = ["a" , "b" , "c" ]
5258 idx = EmbeddingIndex (dim = 3 )
5359 idx .build_index (X , ids )
5460 prefix = os .path .join (tmp_path , "emb" )
@@ -59,7 +65,7 @@ def test_save_and_load_roundtrip_dense(tmp_path):
5965 idx2 .load_index (prefix )
6066
6167 # Query should still work
62- res = idx2 .query (np .array ([1.0 ,0 , 0 ], dtype = np .float32 ), k = 1 )
68+ res = idx2 .query (np .array ([1.0 , 0 , 0 ], dtype = np .float32 ), k = 1 )
6369 assert res [0 ][0 ] == "a"
6470 # meta persisted
6571 with open (prefix + ".meta.json" ) as f :
@@ -72,7 +78,8 @@ def test_build_index_from_sqlite_streams_and_orders(batch_rows):
7278 conn = sqlite3 .connect (":memory:" )
7379 conn .execute ("CREATE TABLE embeddings(spec_id TEXT, vec BLOB, d INTEGER)" )
7480 # Add 3 vectors of dim 3
75- conn .executemany ( "INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)" ,
81+ conn .executemany (
82+ "INSERT INTO embeddings(spec_id, vec, d) VALUES (?,?,?)" ,
7683 [
7784 ("id_1" , np .array ([1.0 , 0.0 , 0.0 ], np .float32 ).tobytes (), 3 ),
7885 ("id_2" , np .array ([1.0 , 1.0 , 0.0 ], np .float32 ).tobytes (), 3 ),
@@ -106,7 +113,7 @@ def test_build_index_from_sqlite_errors():
106113
107114 # Empty table should error
108115 conn2 = sqlite3 .connect (":memory:" )
109- conn2 .execute ("CREATE TABLE embeddings(spec_id TEXT, vec BLOB , d INTEGER)" )
116+ conn2 .execute ("CREATE TABLE embeddings(spec_id TEXT, vec, d INTEGER)" )
110117 with pytest .raises (ValueError , match = "No rows" ):
111118 idx .build_index_from_sqlite (conn2 , embeddings_table = "embeddings" )
112119
@@ -158,3 +165,51 @@ def test_l1_norms_csr():
158165 norms = l1_norms_csr (X )
159166 np .testing .assert_allclose (norms , [3.0 , 1.0 , 4.0 ])
160167 assert norms .dtype == np .float64
168+
169+
170+ def test_save_and_load_roundtrip_fingerprint_sparse (tmp_path ):
171+ # Build a tiny sparse fingerprint matrix with 3 compounds, dim=5
172+ tuples = [
173+ (np .array ([0 , 3 ], dtype = np .int32 ), np .array ([1.0 , 0.5 ], dtype = np .float32 )),
174+ (np .array ([1 ], dtype = np .int32 ), np .array ([1.0 ], dtype = np .float32 )),
175+ (np .array ([2 , 4 ], dtype = np .int32 ), np .array ([0.2 , 2.0 ], dtype = np .float32 )),
176+ ]
177+ csr = tuples_to_csr (tuples , dim = 5 )
178+ comp_ids = np .array ([10 , 11 , 12 ], dtype = int )
179+
180+ idx = FingerprintSparseIndex (dim = 5 )
181+ idx .build_index (
182+ csr ,
183+ comp_ids ,
184+ keep_csr_for_rerank = True ,
185+ compute_l1_for_rerank = True ,
186+ )
187+
188+ # Basic sanity: query with the first fingerprint should hit comp_id=10 first
189+ q = tuples [0 ]
190+ res = idx .query (q , k = 1 )
191+ assert res [0 ][0 ] == 10
192+
193+ # Save to disk
194+ prefix = os .path .join (tmp_path , "fp" )
195+ idx .save_index (prefix )
196+
197+ # New instance loads back
198+ idx2 = FingerprintSparseIndex ()
199+ idx2 .load_index (prefix )
200+
201+ # Query should still work and return the same top compound
202+ res2 = idx2 .query (q , k = 1 )
203+ assert res2 [0 ][0 ] == 10
204+
205+ # CSR and L1 data should have been persisted
206+ assert idx2 ._csr is not None
207+ assert idx2 ._l1 is not None
208+ assert idx2 ._csr .shape == csr .shape
209+ assert idx2 ._l1 .shape [0 ] == csr .shape [0 ]
210+
211+ # Meta persisted and type is correct
212+ with open (prefix + ".meta.json" ) as f :
213+ meta = json .load (f )
214+ assert meta ["type" ] == "FingerprintSparseIndex"
215+ assert meta ["space" ] == "cosinesimil_sparse"
0 commit comments