Skip to content

Commit 5a7c478

Browse files
authored
Fix TileDB URIs with type-erased indexes, add Vamana test to test_cloud.py (#330)
1 parent dabaa1e commit 5a7c478

File tree

9 files changed

+68
-32
lines changed

9 files changed

+68
-32
lines changed

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tiledb.vector_search.module import *
99
from tiledb.vector_search.storage_formats import STORAGE_VERSION
1010
from tiledb.vector_search.storage_formats import storage_formats
11+
from tiledb.vector_search.storage_formats import validate_storage_version
1112

1213
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
1314
INDEX_TYPE = "VAMANA"
@@ -106,6 +107,7 @@ def create(
106107
**kwargs,
107108
) -> VamanaIndex:
108109
warnings.warn("The Vamana index is not yet supported, please use with caution.")
110+
validate_storage_version(storage_version)
109111
ctx = vspy.Ctx(config)
110112
index = vspy.IndexVamana(
111113
feature_type=np.dtype(vector_type).name,

apis/python/test/test_cloud.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def setUpClass(cls):
2727
rand_name = random_name("vector_search")
2828
test_path = f"tiledb://{namespace}/{storage_path}/{rand_name}"
2929
cls.flat_index_uri = f"{test_path}/test_flat_array"
30+
cls.vamana_index_uri = f"{test_path}/vamana_array"
3031
cls.ivf_flat_index_uri = f"{test_path}/test_ivf_flat_array"
3132
cls.ivf_flat_random_sampling_index_uri = (
3233
f"{test_path}/test_ivf_flat_random_sampling_array"
@@ -35,6 +36,7 @@ def setUpClass(cls):
3536
@classmethod
3637
def tearDownClass(cls):
3738
vs.Index.delete_index(uri=cls.flat_index_uri, config=tiledb.cloud.Config())
39+
vs.Index.delete_index(uri=cls.vamana_index_uri, config=tiledb.cloud.Config())
3840
vs.Index.delete_index(uri=cls.ivf_flat_index_uri, config=tiledb.cloud.Config())
3941
vs.Index.delete_index(
4042
uri=cls.ivf_flat_random_sampling_index_uri, config=tiledb.cloud.Config()
@@ -70,6 +72,31 @@ def test_cloud_flat(self):
7072
_, result_i = index.query(queries, k=k)
7173
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
7274

75+
def test_cloud_vamana(self):
76+
source_uri = "tiledb://TileDB-Inc/sift_10k"
77+
queries_uri = siftsmall_query_file
78+
gt_uri = siftsmall_groundtruth_file
79+
index_uri = CloudTests.vamana_index_uri
80+
k = 100
81+
nqueries = 100
82+
83+
load_fvecs(queries_uri)
84+
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)
85+
86+
vs.ingest(
87+
index_type="VAMANA",
88+
index_uri=index_uri,
89+
source_uri=source_uri,
90+
input_vectors_per_work_item=5000,
91+
config=tiledb.cloud.Config().dict(),
92+
# TODO(paris): Fix and then change to Mode.BATCH.
93+
mode=Mode.LOCAL,
94+
)
95+
96+
# TODO(paris): Fix error from loading this URI and then re-enable, and add the rest of the test.
97+
# tiledb_index_uri = groups.info(index_uri).tiledb_uri
98+
# vs.vamana_index.VamanaIndex(uri=tiledb_index_uri)
99+
73100
def test_cloud_ivf_flat(self):
74101
source_uri = "tiledb://TileDB-Inc/sift_10k"
75102
queries_uri = siftsmall_query_file

src/include/detail/linalg/tdb_helpers.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include <tiledb/array.h>
3838
#include <tiledb/tiledb>
3939
#include "stats.h"
40+
#include "tiledb/group_experimental.h"
4041

4142
namespace tiledb_helpers {
4243

@@ -81,6 +82,19 @@ inline void submit_query(
8182
query.submit();
8283
}
8384

85+
// Adds an object to a group. Automatically infers whether to use a relative
86+
// path or absolute path. NOTE(paris): We use absolute paths for tileDB URIs
87+
// because of a bug tracked in SC39197, once that is fixed everything can use
88+
// relative paths.
89+
inline void add_to_group(
90+
tiledb::Group& group, const std::string& uri, const std::string& name) {
91+
if (uri.find("tiledb://") == 0) {
92+
group.add_member(uri, false, name);
93+
} else {
94+
group.add_member(name, true, name);
95+
}
96+
}
97+
8498
} // namespace tiledb_helpers
8599

86100
#endif

src/include/index/flatpq_index.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -653,11 +653,11 @@ class flatpq_index {
653653

654654
auto centroids_uri = group_uri + "/centroids";
655655
write_matrix(ctx, centroids_, centroids_uri);
656-
write_group.add_member("centroids", true, "centroids");
656+
tiledb_helpers::add_to_group(write_group, centroids_uri, "centroids");
657657

658658
auto pq_vectors_uri = group_uri + "/pq_vectors";
659659
write_matrix(ctx, pq_vectors_, pq_vectors_uri);
660-
write_group.add_member("pq_vectors", true, "pq_vectors");
660+
tiledb_helpers::add_to_group(write_group, pq_vectors_uri, "pq_vectors");
661661

662662
for (size_t subspace = 0; subspace < num_subspaces_; ++subspace) {
663663
std::ostringstream oss;
@@ -666,8 +666,8 @@ class flatpq_index {
666666

667667
auto distance_table_uri = group_uri + "/distance_table_" + number;
668668
write_matrix(ctx, distance_tables_[subspace], distance_table_uri);
669-
write_group.add_member(
670-
"distance_table_" + number, true, "distance_table_" + number);
669+
tiledb_helpers::add_to_group(
670+
write_group, distance_table_uri, "distance_table_" + number);
671671
}
672672
write_group.close();
673673
return true;

src/include/index/index_group.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ class base_index_group {
182182
* @param ctx
183183
*/
184184
void init_for_open(const tiledb::Config& cfg) {
185-
tiledb::VFS vfs(cached_ctx_);
186-
if (!vfs.is_dir(group_uri_)) {
185+
if (!exists(cached_ctx_)) {
187186
throw std::runtime_error(
188187
"Group uri " + std::string(group_uri_) + " does not exist.");
189188
}
@@ -264,9 +263,7 @@ class base_index_group {
264263
* @param version
265264
*/
266265
void open_for_write(const tiledb::Config& cfg) {
267-
tiledb::VFS vfs(cached_ctx_);
268-
269-
if (vfs.is_dir(group_uri_)) {
266+
if (exists(cached_ctx_)) {
270267
/** Load the current group metadata */
271268
init_for_open(cfg);
272269
if (index_timestamp_ < metadata_.ingestion_timestamps_.back()) {

src/include/index/ivf_flat_group.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,8 @@ class ivf_flat_index_group
197197
this->get_dimension(),
198198
default_tile_extent,
199199
default_compression);
200-
// write_group.add_member(centroids_uri(), true, centroids_array_name());
201-
write_group.add_member(
202-
centroids_array_name(), true, centroids_array_name());
200+
tiledb_helpers::add_to_group(
201+
write_group, centroids_uri(), centroids_array_name());
203202

204203
create_empty_for_matrix<
205204
typename index_type::feature_type,
@@ -211,22 +210,20 @@ class ivf_flat_index_group
211210
this->get_dimension(),
212211
default_tile_extent,
213212
default_compression);
214-
// write_group.add_member(parts_uri(), true, parts_array_name());
215-
write_group.add_member(parts_array_name(), true, parts_array_name());
213+
tiledb_helpers::add_to_group(write_group, parts_uri(), parts_array_name());
216214

217215
create_empty_for_vector<typename index_type::id_type>(
218216
cached_ctx_, ids_uri(), default_domain, tile_size, default_compression);
219-
// write_group.add_member(ids_uri(), true, ids_array_name());
220-
write_group.add_member(ids_array_name(), true, ids_array_name());
217+
tiledb_helpers::add_to_group(write_group, ids_uri(), ids_array_name());
221218

222219
create_empty_for_vector<typename index_type::indices_type>(
223220
cached_ctx_,
224221
indices_uri(),
225222
default_domain,
226223
default_tile_extent,
227224
default_compression);
228-
// write_group.add_member(indices_uri(), true, indices_array_name());
229-
write_group.add_member(indices_array_name(), true, indices_array_name());
225+
tiledb_helpers::add_to_group(
226+
write_group, indices_uri(), indices_array_name());
230227

231228
// Store the metadata if all of the arrays were created successfully
232229
metadata_.store_metadata(write_group);

src/include/index/ivf_flat_index.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,6 @@ class ivf_flat_index {
792792
const std::string& parts_uri,
793793
const std::string& ids_uri,
794794
const std::string& indices_uri) const {
795-
tiledb::VFS vfs(ctx);
796-
797795
write_matrix(ctx, centroids_, centroids_uri, 0, true);
798796
write_matrix(ctx, *partitioned_vectors_, parts_uri, 0, true);
799797
write_vector(ctx, partitioned_vectors_->ids(), ids_uri, 0, true);

src/include/index/vamana_group.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#ifndef TILEDB_VAMANA_GROUP_H
3333
#define TILEDB_VAMANA_GROUP_H
3434

35+
#include "detail/linalg/tdb_helpers.h"
3536
#include "index/index_defs.h"
3637
#include "index/index_group.h"
3738
#include "index/vamana_metadata.h"
@@ -280,45 +281,45 @@ class vamana_index_group : public base_index_group<vamana_index_group<Index>> {
280281
this->get_dimension(),
281282
default_tile_extent,
282283
default_compression);
283-
write_group.add_member(
284-
feature_vectors_array_name(), true, feature_vectors_array_name());
284+
tiledb_helpers::add_to_group(
285+
write_group, feature_vectors_uri(), feature_vectors_array_name());
285286

286287
create_empty_for_vector<typename index_type::id_type>(
287288
cached_ctx_,
288289
feature_vector_ids_uri(),
289290
default_domain,
290291
tile_size,
291292
default_compression);
292-
write_group.add_member(
293-
feature_vector_ids_name(), true, feature_vector_ids_name());
293+
tiledb_helpers::add_to_group(
294+
write_group, feature_vector_ids_uri(), feature_vector_ids_name());
294295

295296
create_empty_for_vector<typename index_type::score_type>(
296297
cached_ctx_,
297298
adjacency_scores_uri(),
298299
default_domain,
299300
tile_size,
300301
default_compression);
301-
write_group.add_member(
302-
adjacency_scores_array_name(), true, adjacency_scores_array_name());
302+
tiledb_helpers::add_to_group(
303+
write_group, adjacency_scores_uri(), adjacency_scores_array_name());
303304

304305
create_empty_for_vector<typename index_type::id_type>(
305306
cached_ctx_,
306307
adjacency_ids_uri(),
307308
default_domain,
308309
tile_size,
309310
default_compression);
310-
write_group.add_member(
311-
adjacency_ids_array_name(), true, adjacency_ids_array_name());
311+
tiledb_helpers::add_to_group(
312+
write_group, adjacency_ids_uri(), adjacency_ids_array_name());
312313

313314
create_empty_for_vector<typename index_type::id_type>(
314315
cached_ctx_,
315316
adjacency_row_index_uri(),
316317
default_domain,
317318
tile_size,
318319
default_compression);
319-
write_group.add_member(
320-
adjacency_row_index_array_name(),
321-
true,
320+
tiledb_helpers::add_to_group(
321+
write_group,
322+
adjacency_row_index_uri(),
322323
adjacency_row_index_array_name());
323324

324325
// Store the metadata if all of the arrays were created successfully

src/include/tdb_defs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ constexpr auto type_to_tiledb_v = tiledb::impl::type_to_tiledb<T>::tiledb_type;
9090
if (str == "uint64") {
9191
return TILEDB_UINT64;
9292
}
93-
throw std::runtime_error("Unsupported datatype");
93+
throw std::runtime_error("Unsupported datatype: " + str);
9494
}
9595

9696
// cf type_to_str(tiledb_datatype_t type) in tiledb/type.h

0 commit comments

Comments
 (0)