Skip to content

Commit 890e426

Browse files
mdouzefacebook-github-bot
authored andcommitted
Add benchmark to measure the ResultHandler overhead (#4778)
Summary: There seems to be a performance regression in Faiss after the ResultHandler's introduction. This test attemtps to reproduce it. It is difficult to run fine-grained benchmarks on devservers (when trying to spot a 1% regression), but here are the results: before inlined scanner: P2154463854 after inlined scanner: P2154465531 Especially very fast distance computations (like Flat with small dimensions) clearly benefit from it. Reviewed By: junjieqi Differential Revision: D91784507
1 parent 851ce24 commit 890e426

File tree

2 files changed

+205
-2
lines changed

2 files changed

+205
-2
lines changed

benchs/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,15 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
6+
find_package(BLAS REQUIRED)
7+
find_package(LAPACK REQUIRED)
8+
find_package(OpenMP REQUIRED)
79

810
add_executable(bench_ivf_selector EXCLUDE_FROM_ALL bench_ivf_selector.cpp)
9-
target_link_libraries(bench_ivf_selector PRIVATE faiss)
11+
target_link_libraries(bench_ivf_selector PRIVATE faiss_avx512 ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES} OpenMP::OpenMP_CXX)
12+
target_compile_options(bench_ivf_selector PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw -mpopcnt>)
1013

14+
add_executable(bench_result_handler_overhead EXCLUDE_FROM_ALL
15+
bench_result_handler_overhead.cpp)
16+
target_link_libraries(bench_result_handler_overhead PRIVATE faiss_avx512 ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES} OpenMP::OpenMP_CXX)
17+
target_compile_options(bench_result_handler_overhead PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-mavx2 -mfma -mf16c -mavx512f -mavx512cd -mavx512vl -mavx512dq -mavx512bw -mpopcnt>)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <faiss/AutoTune.h>
9+
#include <faiss/Index.h>
10+
#include <faiss/IndexIVF.h>
11+
#include <faiss/index_factory.h>
12+
#include <faiss/utils/random.h>
13+
#include <faiss/utils/utils.h>
14+
#include <omp.h>
15+
16+
#include <cmath>
17+
#include <iomanip>
18+
#include <iostream>
19+
#include <map>
20+
#include <vector>
21+
22+
namespace faiss {
23+
24+
namespace {
25+
26+
constexpr int nb = 100000;
27+
constexpr int nq = 1000;
28+
constexpr int nrun = 100;
29+
constexpr float min_run_len_ms = 2000.0;
30+
31+
struct IndexData {
32+
std::unique_ptr<Index> index;
33+
std::vector<float> xq;
34+
};
35+
36+
struct BenchmarkResult {
37+
std::string index_factory;
38+
int d;
39+
int k;
40+
int nprobe;
41+
double mean_time;
42+
double std_time;
43+
};
44+
45+
std::pair<double, double> run_search(
46+
IndexData& data,
47+
int d,
48+
int k,
49+
int nprobe,
50+
const char* factory_string) {
51+
ParameterSpace().set_index_parameter(data.index.get(), "nprobe", nprobe);
52+
53+
omp_set_num_threads(1);
54+
55+
std::vector<float> distances(nq * k);
56+
std::vector<idx_t> labels(nq * k);
57+
58+
// Warmup
59+
data.index->search(nq, data.xq.data(), k, distances.data(), labels.data());
60+
61+
// Timed runs - stop if total time exceeds 2 seconds
62+
double t0 = getmillisecs();
63+
std::vector<double> search_times;
64+
for (int run = 0; run < nrun; run++) {
65+
indexIVF_stats.reset();
66+
data.index->search(
67+
nq, data.xq.data(), k, distances.data(), labels.data());
68+
search_times.push_back(indexIVF_stats.search_time);
69+
if (getmillisecs() - t0 > min_run_len_ms) {
70+
break;
71+
}
72+
}
73+
74+
// Compute mean and std (in us/query)
75+
double sum = 0.0;
76+
for (double t : search_times) {
77+
sum += t;
78+
}
79+
double mean = sum / search_times.size() / nq * 1000.0;
80+
81+
double sq_sum = 0.0;
82+
for (double t : search_times) {
83+
double t_us = t / nq * 1000.0;
84+
sq_sum += (t_us - mean) * (t_us - mean);
85+
}
86+
double std = search_times.size() > 1
87+
? std::sqrt(sq_sum / (search_times.size() - 1))
88+
: 0.0;
89+
90+
return {mean, std};
91+
}
92+
93+
IndexData build_index(int d, const char* factory_string) {
94+
omp_set_num_threads(32);
95+
96+
int nt = std::max(nb, 1024);
97+
98+
std::vector<float> xt(nt * d);
99+
std::vector<float> xb(nb * d);
100+
101+
rand_smooth_vectors(nt, d, xt.data(), 12345);
102+
rand_smooth_vectors(nb, d, xb.data(), 23456);
103+
104+
IndexData data;
105+
data.index.reset(index_factory(d, factory_string));
106+
data.index->train(nt, xt.data());
107+
data.index->add(nb, xb.data());
108+
109+
data.xq.resize(nq * d);
110+
rand_smooth_vectors(nq, d, data.xq.data(), 34567);
111+
112+
return data;
113+
}
114+
115+
void print_results_table(
116+
const std::string& index_factory,
117+
int d,
118+
const std::vector<BenchmarkResult>& results) {
119+
std::vector<int> ks_list = {1, 4, 16};
120+
std::vector<int> nprobes_list = {1, 4, 16};
121+
122+
std::map<std::pair<int, int>, std::pair<double, double>> result_map;
123+
for (const auto& r : results) {
124+
result_map[{r.k, r.nprobe}] = {r.mean_time, r.std_time};
125+
}
126+
127+
std::cout << "\n"
128+
<< index_factory << " d=" << d
129+
<< " (time in us/query, mean ± stddev)\n";
130+
std::cout << std::string(76, '-') << "\n";
131+
132+
std::cout << std::setw(8) << "k \\ np"
133+
<< " |";
134+
for (int np : nprobes_list) {
135+
std::cout << std::setw(16) << np << " |";
136+
}
137+
std::cout << "\n";
138+
std::cout << std::string(76, '-') << "\n";
139+
140+
for (int k : ks_list) {
141+
std::cout << std::setw(8) << k << " |";
142+
for (int np : nprobes_list) {
143+
auto it = result_map.find({k, np});
144+
if (it != result_map.end()) {
145+
std::ostringstream oss;
146+
oss << std::fixed << std::setprecision(1) << it->second.first
147+
<< " ± " << it->second.second;
148+
std::cout << std::setw(16) << oss.str() << " |";
149+
} else {
150+
std::cout << std::setw(16) << "N/A"
151+
<< " |";
152+
}
153+
}
154+
std::cout << "\n";
155+
}
156+
}
157+
158+
} // namespace
159+
160+
} // namespace faiss
161+
162+
int main() {
163+
std::vector<std::pair<int, std::string>> indexes = {
164+
// 256 bit codes
165+
{64, "IVF256,SQ4"},
166+
{256, "IVF256,RaBitQ"},
167+
{16, "IVF256,SQfp16"},
168+
// 512 bit codes
169+
{128, "IVF256,SQ4"},
170+
{512, "IVF256,RaBitQ"},
171+
{32, "IVF256,SQfp16"},
172+
};
173+
std::vector<int> ks = {1, 4, 16};
174+
std::vector<int> nprobes = {1, 4, 16};
175+
176+
for (auto p : indexes) {
177+
std::string index_factory = p.second;
178+
int d = p.first;
179+
std::cout << "Building " << index_factory << " d=" << d << "..."
180+
<< std::flush;
181+
faiss::IndexData data = faiss::build_index(d, index_factory.c_str());
182+
std::cout << " done\n";
183+
184+
std::vector<faiss::BenchmarkResult> results;
185+
for (int k : ks) {
186+
for (int nprobe : nprobes) {
187+
auto [mean, std] = faiss::run_search(
188+
data, d, k, nprobe, index_factory.c_str());
189+
results.push_back({index_factory, d, k, nprobe, mean, std});
190+
}
191+
}
192+
faiss::print_results_table(index_factory, d, results);
193+
}
194+
195+
return 0;
196+
}

0 commit comments

Comments
 (0)