Skip to content

Commit b7aabc2

Browse files
authored
Add more C++ and Python temporal policy unit tests (#359)
1 parent ce2ae5e commit b7aabc2

File tree

8 files changed

+239
-23
lines changed

8 files changed

+239
-23
lines changed

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ void init_type_erased_module(py::module_& m) {
314314
new (&instance) IndexVamana(
315315
ctx,
316316
group_uri,
317-
timestamp == 0 ? TemporalPolicy() :
317+
timestamp == 0 ? TemporalPolicy(TimeTravel, 0) :
318318
TemporalPolicy(TimeTravel, timestamp));
319319
},
320320
py::keep_alive<1, 2>(), // IndexVamana should keep ctx alive.

apis/python/test/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import random
34
import shutil
@@ -392,3 +393,13 @@ def delete_uri(uri, config):
392393
else:
393394
raise err
394395
group.delete(recursive=True)
396+
397+
398+
def load_metadata(index_uri):
399+
group = tiledb.Group(index_uri, "r")
400+
ingestion_timestamps = [
401+
int(x) for x in list(json.loads(group.meta.get("ingestion_timestamps", "[]")))
402+
]
403+
base_sizes = [int(x) for x in list(json.loads(group.meta.get("base_sizes", "[]")))]
404+
group.close()
405+
return ingestion_timestamps, base_sizes

apis/python/test/test_index.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
2+
import time
23

34
import numpy as np
45
import pytest
56
from array_paths import *
67
from common import *
8+
from common import load_metadata
79

810
import tiledb.vector_search.index as ind
911
from tiledb.vector_search import Index
@@ -125,6 +127,11 @@ def test_ivf_flat_index(tmp_path):
125127
index = ivf_flat_index.create(
126128
uri=uri, dimensions=3, vector_type=vector_type, partitions=partitions
127129
)
130+
131+
ingestion_timestamps, base_sizes = load_metadata(uri)
132+
assert base_sizes == [0]
133+
assert ingestion_timestamps == [0]
134+
128135
query_and_check(
129136
index,
130137
np.array([[2, 2, 2]], dtype=np.float32),
@@ -147,6 +154,13 @@ def test_ivf_flat_index(tmp_path):
147154
)
148155

149156
index = index.consolidate_updates()
157+
# TODO(SC-46771): Investigate whether we should overwrite the existing metadata during the first
158+
# ingestion of Python indexes. I believe as it's currently written we have a bug here.
159+
# ingestion_timestamps, base_sizes = load_metadata(uri)
160+
# assert base_sizes == [5]
161+
# timestamp_5_minutes_from_now = int((time.time() + 5 * 60) * 1000)
162+
# timestamp_5_minutes_ago = int((time.time() - 5 * 60) * 1000)
163+
# assert ingestion_timestamps[0] > timestamp_5_minutes_ago and ingestion_timestamps[0] < timestamp_5_minutes_from_now
150164

151165
query_and_check(
152166
index, np.array([[2, 2, 2]], dtype=np.float32), 3, {1, 2, 3}, nprobe=partitions
@@ -224,6 +238,10 @@ def test_vamana_index(tmp_path):
224238
vector_type=np.dtype(vector_type),
225239
)
226240

241+
ingestion_timestamps, base_sizes = load_metadata(uri)
242+
assert base_sizes == [0]
243+
assert ingestion_timestamps == [0]
244+
227245
queries = np.array([[2, 2, 2]], dtype=np.float32)
228246
distances, ids = index.query(queries, k=1)
229247
assert distances.shape == (1, 1)
@@ -251,6 +269,16 @@ def test_vamana_index(tmp_path):
251269

252270
index = index.consolidate_updates()
253271

272+
# During the first ingestion we overwrite the metadata and end up with a single base size and ingestion timestamp.
273+
ingestion_timestamps, base_sizes = load_metadata(uri)
274+
assert base_sizes == [5]
275+
timestamp_5_minutes_from_now = int((time.time() + 5 * 60) * 1000)
276+
timestamp_5_minutes_ago = int((time.time() - 5 * 60) * 1000)
277+
assert (
278+
ingestion_timestamps[0] > timestamp_5_minutes_ago
279+
and ingestion_timestamps[0] < timestamp_5_minutes_from_now
280+
)
281+
254282
# Check that we throw if we query with an invalid opt_l.
255283
with pytest.raises(ValueError):
256284
index.query(queries, k=3, opt_l=2)

apis/python/test/test_ingestion.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from array_paths import *
66
from common import *
7+
from common import load_metadata
78

89
from tiledb.cloud.dag import Mode
910
from tiledb.vector_search.flat_index import FlatIndex
@@ -513,16 +514,16 @@ def test_ingestion_with_updates(tmp_path):
513514
partitions=partitions,
514515
)
515516

516-
# TODO(paris): Fix Vamana to have same metadata as Python and re-enable.
517-
# with tiledb.Group(index_uri, "r", ctx={}) as group:
518-
# ingestion_timestamps = [int(x) for x in list(json.loads(group.meta.get("ingestion_timestamps", "[]")))]
519-
# base_sizes = [int(x) for x in list(json.loads(group.meta.get("base_sizes", "[]")))]
520-
# assert len(ingestion_timestamps) == 1
521-
# assert len(base_sizes) == 1
522-
# assert base_sizes[0] == 1000
523-
# timestamp_2030 = 1903946089000
524-
# timestamp_5_minutes_ago = int((time.time() - 5 * 60) * 1000)
525-
# assert ingestion_timestamps[0] > timestamp_5_minutes_ago and ingestion_timestamps[0] < timestamp_2030
517+
ingestion_timestamps, base_sizes = load_metadata(index_uri)
518+
assert base_sizes == [1000]
519+
assert len(ingestion_timestamps) == 1
520+
timestamp_5_minutes_from_now = int((time.time() + 5 * 60) * 1000)
521+
timestamp_5_minutes_ago = int((time.time() - 5 * 60) * 1000)
522+
assert (
523+
ingestion_timestamps[0] > timestamp_5_minutes_ago
524+
and ingestion_timestamps[0] < timestamp_5_minutes_from_now
525+
)
526+
ingestion_timestamp = ingestion_timestamps[0]
526527

527528
_, result = index.query(queries, k=k, nprobe=nprobe)
528529
assert accuracy(result, gt_i) == 1.0
@@ -548,6 +549,16 @@ def test_ingestion_with_updates(tmp_path):
548549
_, result = index.query(queries, k=k, nprobe=20)
549550
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
550551

552+
ingestion_timestamps, base_sizes = load_metadata(index_uri)
553+
assert base_sizes == [1000, 1000]
554+
assert len(ingestion_timestamps) == 2
555+
assert ingestion_timestamps[0] == ingestion_timestamp
556+
assert (
557+
ingestion_timestamps[1] != ingestion_timestamp
558+
and ingestion_timestamps[1] > timestamp_5_minutes_ago
559+
and ingestion_timestamps[1] < timestamp_5_minutes_from_now
560+
)
561+
551562
assert vfs.dir_size(index_uri) > 0
552563
Index.delete_index(uri=index_uri, config={})
553564
assert vfs.dir_size(index_uri) == 0
@@ -643,6 +654,10 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
643654
index_timestamp=1,
644655
)
645656

657+
ingestion_timestamps, base_sizes = load_metadata(index_uri)
658+
assert ingestion_timestamps == [1]
659+
assert base_sizes == [1000]
660+
646661
if index_type == "IVF_FLAT":
647662
assert index.partitions == partitions
648663

@@ -651,7 +666,8 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
651666

652667
update_ids_offset = MAX_UINT64 - size
653668
updated_ids = {}
654-
for i in range(2, 102):
669+
timestamp_end = 102
670+
for i in range(2, timestamp_end):
655671
index.delete(external_id=i, timestamp=i)
656672
index.update(
657673
vector=data[i].astype(dtype),
@@ -660,6 +676,10 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
660676
)
661677
updated_ids[i] = i + update_ids_offset
662678

679+
ingestion_timestamps, base_sizes = load_metadata(index_uri)
680+
assert ingestion_timestamps == [1]
681+
assert base_sizes == [1000]
682+
663683
index = index_class(uri=index_uri)
664684
_, result = index.query(queries, k=k, nprobe=partitions)
665685
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
@@ -717,6 +737,11 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
717737

718738
# Consolidate updates
719739
index = index.consolidate_updates()
740+
741+
ingestion_timestamps, base_sizes = load_metadata(index_uri)
742+
assert ingestion_timestamps == [1, timestamp_end]
743+
assert base_sizes == [1000, 1000]
744+
720745
index = index_class(uri=index_uri)
721746
_, result = index.query(queries, k=k, nprobe=partitions)
722747
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0

src/include/index/index_group.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,13 @@ class base_index_group {
370370
auto get_previous_ingestion_timestamp() const {
371371
return metadata_.ingestion_timestamps_.back();
372372
}
373+
auto get_ingestion_timestamp() const {
374+
return metadata_.ingestion_timestamps_[timetravel_index_];
375+
}
373376
auto append_ingestion_timestamp(size_t timestamp) {
374377
metadata_.ingestion_timestamps_.push_back(timestamp);
375378
}
376-
auto get_all_ingestion_timestamps() {
379+
auto get_all_ingestion_timestamps() const {
377380
return metadata_.ingestion_timestamps_;
378381
}
379382

@@ -389,7 +392,7 @@ class base_index_group {
389392
auto append_base_size(size_t size) {
390393
metadata_.base_sizes_.push_back(size);
391394
}
392-
auto get_all_base_sizes() {
395+
auto get_all_base_sizes() const {
393396
return metadata_.base_sizes_;
394397
}
395398

@@ -407,6 +410,10 @@ class base_index_group {
407410
metadata_.dimension_ = dim;
408411
}
409412

413+
auto get_timetravel_index() const {
414+
return timetravel_index_;
415+
}
416+
410417
/**************************************************************************
411418
* Getters for names and uris
412419
**************************************************************************/

src/include/index/vamana_index.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,13 @@ class vamana_index {
949949
return true;
950950
}
951951

952+
const vamana_index_group<vamana_index>& group() const {
953+
if (!group_) {
954+
throw std::runtime_error("No group available");
955+
}
956+
return *group_;
957+
}
958+
952959
/**
953960
* @brief Log statistics about the index
954961
*/
@@ -1035,6 +1042,20 @@ class vamana_index {
10351042
<< std::endl;
10361043
return false;
10371044
}
1045+
if (temporal_policy_.timestamp_start() !=
1046+
rhs.temporal_policy_.timestamp_start()) {
1047+
std::cout << "temporal_policy_.timestamp_start() != "
1048+
"rhs.temporal_policy_.timestamp_start()"
1049+
<< medoid_ << " ! = " << rhs.medoid_ << std::endl;
1050+
return false;
1051+
}
1052+
if (temporal_policy_.timestamp_end() !=
1053+
rhs.temporal_policy_.timestamp_end()) {
1054+
std::cout << "temporal_policy_.timestamp_end() != "
1055+
"rhs.temporal_policy_.timestamp_end()"
1056+
<< medoid_ << " ! = " << rhs.medoid_ << std::endl;
1057+
return false;
1058+
}
10381059

10391060
return true;
10401061
}

0 commit comments

Comments
 (0)