Skip to content

Commit 70adecb

Browse files
committed
benchmark: HNSW
1 parent d076f30 commit 70adecb

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

benchmark/sift1m-hnsw.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# scripts/hnsw_benchmark.py
2+
import sys
3+
import os
4+
import time
5+
import argparse
6+
import numpy as np
7+
8+
sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'build')))
9+
from zenann import HNSWIndex
10+
11+
def load_fvecs(filename, c_contiguous=True):
12+
fv = np.fromfile(filename, dtype=np.float32)
13+
if fv.size == 0:
14+
return np.zeros((0, 0))
15+
dim = fv.view(np.int32)[0]
16+
print(dim)
17+
assert dim > 0
18+
fv = fv.reshape(-1, 1 + dim)
19+
if not all(fv.view(np.int32)[:, 0] == dim):
20+
raise IOError("Non-uniform vector sizes in " + filename)
21+
fv = fv[:, 1:]
22+
if c_contiguous:
23+
fv = fv.copy()
24+
return fv
25+
26+
27+
def load_ivecs(filename, c_contiguous=True):
28+
fv = np.fromfile(filename, dtype=np.int32)
29+
dim = fv.view(np.int32)[0]
30+
assert dim > 0
31+
fv = fv.reshape(-1, 1 + dim)
32+
fv = fv[:, 1:]
33+
return fv
34+
35+
def compute_recall(predicted, groundtruth, k=10):
36+
recalls = []
37+
for true_neighbors, pred_neighbors in zip(groundtruth, predicted):
38+
true_set = set(true_neighbors[:k])
39+
pred_set = set(pred_neighbors[:k])
40+
recall = len(true_set.intersection(pred_set)) / len(true_set)
41+
recalls.append(recall)
42+
return np.mean(recalls)
43+
44+
def main(args):
45+
base = load_fvecs(args.base)
46+
queries = load_fvecs(args.query)
47+
gt = load_ivecs(args.groundtruth)
48+
print(f"Base: {base.shape}, Queries: {queries.shape}, GT: {gt.shape}")
49+
50+
if args.index_file and os.path.exists(args.index_file):
51+
print("Loading index...")
52+
idx = HNSWIndex.read_index(args.index_file)
53+
else:
54+
print("Building HNSW index...")
55+
idx = HNSWIndex(dim=base.shape[1], M=args.M, efConstruction=args.efConstruction)
56+
t0 = time.time()
57+
idx.build(base)
58+
t_build = time.time() - t0
59+
print(f"Build time: {t_build:.3f}s")
60+
if args.index_file:
61+
idx.write_index(args.index_file)
62+
63+
idx.set_ef_search(args.efSearch)
64+
65+
print("Searching...")
66+
K = args.k
67+
mapping_file = args.mapping_file or "bfs_mapping.bin"
68+
idx.reorder_layout(mapping_file)
69+
70+
t0 = time.time()
71+
all_res = idx.search_batch(queries.tolist(), K, args.efSearch)
72+
t_search = time.time() - t0
73+
qps = len(queries) / t_search
74+
75+
predicted_flat = []
76+
for r in all_res:
77+
padded = r.indices + [-1] * (K - len(r.indices))
78+
predicted_flat.extend(padded[:K])
79+
80+
recall = idx.compute_recall_with_mapping(
81+
gt.tolist(), predicted_flat, len(queries), K, mapping_file
82+
)
83+
84+
print(f"Recall@{K}: {recall * 100:.2f}%")
85+
print(f"QPS: {qps:.2f}")
86+
87+
if __name__ == '__main__':
88+
p = argparse.ArgumentParser()
89+
p.add_argument("--base", required=True)
90+
p.add_argument("--query", required=True)
91+
p.add_argument("--groundtruth", required=True)
92+
p.add_argument("--index_file", default=None)
93+
p.add_argument("--M", type=int, default=32)
94+
p.add_argument("--efConstruction", type=int, default=200)
95+
p.add_argument("--efSearch", type=int, default=32)
96+
p.add_argument("--k", type=int, default=10)
97+
p.add_argument("--mapping_file", default=None, help="Where to save mapping file")
98+
args = p.parse_args()
99+
main(args)

0 commit comments

Comments
 (0)