Skip to content

Commit 333eaeb

Browse files
Fix incorrect results in bruteforce with filter
1 parent d44bd5d commit 333eaeb

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

hnswlib/bruteforce.h

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,27 +107,17 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
107107
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
108108
assert(k <= cur_element_count);
109109
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
110-
if (cur_element_count == 0) return topResults;
111-
for (int i = 0; i < k; i++) {
110+
dist_t lastdist = std::numeric_limits<dist_t>::max();
111+
for (int i = 0; i < cur_element_count; i++) {
112112
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
113-
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
114-
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
115-
topResults.emplace(dist, label);
116-
}
117-
}
118-
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
119-
for (int i = k; i < cur_element_count; i++) {
120-
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
121-
if (dist <= lastdist) {
113+
if (dist <= lastdist || topResults.size() < k) {
122114
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
123115
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
124116
topResults.emplace(dist, label);
125-
}
126-
if (topResults.size() > k)
127-
topResults.pop();
128-
129-
if (!topResults.empty()) {
130-
lastdist = topResults.top().first;
117+
if (topResults.size() > k)
118+
topResults.pop();
119+
if (!topResults.empty())
120+
lastdist = topResults.top().first;
131121
}
132122
}
133123
}

0 commit comments

Comments
 (0)