Skip to content

Commit bfada0a

Browse files
authored
Merge pull request #153 from uestc-lfs/interface-addpoint
Interface addpoint:change void* to const void*
2 parents c5c38f0 + b6181d6 commit bfada0a

File tree

3 files changed

+66
-13
lines changed

3 files changed

+66
-13
lines changed

hnswlib/bruteforce.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <unordered_map>
33
#include <fstream>
44
#include <mutex>
5+
#include <algorithm>
56

67
namespace hnswlib {
78
template<typename dist_t>
@@ -40,7 +41,7 @@ namespace hnswlib {
4041

4142
std::unordered_map<labeltype,size_t > dict_external_to_internal;
4243

43-
void addPoint(void *datapoint, labeltype label) {
44+
void addPoint(const void *datapoint, labeltype label) {
4445

4546
int idx;
4647
{
@@ -84,8 +85,10 @@ namespace hnswlib {
8485
}
8586

8687

87-
std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
88+
std::priority_queue<std::pair<dist_t, labeltype >>
89+
searchKnn(const void *query_data, size_t k) const {
8890
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
91+
if (cur_element_count == 0) return topResults;
8992
for (int i = 0; i < k; i++) {
9093
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
9194
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
@@ -106,6 +109,24 @@ namespace hnswlib {
106109
return topResults;
107110
};
108111

112+
template <typename Comp>
113+
std::vector<std::pair<dist_t, labeltype>>
114+
searchKnn(const void* query_data, size_t k, Comp comp) {
115+
std::vector<std::pair<dist_t, labeltype>> result;
116+
if (cur_element_count == 0) return result;
117+
118+
auto ret = searchKnn(query_data, k);
119+
120+
while (!ret.empty()) {
121+
result.push_back(ret.top());
122+
ret.pop();
123+
}
124+
125+
std::sort(result.begin(), result.end(), comp);
126+
127+
return result;
128+
}
129+
109130
void saveIndex(const std::string &location) {
110131
std::ofstream output(location, std::ios::binary);
111132
std::streampos position;

hnswlib/hnswalg.h

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ namespace hnswlib {
150150
}
151151

152152
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
153-
searchBaseLayer(tableint ep_id, void *data_point, int layer) {
153+
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
154154
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
155155
vl_type *visited_array = vl->mass;
156156
vl_type visited_array_tag = vl->curV;
@@ -371,7 +371,7 @@ namespace hnswlib {
371371
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
372372
};
373373

374-
void mutuallyConnectNewElement(void *data_point, tableint cur_c,
374+
void mutuallyConnectNewElement(const void *data_point, tableint cur_c,
375375
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates,
376376
int level) {
377377

@@ -484,6 +484,8 @@ namespace hnswlib {
484484

485485

486486
std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
487+
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
488+
if (cur_element_count == 0) return top_candidates;
487489
tableint currObj = enterpoint_node_;
488490
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
489491

@@ -510,8 +512,6 @@ namespace hnswlib {
510512
}
511513
}
512514

513-
514-
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
515515
if (has_deletions_) {
516516
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
517517
ef_);
@@ -779,11 +779,11 @@ namespace hnswlib {
779779
*((unsigned short int*)(ptr))=*((unsigned short int *)&size);
780780
}
781781

782-
void addPoint(void *data_point, labeltype label) {
782+
void addPoint(const void *data_point, labeltype label) {
783783
addPoint(data_point, label,-1);
784784
}
785785

786-
tableint addPoint(void *data_point, labeltype label, int level) {
786+
tableint addPoint(const void *data_point, labeltype label, int level) {
787787
tableint cur_c = 0;
788788
{
789789
std::unique_lock <std::mutex> lock(cur_element_count_guard_);
@@ -895,7 +895,11 @@ namespace hnswlib {
895895
return cur_c;
896896
};
897897

898-
std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
898+
std::priority_queue<std::pair<dist_t, labeltype >>
899+
searchKnn(const void *query_data, size_t k) const {
900+
std::priority_queue<std::pair<dist_t, labeltype >> result;
901+
if (cur_element_count == 0) return result;
902+
899903
tableint currObj = enterpoint_node_;
900904
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
901905

@@ -934,18 +938,34 @@ namespace hnswlib {
934938
currObj, query_data, std::max(ef_, k));
935939
top_candidates.swap(top_candidates1);
936940
}
937-
std::priority_queue<std::pair<dist_t, labeltype >> results;
938941
while (top_candidates.size() > k) {
939942
top_candidates.pop();
940943
}
941944
while (top_candidates.size() > 0) {
942945
std::pair<dist_t, tableint> rez = top_candidates.top();
943-
results.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
946+
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
944947
top_candidates.pop();
945948
}
946-
return results;
949+
return result;
947950
};
948951

952+
template <typename Comp>
953+
std::vector<std::pair<dist_t, labeltype>>
954+
searchKnn(const void* query_data, size_t k, Comp comp) {
955+
std::vector<std::pair<dist_t, labeltype>> result;
956+
if (cur_element_count == 0) return result;
957+
958+
auto ret = searchKnn(query_data, k);
959+
960+
while (!ret.empty()) {
961+
result.push_back(ret.top());
962+
ret.pop();
963+
}
964+
965+
std::sort(result.begin(), result.end(), comp);
966+
967+
return result;
968+
}
949969

950970
};
951971

hnswlib/hnswlib.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,21 @@
2424
#endif
2525

2626
#include <queue>
27+
#include <vector>
2728

2829
#include <string.h>
2930

3031
namespace hnswlib {
3132
typedef size_t labeltype;
3233

34+
template <typename T>
35+
class pairGreater {
36+
public:
37+
bool operator()(const T& p1, const T& p2) {
38+
return p1.first > p2.first;
39+
}
40+
};
41+
3342
template<typename T>
3443
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
3544
out.write((char *) &podRef, sizeof(T));
@@ -60,8 +69,11 @@ namespace hnswlib {
6069
template<typename dist_t>
6170
class AlgorithmInterface {
6271
public:
63-
virtual void addPoint(void *datapoint, labeltype label)=0;
72+
virtual void addPoint(const void *datapoint, labeltype label)=0;
6473
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
74+
template <typename Comp>
75+
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp) {
76+
}
6577
virtual void saveIndex(const std::string &location)=0;
6678
virtual ~AlgorithmInterface(){
6779
}

0 commit comments

Comments
 (0)