@@ -107,27 +107,17 @@ class BruteforceSearch : public AlgorithmInterface<dist_t> {
107107 searchKnn (const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr ) const {
108108 assert (k <= cur_element_count);
109109 std::priority_queue<std::pair<dist_t , labeltype >> topResults;
110- if (cur_element_count == 0 ) return topResults ;
111- for (int i = 0 ; i < k ; i++) {
110+ dist_t lastdist = std::numeric_limits< dist_t >:: max () ;
111+ for (int i = 0 ; i < cur_element_count ; i++) {
112112 dist_t dist = fstdistfunc_ (query_data, data_ + size_per_element_ * i, dist_func_param_);
113- labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
114- if ((!isIdAllowed) || (*isIdAllowed)(label)) {
115- topResults.emplace (dist, label);
116- }
117- }
118- dist_t lastdist = topResults.empty () ? std::numeric_limits<dist_t >::max () : topResults.top ().first ;
119- for (int i = k; i < cur_element_count; i++) {
120- dist_t dist = fstdistfunc_ (query_data, data_ + size_per_element_ * i, dist_func_param_);
121- if (dist <= lastdist) {
113+ if (dist <= lastdist || topResults.size () < k) {
122114 labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
123115 if ((!isIdAllowed) || (*isIdAllowed)(label)) {
124116 topResults.emplace (dist, label);
125- }
126- if (topResults.size () > k)
127- topResults.pop ();
128-
129- if (!topResults.empty ()) {
130- lastdist = topResults.top ().first ;
117+ if (topResults.size () > k)
118+ topResults.pop ();
119+ if (!topResults.empty ())
120+ lastdist = topResults.top ().first ;
131121 }
132122 }
133123 }
0 commit comments