Skip to content

Commit 7b2b3b7

Browse files
authored
Add IDs to FeatureVectorArray (#252)
What Here we add an option for the caller to pass an ids_uri to FeatureVectorArray. This can be used to keep track of external IDs alongside vectors. Testing Adds unit tests.
1 parent 69b8cbc commit 7b2b3b7

File tree

8 files changed

+663
-12
lines changed

8 files changed

+663
-12
lines changed

src/include/api/feature_vector_array.h

Lines changed: 151 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,19 @@
3838
#ifndef TILEDB_API_FEATURE_VECTOR_ARRAY_H
3939
#define TILEDB_API_FEATURE_VECTOR_ARRAY_H
4040

41+
#include <unordered_set>
4142
#include "api_defs.h"
4243
#include "concepts.h"
4344
#include "cpos.h"
4445
#include "detail/linalg/matrix.h"
46+
#include "detail/linalg/matrix_with_ids.h"
4547
#include "detail/linalg/tdb_helpers.h"
4648
#include "detail/linalg/tdb_matrix.h"
49+
#include "detail/linalg/tdb_matrix_with_ids.h"
4750
#include "scoring.h"
4851
#include "tdb_defs.h"
4952

53+
#include <type_traits>
5054
#include "utils/print_types.h"
5155

5256
class FeatureVectorArray {
@@ -66,11 +70,19 @@ class FeatureVectorArray {
6670
feature_type_ = tiledb::impl::type_to_tiledb<
6771
typename std::remove_cvref_t<T>::value_type>::tiledb_type;
6872
feature_size_ = datatype_to_size(feature_type_);
73+
74+
if constexpr (feature_vector_array_with_ids<
75+
std::remove_cvref_t<decltype(obj)>>) {
76+
ids_type_ = tiledb::impl::type_to_tiledb<
77+
typename std::remove_cvref_t<T>::ids_type>::tiledb_type;
78+
ids_size_ = datatype_to_size(ids_type_);
79+
}
6980
}
7081

7182
FeatureVectorArray(
7283
const tiledb::Context& ctx,
7384
const std::string& uri,
85+
const std::string& ids_uri = "",
7486
size_t num_vectors = 0) {
7587
auto array = tiledb_helpers::open_array(tdb_func__, ctx, uri, TILEDB_READ);
7688
feature_type_ = get_array_datatype(*array);
@@ -84,24 +96,60 @@ class FeatureVectorArray {
8496
* happen with either orientation, and so will work at the other end with
8597
* either orientation since we are just passing a pointer to the data.
8698
*/
87-
if (tdb_col_major_matrix_dispatch_table.find(feature_type_) ==
88-
tdb_col_major_matrix_dispatch_table.end()) {
89-
throw std::runtime_error("Unsupported attribute type");
99+
if (ids_uri.empty()) {
100+
if (tdb_col_major_matrix_dispatch_table.find(feature_type_) ==
101+
tdb_col_major_matrix_dispatch_table.end()) {
102+
throw std::runtime_error("Unsupported features attribute type");
103+
}
104+
vector_array = tdb_col_major_matrix_dispatch_table.at(feature_type_)(
105+
ctx, uri, num_vectors);
106+
} else {
107+
auto ids_array =
108+
tiledb_helpers::open_array(tdb_func__, ctx, ids_uri, TILEDB_READ);
109+
ids_type_ = get_array_datatype(*ids_array);
110+
array->close();
111+
ids_size_ = datatype_to_size(ids_type_);
112+
113+
auto type = std::tuple{feature_type_, ids_type_};
114+
if (tdb_col_major_matrix_with_ids_dispatch_table.find(type) ==
115+
tdb_col_major_matrix_with_ids_dispatch_table.end()) {
116+
throw std::runtime_error(
117+
"Unsupported attribute type for feature vector with ids");
118+
}
119+
vector_array = tdb_col_major_matrix_with_ids_dispatch_table.at(type)(
120+
ctx, uri, ids_uri, num_vectors);
90121
}
91-
vector_array = tdb_col_major_matrix_dispatch_table.at(feature_type_)(
92-
ctx, uri, num_vectors);
93122
(void)vector_array->load();
94123
}
95124

96-
FeatureVectorArray(size_t rows, size_t cols, const std::string type_string) {
125+
FeatureVectorArray(
126+
size_t rows,
127+
size_t cols,
128+
const std::string& type_string,
129+
const std::string& ids_type_string = "") {
97130
feature_type_ = string_to_datatype(type_string);
98131
feature_size_ = datatype_to_size(feature_type_);
99-
if (col_major_matrix_dispatch_table.find(feature_type_) ==
100-
col_major_matrix_dispatch_table.end()) {
101-
throw std::runtime_error("Unsupported attribute type");
132+
133+
if (ids_type_string.empty()) {
134+
if (col_major_matrix_dispatch_table.find(feature_type_) ==
135+
col_major_matrix_dispatch_table.end()) {
136+
throw std::runtime_error("Unsupported features attribute type");
137+
}
138+
vector_array =
139+
col_major_matrix_dispatch_table.at(feature_type_)(rows, cols);
140+
} else {
141+
ids_type_ = string_to_datatype(ids_type_string);
142+
ids_size_ = datatype_to_size(ids_type_);
143+
144+
auto type = std::tuple{feature_type_, ids_type_};
145+
if (col_major_matrix_with_ids_dispatch_table.find(type) ==
146+
col_major_matrix_with_ids_dispatch_table.end()) {
147+
throw std::runtime_error(
148+
"Unsupported attribute type for feature vector with ids");
149+
}
150+
vector_array =
151+
col_major_matrix_with_ids_dispatch_table.at(type)(rows, cols);
102152
}
103-
vector_array =
104-
col_major_matrix_dispatch_table.at(feature_type_)(rows, cols);
105153
}
106154

107155
// A FeatureVectorArray is always loaded
@@ -118,6 +166,10 @@ class FeatureVectorArray {
118166
return vector_array->data();
119167
}
120168

169+
[[nodiscard]] auto ids_data() const {
170+
return vector_array->ids_data();
171+
}
172+
121173
[[nodiscard]] auto extents() const {
122174
return _cpo::extents(*vector_array);
123175
}
@@ -130,6 +182,10 @@ class FeatureVectorArray {
130182
return _cpo::num_vectors(*vector_array);
131183
}
132184

185+
[[nodiscard]] auto num_ids() const {
186+
return _cpo::num_ids(*vector_array);
187+
}
188+
133189
[[nodiscard]] tiledb_datatype_t feature_type() const {
134190
return feature_type_;
135191
}
@@ -142,6 +198,18 @@ class FeatureVectorArray {
142198
return feature_size_;
143199
}
144200

201+
[[nodiscard]] tiledb_datatype_t ids_type() const {
202+
return ids_type_;
203+
}
204+
205+
[[nodiscard]] std::string ids_type_string() const {
206+
return datatype_to_string(ids_type_);
207+
}
208+
209+
[[nodiscard]] size_t ids_size() const {
210+
return ids_size_;
211+
}
212+
145213
/**
146214
* Non-type parameterized base class (for type erasure).
147215
*/
@@ -150,6 +218,8 @@ class FeatureVectorArray {
150218
[[nodiscard]] virtual size_t dimension() const = 0;
151219
[[nodiscard]] virtual size_t num_vectors() const = 0;
152220
[[nodiscard]] virtual void* data() const = 0;
221+
[[nodiscard]] virtual size_t num_ids() const = 0;
222+
[[nodiscard]] virtual const void* ids_data() const = 0;
153223
[[nodiscard]] virtual std::vector<size_t> extents() const = 0;
154224
[[nodiscard]] virtual bool load() = 0;
155225
};
@@ -166,12 +236,25 @@ class FeatureVectorArray {
166236
const tiledb::Context& ctx, const std::string& uri, size_t num_vectors)
167237
: impl_vector_array(ctx, uri, num_vectors) {
168238
}
239+
vector_array_impl(
240+
const tiledb::Context& ctx,
241+
const std::string& uri,
242+
const std::string& ids_uri,
243+
size_t num_vectors)
244+
: impl_vector_array(ctx, uri, ids_uri, num_vectors) {
245+
}
169246
vector_array_impl(size_t rows, size_t cols)
170247
: impl_vector_array(rows, cols) {
171248
}
172249
[[nodiscard]] void* data() const override {
173250
return _cpo::data(impl_vector_array);
174251
}
252+
[[nodiscard]] size_t num_ids() const override {
253+
return _cpo::num_ids(impl_vector_array);
254+
}
255+
[[nodiscard]] const void* ids_data() const override {
256+
return _cpo::ids(impl_vector_array).data();
257+
}
175258
[[nodiscard]] size_t dimension() const override {
176259
return _cpo::dimension(impl_vector_array);
177260
}
@@ -204,9 +287,32 @@ class FeatureVectorArray {
204287
static const tdb_col_major_matrix_table_type
205288
tdb_col_major_matrix_dispatch_table;
206289

290+
using col_major_matrix_with_ids_constructor_function =
291+
std::function<std::unique_ptr<vector_array_base>(size_t, size_t)>;
292+
using col_major_matrix_with_ids_table_type = std::map<
293+
std::tuple<tiledb_datatype_t, tiledb_datatype_t>,
294+
col_major_matrix_with_ids_constructor_function>;
295+
static const col_major_matrix_with_ids_table_type
296+
col_major_matrix_with_ids_dispatch_table;
297+
298+
using tdb_col_major_matrix_with_ids_constructor_function =
299+
std::function<std::unique_ptr<vector_array_base>(
300+
const tiledb::Context&,
301+
const std::string&,
302+
const std::string&,
303+
size_t)>;
304+
using tdb_col_major_matrix_with_ids_table_type = std::map<
305+
std::tuple<tiledb_datatype_t, tiledb_datatype_t>,
306+
tdb_col_major_matrix_with_ids_constructor_function>;
307+
static const tdb_col_major_matrix_with_ids_table_type
308+
tdb_col_major_matrix_with_ids_dispatch_table;
309+
207310
tiledb_datatype_t feature_type_{TILEDB_ANY};
208311
size_t feature_size_{0};
209312

313+
tiledb_datatype_t ids_type_{TILEDB_ANY};
314+
size_t ids_size_{0};
315+
210316
// @todo const????
211317
std::unique_ptr</*const*/ vector_array_base> vector_array;
212318
};
@@ -298,6 +404,40 @@ const FeatureVectorArray::tdb_col_major_matrix_table_type
298404
}},
299405
};
300406

407+
// clang-format off
408+
const FeatureVectorArray::col_major_matrix_with_ids_table_type FeatureVectorArray::col_major_matrix_with_ids_dispatch_table = {
409+
{{TILEDB_FLOAT32, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<float, uint32_t>>>(rows, cols); }},
410+
{{TILEDB_UINT8, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint8_t, uint32_t>>>(rows, cols); }},
411+
{{TILEDB_INT32, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int32_t, uint32_t>>>(rows, cols); }},
412+
{{TILEDB_UINT32, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint32_t, uint32_t>>>(rows, cols); }},
413+
{{TILEDB_INT64, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int64_t, uint32_t>>>(rows, cols); }},
414+
{{TILEDB_UINT64, TILEDB_UINT32},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint64_t, uint32_t>>>(rows, cols); }},
415+
416+
{{TILEDB_FLOAT32, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<float, uint64_t>>>(rows, cols); }},
417+
{{TILEDB_UINT8, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint8_t, uint64_t>>>(rows, cols); }},
418+
{{TILEDB_INT32, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int32_t, uint64_t>>>(rows, cols); }},
419+
{{TILEDB_UINT32, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint32_t, uint64_t>>>(rows, cols); }},
420+
{{TILEDB_INT64, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<int64_t, uint64_t>>>(rows, cols); }},
421+
{{TILEDB_UINT64, TILEDB_UINT64},[](size_t rows, size_t cols) { return std::make_unique<FeatureVectorArray::vector_array_impl<ColMajorMatrixWithIds<uint64_t, uint64_t>>>(rows, cols); }},
422+
};
423+
424+
const FeatureVectorArray::tdb_col_major_matrix_with_ids_table_type FeatureVectorArray::tdb_col_major_matrix_with_ids_dispatch_table = {
425+
{{TILEDB_FLOAT32, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<float, uint32_t>>>(ctx, uri, ids_uri, num_vectors);}},
426+
{{TILEDB_UINT8, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint8_t, uint32_t>>>(ctx, uri, ids_uri, num_vectors);}},
427+
{{TILEDB_INT32, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int32_t, uint32_t>>>(ctx, uri, ids_uri, num_vectors);}},
428+
{{TILEDB_UINT32, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint32_t, uint32_t>>>(ctx, uri, ids_uri, num_vectors);}},
429+
{{TILEDB_INT64, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int64_t, uint32_t>>>(ctx, uri, ids_uri, num_vectors);}},
430+
{{TILEDB_UINT64, TILEDB_UINT32},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint64_t, uint32_t>>>(ctx, uri, ids_uri, num_vectors);}},
431+
432+
{{TILEDB_FLOAT32, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<float, uint64_t>>>(ctx, uri, ids_uri, num_vectors);}},
433+
{{TILEDB_UINT8, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) { return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint8_t, uint64_t>>>(ctx, uri, ids_uri, num_vectors);}},
434+
{{TILEDB_INT32, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int32_t, uint64_t>>>(ctx, uri, ids_uri, num_vectors);}},
435+
{{TILEDB_UINT32, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint32_t, uint64_t>>>(ctx, uri, ids_uri, num_vectors);}},
436+
{{TILEDB_INT64, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<int64_t, uint64_t>>>(ctx, uri, ids_uri, num_vectors);}},
437+
{{TILEDB_UINT64, TILEDB_UINT64},[](const tiledb::Context& ctx, const std::string& uri, const std::string& ids_uri, size_t num_vectors) {return std::make_unique<FeatureVectorArray::vector_array_impl<tdbColMajorMatrixWithIds<uint64_t, uint64_t>>>(ctx, uri, ids_uri, num_vectors);}},
438+
};
439+
// clang-format on
440+
301441
using QueryVectorArray = FeatureVectorArray;
302442

303443
bool validate_top_k(const FeatureVectorArray& a, const FeatureVectorArray& b) {

src/include/concepts.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,12 @@ concept feature_vector_array = requires(D d, size_t n) {
158158
{ d[n] } -> feature_vector; // Maybe redundant
159159
};
160160

161+
template <class D>
162+
concept feature_vector_array_with_ids =
163+
feature_vector_array<D> && requires(D d) {
164+
{ d.ids() };
165+
};
166+
161167
/**
162168
* @brief A concept for contiguous vector ranges. The member function data()
163169
* returns a pointer to the underlying contiguous one-dimensional storage.

src/include/cpos.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ concept _member_num_cols = requires(T t) {
5959
{ t.num_cols() } -> semi_integral;
6060
};
6161

62+
template <class T>
63+
concept _member_num_ids = requires(T t) {
64+
{ t.num_ids() };
65+
};
66+
67+
template <class T>
68+
concept _member_ids = requires(T t) {
69+
{ t.ids() };
70+
};
71+
6272
template <class T>
6373
concept row_major = std::
6474
same_as<typename std::remove_cvref_t<T>::layout_policy, stdx::layout_right>;
@@ -235,6 +245,57 @@ inline namespace _cpo {
235245
inline constexpr auto data = _data::_fn{};
236246
} // namespace _cpo
237247

248+
// ----------------------------------------------------------------------------
249+
// num_ids CPO
250+
// ----------------------------------------------------------------------------
251+
namespace _num_ids {
252+
void num_ids(auto&) = delete;
253+
void num_ids(const auto&) = delete;
254+
255+
struct _fn {
256+
template <class T>
257+
requires(_member_num_ids<T>)
258+
constexpr auto operator()(T&& t) const noexcept {
259+
return t.num_ids();
260+
}
261+
262+
template <class T>
263+
requires(!_member_num_ids<T>)
264+
constexpr auto operator()(T&& t) const noexcept {
265+
return 0;
266+
}
267+
};
268+
} // namespace _num_ids
269+
inline namespace _cpo {
270+
inline constexpr auto num_ids = _num_ids::_fn{};
271+
} // namespace _cpo
272+
273+
// ----------------------------------------------------------------------------
274+
// ids CPO
275+
// @todo Figure out what is wrong with const
276+
// ----------------------------------------------------------------------------
277+
namespace _ids {
278+
void ids(auto&) = delete;
279+
void ids(const auto&) = delete;
280+
281+
struct _fn {
282+
template <class T>
283+
requires(_member_ids<T>)
284+
constexpr const auto& operator()(T&& t) const noexcept {
285+
return t.ids();
286+
}
287+
288+
template <class T>
289+
requires(!_member_ids<T>)
290+
constexpr const auto& operator()(T&& t) const noexcept {
291+
return std::vector<typename std::remove_cvref_t<T>::value_type>{};
292+
}
293+
};
294+
} // namespace _ids
295+
inline namespace _cpo {
296+
inline constexpr auto ids = _ids::_fn{};
297+
} // namespace _cpo
298+
238299
// ----------------------------------------------------------------------------
239300
// extents CPO
240301
// ----------------------------------------------------------------------------

0 commit comments

Comments
 (0)