Skip to content

Commit 39bc6af

Browse files
Add missing normalization check to BFIndex
1 parent 76f5aff commit 39bc6af

File tree

1 file changed

+33
-13
lines changed

1 file changed

+33
-13
lines changed

python_bindings/bindings.cpp

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -871,19 +871,39 @@ class BFIndex {
871871
CustomFilterFunctor idFilter(filter);
872872
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;
873873

874-
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
875-
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
876-
(void*)items.data(row), k, p_idFilter);
877-
if (result.size() != k)
878-
throw std::runtime_error(
879-
"Cannot return the results in a contiguous 2D array. There are not enough elements.");
880-
for (int i = k - 1; i >= 0; i--) {
881-
auto& result_tuple = result.top();
882-
data_numpy_d[row * k + i] = result_tuple.first;
883-
data_numpy_l[row * k + i] = result_tuple.second;
884-
result.pop();
885-
}
886-
});
874+
if (!normalize) {
875+
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
876+
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
877+
(void*)items.data(row), k, p_idFilter);
878+
if (result.size() != k)
879+
throw std::runtime_error(
880+
"Cannot return the results in a contiguous 2D array. There are not enough elements.");
881+
for (int i = k - 1; i >= 0; i--) {
882+
auto& result_tuple = result.top();
883+
data_numpy_d[row * k + i] = result_tuple.first;
884+
data_numpy_l[row * k + i] = result_tuple.second;
885+
result.pop();
886+
}
887+
});
888+
} else {
889+
std::vector<float> norm_array(num_threads * features);
890+
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
891+
size_t start_idx = threadId * dim;
892+
normalize_vector((float*)items.data(row), norm_array.data() + start_idx);
893+
894+
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
895+
(void*)(norm_array.data() + start_idx), k, p_idFilter);
896+
if (result.size() != k)
897+
throw std::runtime_error(
898+
"Cannot return the results in a contiguous 2D array. There are not enough elements.");
899+
for (int i = k - 1; i >= 0; i--) {
900+
auto& result_tuple = result.top();
901+
data_numpy_d[row * k + i] = result_tuple.first;
902+
data_numpy_l[row * k + i] = result_tuple.second;
903+
result.pop();
904+
}
905+
});
906+
}
887907
}
888908

889909
py::capsule free_when_done_l(data_numpy_l, [](void *f) {

0 commit comments

Comments
 (0)