55#include < pybind11/stl.h>
66
77#include " linalg.h"
8- #include " ivf_index.h"
98#include " ivf_query.h"
109#include " flat_query.h"
1110
1211namespace py = pybind11;
1312using Ctx = tiledb::Context;
1413
1514bool global_debug = false ;
16- double global_time_of_interest;
1715
1816bool enable_stats = false ;
1917std::vector<json> core_stats;
@@ -103,6 +101,19 @@ static void declare_pyarray_to_matrix(py::module& m, const std::string& suffix)
103101 });
104102}
105103
104+ namespace {
105+ template <typename ...TArgs>
106+ py::tuple make_python_pair (std::tuple<TArgs...>&& arg) {
107+ static_assert (sizeof ...(TArgs) == 2 , " Must have exactly two arguments" );
108+
109+ return py::make_tuple<py::return_value_policy::automatic>(
110+ py::cast (std::get<0 >(arg), py::return_value_policy::move),
111+ py::cast (std::get<1 >(arg), py::return_value_policy::move)
112+ );
113+ }
114+
115+ }
116+
106117template <typename T, typename Id_Type = uint64_t >
107118static void declare_qv_query_heap_infinite_ram (py::module & m, const std::string& suffix) {
108119 m.def ((" qv_query_heap_infinite_ram_" + suffix).c_str (),
@@ -113,8 +124,7 @@ static void declare_qv_query_heap_infinite_ram(py::module& m, const std::string&
113124 std::vector<Id_Type>& ids,
114125 size_t nprobe,
115126 size_t k_nn,
116- bool nth,
117- size_t nthreads) -> ColMajorMatrix<size_t > { // TODO change return type
127+ size_t nthreads) -> py::tuple { // std::pair<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
118128
119129 auto r = detail::ivf::qv_query_heap_infinite_ram (
120130 parts,
@@ -124,9 +134,8 @@ static void declare_qv_query_heap_infinite_ram(py::module& m, const std::string&
124134 ids,
125135 nprobe,
126136 k_nn,
127- nth,
128137 nthreads);
129- return r ;
138+ return make_python_pair ( std::move (r)) ;
130139 }, py::keep_alive<1 ,2 >());
131140}
132141
@@ -142,8 +151,7 @@ static void declare_qv_query_heap_finite_ram(py::module& m, const std::string& s
142151 size_t nprobe,
143152 size_t k_nn,
144153 size_t upper_bound,
145- bool nth,
146- size_t nthreads) -> ColMajorMatrix<size_t > { // TODO change return type
154+ size_t nthreads) -> py::tuple { // std::tuple<ColMajorMatrix<float>, ColMajorMatrix<size_t>> { // TODO change return type
147155
148156 auto r = detail::ivf::qv_query_heap_finite_ram<T, Id_Type>(
149157 ctx,
@@ -155,9 +163,8 @@ static void declare_qv_query_heap_finite_ram(py::module& m, const std::string& s
155163 nprobe,
156164 k_nn,
157165 upper_bound,
158- nth,
159166 nthreads);
160- return r ;
167+ return make_python_pair ( std::move (r)) ;
161168 }, py::keep_alive<1 ,2 >());
162169}
163170
@@ -171,8 +178,7 @@ static void declare_nuv_query_heap_infinite_ram(py::module& m, const std::string
171178 std::vector<Id_Type>& ids,
172179 size_t nprobe,
173180 size_t k_nn,
174- bool nth,
175- size_t nthreads) -> ColMajorMatrix<size_t > { // TODO change return type
181+ size_t nthreads) -> std::tuple<ColMajorMatrix<float >, ColMajorMatrix<size_t >> { // TODO change return type
176182
177183 auto r = detail::ivf::nuv_query_heap_infinite_ram_reg_blocked (
178184 parts,
@@ -182,7 +188,6 @@ static void declare_nuv_query_heap_infinite_ram(py::module& m, const std::string
182188 ids,
183189 nprobe,
184190 k_nn,
185- nth,
186191 nthreads);
187192 return r;
188193 }, py::keep_alive<1 ,2 >());
@@ -200,8 +205,7 @@ static void declare_nuv_query_heap_finite_ram(py::module& m, const std::string&
200205 size_t nprobe,
201206 size_t k_nn,
202207 size_t upper_bound,
203- bool nth,
204- size_t nthreads) -> ColMajorMatrix<size_t > { // TODO change return type
208+ size_t nthreads) -> std::tuple<ColMajorMatrix<float >, ColMajorMatrix<size_t >> { // TODO change return type
205209
206210 auto r = detail::ivf::nuv_query_heap_finite_ram_reg_blocked<T, Id_Type>(
207211 ctx,
@@ -213,7 +217,6 @@ static void declare_nuv_query_heap_finite_ram(py::module& m, const std::string&
213217 nprobe,
214218 k_nn,
215219 upper_bound,
216- nth,
217220 nthreads);
218221 return r;
219222 }, py::keep_alive<1 ,2 >());
@@ -398,7 +401,7 @@ static void declare_vq_query_heap(py::module& m, const std::string& suffix) {
398401 ColMajorMatrix<float >& query_vectors,
399402 const std::vector<uint64_t > &ids,
400403 int k,
401- size_t nthreads) -> ColMajorMatrix<size_t > {
404+ size_t nthreads) -> std::tuple< ColMajorMatrix<float >, ColMajorMatrix< size_t > > {
402405 auto r = detail::flat::vq_query_heap (data, query_vectors, ids, k, nthreads);
403406 return r;
404407 });
@@ -411,7 +414,7 @@ static void declare_vq_query_heap_pyarray(py::module& m, const std::string& suff
411414 ColMajorMatrix<float >& query_vectors,
412415 const std::vector<uint64_t > &ids,
413416 int k,
414- size_t nthreads) -> ColMajorMatrix<size_t > {
417+ size_t nthreads) -> std::tuple< ColMajorMatrix<float >, ColMajorMatrix< size_t > > {
415418 auto r = detail::flat::vq_query_heap (data, query_vectors, ids, k, nthreads);
416419 return r;
417420 });
@@ -494,17 +497,17 @@ PYBIND11_MODULE(_tiledbvspy, m) {
494497 [](ColMajorMatrix<float >& data,
495498 ColMajorMatrix<float >& query_vectors,
496499 int k,
497- size_t nthreads) -> ColMajorMatrix<size_t > {
498- auto r = detail::flat::vq_query_nth (data, query_vectors, k, true , nthreads);
500+ size_t nthreads) -> std::tuple< ColMajorMatrix<float >, ColMajorMatrix< size_t > > {
501+ auto r = detail::flat::vq_query_heap (data, query_vectors, k, nthreads);
499502 return r;
500503 });
501504
502505 m.def (" query_vq_u8" ,
503506 [](tdbColMajorMatrix<uint8_t >& data,
504507 ColMajorMatrix<float >& query_vectors,
505508 int k,
506- size_t nthreads) -> ColMajorMatrix<size_t > {
507- auto r = detail::flat::vq_query_nth (data, query_vectors, k, true , nthreads);
509+ size_t nthreads) -> std::tuple< ColMajorMatrix<float >, ColMajorMatrix< size_t > > {
510+ auto r = detail::flat::vq_query_heap (data, query_vectors, k, nthreads);
508511 return r;
509512 });
510513
0 commit comments