99
1010#include < vector>
1111#include < iostream>
12+ #include < thread>
1213
1314namespace
1415{
1516
1617using idx_t = hnswlib::labeltype;
1718
19+ template <class Function >
20+ inline void ParallelFor (size_t start, size_t end, size_t numThreads, Function fn) {
21+ if (numThreads <= 0 ) {
22+ numThreads = std::thread::hardware_concurrency ();
23+ }
24+
25+ if (numThreads == 1 ) {
26+ for (size_t id = start; id < end; id++) {
27+ fn (id, 0 );
28+ }
29+ } else {
30+ std::vector<std::thread> threads;
31+ std::atomic<size_t > current (start);
32+
33+ // keep track of exceptions in threads
34+ // https://stackoverflow.com/a/32428427/1713196
35+ std::exception_ptr lastException = nullptr ;
36+ std::mutex lastExceptMutex;
37+
38+ for (size_t threadId = 0 ; threadId < numThreads; ++threadId) {
39+ threads.push_back (std::thread ([&, threadId] {
40+ while (true ) {
41+ size_t id = current.fetch_add (1 );
42+
43+ if ((id >= end)) {
44+ break ;
45+ }
46+
47+ try {
48+ fn (id, threadId);
49+ } catch (...) {
50+ std::unique_lock<std::mutex> lastExcepLock (lastExceptMutex);
51+ lastException = std::current_exception ();
52+ /*
53+ * This will work even when current is the largest value that
54+ * size_t can fit, because fetch_add returns the previous value
55+ * before the increment (what will result in overflow
56+ * and produce 0 instead of current + 1).
57+ */
58+ current = end;
59+ break ;
60+ }
61+ }
62+ }));
63+ }
64+ for (auto &thread : threads) {
65+ thread.join ();
66+ }
67+ if (lastException) {
68+ std::rethrow_exception (lastException);
69+ }
70+ }
71+
72+
73+ }
74+
1875void test () {
1976 int d = 4 ;
2077 idx_t n = 100 ;
@@ -40,10 +97,18 @@ void test() {
4097 hnswlib::AlgorithmInterface<float >* alg_brute = new hnswlib::BruteforceSearch<float >(&space, 2 * n);
4198 hnswlib::AlgorithmInterface<float >* alg_hnsw = new hnswlib::HierarchicalNSW<float >(&space, 2 * n);
4299
43- for (size_t i = 0 ; i < n; ++i) {
44- alg_brute->addPoint (data.data () + d * i, i);
100+ // for (size_t i = 0; i < n; ++i) {
101+ // alg_brute->addPoint(data.data() + d * i, i);
102+ // alg_hnsw->addPoint(data.data() + d * i, i);
103+ // }
104+
105+ ParallelFor (0 , n, 4 , [&](size_t i, size_t threadId) {
45106 alg_hnsw->addPoint (data.data () + d * i, i);
46- }
107+ });
108+
109+ ParallelFor (0 , n, 4 , [&](size_t i, size_t threadId) {
110+ alg_brute->addPoint (data.data () + d * i, i);
111+ });
47112
48113 // test searchKnnCloserFirst of BruteforceSearch
49114 for (size_t j = 0 ; j < nq; ++j) {
0 commit comments