|
1 | 1 | import numpy as np |
2 | 2 | from common import * |
3 | | -from tiledb.cloud.dag import Mode |
| 3 | +import pytest |
4 | 4 |
|
| 5 | +from tiledb.cloud.dag import Mode |
5 | 6 | from tiledb.vector_search.flat_index import FlatIndex |
6 | 7 | from tiledb.vector_search.index import Index |
7 | 8 | from tiledb.vector_search.ingestion import ingest |
@@ -416,6 +417,7 @@ def test_ivf_flat_ingestion_with_batch_updates(tmp_path): |
416 | 417 | _, result = index.query(query_vectors, k=k, nprobe=nprobe) |
417 | 418 | assert accuracy(result, gt_i, updated_ids=updated_ids) > 0.99 |
418 | 419 |
|
| 420 | + |
419 | 421 | def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path): |
420 | 422 | dataset_dir = os.path.join(tmp_path, "dataset") |
421 | 423 | index_uri = os.path.join(tmp_path, "array") |
@@ -669,6 +671,73 @@ def test_ivf_flat_ingestion_with_additions_and_timetravel(tmp_path): |
669 | 671 | _, result = index.query(query_vectors, k=k, nprobe=index.partitions) |
670 | 672 | assert 0.45 < accuracy(result, gt_i) < 0.55 |
671 | 673 |
|
| 674 | + |
| 675 | +def test_storage_versions(tmp_path): |
| 676 | + dataset_dir = os.path.join(tmp_path, "dataset") |
| 677 | + k = 10 |
| 678 | + size = 1000 |
| 679 | + partitions = 10 |
| 680 | + dimensions = 128 |
| 681 | + nqueries = 100 |
| 682 | + data = create_random_dataset_u8(nb=size, d=dimensions, nq=nqueries, k=k, path=dataset_dir) |
| 683 | + source_uri = os.path.join(dataset_dir, "data.u8bin") |
| 684 | + |
| 685 | + dtype = np.uint8 |
| 686 | + query_vectors = get_queries(dataset_dir, dtype=dtype) |
| 687 | + gt_i, _ = get_groundtruth(dataset_dir, k) |
| 688 | + |
| 689 | + indexes = ["FLAT", "IVF_FLAT"] |
| 690 | + index_classes = [FlatIndex, IVFFlatIndex] |
| 691 | + index_files = [tiledb.vector_search.flat_index, tiledb.vector_search.ivf_flat_index] |
| 692 | + for index_type, index_class, index_file in zip(indexes, index_classes, index_files): |
| 693 | + # First we test with an invalid storage version. |
| 694 | + with pytest.raises(ValueError) as error: |
| 695 | + index_uri = os.path.join(tmp_path, f"array_{index_type}_invalid") |
| 696 | + ingest( |
| 697 | + index_type=index_type, |
| 698 | + index_uri=index_uri, |
| 699 | + source_uri=source_uri, |
| 700 | + partitions=partitions, |
| 701 | + storage_version="Foo" |
| 702 | + ) |
| 703 | + assert "Invalid storage version" in str(error.value) |
| 704 | + |
| 705 | + with pytest.raises(ValueError) as error: |
| 706 | + index_file.create(uri=index_uri, dimensions=3, vector_type=np.dtype(dtype), storage_version="Foo") |
| 707 | + assert "Invalid storage version" in str(error.value) |
| 708 | + |
| 709 | + # Then we test with valid storage versions. |
| 710 | + for storage_version, _ in tiledb.vector_search.storage_formats.items(): |
| 711 | + index_uri = os.path.join(tmp_path, f"array_{index_type}_{storage_version}") |
| 712 | + index = ingest( |
| 713 | + index_type=index_type, |
| 714 | + index_uri=index_uri, |
| 715 | + source_uri=source_uri, |
| 716 | + partitions=partitions, |
| 717 | + storage_version=storage_version |
| 718 | + ) |
| 719 | + _, result = index.query(query_vectors, k=k) |
| 720 | + assert accuracy(result, gt_i) >= MINIMUM_ACCURACY |
| 721 | + |
| 722 | + update_ids_offset = MAX_UINT64 - size |
| 723 | + updated_ids = {} |
| 724 | + for i in range(10): |
| 725 | + index.delete(external_id=i) |
| 726 | + index.update(vector=data[i].astype(dtype), external_id=i + update_ids_offset) |
| 727 | + updated_ids[i] = i + update_ids_offset |
| 728 | + |
| 729 | + _, result = index.query(query_vectors, k=k) |
| 730 | + assert accuracy(result, gt_i, updated_ids=updated_ids) >= MINIMUM_ACCURACY |
| 731 | + |
| 732 | + index = index.consolidate_updates(partitions=20) |
| 733 | + _, result = index.query(query_vectors, k=k) |
| 734 | + assert accuracy(result, gt_i, updated_ids=updated_ids) >= MINIMUM_ACCURACY |
| 735 | + |
| 736 | + index_ram = index_class(uri=index_uri) |
| 737 | + _, result = index_ram.query(query_vectors, k=k) |
| 738 | + assert accuracy(result, gt_i) > MINIMUM_ACCURACY |
| 739 | + |
| 740 | + |
672 | 741 | def test_kmeans(): |
673 | 742 | k = 128 |
674 | 743 | d = 16 |
|
0 commit comments