Skip to content

Commit 9a3c03d

Browse files
committed
A new interface taking a template comparator is added, so it can not be virtual.
I modified the sift_1b(not commited) to test the new interface, the result is ok. Test result of sift_1b on 1 million data. Loading GT: Loading queries: Loading index from sift1b_1m_ef_40_M_16.bin: Actual memory usage: 417 Mb Parsing gt: 10000 Loaded gt 1 0.2371 13.319 us 2 0.3712 15.691 us 3 0.4615 18.5166 us 4 0.5273 20.6371 us 5 0.5758 22.1235 us 6 0.6179 24.4141 us 7 0.6502 25.9906 us 8 0.6796 28.2004 us 9 0.7042 29.8559 us 10 0.7243 31.3286 us 11 0.7432 36.0276 us 12 0.7605 34.9448 us 13 0.7754 36.4176 us 14 0.7874 37.7606 us 15 0.8013 44.6698 us 16 0.8116 47.4424 us 17 0.8239 46.9154 us 18 0.8312 45.9322 us 19 0.8379 49.3406 us 20 0.8442 49.124 us 21 0.8507 52.1223 us 22 0.8566 52.4161 us 23 0.8622 56.9665 us 24 0.8675 71.5782 us 25 0.8731 72.4451 us 26 0.8768 57.0935 us 27 0.8812 58.3525 us 28 0.8845 59.5751 us 29 0.889 61.7516 us 30 0.8935 62.6091 us 40 0.9224 76.8735 us 50 0.9412 92.5431 us 60 0.9541 107.141 us 70 0.9632 121.24 us 80 0.9708 135.862 us 90 0.9756 163.516 us 100 0.9792 180.539 us 140 0.9883 228.747 us 180 0.9921 281.199 us 220 0.9942 338.32 us 260 0.9956 388.501 us 300 0.9962 445.776 us 340 0.9968 477.474 us 380 0.9975 534.054 us 420 0.9982 582.327 us 460 0.9983 625.824 us Actual memory usage: 419 Mb
1 parent cca05f7 commit 9a3c03d

File tree

4 files changed

+65
-30
lines changed

4 files changed

+65
-30
lines changed

hnswlib/bruteforce.h

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <unordered_map>
33
#include <fstream>
44
#include <mutex>
5+
#include <algorithm>
56

67
namespace hnswlib {
78
template<typename dist_t>
@@ -84,11 +85,10 @@ namespace hnswlib {
8485
}
8586

8687

87-
retType<dist_t> searchKnn(const void *query_data, size_t k) const {
88-
retType<dist_t> result;
89-
if (cur_element_count == 0) return result;
90-
88+
std::priority_queue<std::pair<dist_t, labeltype >>
89+
searchKnn(const void *query_data, size_t k) const {
9190
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
91+
if (cur_element_count == 0) return topResults;
9292
for (int i = 0; i < k; i++) {
9393
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
9494
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
@@ -106,13 +106,30 @@ namespace hnswlib {
106106
}
107107

108108
}
109-
while (!topResults.empty()) {
110-
auto each = topResults.top();
111-
result.push(each);
112-
topResults.pop();
109+
return topResults;
110+
};
111+
112+
template <typename Comp>
113+
std::vector<std::pair<dist_t, labeltype>>
114+
searchKnn(const void* query_data, size_t k, Comp comp) {
115+
std::vector<std::pair<dist_t, labeltype>> result;
116+
if (cur_element_count == 0) return result;
117+
118+
auto ret = searchKnn(query_data, k);
119+
120+
while (!ret.empty()) {
121+
result.push_back(ret.top());
122+
ret.pop();
113123
}
124+
125+
if (result.size() > 1) {
126+
if (!comp(result.front(), result.back())) {
127+
std::reverse(result.begin(), result.end());
128+
}
129+
}
130+
114131
return result;
115-
};
132+
}
116133

117134
void saveIndex(const std::string &location) {
118135
std::ofstream output(location, std::ios::binary);

hnswlib/hnswalg.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ namespace hnswlib {
484484

485485

486486
std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
487+
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
488+
if (cur_element_count == 0) return top_candidates;
487489
tableint currObj = enterpoint_node_;
488490
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
489491

@@ -510,8 +512,6 @@ namespace hnswlib {
510512
}
511513
}
512514

513-
514-
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
515515
if (has_deletions_) {
516516
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
517517
ef_);
@@ -895,8 +895,9 @@ namespace hnswlib {
895895
return cur_c;
896896
};
897897

898-
retType<dist_t> searchKnn(const void *query_data, size_t k) const {
899-
retType<dist_t> result;
898+
std::priority_queue<std::pair<dist_t, labeltype >>
899+
searchKnn(const void *query_data, size_t k) const {
900+
std::priority_queue<std::pair<dist_t, labeltype >> result;
900901
if (cur_element_count == 0) return result;
901902

902903
tableint currObj = enterpoint_node_;
@@ -948,6 +949,27 @@ namespace hnswlib {
948949
return result;
949950
};
950951

952+
template <typename Comp>
953+
std::vector<std::pair<dist_t, labeltype>>
954+
searchKnn(const void* query_data, size_t k, Comp comp) {
955+
std::vector<std::pair<dist_t, labeltype>> result;
956+
if (cur_element_count == 0) return result;
957+
958+
auto ret = searchKnn(query_data, k);
959+
960+
while (!ret.empty()) {
961+
result.push_back(ret.top());
962+
ret.pop();
963+
}
964+
965+
if (result.size() > 1) {
966+
if (!comp(result.front(), result.back())) {
967+
std::reverse(result.begin(), result.end());
968+
}
969+
}
970+
971+
return result;
972+
}
951973

952974
};
953975

hnswlib/hnswlib.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,6 @@ namespace hnswlib {
3939
}
4040
};
4141

42-
template <typename T>
43-
using retType = std::priority_queue<
44-
std::pair<T, labeltype>,
45-
std::vector<std::pair<T, labeltype>>,
46-
pairGreater<std::pair<T, labeltype>>
47-
>;
48-
49-
5042
template<typename T>
5143
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
5244
out.write((char *) &podRef, sizeof(T));
@@ -78,7 +70,10 @@ namespace hnswlib {
7870
class AlgorithmInterface {
7971
public:
8072
virtual void addPoint(const void *datapoint, labeltype label)=0;
81-
virtual retType<dist_t> searchKnn(const void *, size_t) const = 0;
73+
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
74+
template <typename Comp>
75+
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp) {
76+
}
8277
virtual void saveIndex(const std::string &location)=0;
8378
virtual ~AlgorithmInterface(){
8479
}

sift_1b.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <iostream>
22
#include <fstream>
3+
#include <queue>
34
#include <chrono>
45
#include "hnswlib/hnswlib.h"
56

@@ -146,10 +147,10 @@ static size_t getCurrentRSS() {
146147

147148
static void
148149
get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t vecsize, size_t qsize, L2SpaceI &l2space,
149-
size_t vecdim, vector<retType<int>> &answers, size_t k) {
150+
size_t vecdim, vector<std::priority_queue<std::pair<int, labeltype >>> &answers, size_t k) {
150151

151152

152-
(vector<retType<int>>(qsize)).swap(answers);
153+
(vector<std::priority_queue<std::pair<int, labeltype >>>(qsize)).swap(answers);
153154
DISTFUNC<int> fstdistfunc_ = l2space.get_dist_func();
154155
cout << qsize << "\n";
155156
for (int i = 0; i < qsize; i++) {
@@ -161,15 +162,15 @@ get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t v
161162

162163
static float
163164
test_approx(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW<int> &appr_alg, size_t vecdim,
164-
vector<retType<int>> &answers, size_t k) {
165+
vector<std::priority_queue<std::pair<int, labeltype >>> &answers, size_t k) {
165166
size_t correct = 0;
166167
size_t total = 0;
167168
//uncomment to test in parallel mode:
168169
//#pragma omp parallel for
169170
for (int i = 0; i < qsize; i++) {
170171

171-
retType<int> result = appr_alg.searchKnn(massQ + vecdim * i, k);
172-
retType<int> gt(answers[i]);
172+
std::priority_queue<std::pair<int, labeltype >> result = appr_alg.searchKnn(massQ + vecdim * i, k);
173+
std::priority_queue<std::pair<int, labeltype >> gt(answers[i]);
173174
unordered_set<labeltype> g;
174175
total += gt.size();
175176

@@ -195,7 +196,7 @@ test_approx(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW<
195196

196197
static void
197198
test_vs_recall(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW<int> &appr_alg, size_t vecdim,
198-
vector<retType<int>> &answers, size_t k) {
199+
vector<std::priority_queue<std::pair<int, labeltype >>> &answers, size_t k) {
199200
vector<size_t> efs;// = { 10,10,10,10,10 };
200201
for (int i = k; i < 30; i++) {
201202
efs.push_back(i);
@@ -230,7 +231,7 @@ inline bool exists_test(const std::string &name) {
230231
void sift_test1B() {
231232

232233

233-
int subset_size_milllions = 1;
234+
int subset_size_milllions = 200;
234235
int efConstruction = 40;
235236
int M = 16;
236237

@@ -350,7 +351,7 @@ void sift_test1B() {
350351
}
351352

352353

353-
vector<retType<int>> answers;
354+
vector<std::priority_queue<std::pair<int, labeltype >>> answers;
354355
size_t k = 1;
355356
cout << "Parsing gt:\n";
356357
get_gt(massQA, massQ, mass, vecsize, qsize, l2space, vecdim, answers, k);

0 commit comments

Comments
 (0)