Skip to content

Commit 32d17ee

Browse files
authored
Fix bug where in Python where if an index was loaded before the earliest ingested data, we'd still load and query data from the future (#365)
1 parent bea859b commit 32d17ee

File tree

3 files changed

+288
-34
lines changed

3 files changed

+288
-34
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -91,39 +91,56 @@ def __init__(
9191
raise ValueError(
9292
"'timestamp' argument expects either int or tuple(start: int, end: int)"
9393
)
94-
if timestamp[0] is not None:
95-
if timestamp[0] > self.ingestion_timestamps[0]:
96-
self.query_base_array = False
97-
self.update_array_timestamp = timestamp
98-
else:
94+
if (
95+
timestamp[0] is not None
96+
and timestamp[0] > self.ingestion_timestamps[0]
97+
):
98+
self.query_base_array = False
99+
self.update_array_timestamp = timestamp
100+
else:
101+
if (
102+
timestamp[1] is None
103+
or timestamp[1] >= self.ingestion_timestamps[0]
104+
):
99105
self.history_index = 0
100106
self.base_size = self.base_sizes[self.history_index]
101107
self.base_array_timestamp = self.ingestion_timestamps[
102108
self.history_index
103109
]
104-
self.update_array_timestamp = (
105-
self.base_array_timestamp + 1,
106-
timestamp[1],
107-
)
108-
else:
109-
self.history_index = 0
110-
self.base_size = self.base_sizes[self.history_index]
111-
self.base_array_timestamp = self.ingestion_timestamps[
112-
self.history_index
113-
]
110+
else:
111+
# If the timestamp is before the first ingestion, we'll have no vectors to return.
112+
self.history_index = 0
113+
self.base_size = 0
114+
self.base_array_timestamp = timestamp[1]
115+
self.query_base_array = False
116+
114117
self.update_array_timestamp = (
115118
self.base_array_timestamp + 1,
116119
timestamp[1],
117120
)
121+
118122
elif isinstance(timestamp, int):
119-
self.history_index = 0
120-
i = 0
121-
for ingestion_timestamp in self.ingestion_timestamps:
122-
if ingestion_timestamp <= timestamp:
123-
self.base_array_timestamp = ingestion_timestamp
124-
self.history_index = i
125-
self.base_size = self.base_sizes[self.history_index]
126-
i += 1
123+
# NOTE(paris): We could instead use the same logic as in the else statement above,
124+
# but we do it like this as a performance improvment so that we read less from the
125+
# updates array and more from ingestions. Above we need to read just the first
126+
# ingestion and then from the updates array in case we get a timestamp in between an
127+
# ingestion and an update.
128+
if timestamp >= self.ingestion_timestamps[0]:
129+
self.history_index = 0
130+
i = 0
131+
for ingestion_timestamp in self.ingestion_timestamps:
132+
if ingestion_timestamp <= timestamp:
133+
self.base_array_timestamp = ingestion_timestamp
134+
self.history_index = i
135+
self.base_size = self.base_sizes[self.history_index]
136+
i += 1
137+
else:
138+
# If the timestamp is before the first ingestion, we'll have no vectors to return.
139+
self.history_index = 0
140+
self.base_size = 0
141+
self.base_array_timestamp = timestamp
142+
self.query_base_array = False
143+
127144
self.update_array_timestamp = (self.base_array_timestamp + 1, timestamp)
128145
else:
129146
raise TypeError(

apis/python/test/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,11 @@ def check_equals(result_d, result_i, expected_result_d, expected_result_i):
314314
result_i_expected: int
315315
The expected indices
316316
"""
317-
assert (
318-
result_i == expected_result_i
317+
assert np.array_equal(
318+
result_i, expected_result_i
319319
), f"result_i: {result_i} != expected_result_i: {expected_result_i}"
320-
assert (
321-
result_d == expected_result_d
320+
assert np.array_equal(
321+
result_d, expected_result_d
322322
), f"result_d: {result_d} != expected_result_d: {expected_result_d}"
323323

324324

apis/python/test/test_ingestion.py

Lines changed: 244 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,242 @@ def test_ingestion_external_ids_numpy(tmp_path):
487487
assert vfs.dir_size(index_uri) == 0
488488

489489

490+
def test_ingestion_timetravel(tmp_path):
491+
for index_type, index_class in zip(INDEXES, INDEX_CLASSES):
492+
index_uri = os.path.join(tmp_path, f"array_{index_type}")
493+
494+
data = np.array([[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3]], dtype=np.float32)
495+
default_result_d = [[np.finfo(np.float32).max], [np.finfo(np.float32).max]]
496+
default_result_i = [[np.iinfo(np.uint64).max], [np.iinfo(np.uint64).max]]
497+
498+
# We ingest at timestamp 10.
499+
ingest(
500+
index_type=index_type,
501+
index_uri=index_uri,
502+
input_vectors=data,
503+
index_timestamp=10,
504+
)
505+
506+
# If we load the index with any timestamp < 10, then we have no data and so have no results.
507+
query_and_check_equals(
508+
index=index_class(uri=index_uri, timestamp=0),
509+
queries=data,
510+
expected_result_d=default_result_d,
511+
expected_result_i=default_result_i,
512+
)
513+
query_and_check_equals(
514+
index=index_class(uri=index_uri, timestamp=9),
515+
queries=data,
516+
expected_result_d=default_result_d,
517+
expected_result_i=default_result_i,
518+
)
519+
query_and_check_equals(
520+
index=index_class(uri=index_uri, timestamp=(5, 9)),
521+
queries=data,
522+
expected_result_d=default_result_d,
523+
expected_result_i=default_result_i,
524+
)
525+
query_and_check_equals(
526+
index=index_class(uri=index_uri, timestamp=(None, 9)),
527+
queries=data,
528+
expected_result_d=default_result_d,
529+
expected_result_i=default_result_i,
530+
)
531+
532+
# If we load the index with timestamp >= 10 then we get results.
533+
query_and_check_equals(
534+
index=index_class(uri=index_uri, timestamp=10),
535+
queries=data,
536+
expected_result_d=[[0], [0]],
537+
expected_result_i=[[0], [1]],
538+
)
539+
query_and_check_equals(
540+
index=index_class(uri=index_uri, timestamp=1000),
541+
queries=data,
542+
expected_result_d=[[0], [0]],
543+
expected_result_i=[[0], [1]],
544+
)
545+
query_and_check_equals(
546+
index=index_class(uri=index_uri, timestamp=(5, 15)),
547+
queries=data,
548+
expected_result_d=[[0], [0]],
549+
expected_result_i=[[0], [1]],
550+
)
551+
query_and_check_equals(
552+
index=index_class(uri=index_uri, timestamp=(None, 20)),
553+
queries=data,
554+
expected_result_d=[[0], [0]],
555+
expected_result_i=[[0], [1]],
556+
)
557+
558+
# We add a third vector at timestamp 20 and consolidate updates, meaning we'll re-ingest at timestamp = 20.
559+
data = np.array(
560+
[[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3], [3.0, 3.1, 3.2, 3.3]],
561+
dtype=np.float32,
562+
)
563+
default_result_d = [
564+
[np.finfo(np.float32).max],
565+
[np.finfo(np.float32).max],
566+
[np.finfo(np.float32).max],
567+
]
568+
default_result_i = [
569+
[np.iinfo(np.uint64).max],
570+
[np.iinfo(np.uint64).max],
571+
[np.iinfo(np.uint64).max],
572+
]
573+
index = index_class(uri=index_uri)
574+
index.update(
575+
vector=data[2],
576+
external_id=2,
577+
timestamp=20,
578+
)
579+
index = index.consolidate_updates()
580+
581+
# We still have no results before timestamp 10.
582+
query_and_check_equals(
583+
index=index_class(uri=index_uri, timestamp=0),
584+
queries=data,
585+
expected_result_d=default_result_d,
586+
expected_result_i=default_result_i,
587+
)
588+
query_and_check_equals(
589+
index=index_class(uri=index_uri, timestamp=9),
590+
queries=data,
591+
expected_result_d=default_result_d,
592+
expected_result_i=default_result_i,
593+
)
594+
query_and_check_equals(
595+
index=index_class(uri=index_uri, timestamp=(5, 9)),
596+
queries=data,
597+
expected_result_d=default_result_d,
598+
expected_result_i=default_result_i,
599+
)
600+
query_and_check_equals(
601+
index=index_class(uri=index_uri, timestamp=(None, 9)),
602+
queries=data,
603+
expected_result_d=default_result_d,
604+
expected_result_i=default_result_i,
605+
)
606+
607+
# We have no results if we load in between 10 and 20.
608+
if index_type == "VAMANA":
609+
# TODO(paris): Fix Vamana and re-enable this test.
610+
continue
611+
query_and_check_equals(
612+
index=index_class(uri=index_uri, timestamp=(11, 19)),
613+
queries=data,
614+
expected_result_d=default_result_d,
615+
expected_result_i=default_result_i,
616+
)
617+
618+
# If we load the index from timestamp 0 -> 19, we only are returned the first two vectors.
619+
query_and_check_equals(
620+
index=index_class(uri=index_uri, timestamp=10),
621+
queries=data,
622+
expected_result_d=[[0], [0], [4]],
623+
expected_result_i=[[0], [1], [1]],
624+
)
625+
query_and_check_equals(
626+
index=index_class(uri=index_uri, timestamp=19),
627+
queries=data,
628+
expected_result_d=[[0], [0], [4]],
629+
expected_result_i=[[0], [1], [1]],
630+
)
631+
query_and_check_equals(
632+
index=index_class(uri=index_uri, timestamp=(0, 19)),
633+
queries=data,
634+
expected_result_d=[[0], [0], [4]],
635+
expected_result_i=[[0], [1], [1]],
636+
)
637+
query_and_check_equals(
638+
index=index_class(uri=index_uri, timestamp=(None, 19)),
639+
queries=data,
640+
expected_result_d=[[0], [0], [4]],
641+
expected_result_i=[[0], [1], [1]],
642+
)
643+
644+
# But if we load with timestamp >= 20 then we get results for all three vectors.
645+
query_and_check_equals(
646+
index=index_class(uri=index_uri, timestamp=None),
647+
queries=data,
648+
expected_result_d=[[0], [0], [0]],
649+
expected_result_i=[[0], [1], [2]],
650+
)
651+
query_and_check_equals(
652+
index=index_class(uri=index_uri, timestamp=1000),
653+
queries=data,
654+
expected_result_d=[[0], [0], [0]],
655+
expected_result_i=[[0], [1], [2]],
656+
)
657+
query_and_check_equals(
658+
index=index_class(uri=index_uri, timestamp=(0, 1000)),
659+
queries=data,
660+
expected_result_d=[[0], [0], [0]],
661+
expected_result_i=[[0], [1], [2]],
662+
)
663+
query_and_check_equals(
664+
index=index_class(uri=index_uri, timestamp=(None, 20)),
665+
queries=data,
666+
expected_result_d=[[0], [0], [0]],
667+
expected_result_i=[[0], [1], [2]],
668+
)
669+
670+
# Clear all history at timestamp 19.
671+
Index.clear_history(uri=index_uri, timestamp=19)
672+
673+
# If we load the index from timestamp 0 -> < 19, we only are returned the first two vectors.
674+
query_and_check_equals(
675+
index=index_class(uri=index_uri, timestamp=10),
676+
queries=data,
677+
expected_result_d=default_result_d,
678+
expected_result_i=default_result_i,
679+
)
680+
query_and_check_equals(
681+
index=index_class(uri=index_uri, timestamp=19),
682+
queries=data,
683+
expected_result_d=default_result_d,
684+
expected_result_i=default_result_i,
685+
)
686+
query_and_check_equals(
687+
index=index_class(uri=index_uri, timestamp=(0, 19)),
688+
queries=data,
689+
expected_result_d=default_result_d,
690+
expected_result_i=default_result_i,
691+
)
692+
query_and_check_equals(
693+
index=index_class(uri=index_uri, timestamp=(None, 19)),
694+
queries=data,
695+
expected_result_d=default_result_d,
696+
expected_result_i=default_result_i,
697+
)
698+
699+
# But if we load with timestamp > 20 then we get results for all three vectors.
700+
query_and_check_equals(
701+
index=index_class(uri=index_uri, timestamp=None),
702+
queries=data,
703+
expected_result_d=[[0], [0], [0]],
704+
expected_result_i=[[0], [1], [2]],
705+
)
706+
query_and_check_equals(
707+
index=index_class(uri=index_uri, timestamp=1000),
708+
queries=data,
709+
expected_result_d=[[0], [0], [0]],
710+
expected_result_i=[[0], [1], [2]],
711+
)
712+
query_and_check_equals(
713+
index=index_class(uri=index_uri, timestamp=(0, 1000)),
714+
queries=data,
715+
expected_result_d=[[0], [0], [0]],
716+
expected_result_i=[[0], [1], [2]],
717+
)
718+
query_and_check_equals(
719+
index=index_class(uri=index_uri, timestamp=(None, 21)),
720+
queries=data,
721+
expected_result_d=[[0], [0], [0]],
722+
expected_result_i=[[0], [1], [2]],
723+
)
724+
725+
490726
def test_ingestion_with_updates(tmp_path):
491727
vfs = tiledb.VFS()
492728

@@ -798,6 +1034,7 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
7981034
assert accuracy(result, gt_i) == 1.0
7991035

8001036
# Clear history before the latest ingestion
1037+
assert index.latest_ingestion_timestamp == 102
8011038
Index.clear_history(
8021039
uri=index_uri, timestamp=index.latest_ingestion_timestamp - 1
8031040
)
@@ -806,32 +1043,32 @@ def test_ingestion_with_updates_and_timetravel(tmp_path):
8061043
continue
8071044
index = index_class(uri=index_uri, timestamp=1)
8081045
_, result = index.query(queries, k=k, nprobe=partitions)
809-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
1046+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
8101047
index = index_class(uri=index_uri, timestamp=51)
8111048
_, result = index.query(queries, k=k, nprobe=partitions)
812-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
1049+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
8131050
index = index_class(uri=index_uri, timestamp=101)
8141051
_, result = index.query(queries, k=k, nprobe=partitions)
815-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
1052+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
8161053
index = index_class(uri=index_uri)
8171054
_, result = index.query(queries, k=k, nprobe=partitions)
8181055
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
8191056
index = index_class(uri=index_uri, timestamp=(0, 51))
8201057
_, result = index.query(queries, k=k, nprobe=partitions)
821-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
1058+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
8221059
index_uri = move_local_index_to_new_location(index_uri)
8231060
index = index_class(uri=index_uri, timestamp=(0, 101))
8241061
_, result = index.query(queries, k=k, nprobe=partitions)
825-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
1062+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
8261063
index = index_class(uri=index_uri, timestamp=(0, None))
8271064
_, result = index.query(queries, k=k, nprobe=partitions)
8281065
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
8291066
index = index_class(uri=index_uri, timestamp=(2, 51))
8301067
_, result = index.query(queries, k=k, nprobe=partitions)
831-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
1068+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
8321069
index = index_class(uri=index_uri, timestamp=(2, 101))
8331070
_, result = index.query(queries, k=k, nprobe=partitions)
834-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
1071+
assert accuracy(result, gt_i, updated_ids=updated_ids) == 0.0
8351072
index = index_class(uri=index_uri, timestamp=(2, None))
8361073
_, result = index.query(queries, k=k, nprobe=partitions)
8371074
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0

0 commit comments

Comments
 (0)