Skip to content

Commit e7ae919

Browse files
authored
Re-enable IVF PQ tests (#450)
1 parent 2e5d929 commit e7ae919

File tree

3 files changed

+50
-11
lines changed

3 files changed

+50
-11
lines changed

apis/python/src/tiledb/vector_search/ivf_pq_index.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(
6060
)
6161
# TODO(SC-48710): Add support for `open_for_remote_query_execution`. We don't leave `self.index`` as `None` because we need to be able to call index.dimensions().
6262
self.index = vspy.IndexIVFPQ(self.ctx, uri, to_temporal_policy(timestamp))
63-
# TODO(paris): This is incorrect - should be fixed when we fix consolidation.
6463
self.db_uri = self.group[
6564
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
6665
].uri

apis/python/test/test_index.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,46 @@ def test_ivf_pq_index(tmp_path):
350350
index, np.array([[2, 2, 2]], dtype=np.float32), 2, [[0, 3]], [[2, 1]]
351351
)
352352

353-
# TODO(paris): Add tests for consolidation once we enable it.
353+
index = index.consolidate_updates()
354+
355+
# During the first ingestion we overwrite the metadata and end up with a single base size and ingestion timestamp.
356+
ingestion_timestamps, base_sizes = load_metadata(uri)
357+
assert base_sizes == [5]
358+
timestamp_5_minutes_from_now = int((time.time() + 5 * 60) * 1000)
359+
timestamp_5_minutes_ago = int((time.time() - 5 * 60) * 1000)
360+
assert (
361+
ingestion_timestamps[0] > timestamp_5_minutes_ago
362+
and ingestion_timestamps[0] < timestamp_5_minutes_from_now
363+
)
364+
365+
# Test that we can query with multiple query vectors.
366+
for i in range(5):
367+
query_and_check_distances(
368+
index,
369+
np.array([[i, i, i], [i, i, i]], dtype=np.float32),
370+
1,
371+
[[0], [0]],
372+
[[i], [i]],
373+
)
374+
375+
# Test that we can query with k > 1.
376+
query_and_check_distances(
377+
index, np.array([[0, 0, 0]], dtype=np.float32), 2, [[0, 3]], [[0, 1]]
378+
)
379+
380+
# Test that we can query with multiple query vectors and k > 1.
381+
query_and_check_distances(
382+
index,
383+
np.array([[0, 0, 0], [4, 4, 4]], dtype=np.float32),
384+
2,
385+
[[0, 3], [0, 3]],
386+
[[0, 1], [4, 3]],
387+
)
388+
389+
vfs = tiledb.VFS()
390+
assert vfs.dir_size(uri) > 0
391+
Index.delete_index(uri=uri, config={})
392+
assert vfs.dir_size(uri) == 0
354393

355394

356395
def test_delete_invalid_index(tmp_path):

apis/python/test/test_ingestion.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -824,12 +824,7 @@ def test_ingestion_with_updates(tmp_path):
824824
ingestion_timestamp = ingestion_timestamps[0]
825825

826826
_, result = index.query(queries, k=k, nprobe=nprobe)
827-
if index_type == "IVF_PQ":
828-
# TODO(paris): We get 0.989 accuracy instead of 1.0. Investigate why - it should be 1.0
829-
# when we have `nprobe = partitions` and `num_subspaces = dimensions`.
830-
assert accuracy(result, gt_i) > 0.9
831-
continue
832-
assert accuracy(result, gt_i) == 1.0
827+
assert accuracy(result, gt_i) >= (0.998 if index_type == "IVF_PQ" else 1.0)
833828

834829
update_ids_offset = MAX_UINT64 - size
835830
updated_ids = {}
@@ -841,16 +836,22 @@ def test_ingestion_with_updates(tmp_path):
841836
updated_ids[i] = i + update_ids_offset
842837

843838
_, result = index.query(queries, k=k, nprobe=nprobe)
844-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
839+
assert accuracy(result, gt_i, updated_ids=updated_ids) >= (
840+
0.998 if index_type == "IVF_PQ" else 1.0
841+
)
845842

846843
index = index.consolidate_updates(retrain_index=True, partitions=20)
847844
_, result = index.query(queries, k=k, nprobe=20)
848-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
845+
assert accuracy(result, gt_i, updated_ids=updated_ids) >= (
846+
0.998 if index_type == "IVF_PQ" else 1.0
847+
)
849848

850849
index_uri = move_local_index_to_new_location(index_uri)
851850
index = index_class(uri=index_uri)
852851
_, result = index.query(queries, k=k, nprobe=20)
853-
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
852+
assert accuracy(result, gt_i, updated_ids=updated_ids) >= (
853+
0.998 if index_type == "IVF_PQ" else 1.0
854+
)
854855

855856
ingestion_timestamps, base_sizes = load_metadata(index_uri)
856857
assert base_sizes == [1000, 1000]

0 commit comments

Comments
 (0)