|
| 1 | +import sys |
| 2 | +import os |
| 3 | +import time |
| 4 | +import argparse |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +# allow import of the built extension |
| 8 | +sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'build'))) |
| 9 | +from zenann import IVFFlatIndex |
| 10 | + |
| 11 | + |
| 12 | +def read_fvecs(fname): |
| 13 | + """ |
| 14 | + Read FLANN .fvecs file: each vector stored as int dim, then float data |
| 15 | + """ |
| 16 | + data = np.fromfile(fname, dtype='int32') |
| 17 | + dim = data[0] |
| 18 | + assert data.size % (dim + 1) == 0, "Invalid .fvecs file" |
| 19 | + vects = data.reshape(-1, dim+1)[:, 1:].astype('float32') |
| 20 | + return vects |
| 21 | + |
| 22 | + |
| 23 | +def read_ivecs(fname): |
| 24 | + """ |
| 25 | + Read FLANN .ivecs file: each entry stored as int dim, then int indices |
| 26 | + """ |
| 27 | + data = np.fromfile(fname, dtype='int32') |
| 28 | + k = data[0] |
| 29 | + assert data.size % (k + 1) == 0, "Invalid .ivecs file" |
| 30 | + vects = data.reshape(-1, k+1)[:, 1:] |
| 31 | + return vects |
| 32 | + |
| 33 | + |
| 34 | +def compute_recall(results, groundtruth, k): |
| 35 | + """ |
| 36 | + Compute recall@k: fraction of queries where gt[0] in results |
| 37 | + """ |
| 38 | + num_q = results.shape[0] |
| 39 | + hits = 0 |
| 40 | + for i in range(num_q): |
| 41 | + if groundtruth[i, 0] in results[i, :k]: |
| 42 | + hits += 1 |
| 43 | + return hits / num_q |
| 44 | + |
| 45 | + |
| 46 | +def main(args): |
| 47 | + # Load dataset |
| 48 | + print("Loading dataset...") |
| 49 | + base = read_fvecs(args.base) |
| 50 | + queries = read_fvecs(args.query) |
| 51 | + gt = read_ivecs(args.groundtruth) |
| 52 | + |
| 53 | + print(f"Base vectors: {base.shape}") |
| 54 | + print(f"Queries: {queries.shape}") |
| 55 | + print(f"Groundtruth: {gt.shape}") |
| 56 | + |
| 57 | + # Build IVF index |
| 58 | + print("Building IVF index...") |
| 59 | + index = IVFFlatIndex(dim=base.shape[1], nlist=args.nlist, nprobe=args.nprobe) |
| 60 | + t0 = time.time() |
| 61 | + index.build(base) |
| 62 | + t_build = time.time() - t0 |
| 63 | + print(f"Index build time: {t_build:.3f} s") |
| 64 | + |
| 65 | + # Search |
| 66 | + print("Running search...") |
| 67 | + K = args.k |
| 68 | + results = np.empty((queries.shape[0], K), dtype=np.int64) |
| 69 | + t0 = time.time() |
| 70 | + for i, q in enumerate(queries): |
| 71 | + res = index.search(q.tolist(), K) |
| 72 | + results[i, :len(res.indices)] = res.indices |
| 73 | + t_search = time.time() - t0 |
| 74 | + qps = queries.shape[0] / t_search |
| 75 | + |
| 76 | + # Compute recall |
| 77 | + recall = compute_recall(results, gt, K) |
| 78 | + |
| 79 | + print(f"Recall@{K}: {recall * 100:.2f}%") |
| 80 | + print(f"Search QPS: {qps:.2f} queries/sec") |
| 81 | + |
| 82 | +if __name__ == '__main__': |
| 83 | + parser = argparse.ArgumentParser("ZenANN IVF SIFT1M benchmark") |
| 84 | + parser.add_argument("--base", required=True, help="Path to sift1M_base.fvecs") |
| 85 | + parser.add_argument("--query", required=True, help="Path to sift1M_query.fvecs") |
| 86 | + parser.add_argument("--groundtruth", required=True, help="Path to sift1M_groundtruth.ivecs") |
| 87 | + parser.add_argument("--nlist", type=int, default=1024, help="Number of IVF clusters") |
| 88 | + parser.add_argument("--nprobe", type=int, default=4, help="Number of clusters to probe") |
| 89 | + parser.add_argument("--k", type=int, default=10, help="Number of nearest neighbors (k)") |
| 90 | + args = parser.parse_args() |
| 91 | + main(args) |
0 commit comments