Skip to content

Commit 3a3426f

Browse files
author
Nikos Papailiou
committed
Fix uint64 error
1 parent 187c831 commit 3a3426f

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def query(self, queries: np.ndarray, k, **kwargs):
7070
if res in updated_ids:
7171
internal_results_d[query_id, res_id] = MAX_FLOAT_32
7272
internal_results_i[query_id, res_id] = MAX_UINT64
73+
if (
74+
internal_results_d[query_id, res_id] == 0
75+
and internal_results_i[query_id, res_id] == 0
76+
):
77+
internal_results_d[query_id, res_id] = MAX_FLOAT_32
78+
internal_results_i[query_id, res_id] = MAX_UINT64
7379
res_id += 1
7480
query_id += 1
7581
sort_index = np.argsort(internal_results_d, axis=1)

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ static void declare_vq_query_heap(py::module& m, const std::string& suffix) {
402402
const std::vector<uint64_t> &ids,
403403
int k,
404404
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> {
405-
auto r = detail::flat::vq_query_heap(data, query_vectors, ids, k, nthreads);
405+
auto r = detail::flat::vq_query_heap<tdbColMajorMatrix<T>, ColMajorMatrix<float>, uint64_t>(data, query_vectors, ids, k, nthreads);
406406
return r;
407407
});
408408
}
@@ -415,7 +415,7 @@ static void declare_vq_query_heap_pyarray(py::module& m, const std::string& suff
415415
const std::vector<uint64_t> &ids,
416416
int k,
417417
size_t nthreads) -> std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> {
418-
auto r = detail::flat::vq_query_heap(data, query_vectors, ids, k, nthreads);
418+
auto r = detail::flat::vq_query_heap<ColMajorMatrix<T>, ColMajorMatrix<float>, uint64_t>(data, query_vectors, ids, k, nthreads);
419419
return r;
420420
});
421421
}

src/include/detail/flat/vq.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ auto vq_query_heap(
8383
unsigned nthreads) {
8484
// @todo Need to get the total number of queries, not just the first block
8585
// @todo Use Matrix here rather than vector of vectors
86-
std::vector<std::vector<fixed_min_pair_heap<float, unsigned>>> scores(
86+
std::vector<std::vector<fixed_min_pair_heap<float, Index>>> scores(
8787
nthreads,
88-
std::vector<fixed_min_pair_heap<float, unsigned>>(
89-
size(q), fixed_min_pair_heap<float, unsigned>(k_nn)));
88+
std::vector<fixed_min_pair_heap<float, Index>>(
89+
size(q), fixed_min_pair_heap<float, Index>(k_nn)));
9090

9191
unsigned size_q = size(q);
9292
auto par = stdx::execution::indexed_parallel_policy{nthreads};
@@ -184,10 +184,10 @@ auto vq_query_heap_tiled(
184184
unsigned nthreads) {
185185
// @todo Need to get the total number of queries, not just the first block
186186
// @todo Use Matrix here rather than vector of vectors
187-
std::vector<std::vector<fixed_min_pair_heap<float, unsigned>>> scores(
187+
std::vector<std::vector<fixed_min_pair_heap<float, Index>>> scores(
188188
nthreads,
189-
std::vector<fixed_min_pair_heap<float, unsigned>>(
190-
size(q), fixed_min_pair_heap<float, unsigned>(k_nn)));
189+
std::vector<fixed_min_pair_heap<float, Index>>(
190+
size(q), fixed_min_pair_heap<float, Index>(k_nn)));
191191

192192
unsigned size_q = size(q);
193193
auto par = stdx::execution::indexed_parallel_policy{nthreads};
@@ -261,10 +261,10 @@ auto vq_query_heap_2(
261261
unsigned nthreads) {
262262
// @todo Need to get the total number of queries, not just the first block
263263
// @todo Use Matrix here rather than vector of vectors
264-
std::vector<std::vector<fixed_min_pair_heap<float, size_t>>> scores(
264+
std::vector<std::vector<fixed_min_pair_heap<float, Index>>> scores(
265265
nthreads,
266-
std::vector<fixed_min_pair_heap<float, size_t>>(
267-
size(q), fixed_min_pair_heap<float, size_t>(k_nn)));
266+
std::vector<fixed_min_pair_heap<float, Index>>(
267+
size(q), fixed_min_pair_heap<float, Index>(k_nn)));
268268

269269
unsigned size_q = size(q);
270270
auto par = stdx::execution::indexed_parallel_policy{nthreads};

0 commit comments

Comments
 (0)