|
| 1 | +# scripts/hnsw_benchmark.py |
| 2 | +import sys |
| 3 | +import os |
| 4 | +import time |
| 5 | +import argparse |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..', 'build'))) |
| 9 | +from zenann import HNSWIndex |
| 10 | + |
| 11 | +def load_fvecs(filename, c_contiguous=True): |
| 12 | + fv = np.fromfile(filename, dtype=np.float32) |
| 13 | + if fv.size == 0: |
| 14 | + return np.zeros((0, 0)) |
| 15 | + dim = fv.view(np.int32)[0] |
| 16 | + print(dim) |
| 17 | + assert dim > 0 |
| 18 | + fv = fv.reshape(-1, 1 + dim) |
| 19 | + if not all(fv.view(np.int32)[:, 0] == dim): |
| 20 | + raise IOError("Non-uniform vector sizes in " + filename) |
| 21 | + fv = fv[:, 1:] |
| 22 | + if c_contiguous: |
| 23 | + fv = fv.copy() |
| 24 | + return fv |
| 25 | + |
| 26 | + |
| 27 | +def load_ivecs(filename, c_contiguous=True): |
| 28 | + fv = np.fromfile(filename, dtype=np.int32) |
| 29 | + dim = fv.view(np.int32)[0] |
| 30 | + assert dim > 0 |
| 31 | + fv = fv.reshape(-1, 1 + dim) |
| 32 | + fv = fv[:, 1:] |
| 33 | + return fv |
| 34 | + |
| 35 | +def compute_recall(predicted, groundtruth, k=10): |
| 36 | + recalls = [] |
| 37 | + for true_neighbors, pred_neighbors in zip(groundtruth, predicted): |
| 38 | + true_set = set(true_neighbors[:k]) |
| 39 | + pred_set = set(pred_neighbors[:k]) |
| 40 | + recall = len(true_set.intersection(pred_set)) / len(true_set) |
| 41 | + recalls.append(recall) |
| 42 | + return np.mean(recalls) |
| 43 | + |
| 44 | +def main(args): |
| 45 | + base = load_fvecs(args.base) |
| 46 | + queries = load_fvecs(args.query) |
| 47 | + gt = load_ivecs(args.groundtruth) |
| 48 | + print(f"Base: {base.shape}, Queries: {queries.shape}, GT: {gt.shape}") |
| 49 | + |
| 50 | + if args.index_file and os.path.exists(args.index_file): |
| 51 | + print("Loading index...") |
| 52 | + idx = HNSWIndex.read_index(args.index_file) |
| 53 | + else: |
| 54 | + print("Building HNSW index...") |
| 55 | + idx = HNSWIndex(dim=base.shape[1], M=args.M, efConstruction=args.efConstruction) |
| 56 | + t0 = time.time() |
| 57 | + idx.build(base) |
| 58 | + t_build = time.time() - t0 |
| 59 | + print(f"Build time: {t_build:.3f}s") |
| 60 | + if args.index_file: |
| 61 | + idx.write_index(args.index_file) |
| 62 | + |
| 63 | + idx.set_ef_search(args.efSearch) |
| 64 | + |
| 65 | + print("Searching...") |
| 66 | + K = args.k |
| 67 | + mapping_file = args.mapping_file or "bfs_mapping.bin" |
| 68 | + idx.reorder_layout(mapping_file) |
| 69 | + |
| 70 | + t0 = time.time() |
| 71 | + all_res = idx.search_batch(queries.tolist(), K, args.efSearch) |
| 72 | + t_search = time.time() - t0 |
| 73 | + qps = len(queries) / t_search |
| 74 | + |
| 75 | + predicted_flat = [] |
| 76 | + for r in all_res: |
| 77 | + padded = r.indices + [-1] * (K - len(r.indices)) |
| 78 | + predicted_flat.extend(padded[:K]) |
| 79 | + |
| 80 | + recall = idx.compute_recall_with_mapping( |
| 81 | + gt.tolist(), predicted_flat, len(queries), K, mapping_file |
| 82 | + ) |
| 83 | + |
| 84 | + print(f"Recall@{K}: {recall * 100:.2f}%") |
| 85 | + print(f"QPS: {qps:.2f}") |
| 86 | + |
| 87 | +if __name__ == '__main__': |
| 88 | + p = argparse.ArgumentParser() |
| 89 | + p.add_argument("--base", required=True) |
| 90 | + p.add_argument("--query", required=True) |
| 91 | + p.add_argument("--groundtruth", required=True) |
| 92 | + p.add_argument("--index_file", default=None) |
| 93 | + p.add_argument("--M", type=int, default=32) |
| 94 | + p.add_argument("--efConstruction", type=int, default=200) |
| 95 | + p.add_argument("--efSearch", type=int, default=32) |
| 96 | + p.add_argument("--k", type=int, default=10) |
| 97 | + p.add_argument("--mapping_file", default=None, help="Where to save mapping file") |
| 98 | + args = p.parse_args() |
| 99 | + main(args) |
0 commit comments