@@ -74,26 +74,40 @@ static void declareColMajorMatrix(py::module& mod, std::string const& suffix) {
7474
7575}
7676
77+ template <typename T>
78+ static void declare_pyarray_to_matrix (py::module & m, const std::string& suffix) {
79+ m.def ((" pyarray_copyto_matrix" + suffix).c_str (),
80+ [](py::array_t <T> arr) -> ColMajorMatrix<T> {
81+ py::buffer_info info = arr.request ();
82+ if (info.ndim != 2 )
83+ throw std::runtime_error (" Number of dimensions must be two" );
84+ if (info.format != py::format_descriptor<T>::format ())
85+ throw std::runtime_error (" Mismatched buffer format!" );
86+
87+ auto data = std::unique_ptr<T[]>{new T[info.shape [0 ] * info.shape [1 ]]};
88+ auto r = ColMajorMatrix<T>(std::move (data), info.shape [0 ], info.shape [1 ]);
89+ return r;
90+ });
91+ }
92+
7793template <typename T>
7894static void declare_kmeans_query (py::module & m, const std::string& suffix) {
7995 m.def ((" kmeans_query_" + suffix).c_str (),
80- [](Ctx ctx,
81- const std::string& part_uri,
96+ [](const ColMajorMatrix<T>& parts,
8297 const ColMajorMatrix<float >& centroids,
8398 const ColMajorMatrix<float >& query_vectors,
8499 std::vector<uint64_t >& indices,
85- const std::string& id_uri ,
100+ std::vector<shuffled_ids_type>& ids ,
86101 size_t nprobe,
87102 size_t k_nn,
88103 bool nth,
89- size_t nthreads) -> ColMajorMatrix<size_t > {
104+ size_t nthreads) -> ColMajorMatrix<size_t > { // TODO change return type
90105 auto r = detail::ivf::qv_query_heap_infinite_ram<T>(
91- ctx,
92- part_uri,
106+ parts,
93107 centroids,
94108 query_vectors,
95109 indices,
96- id_uri ,
110+ ids ,
97111 nprobe,
98112 k_nn,
99113 nth,
@@ -102,22 +116,6 @@ static void declare_kmeans_query(py::module& m, const std::string& suffix) {
102116 }, py::keep_alive<1 ,2 >());
103117}
104118
105- template <typename T>
106- static void declare_pyarray_to_matrix (py::module & m, const std::string& suffix) {
107- m.def ((" pyarray_copyto_matrix" + suffix).c_str (),
108- [](py::array_t <T> arr) -> ColMajorMatrix<T> {
109- py::buffer_info info = arr.request ();
110- if (info.ndim != 2 )
111- throw std::runtime_error (" Number of dimensions must be two" );
112- if (info.format != py::format_descriptor<T>::format ())
113- throw std::runtime_error (" Mismatched buffer format!" );
114-
115- auto data = std::unique_ptr<T[]>{new T[info.shape [0 ] * info.shape [1 ]]};
116- auto r = ColMajorMatrix<T>(std::move (data), info.shape [0 ], info.shape [1 ]);
117- return r;
118- });
119- }
120-
121119
122120// Declarations for typed subclasses of ColMajorMatrix
123121template <typename P>
0 commit comments