1+ #include " ../hnswlib/hnswlib.h"
2+ #include < thread>
3+ class StopW
4+ {
5+ std::chrono::steady_clock::time_point time_begin;
6+
7+ public:
8+ StopW ()
9+ {
10+ time_begin = std::chrono::steady_clock::now ();
11+ }
12+
13+ float getElapsedTimeMicro ()
14+ {
15+ std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now ();
16+ return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count ());
17+ }
18+
19+ void reset ()
20+ {
21+ time_begin = std::chrono::steady_clock::now ();
22+ }
23+ };
24+
25+ /*
26+ * replacement for the openmp '#pragma omp parallel for' directive
27+ * only handles a subset of functionality (no reductions etc)
28+ * Process ids from start (inclusive) to end (EXCLUSIVE)
29+ *
30+ * The method is borrowed from nmslib
31+ */
32+ template <class Function >
33+ inline void ParallelFor (size_t start, size_t end, size_t numThreads, Function fn) {
34+ if (numThreads <= 0 ) {
35+ numThreads = std::thread::hardware_concurrency ();
36+ }
37+
38+ if (numThreads == 1 ) {
39+ for (size_t id = start; id < end; id++) {
40+ fn (id, 0 );
41+ }
42+ } else {
43+ std::vector<std::thread> threads;
44+ std::atomic<size_t > current (start);
45+
46+ // keep track of exceptions in threads
47+ // https://stackoverflow.com/a/32428427/1713196
48+ std::exception_ptr lastException = nullptr ;
49+ std::mutex lastExceptMutex;
50+
51+ for (size_t threadId = 0 ; threadId < numThreads; ++threadId) {
52+ threads.push_back (std::thread ([&, threadId] {
53+ while (true ) {
54+ size_t id = current.fetch_add (1 );
55+
56+ if ((id >= end)) {
57+ break ;
58+ }
59+
60+ try {
61+ fn (id, threadId);
62+ } catch (...) {
63+ std::unique_lock<std::mutex> lastExcepLock (lastExceptMutex);
64+ lastException = std::current_exception ();
65+ /*
66+ * This will work even when current is the largest value that
67+ * size_t can fit, because fetch_add returns the previous value
68+ * before the increment (what will result in overflow
69+ * and produce 0 instead of current + 1).
70+ */
71+ current = end;
72+ break ;
73+ }
74+ }
75+ }));
76+ }
77+ for (auto &thread : threads) {
78+ thread.join ();
79+ }
80+ if (lastException) {
81+ std::rethrow_exception (lastException);
82+ }
83+ }
84+
85+
86+ }
87+
88+
89+ template <typename datatype>
90+ std::vector<datatype> load_batch (std::string path, int size)
91+ {
92+ std::cout << " Loading " << path << " ..." ;
93+ // float or int32 (python)
94+ assert (sizeof (datatype) == 4 );
95+
96+ std::ifstream file;
97+ file.open (path);
98+ if (!file.is_open ())
99+ {
100+ std::cout << " Cannot open " << path << " \n " ;
101+ exit (1 );
102+ }
103+ std::vector<datatype> batch (size);
104+
105+ file.read ((char *)batch.data (), size * sizeof (float ));
106+ std::cout << " DONE\n " ;
107+ return batch;
108+ }
109+
110+ template <typename d_type>
111+ static float
112+ test_approx (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
113+ std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
114+ {
115+ size_t correct = 0 ;
116+ size_t total = 0 ;
117+ // uncomment to test in parallel mode:
118+
119+
120+ for (int i = 0 ; i < qsize; i++)
121+ {
122+
123+ std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn ((char *)(queries.data () + vecdim * i), K);
124+ total += K;
125+ while (result.size ())
126+ {
127+ if (answers[i].find (result.top ().second ) != answers[i].end ())
128+ {
129+ correct++;
130+ }
131+ else
132+ {
133+ }
134+ result.pop ();
135+ }
136+ }
137+ return 1 .0f * correct / total;
138+ }
139+
140+ static void
141+ test_vs_recall (std::vector<float > &queries, size_t qsize, hnswlib::HierarchicalNSW<float > &appr_alg, size_t vecdim,
142+ std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
143+ {
144+ std::vector<size_t > efs = {1 };
145+ for (int i = k; i < 30 ; i++)
146+ {
147+ efs.push_back (i);
148+ }
149+ for (int i = 30 ; i < 400 ; i+=10 )
150+ {
151+ efs.push_back (i);
152+ }
153+ for (int i = 1000 ; i < 100000 ; i += 5000 )
154+ {
155+ efs.push_back (i);
156+ }
157+ std::cout << " ef\t recall\t time\t hops\t distcomp\n " ;
158+ for (size_t ef : efs)
159+ {
160+ appr_alg.setEf (ef);
161+
162+ appr_alg.metric_hops =0 ;
163+ appr_alg.metric_distance_computations =0 ;
164+ StopW stopw = StopW ();
165+
166+ float recall = test_approx<float >(queries, qsize, appr_alg, vecdim, answers, k);
167+ float time_us_per_query = stopw.getElapsedTimeMicro () / qsize;
168+ float distance_comp_per_query = appr_alg.metric_distance_computations / (1 .0f * qsize);
169+ float hops_per_query = appr_alg.metric_hops / (1 .0f * qsize);
170+
171+ std::cout << ef << " \t " << recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
172+ if (recall > 0.99 )
173+ {
174+ std::cout << " Recall is over 0.99! " <<recall << " \t " << time_us_per_query << " us \t " <<hops_per_query<<" \t " <<distance_comp_per_query << " \n " ;
175+ break ;
176+ }
177+ }
178+ }
179+
180+ int main (int argc, char **argv)
181+ {
182+
183+ int M = 16 ;
184+ int efConstruction = 200 ;
185+ int num_threads = std::thread::hardware_concurrency ();
186+
187+
188+
189+ bool update = false ;
190+
191+ if (argc == 2 )
192+ {
193+ if (std::string (argv[1 ]) == " update" )
194+ {
195+ update = true ;
196+ std::cout << " Updates are on\n " ;
197+ }
198+ else {
199+ std::cout<<" Usage ./test_updates [update]\n " ;
200+ exit (1 );
201+ }
202+ }
203+ else if (argc>2 ){
204+ std::cout<<" Usage ./test_updates [update]\n " ;
205+ exit (1 );
206+ }
207+
208+ std::string path = " ../examples/data/" ;
209+
210+
211+ int N;
212+ int dummy_data_multiplier;
213+ int N_queries;
214+ int d;
215+ int K;
216+ {
217+ std::ifstream configfile;
218+ configfile.open (path + " /config.txt" );
219+ if (!configfile.is_open ())
220+ {
221+ std::cout << " Cannot open config.txt\n " ;
222+ return 1 ;
223+ }
224+ configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K;
225+
226+ printf (" Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n " , N, dummy_data_multiplier, N_queries, d, K);
227+ }
228+
229+ hnswlib::L2Space l2space (d);
230+ hnswlib::HierarchicalNSW<float > appr_alg (&l2space, N + 1 , M, efConstruction);
231+
232+ std::vector<float > dummy_batch = load_batch<float >(path + " batch_dummy_00.bin" , N * d);
233+
234+ // Adding enterpoint:
235+
236+ appr_alg.addPoint ((void *)dummy_batch.data (), (size_t )0 );
237+
238+ StopW stopw = StopW ();
239+
240+ if (update)
241+ {
242+ std::cout << " Update iteration 0\n " ;
243+
244+
245+ ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
246+ appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
247+ });
248+ appr_alg.checkIntegrity ();
249+
250+ ParallelFor (1 , N, num_threads, [&](size_t i, size_t threadId) {
251+ appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
252+ });
253+ appr_alg.checkIntegrity ();
254+
255+ for (int b = 1 ; b < dummy_data_multiplier; b++)
256+ {
257+ std::cout << " Update iteration " << b << " \n " ;
258+ char cpath[1024 ];
259+ sprintf (cpath, " batch_dummy_%02d.bin" , b);
260+ std::vector<float > dummy_batchb = load_batch<float >(path + cpath, N * d);
261+
262+ ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
263+ appr_alg.addPoint ((void *)(dummy_batch.data () + i * d), i);
264+ });
265+ appr_alg.checkIntegrity ();
266+ }
267+ }
268+
269+ std::cout << " Inserting final elements\n " ;
270+ std::vector<float > final_batch = load_batch<float >(path + " batch_final.bin" , N * d);
271+
272+ stopw.reset ();
273+ ParallelFor (0 , N, num_threads, [&](size_t i, size_t threadId) {
274+ appr_alg.addPoint ((void *)(final_batch.data () + i * d), i);
275+ });
276+ std::cout<<" Finished. Time taken:" << stopw.getElapsedTimeMicro ()*1e-6 << " s\n " ;
277+ std::cout << " Running tests\n " ;
278+ std::vector<float > queries_batch = load_batch<float >(path + " queries.bin" , N_queries * d);
279+
280+ std::vector<int > gt = load_batch<int >(path + " gt.bin" , N_queries * K);
281+
282+ std::vector<std::unordered_set<hnswlib::labeltype>> answers (N_queries);
283+ for (int i = 0 ; i < N_queries; i++)
284+ {
285+ for (int j = 0 ; j < K; j++)
286+ {
287+ answers[i].insert (gt[i * K + j]);
288+ }
289+ }
290+
291+ for (int i = 0 ; i < 3 ; i++)
292+ {
293+ std::cout << " Test iteration " << i << " \n " ;
294+ test_vs_recall (queries_batch, N_queries, appr_alg, d, answers, K);
295+ }
296+
297+ return 0 ;
298+ };
0 commit comments