Skip to content

Commit 9eefe29

Browse files
committed
Two main changes:
1. searchKnn will fist check if the graph is empty 2. searchKnn will return a min-heap The test code in sift_1b is changed and tested.
1 parent 0f3f72a commit 9eefe29

File tree

4 files changed

+42
-15
lines changed

4 files changed

+42
-15
lines changed

hnswlib/bruteforce.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ namespace hnswlib {
8484
}
8585

8686

87-
std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
87+
retType<dist_t> searchKnn(const void *query_data, size_t k) const {
88+
retType<dist_t> result;
89+
if (cur_element_count == 0) return result;
90+
8891
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
8992
for (int i = 0; i < k; i++) {
9093
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
@@ -103,7 +106,12 @@ namespace hnswlib {
103106
}
104107

105108
}
106-
return topResults;
109+
while (!topResults.empty()) {
110+
auto each = topResults.top();
111+
result.push(each);
112+
topResults.pop();
113+
}
114+
return result;
107115
};
108116

109117
void saveIndex(const std::string &location) {

hnswlib/hnswalg.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,10 @@ namespace hnswlib {
895895
return cur_c;
896896
};
897897

898-
std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
898+
retType<dist_t> searchKnn(const void *query_data, size_t k) const {
899+
retType<dist_t> result;
900+
if (cur_element_count == 0) return result;
901+
899902
tableint currObj = enterpoint_node_;
900903
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
901904

@@ -934,16 +937,15 @@ namespace hnswlib {
934937
currObj, query_data, std::max(ef_, k));
935938
top_candidates.swap(top_candidates1);
936939
}
937-
std::priority_queue<std::pair<dist_t, labeltype >> results;
938940
while (top_candidates.size() > k) {
939941
top_candidates.pop();
940942
}
941943
while (top_candidates.size() > 0) {
942944
std::pair<dist_t, tableint> rez = top_candidates.top();
943-
results.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
945+
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
944946
top_candidates.pop();
945947
}
946-
return results;
948+
return result;
947949
};
948950

949951

hnswlib/hnswlib.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,29 @@
2424
#endif
2525

2626
#include <queue>
27+
#include <vector>
2728

2829
#include <string.h>
2930

3031
namespace hnswlib {
3132
typedef size_t labeltype;
3233

34+
template <typename T>
35+
class pairGreater {
36+
public:
37+
bool operator()(const T& p1, const T& p2) {
38+
return p1.first > p2.first;
39+
}
40+
};
41+
42+
template <typename T>
43+
using retType = std::priority_queue<
44+
std::pair<T, labeltype>,
45+
std::vector<std::pair<T, labeltype>>,
46+
pairGreater<std::pair<T, labeltype>>
47+
>;
48+
49+
3350
template<typename T>
3451
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
3552
out.write((char *) &podRef, sizeof(T));
@@ -61,7 +78,7 @@ namespace hnswlib {
6178
class AlgorithmInterface {
6279
public:
6380
virtual void addPoint(const void *datapoint, labeltype label)=0;
64-
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
81+
virtual retType<dist_t> searchKnn(const void *, size_t) const = 0;
6582
virtual void saveIndex(const std::string &location)=0;
6683
virtual ~AlgorithmInterface(){
6784
}

sift_1b.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ static size_t getCurrentRSS() {
147147

148148
static void
149149
get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t vecsize, size_t qsize, L2SpaceI &l2space,
150-
size_t vecdim, vector<std::priority_queue<std::pair<int, labeltype >>> &answers, size_t k) {
150+
size_t vecdim, vector<retType<int>> &answers, size_t k) {
151151

152152

153-
(vector<std::priority_queue<std::pair<int, labeltype >>>(qsize)).swap(answers);
153+
(vector<retType<int>>(qsize)).swap(answers);
154154
DISTFUNC<int> fstdistfunc_ = l2space.get_dist_func();
155155
cout << qsize << "\n";
156156
for (int i = 0; i < qsize; i++) {
@@ -162,15 +162,15 @@ get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t v
162162

163163
static float
164164
test_approx(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW<int> &appr_alg, size_t vecdim,
165-
vector<std::priority_queue<std::pair<int, labeltype >>> &answers, size_t k) {
165+
vector<retType<int>> &answers, size_t k) {
166166
size_t correct = 0;
167167
size_t total = 0;
168168
//uncomment to test in parallel mode:
169169
//#pragma omp parallel for
170170
for (int i = 0; i < qsize; i++) {
171171

172-
std::priority_queue<std::pair<int, labeltype >> result = appr_alg.searchKnn(massQ + vecdim * i, k);
173-
std::priority_queue<std::pair<int, labeltype >> gt(answers[i]);
172+
retType<int> result = appr_alg.searchKnn(massQ + vecdim * i, k);
173+
retType<int> gt(answers[i]);
174174
unordered_set<labeltype> g;
175175
total += gt.size();
176176

@@ -196,7 +196,7 @@ test_approx(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW<
196196

197197
static void
198198
test_vs_recall(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW<int> &appr_alg, size_t vecdim,
199-
vector<std::priority_queue<std::pair<int, labeltype >>> &answers, size_t k) {
199+
vector<retType<int>> &answers, size_t k) {
200200
vector<size_t> efs;// = { 10,10,10,10,10 };
201201
for (int i = k; i < 30; i++) {
202202
efs.push_back(i);
@@ -231,7 +231,7 @@ inline bool exists_test(const std::string &name) {
231231
void sift_test1B() {
232232

233233

234-
int subset_size_milllions = 200;
234+
int subset_size_milllions = 1;
235235
int efConstruction = 40;
236236
int M = 16;
237237

@@ -351,7 +351,7 @@ void sift_test1B() {
351351
}
352352

353353

354-
vector<std::priority_queue<std::pair<int, labeltype >>> answers;
354+
vector<retType<int>> answers;
355355
size_t k = 1;
356356
cout << "Parsing gt:\n";
357357
get_gt(massQA, massQ, mass, vecsize, qsize, l2space, vecdim, answers, k);

0 commit comments

Comments
 (0)