@@ -871,16 +871,39 @@ class BFIndex {
871871 CustomFilterFunctor idFilter (filter);
872872 CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr ;
873873
874- ParallelFor (0 , rows, num_threads, [&](size_t row, size_t threadId) {
875- std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
876- (void *)items.data (row), k, p_idFilter);
877- for (int i = k - 1 ; i >= 0 ; i--) {
878- auto & result_tuple = result.top ();
879- data_numpy_d[row * k + i] = result_tuple.first ;
880- data_numpy_l[row * k + i] = result_tuple.second ;
881- result.pop ();
882- }
883- });
874+ if (!normalize) {
875+ ParallelFor (0 , rows, num_threads, [&](size_t row, size_t threadId) {
876+ std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
877+ (void *)items.data (row), k, p_idFilter);
878+ if (result.size () != k)
879+ throw std::runtime_error (
880+ " Cannot return the results in a contiguous 2D array. There are not enough elements." );
881+ for (int i = k - 1 ; i >= 0 ; i--) {
882+ auto & result_tuple = result.top ();
883+ data_numpy_d[row * k + i] = result_tuple.first ;
884+ data_numpy_l[row * k + i] = result_tuple.second ;
885+ result.pop ();
886+ }
887+ });
888+ } else {
889+ std::vector<float > norm_array (num_threads * features);
890+ ParallelFor (0 , rows, num_threads, [&](size_t row, size_t threadId) {
891+ size_t start_idx = threadId * dim;
892+ normalize_vector ((float *)items.data (row), norm_array.data () + start_idx);
893+
894+ std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
895+ (void *)(norm_array.data () + start_idx), k, p_idFilter);
896+ if (result.size () != k)
897+ throw std::runtime_error (
898+ " Cannot return the results in a contiguous 2D array. There are not enough elements." );
899+ for (int i = k - 1 ; i >= 0 ; i--) {
900+ auto & result_tuple = result.top ();
901+ data_numpy_d[row * k + i] = result_tuple.first ;
902+ data_numpy_l[row * k + i] = result_tuple.second ;
903+ result.pop ();
904+ }
905+ });
906+ }
884907 }
885908
886909 py::capsule free_when_done_l (data_numpy_l, [](void *f) {
0 commit comments