@@ -102,6 +102,22 @@ static void declare_kmeans_query(py::module& m, const std::string& suffix) {
102102 }, py::keep_alive<1 ,2 >());
103103}
104104
105+ template <typename T>
106+ static void declare_pyarray_to_matrix (py::module & m, const std::string& suffix) {
107+ m.def ((" pyarray_copyto_matrix" + suffix).c_str (),
108+ [](py::array_t <T> arr) -> ColMajorMatrix<T> {
109+ py::buffer_info info = arr.request ();
110+ if (info.ndim != 2 )
111+ throw std::runtime_error (" Number of dimensions must be two" );
112+ if (info.format != py::format_descriptor<T>::format ())
113+ throw std::runtime_error (" Mismatched buffer format!" );
114+
115+ auto data = std::unique_ptr<T[]>{new T[info.shape [0 ] * info.shape [1 ]]};
116+ auto r = ColMajorMatrix<T>(std::move (data), info.shape [0 ], info.shape [1 ]);
117+ return r;
118+ });
119+ }
120+
105121
106122// Declarations for typed subclasses of ColMajorMatrix
107123template <typename P>
@@ -133,41 +149,48 @@ PYBIND11_MODULE(_tiledbvspy, m) {
133149 }
134150 ));
135151
136- /* Vector */
152+ /* === Vector === */
153+
137154 declareVector<uint32_t >(m, " _u32" );
138155 declareVector<uint64_t >(m, " _u64" );
139156 declareVector<float >(m, " _f32" );
140157 declareVector<double >(m, " _f64" );
141158
142159 m.def (" read_vector_u32" , &read_vector<uint32_t >, " Read a vector from TileDB" );
143160 m.def (" read_vector_u64" , &read_vector<uint64_t >, " Read a vector from TileDB" );
161+
162+
144163 /* === Matrix === */
145164
146165 // template specializations
147- // declareTdbMatrix<float>(m, "_f32");
148-
149166 declareColMajorMatrix<uint8_t >(m, " _u8" );
150167 declareColMajorMatrix<float >(m, " _f32" );
151168 declareColMajorMatrix<double >(m, " _f64" );
152169 declareColMajorMatrix<int32_t >(m, " _i32" );
153170 declareColMajorMatrix<int64_t >(m, " _i64" );
154- // declareColMajorMatrix<uint64_t>(m, "_u64");
155- declareColMajorMatrix<size_t >(m, " _szt" );
171+ declareColMajorMatrix<uint32_t >(m, " _u32" );
172+ declareColMajorMatrix<uint64_t >(m, " _u64" );
173+ if constexpr (!std::is_same<uint64_t , unsigned long >::value) {
174+ // Required for a return type, but these types are equivalent on linux :/
175+ declareColMajorMatrix<unsigned long >(m, " _ul" );
176+ }
156177
157178 declareColMajorMatrixSubclass<tdbColMajorMatrix<uint8_t >>(
158179 m, " tdbColMajorMatrix" , " _u8" );
159- declareColMajorMatrixSubclass<tdbColMajorMatrix<size_t >>(
160- m, " tdbColMajorMatrix" , " _szt " );
180+ declareColMajorMatrixSubclass<tdbColMajorMatrix<uint64_t >>(
181+ m, " tdbColMajorMatrix" , " _u64 " );
161182 declareColMajorMatrixSubclass<tdbColMajorMatrix<float >>(
162183 m, " tdbColMajorMatrix" , " _f32" );
163184 declareColMajorMatrixSubclass<tdbColMajorMatrix<int32_t >>(
164185 m, " tdbColMajorMatrix" , " _i32" );
165186 declareColMajorMatrixSubclass<tdbColMajorMatrix<int64_t >>(
166187 m, " tdbColMajorMatrix" , " _i64" );
167- // declareColMajorMatrixSubclass<tdbColMajorMatrix<uint64_t>>(
168- // m, "tdbColMajorMatrix", "_u64");
169-
170188
189+ // Converters from pyarray to matrix
190+ declare_pyarray_to_matrix<uint8_t >(m, " _u8" );
191+ declare_pyarray_to_matrix<uint64_t >(m, " _u64" );
192+ declare_pyarray_to_matrix<float >(m, " _f32" );
193+ declare_pyarray_to_matrix<double >(m, " _f64" );
171194
172195 /* Query API */
173196
@@ -176,7 +199,7 @@ PYBIND11_MODULE(_tiledbvspy, m) {
176199 const ColMajorMatrix<float >& query_vectors,
177200 int k,
178201 bool nth,
179- size_t nthreads) {
202+ size_t nthreads) -> ColMajorMatrix< uint64_t > {
180203 auto r = detail::flat::vq_query_heap (data, query_vectors, k, nthreads);
181204 return r;
182205 });
@@ -186,13 +209,13 @@ PYBIND11_MODULE(_tiledbvspy, m) {
186209 const ColMajorMatrix<float >& query_vectors,
187210 int k,
188211 bool nth,
189- size_t nthreads) {
212+ size_t nthreads) -> ColMajorMatrix< uint64_t > {
190213 auto r = detail::flat::vq_query_heap (data, query_vectors, k, nthreads);
191214 return r;
192215 });
193216
194- m.def (" validate_top_k " ,
195- [](const ColMajorMatrix<size_t >& top_k,
217+ m.def (" validate_top_k_u64 " ,
218+ [](const ColMajorMatrix<uint64_t >& top_k,
196219 const ColMajorMatrix<int32_t >& ground_truth) -> bool {
197220 return validate_top_k (top_k, ground_truth);
198221 });
0 commit comments