|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import multiprocessing as mp |
| 7 | +import time |
| 8 | + |
| 9 | +import faiss |
| 10 | +import matplotlib.pyplot as plt |
| 11 | +import numpy as np |
| 12 | + |
| 13 | +try: |
| 14 | + from faiss.contrib.datasets_fb import ( |
| 15 | + DatasetSIFT1M, |
| 16 | + DatasetGIST1M, |
| 17 | + SyntheticDataset, |
| 18 | + ) |
| 19 | +except ImportError: |
| 20 | + from faiss.contrib.datasets import ( |
| 21 | + DatasetSIFT1M, |
| 22 | + DatasetGIST1M, |
| 23 | + SyntheticDataset, |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +def eval_recall(index, efSearch_val, xq, gt, k): |
| 28 | + """Evaluate recall and QPS for a given efSearch value.""" |
| 29 | + t0 = time.time() |
| 30 | + _, I = index.search(xq, k=k) |
| 31 | + t = time.time() - t0 |
| 32 | + speed = t * 1000 / len(xq) |
| 33 | + qps = 1000 / speed |
| 34 | + |
| 35 | + corrects = (gt == I).sum() |
| 36 | + recall = corrects / (len(xq) * k) |
| 37 | + print( |
| 38 | + f"\tefSearch {efSearch_val:3d}, Recall@{k}: " |
| 39 | + f"{recall:.6f}, speed: {speed:.6f} ms/query, QPS: {qps:.2f}" |
| 40 | + ) |
| 41 | + |
| 42 | + return recall, qps |
| 43 | + |
| 44 | + |
| 45 | +def get_hnsw_index(index): |
| 46 | + """Extract the underlying HNSW index from a PreTransform index.""" |
| 47 | + if isinstance(index, faiss.IndexPreTransform): |
| 48 | + return faiss.downcast_index(index.index) |
| 49 | + return index |
| 50 | + |
| 51 | + |
| 52 | +def eval_and_plot(name, ds, k=10, nlevels=8, plot_data=None): |
| 53 | + """Evaluate an index configuration and collect data for plotting.""" |
| 54 | + xq = ds.get_queries() |
| 55 | + xb = ds.get_database() |
| 56 | + gt = ds.get_groundtruth() |
| 57 | + |
| 58 | + if hasattr(ds, "get_train"): |
| 59 | + xt = ds.get_train() |
| 60 | + else: |
| 61 | + # Use database as training data if no separate train set |
| 62 | + xt = xb |
| 63 | + |
| 64 | + nb, d = xb.shape |
| 65 | + nq, d = xq.shape |
| 66 | + gt = gt[:, :k] |
| 67 | + |
| 68 | + print(f"\n======{name} on {ds.__class__.__name__}======") |
| 69 | + print(f"Database: {nb} vectors, {d} dimensions") |
| 70 | + print(f"Queries: {nq} vectors") |
| 71 | + |
| 72 | + # Create index |
| 73 | + index = faiss.index_factory(d, name) |
| 74 | + |
| 75 | + faiss.omp_set_num_threads(mp.cpu_count()) |
| 76 | + index.train(xt) |
| 77 | + index.add(xb) |
| 78 | + |
| 79 | + faiss.omp_set_num_threads(1) |
| 80 | + |
| 81 | + # Get the underlying HNSW index for setting efSearch |
| 82 | + hnsw_index = get_hnsw_index(index) |
| 83 | + |
| 84 | + data = [] |
| 85 | + for efSearch in [16, 32, 64, 128, 256, 512]: |
| 86 | + hnsw_index.hnsw.efSearch = efSearch |
| 87 | + recall, qps = eval_recall(index, efSearch, xq, gt, k) |
| 88 | + data.append((recall, qps)) |
| 89 | + |
| 90 | + if plot_data is not None: |
| 91 | + data = np.array(data) |
| 92 | + plot_data.append((name, data)) |
| 93 | + |
| 94 | + |
| 95 | +def benchmark_dataset(ds, dataset_name, k=10, nlevels=8, M=32): |
| 96 | + """Benchmark both regular HNSW and HNSW Panorama on a dataset.""" |
| 97 | + d = ds.d |
| 98 | + |
| 99 | + plot_data = [] |
| 100 | + |
| 101 | + # HNSW Flat (baseline) |
| 102 | + eval_and_plot(f"HNSW{M},Flat", ds, k=k, nlevels=nlevels, plot_data=plot_data) |
| 103 | + |
| 104 | + # HNSW Flat Panorama (with PCA to concentrate energy) |
| 105 | + eval_and_plot( |
| 106 | + f"PCA{d},HNSW{M},FlatPanorama{nlevels}", |
| 107 | + ds, |
| 108 | + k=k, |
| 109 | + nlevels=nlevels, |
| 110 | + plot_data=plot_data, |
| 111 | + ) |
| 112 | + |
| 113 | + # Plot results |
| 114 | + plt.figure(figsize=(8, 6), dpi=80) |
| 115 | + for name, data in plot_data: |
| 116 | + plt.plot(data[:, 0], data[:, 1], marker="o", label=name) |
| 117 | + |
| 118 | + plt.title(f"HNSW Indexes on {dataset_name}") |
| 119 | + plt.xlabel(f"Recall@{k}") |
| 120 | + plt.ylabel("QPS") |
| 121 | + plt.yscale("log") |
| 122 | + plt.legend(bbox_to_anchor=(1.02, 0.1), loc="upper left", borderaxespad=0) |
| 123 | + plt.grid(True, alpha=0.3) |
| 124 | + |
| 125 | + output_file = f"bench_hnsw_flat_panorama_{dataset_name}.png" |
| 126 | + plt.savefig(output_file, bbox_inches="tight") |
| 127 | + print(f"Saved plot to {output_file}") |
| 128 | + plt.close() |
| 129 | + |
| 130 | + |
| 131 | +if __name__ == "__main__": |
| 132 | + k = 10 |
| 133 | + nlevels = 8 |
| 134 | + M = 32 |
| 135 | + |
| 136 | + # Test on 3 datasets with varying dimensionality: |
| 137 | + # SIFT1M (128d), GIST1M (960d), and Synthetic high-dim (2048d) |
| 138 | + datasets = [ |
| 139 | + (DatasetSIFT1M(), "SIFT1M"), |
| 140 | + (DatasetGIST1M(), "GIST1M"), |
| 141 | + # Synthetic high-dimensional dataset: 2048d, 100k train, 1M database, 10k queries |
| 142 | + (SyntheticDataset(2048, 100000, 1000000, 10000), "Synthetic2048D"), |
| 143 | + ] |
| 144 | + |
| 145 | + for ds, name in datasets: |
| 146 | + print(f"\n{'='*60}") |
| 147 | + print(f"Benchmarking on {name}") |
| 148 | + print(f"{'='*60}") |
| 149 | + benchmark_dataset(ds, name, k=k, nlevels=nlevels, M=M) |
| 150 | + |
| 151 | + print("\n" + "="*60) |
| 152 | + print("All benchmarks completed!") |
| 153 | + print("="*60) |
0 commit comments