Skip to content

Commit e089a2f

Browse files
author
Nikos Papailiou
committed
Fix empty results
1 parent 3e0d06b commit e089a2f

File tree

5 files changed

+27
-32
lines changed

5 files changed

+27
-32
lines changed

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ static void declare_qv_query_heap_infinite_ram(py::module& m, const std::string&
126126
size_t k_nn,
127127
size_t nthreads) -> py::tuple { //std::pair<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
128128

129-
auto r = detail::ivf::qv_query_heap_infinite_ram(
129+
auto r = detail::ivf::qv_query_heap_infinite_ram<Id_Type>(
130130
parts,
131131
centroids,
132132
query_vectors,

src/include/detail/flat/vq.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ auto vq_query_heap(
138138
}
139139

140140
consolidate_scores(scores);
141-
auto top_k = get_top_k_with_scores(scores, k_nn);
141+
auto top_k = get_top_k_with_scores<fixed_min_pair_heap<float, Index>, Index>(scores, k_nn);
142142

143143
return top_k;
144144
}
@@ -223,7 +223,7 @@ auto vq_query_heap_tiled(
223223
} while (load(db));
224224

225225
consolidate_scores(scores);
226-
auto top_k = get_top_k_with_scores(scores, k_nn);
226+
auto top_k = get_top_k_with_scores<fixed_min_pair_heap<float, Index>, Index>(scores, k_nn);
227227

228228
return top_k;
229229
}
@@ -300,7 +300,7 @@ auto vq_query_heap_2(
300300
} while (load(db));
301301

302302
consolidate_scores(scores);
303-
auto top_k = get_top_k_with_scores(scores, k_nn);
303+
auto top_k = get_top_k_with_scores<fixed_min_pair_heap<float, Index>, Index>(scores, k_nn);
304304

305305
return top_k;
306306
}

src/include/detail/ivf/qv.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ namespace detail::ivf {
7878
* Overload for already opened arrays. Since the array is already opened, we
7979
* don't need to specify its type with a template parameter.
8080
*/
81+
template <class ids_type = size_t>
8182
auto qv_query_heap_infinite_ram(
8283
auto&& partitioned_db,
8384
auto&& centroids,
@@ -113,7 +114,7 @@ auto qv_query_heap_infinite_ram(
113114
auto partitioned_db = tdbColMajorMatrix<T>(ctx, part_uri);
114115
auto partitioned_ids = read_vector<partitioned_ids_type>(ctx, id_uri);
115116

116-
return qv_query_heap_infinite_ram(
117+
return qv_query_heap_infinite_ram<partitioned_ids_type>(
117118
partitioned_db,
118119
centroids,
119120
q,
@@ -148,6 +149,7 @@ auto qv_query_heap_infinite_ram(
148149
* @param nthreads How many threads to use for parallel execution
149150
* @return The indices of the top_k neighbors for each query vector
150151
*/
152+
template <class ids_type = size_t>
151153
auto qv_query_heap_infinite_ram(
152154
const std::string& part_uri,
153155
auto&& centroids,
@@ -158,7 +160,7 @@ auto qv_query_heap_infinite_ram(
158160
size_t k_nn,
159161
size_t nthreads) {
160162
tiledb::Context ctx;
161-
return qv_query_heap_infinite_ram(
163+
return qv_query_heap_infinite_ram<ids_type>(
162164
ctx, part_uri, centroids, q, indices, id_uri, nprobe, k_nn, nthreads);
163165
}
164166

@@ -188,6 +190,7 @@ auto qv_query_heap_infinite_ram(
188190
* @return The indices of the top_k neighbors for each query vector
189191
*/
190192
// @todo We should still order the queries so partitions are searched in order
193+
template <class ids_type>
191194
auto qv_query_heap_infinite_ram(
192195
auto&& partitioned_db,
193196
auto&& centroids,
@@ -221,8 +224,8 @@ auto qv_query_heap_infinite_ram(
221224
auto top_centroids =
222225
detail::flat::qv_query_heap_0(centroids, q, nprobe, nthreads);
223226

224-
auto min_scores = std::vector<fixed_min_pair_heap<float, size_t>>(
225-
size(q), fixed_min_pair_heap<float, size_t>(k_nn));
227+
auto min_scores = std::vector<fixed_min_pair_heap<float, ids_type>>(
228+
size(q), fixed_min_pair_heap<float, ids_type>(k_nn));
226229

227230
// Parallelizing over q is not going to be very efficient
228231
{
@@ -244,7 +247,7 @@ auto qv_query_heap_infinite_ram(
244247
});
245248
}
246249

247-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
250+
auto top_k = get_top_k_with_scores<fixed_min_pair_heap<float, ids_type>, ids_type>(min_scores, k_nn);
248251
return top_k;
249252
}
250253

src/include/scoring.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include <cmath>
4848
#include <future>
4949
#include <iostream>
50+
#include <limits>
5051
#include <memory>
5152
#include <numeric>
5253
#include <queue>
@@ -59,9 +60,6 @@
5960
#include "utils/fixed_min_heap.h"
6061
#include "utils/timer.h"
6162

62-
63-
64-
6563
// ----------------------------------------------------------------------------
6664
// Helper utilities
6765
//----------------------------------------------------------------------------
@@ -292,7 +290,7 @@ inline auto get_top_k(std::vector<std::vector<Heap>>& scores, size_t k_nn) {
292290
// ----------------------------------------------------------------------------
293291
// Functions for computing top k neighbors with scores
294292
// ----------------------------------------------------------------------------
295-
293+
template <class Index = size_t, class score_type = float>
296294
inline void get_top_k_with_scores_from_heap(
297295
auto&& min_scores, auto&& top_k, auto&& top_k_scores) {
298296
std::sort_heap(begin(min_scores), end(min_scores), [](auto&& a, auto&& b) {
@@ -306,6 +304,10 @@ inline void get_top_k_with_scores_from_heap(
306304
begin(min_scores), end(min_scores), begin(top_k), ([](auto&& e) {
307305
return std::get<1>(e);
308306
}));
307+
for (size_t i = min_scores.size(); i < top_k.size(); ++i) {
308+
top_k[i] = std::numeric_limits<Index>::max();
309+
top_k_scores[i] = std::numeric_limits<score_type>::max();
310+
}
309311
}
310312

311313
// Overload for one-d scores
@@ -320,7 +322,7 @@ inline auto get_top_k_with_scores(std::vector<Heap>& scores, size_t k_nn) {
320322
ColMajorMatrix<score_type> top_scores(k_nn, num_queries);
321323

322324
for (size_t j = 0; j < num_queries; ++j) {
323-
get_top_k_with_scores_from_heap(scores[j], top_k[j], top_scores[j]);
325+
get_top_k_with_scores_from_heap<Index, score_type>(scores[j], top_k[j], top_scores[j]);
324326
}
325327
return std::make_tuple(std::move(top_scores), std::move(top_k));
326328
}

src/include/utils/fixed_min_heap.h

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,8 @@ class fixed_min_set_heap_1 : public std::vector<T> {
5757

5858
void insert(T const& x) {
5959
if (Base::size() < max_size) {
60-
Base::push_back(x);
61-
// std::push_heap(begin(*this), end(*this), std::less<T>());
62-
if (Base::size() == max_size) {
63-
std::make_heap(begin(*this), end(*this), std::less<T>());
64-
}
60+
this->push_back(x);
61+
std::push_heap(begin(*this), end(*this), std::less<T>());
6562
} else if (x < this->front()) {
6663
std::pop_heap(begin(*this), end(*this), std::less<T>());
6764
this->pop_back();
@@ -91,12 +88,8 @@ class fixed_min_set_heap_2 : public std::vector<T> {
9188

9289
void insert(T const& x) {
9390
if (Base::size() < max_size) {
94-
Base::push_back(x);
95-
// std::push_heap(begin(*this), end(*this), std::less<T>());
96-
if (Base::size() == max_size) {
97-
// std::make_heap(begin(*this), end(*this), std::less<T>());
98-
std::make_heap(begin(*this), end(*this));
99-
}
91+
this->push_back(x);
92+
std::push_heap(begin(*this), end(*this));
10093
} else if (x < this->front()) {
10194
// std::pop_heap(begin(*this), end(*this), std::less<T>());
10295
std::pop_heap(begin(*this), end(*this));
@@ -138,13 +131,10 @@ class fixed_min_pair_heap : public std::vector<std::tuple<T, U>> {
138131

139132
void insert(const T& x, const U& y) {
140133
if (Base::size() < max_size) {
141-
Base::emplace_back(x, y);
142-
// std::push_heap(begin(*this), end(*this), std::less<T>());
143-
if (Base::size() == max_size) {
144-
std::make_heap(begin(*this), end(*this), [&](auto& a, auto& b) {
145-
return std::get<0>(a) < std::get<0>(b);
146-
});
147-
}
134+
this->emplace_back(x, y);
135+
std::push_heap(begin(*this), end(*this), [&](auto& a, auto& b) {
136+
return std::get<0>(a) < std::get<0>(b);
137+
});
148138
} else if (x < std::get<0>(this->front())) {
149139
std::pop_heap(begin(*this), end(*this), [&](auto& a, auto& b) {
150140
return std::get<0>(a) < std::get<0>(b);

0 commit comments

Comments
 (0)