Skip to content

Commit c5b4caf

Browse files
committed
test: add unittest and benchmark test for IVF
1 parent bd10fd4 commit c5b4caf

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

benchmark/sift1m-ivf.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import sys
2+
import os
3+
import time
4+
import argparse
5+
import numpy as np
6+
7+
# allow import of the built extension
8+
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'build')))
9+
from zenann import IVFFlatIndex
10+
11+
12+
def read_fvecs(fname):
13+
"""
14+
Read FLANN .fvecs file: each vector stored as int dim, then float data
15+
"""
16+
data = np.fromfile(fname, dtype='int32')
17+
dim = data[0]
18+
assert data.size % (dim + 1) == 0, "Invalid .fvecs file"
19+
vects = data.reshape(-1, dim+1)[:, 1:].astype('float32')
20+
return vects
21+
22+
23+
def read_ivecs(fname):
24+
"""
25+
Read FLANN .ivecs file: each entry stored as int dim, then int indices
26+
"""
27+
data = np.fromfile(fname, dtype='int32')
28+
k = data[0]
29+
assert data.size % (k + 1) == 0, "Invalid .ivecs file"
30+
vects = data.reshape(-1, k+1)[:, 1:]
31+
return vects
32+
33+
34+
def compute_recall(results, groundtruth, k):
35+
"""
36+
Compute recall@k: fraction of queries where gt[0] in results
37+
"""
38+
num_q = results.shape[0]
39+
hits = 0
40+
for i in range(num_q):
41+
if groundtruth[i, 0] in results[i, :k]:
42+
hits += 1
43+
return hits / num_q
44+
45+
46+
def main(args):
47+
# Load dataset
48+
print("Loading dataset...")
49+
base = read_fvecs(args.base)
50+
queries = read_fvecs(args.query)
51+
gt = read_ivecs(args.groundtruth)
52+
53+
print(f"Base vectors: {base.shape}")
54+
print(f"Queries: {queries.shape}")
55+
print(f"Groundtruth: {gt.shape}")
56+
57+
# Build IVF index
58+
print("Building IVF index...")
59+
index = IVFFlatIndex(dim=base.shape[1], nlist=args.nlist, nprobe=args.nprobe)
60+
t0 = time.time()
61+
index.build(base)
62+
t_build = time.time() - t0
63+
print(f"Index build time: {t_build:.3f} s")
64+
65+
# Search
66+
print("Running search...")
67+
K = args.k
68+
results = np.empty((queries.shape[0], K), dtype=np.int64)
69+
t0 = time.time()
70+
for i, q in enumerate(queries):
71+
res = index.search(q.tolist(), K)
72+
results[i, :len(res.indices)] = res.indices
73+
t_search = time.time() - t0
74+
qps = queries.shape[0] / t_search
75+
76+
# Compute recall
77+
recall = compute_recall(results, gt, K)
78+
79+
print(f"Recall@{K}: {recall * 100:.2f}%")
80+
print(f"Search QPS: {qps:.2f} queries/sec")
81+
82+
if __name__ == '__main__':
83+
parser = argparse.ArgumentParser("ZenANN IVF SIFT1M benchmark")
84+
parser.add_argument("--base", required=True, help="Path to sift1M_base.fvecs")
85+
parser.add_argument("--query", required=True, help="Path to sift1M_query.fvecs")
86+
parser.add_argument("--groundtruth", required=True, help="Path to sift1M_groundtruth.ivecs")
87+
parser.add_argument("--nlist", type=int, default=1024, help="Number of IVF clusters")
88+
parser.add_argument("--nprobe", type=int, default=4, help="Number of clusters to probe")
89+
parser.add_argument("--k", type=int, default=10, help="Number of nearest neighbors (k)")
90+
args = parser.parse_args()
91+
main(args)

tests/test_ivfflat.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import sys
2+
import unittest
3+
4+
# allow import of the built extension
5+
sys.path.append('..')
6+
sys.path.append('build')
7+
8+
from zenann import IVFFlatIndex, SearchResult
9+
10+
class TestIVFBasic(unittest.TestCase):
11+
def setUp(self):
12+
# one-hot vectors in 4-d for simple testing
13+
self.dim = 4
14+
self.data = []
15+
for i in range(self.dim):
16+
v = [0.0] * self.dim
17+
v[i] = 1.0
18+
self.data.append(v)
19+
20+
def test_search_sorted(self):
21+
# Test that search returns distances sorted in ascending order
22+
idx = IVFFlatIndex(dim=self.dim, nlist=self.dim, nprobe=self.dim)
23+
idx.build(self.data)
24+
result = idx.search(self.data[0], k=3)
25+
# distances should be non-decreasing
26+
dists = result.distances
27+
self.assertTrue(all(dists[i] <= dists[i+1] for i in range(len(dists)-1)))
28+
29+
def test_search_batch(self):
30+
idx = IVFFlatIndex(dim=self.dim, nlist=self.dim, nprobe=2)
31+
idx.build(self.data)
32+
# batch of two queries
33+
batch = [self.data[0], self.data[1]]
34+
results = idx.search_batch(batch, k=2)
35+
self.assertEqual(len(results), 2)
36+
for res, query in zip(results, batch):
37+
self.assertIsInstance(res, SearchResult)
38+
# verify the true index is in results
39+
true_id = self.data.index(query)
40+
self.assertIn(true_id, res.indices)
41+
42+
def test_nprobe_effect(self):
43+
# nprobe=1 returns only 1 result
44+
idx1 = IVFFlatIndex(dim=self.dim, nlist=self.dim, nprobe=1)
45+
idx1.build(self.data)
46+
res1 = idx1.search(self.data[2], k=2)
47+
self.assertEqual(len(res1.indices), 1)
48+
# nprobe=2 can return up to 2
49+
idx2 = IVFFlatIndex(dim=self.dim, nlist=self.dim, nprobe=2)
50+
idx2.build(self.data)
51+
res2 = idx2.search(self.data[2], k=2)
52+
self.assertTrue(1 <= len(res2.indices) <= 2)
53+
54+
if __name__ == '__main__':
55+
unittest.main()
56+

0 commit comments

Comments
 (0)