Skip to content

Commit 93abb2b

Browse files
committed
- Apache 2 license
1 parent 02575b0 commit 93abb2b

File tree

7 files changed

+2891
-0
lines changed

7 files changed

+2891
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#pragma once
2+
#include <unordered_map>
3+
#include <fstream>
4+
#include <mutex>
5+
#include <algorithm>
6+
#include <assert.h>
7+
8+
namespace hnswlib {
9+
template<typename dist_t>
10+
class BruteforceSearch : public AlgorithmInterface<dist_t> {
11+
public:
12+
char *data_;
13+
size_t maxelements_;
14+
size_t cur_element_count;
15+
size_t size_per_element_;
16+
17+
size_t data_size_;
18+
DISTFUNC <dist_t> fstdistfunc_;
19+
void *dist_func_param_;
20+
std::mutex index_lock;
21+
22+
std::unordered_map<labeltype, size_t > dict_external_to_internal;
23+
24+
25+
BruteforceSearch(SpaceInterface <dist_t> *s)
26+
: data_(nullptr),
27+
maxelements_(0),
28+
cur_element_count(0),
29+
size_per_element_(0),
30+
data_size_(0),
31+
dist_func_param_(nullptr) {
32+
}
33+
34+
35+
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
36+
: data_(nullptr),
37+
maxelements_(0),
38+
cur_element_count(0),
39+
size_per_element_(0),
40+
data_size_(0),
41+
dist_func_param_(nullptr) {
42+
loadIndex(location, s);
43+
}
44+
45+
46+
BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
47+
maxelements_ = maxElements;
48+
data_size_ = s->get_data_size();
49+
fstdistfunc_ = s->get_dist_func();
50+
dist_func_param_ = s->get_dist_func_param();
51+
size_per_element_ = data_size_ + sizeof(labeltype);
52+
data_ = (char *) malloc(maxElements * size_per_element_);
53+
if (data_ == nullptr)
54+
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
55+
cur_element_count = 0;
56+
}
57+
58+
59+
~BruteforceSearch() {
60+
free(data_);
61+
}
62+
63+
64+
void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
65+
int idx;
66+
{
67+
std::unique_lock<std::mutex> lock(index_lock);
68+
69+
auto search = dict_external_to_internal.find(label);
70+
if (search != dict_external_to_internal.end()) {
71+
idx = search->second;
72+
} else {
73+
if (cur_element_count >= maxelements_) {
74+
throw std::runtime_error("The number of elements exceeds the specified limit\n");
75+
}
76+
idx = cur_element_count;
77+
dict_external_to_internal[label] = idx;
78+
cur_element_count++;
79+
}
80+
}
81+
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
82+
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
83+
}
84+
85+
86+
void removePoint(labeltype cur_external) {
87+
std::unique_lock<std::mutex> lock(index_lock);
88+
89+
auto found = dict_external_to_internal.find(cur_external);
90+
if (found == dict_external_to_internal.end()) {
91+
return;
92+
}
93+
94+
dict_external_to_internal.erase(found);
95+
96+
size_t cur_c = found->second;
97+
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
98+
dict_external_to_internal[label] = cur_c;
99+
memcpy(data_ + size_per_element_ * cur_c,
100+
data_ + size_per_element_ * (cur_element_count-1),
101+
data_size_+sizeof(labeltype));
102+
cur_element_count--;
103+
}
104+
105+
106+
std::priority_queue<std::pair<dist_t, labeltype >>
107+
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
108+
assert(k <= cur_element_count);
109+
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
110+
if (cur_element_count == 0) return topResults;
111+
for (int i = 0; i < k; i++) {
112+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
113+
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
114+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
115+
topResults.emplace(dist, label);
116+
}
117+
}
118+
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
119+
for (int i = k; i < cur_element_count; i++) {
120+
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
121+
if (dist <= lastdist) {
122+
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
123+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
124+
topResults.emplace(dist, label);
125+
}
126+
if (topResults.size() > k)
127+
topResults.pop();
128+
129+
if (!topResults.empty()) {
130+
lastdist = topResults.top().first;
131+
}
132+
}
133+
}
134+
return topResults;
135+
}
136+
137+
138+
void saveIndex(const std::string &location) {
139+
std::ofstream output(location, std::ios::binary);
140+
std::streampos position;
141+
142+
writeBinaryPOD(output, maxelements_);
143+
writeBinaryPOD(output, size_per_element_);
144+
writeBinaryPOD(output, cur_element_count);
145+
146+
output.write(data_, maxelements_ * size_per_element_);
147+
148+
output.close();
149+
}
150+
151+
152+
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
153+
std::ifstream input(location, std::ios::binary);
154+
std::streampos position;
155+
156+
readBinaryPOD(input, maxelements_);
157+
readBinaryPOD(input, size_per_element_);
158+
readBinaryPOD(input, cur_element_count);
159+
160+
data_size_ = s->get_data_size();
161+
fstdistfunc_ = s->get_dist_func();
162+
dist_func_param_ = s->get_dist_func_param();
163+
size_per_element_ = data_size_ + sizeof(labeltype);
164+
data_ = (char *) malloc(maxelements_ * size_per_element_);
165+
if (data_ == nullptr)
166+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
167+
168+
input.read(data_, maxelements_ * size_per_element_);
169+
170+
input.close();
171+
}
172+
};
173+
} // namespace hnswlib

0 commit comments

Comments
 (0)