3838#ifndef TILEDB_API_FEATURE_VECTOR_ARRAY_H
3939#define TILEDB_API_FEATURE_VECTOR_ARRAY_H
4040
41+ #include < unordered_set>
4142#include " api_defs.h"
4243#include " concepts.h"
4344#include " cpos.h"
4445#include " detail/linalg/matrix.h"
46+ #include " detail/linalg/matrix_with_ids.h"
4547#include " detail/linalg/tdb_helpers.h"
4648#include " detail/linalg/tdb_matrix.h"
49+ #include " detail/linalg/tdb_matrix_with_ids.h"
4750#include " scoring.h"
4851#include " tdb_defs.h"
4952
53+ #include < type_traits>
5054#include " utils/print_types.h"
5155
5256class FeatureVectorArray {
@@ -66,11 +70,19 @@ class FeatureVectorArray {
6670 feature_type_ = tiledb::impl::type_to_tiledb<
6771 typename std::remove_cvref_t <T>::value_type>::tiledb_type;
6872 feature_size_ = datatype_to_size (feature_type_);
73+
74+ if constexpr (feature_vector_array_with_ids<
75+ std::remove_cvref_t <decltype (obj)>>) {
76+ ids_type_ = tiledb::impl::type_to_tiledb<
77+ typename std::remove_cvref_t <T>::ids_type>::tiledb_type;
78+ ids_size_ = datatype_to_size (ids_type_);
79+ }
6980 }
7081
7182 FeatureVectorArray (
7283 const tiledb::Context& ctx,
7384 const std::string& uri,
85+ const std::string& ids_uri = " " ,
7486 size_t num_vectors = 0 ) {
7587 auto array = tiledb_helpers::open_array (tdb_func__, ctx, uri, TILEDB_READ);
7688 feature_type_ = get_array_datatype (*array);
@@ -84,24 +96,60 @@ class FeatureVectorArray {
8496 * happen with either orientation, and so will work at the other end with
8597 * either orientation since we are just passing a pointer to the data.
8698 */
87- if (tdb_col_major_matrix_dispatch_table.find (feature_type_) ==
88- tdb_col_major_matrix_dispatch_table.end ()) {
89- throw std::runtime_error (" Unsupported attribute type" );
99+ if (ids_uri.empty ()) {
100+ if (tdb_col_major_matrix_dispatch_table.find (feature_type_) ==
101+ tdb_col_major_matrix_dispatch_table.end ()) {
102+ throw std::runtime_error (" Unsupported features attribute type" );
103+ }
104+ vector_array = tdb_col_major_matrix_dispatch_table.at (feature_type_)(
105+ ctx, uri, num_vectors);
106+ } else {
107+ auto ids_array =
108+ tiledb_helpers::open_array (tdb_func__, ctx, ids_uri, TILEDB_READ);
109+ ids_type_ = get_array_datatype (*ids_array);
110+ array->close ();
111+ ids_size_ = datatype_to_size (ids_type_);
112+
113+ auto type = std::tuple{feature_type_, ids_type_};
114+ if (tdb_col_major_matrix_with_ids_dispatch_table.find (type) ==
115+ tdb_col_major_matrix_with_ids_dispatch_table.end ()) {
116+ throw std::runtime_error (
117+ " Unsupported attribute type for feature vector with ids" );
118+ }
119+ vector_array = tdb_col_major_matrix_with_ids_dispatch_table.at (type)(
120+ ctx, uri, ids_uri, num_vectors);
90121 }
91- vector_array = tdb_col_major_matrix_dispatch_table.at (feature_type_)(
92- ctx, uri, num_vectors);
93122 (void )vector_array->load ();
94123 }
95124
96- FeatureVectorArray (size_t rows, size_t cols, const std::string type_string) {
125+ FeatureVectorArray (
126+ size_t rows,
127+ size_t cols,
128+ const std::string& type_string,
129+ const std::string& ids_type_string = " " ) {
97130 feature_type_ = string_to_datatype (type_string);
98131 feature_size_ = datatype_to_size (feature_type_);
99- if (col_major_matrix_dispatch_table.find (feature_type_) ==
100- col_major_matrix_dispatch_table.end ()) {
101- throw std::runtime_error (" Unsupported attribute type" );
132+
133+ if (ids_type_string.empty ()) {
134+ if (col_major_matrix_dispatch_table.find (feature_type_) ==
135+ col_major_matrix_dispatch_table.end ()) {
136+ throw std::runtime_error (" Unsupported features attribute type" );
137+ }
138+ vector_array =
139+ col_major_matrix_dispatch_table.at (feature_type_)(rows, cols);
140+ } else {
141+ ids_type_ = string_to_datatype (ids_type_string);
142+ ids_size_ = datatype_to_size (ids_type_);
143+
144+ auto type = std::tuple{feature_type_, ids_type_};
145+ if (col_major_matrix_with_ids_dispatch_table.find (type) ==
146+ col_major_matrix_with_ids_dispatch_table.end ()) {
147+ throw std::runtime_error (
148+ " Unsupported attribute type for feature vector with ids" );
149+ }
150+ vector_array =
151+ col_major_matrix_with_ids_dispatch_table.at (type)(rows, cols);
102152 }
103- vector_array =
104- col_major_matrix_dispatch_table.at (feature_type_)(rows, cols);
105153 }
106154
107155 // A FeatureVectorArray is always loaded
@@ -118,6 +166,10 @@ class FeatureVectorArray {
118166 return vector_array->data ();
119167 }
120168
169+ [[nodiscard]] auto ids_data () const {
170+ return vector_array->ids_data ();
171+ }
172+
121173 [[nodiscard]] auto extents () const {
122174 return _cpo::extents (*vector_array);
123175 }
@@ -130,6 +182,10 @@ class FeatureVectorArray {
130182 return _cpo::num_vectors (*vector_array);
131183 }
132184
185+ [[nodiscard]] auto num_ids () const {
186+ return _cpo::num_ids (*vector_array);
187+ }
188+
133189 [[nodiscard]] tiledb_datatype_t feature_type () const {
134190 return feature_type_;
135191 }
@@ -142,6 +198,18 @@ class FeatureVectorArray {
142198 return feature_size_;
143199 }
144200
201+ [[nodiscard]] tiledb_datatype_t ids_type () const {
202+ return ids_type_;
203+ }
204+
205+ [[nodiscard]] std::string ids_type_string () const {
206+ return datatype_to_string (ids_type_);
207+ }
208+
209+ [[nodiscard]] size_t ids_size () const {
210+ return ids_size_;
211+ }
212+
145213 /* *
146214 * Non-type parameterized base class (for type erasure).
147215 */
@@ -150,6 +218,8 @@ class FeatureVectorArray {
150218 [[nodiscard]] virtual size_t dimension () const = 0;
151219 [[nodiscard]] virtual size_t num_vectors () const = 0;
152220 [[nodiscard]] virtual void * data () const = 0;
221+ [[nodiscard]] virtual size_t num_ids () const = 0;
222+ [[nodiscard]] virtual const void * ids_data () const = 0;
153223 [[nodiscard]] virtual std::vector<size_t > extents () const = 0;
154224 [[nodiscard]] virtual bool load () = 0;
155225 };
@@ -166,12 +236,25 @@ class FeatureVectorArray {
166236 const tiledb::Context& ctx, const std::string& uri, size_t num_vectors)
167237 : impl_vector_array(ctx, uri, num_vectors) {
168238 }
239+ vector_array_impl (
240+ const tiledb::Context& ctx,
241+ const std::string& uri,
242+ const std::string& ids_uri,
243+ size_t num_vectors)
244+ : impl_vector_array(ctx, uri, ids_uri, num_vectors) {
245+ }
169246 vector_array_impl (size_t rows, size_t cols)
170247 : impl_vector_array(rows, cols) {
171248 }
172249 [[nodiscard]] void * data () const override {
173250 return _cpo::data (impl_vector_array);
174251 }
252+ [[nodiscard]] size_t num_ids () const override {
253+ return _cpo::num_ids (impl_vector_array);
254+ }
255+ [[nodiscard]] const void * ids_data () const override {
256+ return _cpo::ids (impl_vector_array).data ();
257+ }
175258 [[nodiscard]] size_t dimension () const override {
176259 return _cpo::dimension (impl_vector_array);
177260 }
@@ -204,9 +287,32 @@ class FeatureVectorArray {
204287 static const tdb_col_major_matrix_table_type
205288 tdb_col_major_matrix_dispatch_table;
206289
290+ using col_major_matrix_with_ids_constructor_function =
291+ std::function<std::unique_ptr<vector_array_base>(size_t , size_t )>;
292+ using col_major_matrix_with_ids_table_type = std::map<
293+ std::tuple<tiledb_datatype_t , tiledb_datatype_t >,
294+ col_major_matrix_with_ids_constructor_function>;
295+ static const col_major_matrix_with_ids_table_type
296+ col_major_matrix_with_ids_dispatch_table;
297+
298+ using tdb_col_major_matrix_with_ids_constructor_function =
299+ std::function<std::unique_ptr<vector_array_base>(
300+ const tiledb::Context&,
301+ const std::string&,
302+ const std::string&,
303+ size_t )>;
304+ using tdb_col_major_matrix_with_ids_table_type = std::map<
305+ std::tuple<tiledb_datatype_t , tiledb_datatype_t >,
306+ tdb_col_major_matrix_with_ids_constructor_function>;
307+ static const tdb_col_major_matrix_with_ids_table_type
308+ tdb_col_major_matrix_with_ids_dispatch_table;
309+
207310 tiledb_datatype_t feature_type_{TILEDB_ANY};
208311 size_t feature_size_{0 };
209312
313+ tiledb_datatype_t ids_type_{TILEDB_ANY};
314+ size_t ids_size_{0 };
315+
210316 // @todo const????
211317 std::unique_ptr</* const*/ vector_array_base> vector_array;
212318};
@@ -298,6 +404,40 @@ const FeatureVectorArray::tdb_col_major_matrix_table_type
298404 }},
299405};
300406
407+ // clang-format off
408+ const FeatureVectorArray::col_major_matrix_with_ids_table_type FeatureVectorArray::col_major_matrix_with_ids_dispatch_table = {
409+ {{TILEDB_FLOAT32, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<float , uint32_t >>>(rows, cols); }},
410+ {{TILEDB_UINT8, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint8_t , uint32_t >>>(rows, cols); }},
411+ {{TILEDB_INT32, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int32_t , uint32_t >>>(rows, cols); }},
412+ {{TILEDB_UINT32, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint32_t , uint32_t >>>(rows, cols); }},
413+ {{TILEDB_INT64, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int64_t , uint32_t >>>(rows, cols); }},
414+ {{TILEDB_UINT64, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint64_t , uint32_t >>>(rows, cols); }},
415+
416+ {{TILEDB_FLOAT32, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<float , uint64_t >>>(rows, cols); }},
417+ {{TILEDB_UINT8, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint8_t , uint64_t >>>(rows, cols); }},
418+ {{TILEDB_INT32, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int32_t , uint64_t >>>(rows, cols); }},
419+ {{TILEDB_UINT32, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint32_t , uint64_t >>>(rows, cols); }},
420+ {{TILEDB_INT64, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int64_t , uint64_t >>>(rows, cols); }},
421+ {{TILEDB_UINT64, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint64_t , uint64_t >>>(rows, cols); }},
422+ };
423+
424+ const FeatureVectorArray::tdb_col_major_matrix_with_ids_table_type FeatureVectorArray::tdb_col_major_matrix_with_ids_dispatch_table = {
425+ {{TILEDB_FLOAT32, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<float , uint32_t >>>(ctx, uri, ids_uri, num_vectors);}},
426+ {{TILEDB_UINT8, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint8_t , uint32_t >>>(ctx, uri, ids_uri, num_vectors);}},
427+ {{TILEDB_INT32, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int32_t , uint32_t >>>(ctx, uri, ids_uri, num_vectors);}},
428+ {{TILEDB_UINT32, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint32_t , uint32_t >>>(ctx, uri, ids_uri, num_vectors);}},
429+ {{TILEDB_INT64, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int64_t , uint32_t >>>(ctx, uri, ids_uri, num_vectors);}},
430+ {{TILEDB_UINT64, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint64_t , uint32_t >>>(ctx, uri, ids_uri, num_vectors);}},
431+
432+ {{TILEDB_FLOAT32, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<float , uint64_t >>>(ctx, uri, ids_uri, num_vectors);}},
433+ {{TILEDB_UINT8, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint8_t , uint64_t >>>(ctx, uri, ids_uri, num_vectors);}},
434+ {{TILEDB_INT32, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int32_t , uint64_t >>>(ctx, uri, ids_uri, num_vectors);}},
435+ {{TILEDB_UINT32, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint32_t , uint64_t >>>(ctx, uri, ids_uri, num_vectors);}},
436+ {{TILEDB_INT64, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int64_t , uint64_t >>>(ctx, uri, ids_uri, num_vectors);}},
437+ {{TILEDB_UINT64, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint64_t , uint64_t >>>(ctx, uri, ids_uri, num_vectors);}},
438+ };
439+ // clang-format on
440+
301441using QueryVectorArray = FeatureVectorArray;
302442
303443bool validate_top_k (const FeatureVectorArray& a, const FeatureVectorArray& b) {
0 commit comments