Skip to content

Commit 4833abe

Browse files
committed
Parallel indexing
1 parent 1866a1d commit 4833abe

File tree

1 file changed

+68
-3
lines changed

1 file changed

+68
-3
lines changed

examples/searchKnnCloserFirst_test.cpp

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,69 @@
99

1010
#include <vector>
1111
#include <iostream>
12+
#include <thread>
1213

1314
namespace
1415
{
1516

1617
using idx_t = hnswlib::labeltype;
1718

19+
template<class Function>
20+
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
21+
if (numThreads <= 0) {
22+
numThreads = std::thread::hardware_concurrency();
23+
}
24+
25+
if (numThreads == 1) {
26+
for (size_t id = start; id < end; id++) {
27+
fn(id, 0);
28+
}
29+
} else {
30+
std::vector<std::thread> threads;
31+
std::atomic<size_t> current(start);
32+
33+
// keep track of exceptions in threads
34+
// https://stackoverflow.com/a/32428427/1713196
35+
std::exception_ptr lastException = nullptr;
36+
std::mutex lastExceptMutex;
37+
38+
for (size_t threadId = 0; threadId < numThreads; ++threadId) {
39+
threads.push_back(std::thread([&, threadId] {
40+
while (true) {
41+
size_t id = current.fetch_add(1);
42+
43+
if ((id >= end)) {
44+
break;
45+
}
46+
47+
try {
48+
fn(id, threadId);
49+
} catch (...) {
50+
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
51+
lastException = std::current_exception();
52+
/*
53+
* This will work even when current is the largest value that
54+
* size_t can fit, because fetch_add returns the previous value
55+
* before the increment (what will result in overflow
56+
* and produce 0 instead of current + 1).
57+
*/
58+
current = end;
59+
break;
60+
}
61+
}
62+
}));
63+
}
64+
for (auto &thread : threads) {
65+
thread.join();
66+
}
67+
if (lastException) {
68+
std::rethrow_exception(lastException);
69+
}
70+
}
71+
72+
73+
}
74+
1875
void test() {
1976
int d = 4;
2077
idx_t n = 100;
@@ -40,10 +97,18 @@ void test() {
4097
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
4198
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
4299

43-
for (size_t i = 0; i < n; ++i) {
44-
alg_brute->addPoint(data.data() + d * i, i);
100+
// for (size_t i = 0; i < n; ++i) {
101+
// alg_brute->addPoint(data.data() + d * i, i);
102+
// alg_hnsw->addPoint(data.data() + d * i, i);
103+
// }
104+
105+
ParallelFor(0, n, 4, [&](size_t i, size_t threadId) {
45106
alg_hnsw->addPoint(data.data() + d * i, i);
46-
}
107+
});
108+
109+
ParallelFor(0, n, 4, [&](size_t i, size_t threadId) {
110+
alg_brute->addPoint(data.data() + d * i, i);
111+
});
47112

48113
// test searchKnnCloserFirst of BruteforceSearch
49114
for (size_t j = 0; j < nq; ++j) {

0 commit comments

Comments
 (0)