Skip to content

Commit dabaa1e

Browse files
authored
Support writing C++ type-erased indexes with a storage_version (#326)
1 parent bef3145 commit dabaa1e

File tree

11 files changed

+264
-52
lines changed

11 files changed

+264
-52
lines changed

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,13 @@ void init_type_erased_module(py::module_& m) {
283283
"write_index",
284284
[](IndexVamana& index,
285285
const tiledb::Context& ctx,
286-
const std::string& group_uri) {
287-
index.write_index(ctx, group_uri);
286+
const std::string& group_uri,
287+
const std::string& storage_version) {
288+
index.write_index(ctx, group_uri, storage_version);
288289
},
289290
py::arg("ctx"),
290-
py::arg("group_uri"))
291+
py::arg("group_uri"),
292+
py::arg("storage_version") = "")
291293
.def("feature_type_string", &IndexVamana::feature_type_string)
292294
.def("id_type_string", &IndexVamana::id_type_string)
293295
.def(

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def query_internal(
9797
return np.array(distances, copy=False), np.array(ids, copy=False)
9898

9999

100-
# TODO(paris): Pass more arguments to C++, i.e. storage_version.
101100
def create(
102101
uri: str,
103102
dimensions: int,
@@ -120,5 +119,5 @@ def create(
120119
)
121120
index.train(empty_vector)
122121
index.add(empty_vector)
123-
index.write_index(ctx, uri)
122+
index.write_index(ctx, uri, storage_version)
124123
return VamanaIndex(uri=uri, config=config, memory_budget=1000000)

src/include/api/ivf_flat_index.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,12 +427,14 @@ class IndexIVFFlat {
427427
}
428428

429429
void write_index(
430-
const tiledb::Context& ctx, const std::string& group_uri) const {
430+
const tiledb::Context& ctx,
431+
const std::string& group_uri,
432+
const std::string& storage_version = "") const {
431433
if (!index_) {
432434
throw std::runtime_error(
433435
"Cannot write_index() because there is no index.");
434436
}
435-
index_->write_index(ctx, group_uri);
437+
index_->write_index(ctx, group_uri, storage_version);
436438
}
437439

438440
constexpr auto dimension() const {
@@ -503,7 +505,9 @@ class IndexIVFFlat {
503505
virtual void remove(const IdVector& ids) const = 0;
504506

505507
virtual void write_index(
506-
const tiledb::Context& ctx, const std::string& group_uri) const = 0;
508+
const tiledb::Context& ctx,
509+
const std::string& group_uri,
510+
const std::string& storage_version) const = 0;
507511

508512
[[nodiscard]] virtual size_t dimension() const = 0;
509513

@@ -687,9 +691,11 @@ class IndexIVFFlat {
687691
// index_.update(vectors_uri, ids, options);
688692
}
689693

690-
void write_index(const tiledb::Context& ctx, const std::string& group_uri)
691-
const override {
692-
impl_index_.write_index(ctx, group_uri);
694+
void write_index(
695+
const tiledb::Context& ctx,
696+
const std::string& group_uri,
697+
const std::string& storage_version) const override {
698+
impl_index_.write_index(ctx, group_uri, storage_version);
693699
}
694700

695701
// WIP

src/include/api/vamana_index.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,14 @@ class IndexVamana {
233233
}
234234

235235
void write_index(
236-
const tiledb::Context& ctx, const std::string& group_uri) const {
236+
const tiledb::Context& ctx,
237+
const std::string& group_uri,
238+
const std::string& storage_version = "") const {
237239
if (!index_) {
238240
throw std::runtime_error(
239241
"Cannot write_index() because there is no index.");
240242
}
241-
index_->write_index(ctx, group_uri);
243+
index_->write_index(ctx, group_uri, storage_version);
242244
}
243245

244246
constexpr auto dimension() const {
@@ -287,7 +289,9 @@ class IndexVamana {
287289
std::optional<size_t> opt_L) = 0;
288290

289291
virtual void write_index(
290-
const tiledb::Context& ctx, const std::string& group_uri) const = 0;
292+
const tiledb::Context& ctx,
293+
const std::string& group_uri,
294+
const std::string& storage_version) const = 0;
291295

292296
[[nodiscard]] virtual size_t dimension() const = 0;
293297
};
@@ -390,9 +394,11 @@ class IndexVamana {
390394
}
391395
}
392396

393-
void write_index(const tiledb::Context& ctx, const std::string& group_uri)
394-
const override {
395-
impl_index_.write_index(ctx, group_uri);
397+
void write_index(
398+
const tiledb::Context& ctx,
399+
const std::string& group_uri,
400+
const std::string& storage_version) const override {
401+
impl_index_.write_index(ctx, group_uri, storage_version);
396402
}
397403

398404
size_t dimension() const override {

src/include/index/index_group.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class base_index_group {
196196
metadata_.load_metadata(read_group);
197197
if (!empty(version_) && metadata_.storage_version_ != version_) {
198198
throw std::runtime_error(
199-
"Version mismatch. Requested " + version_ + " but found " +
199+
"Version mismatch. Requested " + version_ + " but found " +
200200
metadata_.storage_version_);
201201
} else if (empty(version_)) {
202202
version_ = metadata_.storage_version_;

src/include/index/ivf_flat_index.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,18 @@ class ivf_flat_index {
737737
* all of it to a TileDB group. Since we have all of it in memory,
738738
* we write from the PartitionedMatrix base class.
739739
*
740-
* @param group_uri
741-
* @return bool indicating success or failure
740+
* @param group_uri The URI of the TileDB group where the index will be saved
741+
* @param storage_version The storage version to use. If empty, use the most
742+
* defult version.
743+
* @return Whether the write was successful
742744
*/
743745
auto write_index(
744-
const tiledb::Context& ctx, const std::string& group_uri) const {
746+
const tiledb::Context& ctx,
747+
const std::string& group_uri,
748+
const std::string& storage_version = "") const {
745749
// Write the group
746-
auto write_group =
747-
ivf_flat_index_group(*this, ctx, group_uri, TILEDB_WRITE, timestamp_);
750+
auto write_group = ivf_flat_index_group(
751+
*this, ctx, group_uri, TILEDB_WRITE, timestamp_, storage_version);
748752

749753
write_group.set_dimension(dimension_);
750754

src/include/index/vamana_index.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,8 @@ class vamana_index {
824824
/**
825825
* @brief Write the index to a TileDB group
826826
* @param group_uri The URI of the TileDB group where the index will be saved
827+
* @param storage_version The storage version to use. If empty, use the most
828+
* defult version.
827829
* @return Whether the write was successful
828830
*
829831
* The group consists of the original feature vectors, and the graph index,
@@ -836,12 +838,14 @@ class vamana_index {
836838
* the group?
837839
*/
838840
auto write_index(
839-
const tiledb::Context& ctx, const std::string& group_uri) const {
841+
const tiledb::Context& ctx,
842+
const std::string& group_uri,
843+
const std::string& storage_version = "") const {
840844
// metadata: dimension, ntotal, L, R, B, alpha_min, alpha_max, medoid
841845
// Save as a group: metadata, feature_vectors, graph edges, offsets
842846

843-
auto write_group =
844-
vamana_index_group(*this, ctx, group_uri, TILEDB_WRITE, timestamp_);
847+
auto write_group = vamana_index_group(
848+
*this, ctx, group_uri, TILEDB_WRITE, timestamp_, storage_version);
845849

846850
// @todo Make this table-driven
847851
write_group.set_dimension(dimension_);

src/include/test/unit_api_vamana_index.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,60 @@ TEST_CASE("api_vamana_index: read index and query", "[api_vamana_index]") {
441441
auto recall = ((double)intersections_a) / ((double)nt * k_nn);
442442
CHECK(recall == 1.0);
443443
}
444+
445+
TEST_CASE("api_vamana_index: storage_version", "[api_vamana_index]") {
446+
auto ctx = tiledb::Context{};
447+
using feature_type_type = uint8_t;
448+
using id_type_type = uint32_t;
449+
auto feature_type = "uint8";
450+
auto id_type = "uint32";
451+
auto adjacency_row_index_type = "uint32";
452+
size_t dimensions = 3;
453+
454+
std::string index_uri =
455+
(std::filesystem::temp_directory_path() / "api_vamana_index").string();
456+
tiledb::VFS vfs(ctx);
457+
if (vfs.is_dir(index_uri)) {
458+
vfs.remove_dir(index_uri);
459+
}
460+
461+
{
462+
// First we create the index with a storage_version.
463+
auto index = IndexVamana(std::make_optional<IndexOptions>(
464+
{{"feature_type", feature_type},
465+
{"id_type", id_type},
466+
{"adjacency_row_index_type", adjacency_row_index_type}}));
467+
468+
size_t num_vectors = 0;
469+
auto empty_training_vector_array =
470+
FeatureVectorArray(dimensions, num_vectors, feature_type, id_type);
471+
index.train(empty_training_vector_array);
472+
index.add(empty_training_vector_array);
473+
index.write_index(ctx, index_uri, "0.3");
474+
475+
CHECK(index.feature_type_string() == feature_type);
476+
CHECK(index.id_type_string() == id_type);
477+
CHECK(index.adjacency_row_index_type_string() == adjacency_row_index_type);
478+
}
479+
480+
{
481+
// Now make sure if we try to write it again with a different
482+
// storage_version, we throw.
483+
auto index = IndexVamana(ctx, index_uri);
484+
auto training = ColMajorMatrixWithIds<feature_type_type, id_type_type>{
485+
{{8, 6, 7}, {5, 3, 0}, {9, 5, 0}, {2, 7, 3}}, {10, 11, 12, 13}};
486+
487+
auto training_vector_array = FeatureVectorArray(training);
488+
index.train(training_vector_array);
489+
index.add(training_vector_array);
490+
491+
// Throw with the wrong version.
492+
CHECK_THROWS_WITH(
493+
index.write_index(ctx, index_uri, "0.4"),
494+
"Version mismatch. Requested 0.4 but found 0.3");
495+
// Succeed without a version.
496+
index.write_index(ctx, index_uri);
497+
// Succeed with the same version.
498+
index.write_index(ctx, index_uri, "0.3");
499+
}
500+
}

src/include/test/unit_flatpq_index.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,10 @@ TEST_CASE("flatpq_index: flatpq_index write and read", "[flatpq_index]") {
829829
tiledb::Context ctx;
830830
std::string flatpq_index_uri =
831831
(std::filesystem::temp_directory_path() / "tmp_flatpq_index").string();
832+
tiledb::VFS vfs(ctx);
833+
if (vfs.is_dir(flatpq_index_uri)) {
834+
vfs.remove_dir(flatpq_index_uri);
835+
}
832836
auto training_set =
833837
tdbColMajorMatrix<siftsmall_feature_type>(ctx, siftsmall_inputs_uri, 0);
834838
load(training_set);

src/include/test/unit_ivf_flat_group.cc

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ struct dummy_index {
9292
}
9393
};
9494

95-
// The catch2 check for exception doesn't seem to be working correctly
96-
// @todo Fix this
97-
#if 0
9895
TEST_CASE(
9996
"ivf_flat_group: read constructor for non-existent group",
10097
"[ivf_flat_group]") {
@@ -104,7 +101,6 @@ TEST_CASE(
104101
ivf_flat_index_group(dummy_index{}, ctx, "I dont exist"),
105102
"Group uri I dont exist does not exist.");
106103
}
107-
#endif
108104

109105
TEST_CASE("ivf_flat_group: write constructor - create", "[ivf_flat_group]") {
110106
std::string tmp_uri = (std::filesystem::temp_directory_path() /
@@ -119,7 +115,6 @@ TEST_CASE("ivf_flat_group: write constructor - create", "[ivf_flat_group]") {
119115

120116
ivf_flat_index_group x =
121117
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE);
122-
x.dump("Write constructor - create");
123118
}
124119

125120
TEST_CASE(
@@ -137,11 +132,9 @@ TEST_CASE(
137132

138133
ivf_flat_index_group x =
139134
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE);
140-
x.dump("Write constructor - create before open");
141135

142136
ivf_flat_index_group y =
143137
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE);
144-
x.dump("Write constructor - open");
145138
}
146139

147140
TEST_CASE(
@@ -159,11 +152,9 @@ TEST_CASE(
159152

160153
ivf_flat_index_group x =
161154
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE);
162-
x.dump("Write constructor - create before open");
163155

164156
ivf_flat_index_group y =
165157
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_READ);
166-
x.dump("Write constructor - open for read");
167158
}
168159

169160
TEST_CASE(
@@ -182,15 +173,12 @@ TEST_CASE(
182173

183174
ivf_flat_index_group x =
184175
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE);
185-
x.dump("Write constructor - create before open");
186176

187177
ivf_flat_index_group y =
188178
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE);
189-
x.dump("Write constructor - open for write");
190179

191180
ivf_flat_index_group z =
192181
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_READ);
193-
x.dump("Write constructor - open for read");
194182
}
195183

196184
TEST_CASE(
@@ -365,3 +353,81 @@ TEST_CASE(
365353
CHECK(x.get_temp_size() == expected_temp_size + offset);
366354
CHECK(x.get_dimension() == expected_dimension + offset);
367355
}
356+
357+
TEST_CASE("ivf_flat_group: storage version", "[ivf_flat_group]") {
358+
std::string tmp_uri =
359+
(std::filesystem::temp_directory_path() / "ivf_flat_group").string();
360+
361+
tiledb::Context ctx;
362+
tiledb::VFS vfs(ctx);
363+
if (vfs.is_dir(tmp_uri)) {
364+
vfs.remove_dir(tmp_uri);
365+
}
366+
367+
size_t expected_ingestion = 23094;
368+
size_t expected_base = 9234;
369+
size_t expected_partitions = 200;
370+
size_t expected_temp_size = 11;
371+
size_t expected_dimension = 19238;
372+
auto offset = 2345;
373+
374+
ivf_flat_index_group x =
375+
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE);
376+
377+
SECTION("0.3") {
378+
x = ivf_flat_index_group(
379+
dummy_index{}, ctx, tmp_uri, TILEDB_WRITE, 0, "0.3");
380+
}
381+
382+
SECTION("current_storage_version") {
383+
x = ivf_flat_index_group(
384+
dummy_index{}, ctx, tmp_uri, TILEDB_WRITE, 0, current_storage_version);
385+
}
386+
387+
x.set_ingestion_timestamp(expected_ingestion + offset);
388+
x.set_base_size(expected_base + offset);
389+
x.set_num_partitions(expected_partitions + offset);
390+
x.set_temp_size(expected_temp_size + offset);
391+
x.set_dimension(expected_dimension + offset);
392+
393+
CHECK(size(x.get_all_ingestion_timestamps()) == 1);
394+
CHECK(size(x.get_all_base_sizes()) == 1);
395+
CHECK(size(x.get_all_num_partitions()) == 1);
396+
CHECK(x.get_previous_ingestion_timestamp() == expected_ingestion + offset);
397+
CHECK(x.get_previous_base_size() == expected_base + offset);
398+
CHECK(x.get_previous_num_partitions() == expected_partitions + offset);
399+
CHECK(x.get_temp_size() == expected_temp_size + offset);
400+
CHECK(x.get_dimension() == expected_dimension + offset);
401+
}
402+
403+
TEST_CASE("ivf_flat_group: invalid storage version", "[ivf_flat_group]") {
404+
std::string tmp_uri =
405+
(std::filesystem::temp_directory_path() / "ivf_flat_group").string();
406+
407+
tiledb::Context ctx;
408+
tiledb::VFS vfs(ctx);
409+
if (vfs.is_dir(tmp_uri)) {
410+
vfs.remove_dir(tmp_uri);
411+
}
412+
CHECK_THROWS(ivf_flat_index_group(
413+
dummy_index{}, ctx, tmp_uri, TILEDB_WRITE, 0, "invalid"));
414+
}
415+
416+
TEST_CASE("ivf_flat_group: mismatched storage version", "[ivf_flat_group]") {
417+
std::string tmp_uri =
418+
(std::filesystem::temp_directory_path() / "ivf_flat_group").string();
419+
420+
tiledb::Context ctx;
421+
tiledb::VFS vfs(ctx);
422+
if (vfs.is_dir(tmp_uri)) {
423+
vfs.remove_dir(tmp_uri);
424+
}
425+
426+
ivf_flat_index_group x =
427+
ivf_flat_index_group(dummy_index{}, ctx, tmp_uri, TILEDB_WRITE, 0, "0.3");
428+
429+
CHECK_THROWS_WITH(
430+
ivf_flat_index_group(
431+
dummy_index{}, ctx, tmp_uri, TILEDB_WRITE, 0, "different_version"),
432+
"Version mismatch. Requested different_version but found 0.3");
433+
}

0 commit comments

Comments
 (0)