Skip to content

Commit 6936b77

Browse files
authored
Vamana type-erased index can be opened at a specific timestamp (#354)
1 parent 33d031b commit 6936b77

File tree

13 files changed

+348
-90
lines changed

13 files changed

+348
-90
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def __init__(
3434
timestamp=None,
3535
**kwargs,
3636
):
37-
super().__init__(uri=uri, config=config, timestamp=timestamp)
3837
self.index_type = INDEX_TYPE
38+
super().__init__(uri=uri, config=config, timestamp=timestamp)
3939
self._index = None
4040
self.db_uri = self.group[
4141
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]

apis/python/src/tiledb/vector_search/index.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tiledb.vector_search.module import *
99
from tiledb.vector_search.storage_formats import storage_formats
1010
from tiledb.vector_search.utils import add_to_group
11+
from tiledb.vector_search.utils import is_type_erased_index
1112

1213
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
1314
MAX_INT32 = np.iinfo(np.dtype("int32")).max
@@ -466,6 +467,7 @@ def clear_history(
466467
):
467468
with tiledb.scope_ctx(ctx_or_config=config):
468469
group = tiledb.Group(uri, "r")
470+
index_type = group.meta.get("index_type", "")
469471
storage_version = group.meta.get("storage_version", "0.1")
470472
if not storage_formats[storage_version]["SUPPORT_TIMETRAVEL"]:
471473
raise ValueError(
@@ -490,7 +492,9 @@ def clear_history(
490492
if ingestion_timestamp > timestamp:
491493
new_ingestion_timestamps.append(ingestion_timestamp)
492494
new_base_sizes.append(base_sizes[i])
493-
new_partition_history.append(partition_history[i])
495+
# Type erased indexes don't have partition_history, skip to avoid crash.
496+
if not is_type_erased_index(index_type):
497+
new_partition_history.append(partition_history[i])
494498
i += 1
495499
if len(new_ingestion_timestamps) == 0:
496500
new_ingestion_timestamps = [0]
@@ -502,7 +506,9 @@ def clear_history(
502506
group = tiledb.Group(uri, "w")
503507
group.meta["ingestion_timestamps"] = json.dumps(new_ingestion_timestamps)
504508
group.meta["base_sizes"] = json.dumps(new_base_sizes)
505-
group.meta["partition_history"] = json.dumps(new_partition_history)
509+
# Type erased indexes don't have partition_history, skip to avoid crash.
510+
if not is_type_erased_index(index_type):
511+
group.meta["partition_history"] = json.dumps(new_partition_history)
506512
group.close()
507513

508514
group = tiledb.Group(uri, "r")

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,13 @@ def ingest_vamana(
16321632

16331633
ctx = vspy.Ctx(config)
16341634
index = vspy.IndexVamana(ctx, index_group_uri)
1635-
data = vspy.FeatureVectorArray(ctx, parts_array_uri, ids_array_uri)
1635+
data = vspy.FeatureVectorArray(
1636+
ctx,
1637+
parts_array_uri,
1638+
ids_array_uri,
1639+
0,
1640+
index_timestamp if index_timestamp else 0,
1641+
)
16361642
index.train(data)
16371643
index.add(data)
16381644
index.write_index(ctx, index_group_uri, index_timestamp)

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def __init__(
4848
memory_budget: int = -1,
4949
**kwargs,
5050
):
51-
super().__init__(uri=uri, config=config, timestamp=timestamp)
5251
self.index_type = INDEX_TYPE
52+
super().__init__(uri=uri, config=config, timestamp=timestamp)
5353
self.db_uri = self.group[
5454
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
5555
+ self.index_version

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,22 @@ void init_type_erased_module(py::module_& m) {
173173
py::keep_alive<1, 2>() // FeatureVectorArray should keep ctx alive.
174174
)
175175
.def(
176-
py::init<
177-
const tiledb::Context&,
178-
const std::string&,
179-
const std::string&>(),
180-
py::keep_alive<1, 2>() // FeatureVectorArray should keep ctx alive.
181-
)
176+
"__init__",
177+
[](FeatureVectorArray& instance,
178+
const tiledb::Context& ctx,
179+
const std::string& uri,
180+
const std::string& ids_uri,
181+
size_t num_vectors,
182+
size_t timestamp) {
183+
new (&instance)
184+
FeatureVectorArray(ctx, uri, ids_uri, num_vectors, timestamp);
185+
},
186+
py::keep_alive<1, 2>(), // FeatureVectorArray should keep ctx alive.
187+
py::arg("ctx"),
188+
py::arg("uri"),
189+
py::arg("ids_uri") = "",
190+
py::arg("num_vectors") = 0,
191+
py::arg("timestamp") = 0)
182192
.def(py::init<size_t, size_t, const std::string&, const std::string&>())
183193
.def("dimension", &FeatureVectorArray::dimension)
184194
.def("num_vectors", &FeatureVectorArray::num_vectors)
@@ -261,9 +271,17 @@ void init_type_erased_module(py::module_& m) {
261271

262272
py::class_<IndexVamana>(m, "IndexVamana")
263273
.def(
264-
py::init<const tiledb::Context&, const std::string&>(),
265-
py::keep_alive<1, 2>() // IndexVamana should keep ctx alive.
266-
)
274+
"__init__",
275+
[](IndexVamana& instance,
276+
const tiledb::Context& ctx,
277+
const std::string& group_uri,
278+
size_t timestamp) {
279+
new (&instance) IndexVamana(ctx, group_uri, timestamp);
280+
},
281+
py::keep_alive<1, 2>(), // IndexVamana should keep ctx alive.
282+
py::arg("ctx"),
283+
py::arg("group_uri"),
284+
py::arg("timestamp") = 0)
267285
.def(
268286
"__init__",
269287
[](IndexVamana& instance, py::kwargs kwargs) {

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ def __init__(
3535
):
3636
super().__init__(uri=uri, config=config, timestamp=timestamp)
3737
self.index_type = INDEX_TYPE
38-
self.index = vspy.IndexVamana(vspy.Ctx(config), uri)
38+
# TODO(paris): Support (start, end) timestamps and remove.
39+
type_erased_timestamp = timestamp
40+
if isinstance(timestamp, tuple):
41+
type_erased_timestamp = timestamp[1]
42+
type_erased_timestamp = type_erased_timestamp if type_erased_timestamp else 0
43+
self.index = vspy.IndexVamana(self.ctx, uri, type_erased_timestamp)
3944
self.db_uri = self.group[
4045
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
4146
].uri

apis/python/test/test_ingestion.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -560,11 +560,6 @@ def test_ingestion_with_batch_updates(tmp_path):
560560
gt_i, gt_d = get_groundtruth(dataset_dir, k)
561561

562562
for index_type, index_class in zip(INDEXES, INDEX_CLASSES):
563-
# TODO(paris): Fix Vamana bug and re-enable:
564-
# tiledb.cc.TileDBError: [TileDB::ArrayDirectory] Error: Cannot open array; Array does not exist.
565-
if index_type == "VAMANA":
566-
continue
567-
568563
index_uri = os.path.join(tmp_path, f"array_{index_type}")
569564
index = ingest(
570565
index_type=index_type,
@@ -665,10 +660,6 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
665660
index_uri = move_local_index_to_new_location(index_uri)
666661
index = index_class(uri=index_uri, timestamp=(0, 101))
667662
_, result = index.query(queries, k=k, nprobe=partitions)
668-
# TODO(paris): Fix Vamana accuracy bug and re-enable:
669-
# assert 0.105 == 1.0
670-
if index_type == "VAMANA":
671-
continue
672663
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
673664
index = index_class(uri=index_uri, timestamp=(0, None))
674665
_, result = index.query(queries, k=k, nprobe=partitions)
@@ -773,6 +764,9 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
773764
Index.clear_history(
774765
uri=index_uri, timestamp=index.latest_ingestion_timestamp - 1
775766
)
767+
if index_type == "VAMANA":
768+
# TODO(paris): Re-enable once we support (start, end) timestamps for Vamana.
769+
continue
776770
index = index_class(uri=index_uri, timestamp=1)
777771
_, result = index.query(queries, k=k, nprobe=partitions)
778772
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0

apis/python/test/test_type_erased_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,4 +338,4 @@ def test_inplace_build_infinite_query_IndexIVFFlat():
338338
if nprobe == 8:
339339
assert recall > 0.925
340340
if nprobe == 32:
341-
assert recall >= 0.999
341+
assert recall >= 0.998

0 commit comments

Comments
 (0)