Skip to content

Commit 60fc836

Browse files
authored
Test that we can index different storage versions and then query them (#168)
1 parent 3e3e7f8 commit 60fc836

File tree

6 files changed

+109
-18
lines changed

6 files changed

+109
-18
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from tiledb.vector_search import index
77
from tiledb.vector_search.module import *
88
from tiledb.vector_search.storage_formats import (STORAGE_VERSION,
9-
storage_formats)
9+
storage_formats,
10+
validate_storage_version)
1011

1112
MAX_INT32 = np.iinfo(np.dtype("int32")).max
1213
TILE_SIZE_BYTES = 128000000 # 128MB
@@ -119,21 +120,25 @@ def create(
119120
vector_type: np.dtype,
120121
group_exists: bool = False,
121122
config: Optional[Mapping[str, Any]] = None,
123+
storage_version: str = STORAGE_VERSION,
122124
**kwargs,
123125
) -> FlatIndex:
126+
validate_storage_version(storage_version)
127+
124128
index.create_metadata(
125129
uri=uri,
126130
dimensions=dimensions,
127131
vector_type=vector_type,
128132
index_type=INDEX_TYPE,
133+
storage_version=storage_version,
129134
group_exists=group_exists,
130135
config=config,
131136
)
132137
with tiledb.scope_ctx(ctx_or_config=config):
133138
group = tiledb.Group(uri, "w")
134139
tile_size = TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
135-
ids_array_name = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
136-
parts_array_name = storage_formats[STORAGE_VERSION]["PARTS_ARRAY_NAME"]
140+
ids_array_name = storage_formats[storage_version]["IDS_ARRAY_NAME"]
141+
parts_array_name = storage_formats[storage_version]["PARTS_ARRAY_NAME"]
137142
ids_uri = f"{uri}/{ids_array_name}"
138143
parts_uri = f"{uri}/{parts_array_name}"
139144

@@ -147,7 +152,7 @@ def create(
147152
ids_attr = tiledb.Attr(
148153
name="values",
149154
dtype=np.dtype(np.uint64),
150-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
155+
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
151156
)
152157
ids_schema = tiledb.ArraySchema(
153158
domain=ids_array_dom,
@@ -175,7 +180,7 @@ def create(
175180
parts_attr = tiledb.Attr(
176181
name="values",
177182
dtype=vector_type,
178-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
183+
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
179184
)
180185
parts_schema = tiledb.ArraySchema(
181186
domain=parts_array_dom,

apis/python/src/tiledb/vector_search/index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def create_metadata(
489489
dimensions: int,
490490
vector_type: np.dtype,
491491
index_type: str,
492+
storage_version: str,
492493
group_exists: bool = False,
493494
config: Optional[Mapping[str, Any]] = None,
494495
):
@@ -501,7 +502,7 @@ def create_metadata(
501502
group = tiledb.Group(uri, "w")
502503
group.meta["dataset_type"] = DATASET_TYPE
503504
group.meta["dtype"] = np.dtype(vector_type).name
504-
group.meta["storage_version"] = STORAGE_VERSION
505+
group.meta["storage_version"] = storage_version
505506
group.meta["index_type"] = index_type
506507
group.meta["base_sizes"] = json.dumps([0])
507508
group.meta["ingestion_timestamps"] = json.dumps([0])

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tiledb.cloud.dag import Mode
77

88
from tiledb.vector_search._tiledbvspy import *
9-
from tiledb.vector_search.storage_formats import STORAGE_VERSION
9+
from tiledb.vector_search.storage_formats import STORAGE_VERSION, validate_storage_version
1010

1111

1212
def ingest(
@@ -88,7 +88,7 @@ def ingest(
8888
Max number of tasks per execution stage of ingestion,
8989
if not provided, is auto-configured
9090
storage_version: str
91-
Vector index storage format version.
91+
Vector index storage format version. If not provided, defaults to the latest version.
9292
verbose: bool
9393
verbose logging, defaults to False
9494
trace_id: Optional[str]
@@ -119,6 +119,8 @@ def ingest(
119119
from tiledb.vector_search.index import Index
120120
from tiledb.vector_search.storage_formats import storage_formats
121121

122+
validate_storage_version(storage_version)
123+
122124
# use index_group_uri for internal clarity
123125
index_group_uri = index_uri
124126

@@ -355,6 +357,7 @@ def create_arrays(
355357
input_vectors_work_items: int,
356358
vector_type: np.dtype,
357359
logger: logging.Logger,
360+
storage_version: str,
358361
) -> None:
359362
if index_type == "FLAT":
360363
if not arrays_created:
@@ -364,6 +367,7 @@ def create_arrays(
364367
vector_type=vector_type,
365368
group_exists=True,
366369
config=config,
370+
storage_version=storage_version
367371
)
368372
elif index_type == "IVF_FLAT":
369373
if not arrays_created:
@@ -373,6 +377,7 @@ def create_arrays(
373377
vector_type=vector_type,
374378
group_exists=True,
375379
config=config,
380+
storage_version=storage_version
376381
)
377382
tile_size = int(
378383
ivf_flat_index.TILE_SIZE_BYTES
@@ -1935,6 +1940,7 @@ def consolidate_and_vacuum(
19351940
input_vectors_work_items=input_vectors_work_items,
19361941
vector_type=vector_type,
19371942
logger=logger,
1943+
storage_version=storage_version
19381944
)
19391945
group.meta["temp_size"] = size
19401946
group.close()

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from tiledb.vector_search import index
99
from tiledb.vector_search.module import *
1010
from tiledb.vector_search.storage_formats import (STORAGE_VERSION,
11-
storage_formats)
11+
storage_formats,
12+
validate_storage_version)
1213

1314
MAX_INT32 = np.iinfo(np.dtype("int32")).max
1415
TILE_SIZE_BYTES = 64000000 # 64MB
@@ -450,24 +451,28 @@ def create(
450451
vector_type: np.dtype,
451452
group_exists: bool = False,
452453
config: Optional[Mapping[str, Any]] = None,
454+
storage_version: str = STORAGE_VERSION,
453455
**kwargs,
454456
) -> IVFFlatIndex:
457+
validate_storage_version(storage_version)
458+
455459
index.create_metadata(
456460
uri=uri,
457461
dimensions=dimensions,
458462
vector_type=vector_type,
459463
index_type=INDEX_TYPE,
464+
storage_version=storage_version,
460465
group_exists=group_exists,
461466
config=config,
462467
)
463468
with tiledb.scope_ctx(ctx_or_config=config):
464469
group = tiledb.Group(uri, "w")
465470
tile_size = int(TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions)
466471
group.meta["partition_history"] = json.dumps([0])
467-
centroids_array_name = storage_formats[STORAGE_VERSION]["CENTROIDS_ARRAY_NAME"]
468-
index_array_name = storage_formats[STORAGE_VERSION]["INDEX_ARRAY_NAME"]
469-
ids_array_name = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
470-
parts_array_name = storage_formats[STORAGE_VERSION]["PARTS_ARRAY_NAME"]
472+
centroids_array_name = storage_formats[storage_version]["CENTROIDS_ARRAY_NAME"]
473+
index_array_name = storage_formats[storage_version]["INDEX_ARRAY_NAME"]
474+
ids_array_name = storage_formats[storage_version]["IDS_ARRAY_NAME"]
475+
parts_array_name = storage_formats[storage_version]["PARTS_ARRAY_NAME"]
471476
centroids_uri = f"{uri}/{centroids_array_name}"
472477
index_array_uri = f"{uri}/{index_array_name}"
473478
ids_uri = f"{uri}/{ids_array_name}"
@@ -491,7 +496,7 @@ def create(
491496
centroids_attr = tiledb.Attr(
492497
name="centroids",
493498
dtype=np.dtype(np.float32),
494-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
499+
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
495500
)
496501
centroids_schema = tiledb.ArraySchema(
497502
domain=centroids_array_dom,
@@ -513,7 +518,7 @@ def create(
513518
index_attr = tiledb.Attr(
514519
name="values",
515520
dtype=np.dtype(np.uint64),
516-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
521+
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
517522
)
518523
index_schema = tiledb.ArraySchema(
519524
domain=index_array_dom,
@@ -535,7 +540,7 @@ def create(
535540
ids_attr = tiledb.Attr(
536541
name="values",
537542
dtype=np.dtype(np.uint64),
538-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
543+
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
539544
)
540545
ids_schema = tiledb.ArraySchema(
541546
domain=ids_array_dom,
@@ -563,7 +568,7 @@ def create(
563568
parts_attr = tiledb.Attr(
564569
name="values",
565570
dtype=vector_type,
566-
filters=storage_formats[STORAGE_VERSION]["DEFAULT_ATTR_FILTERS"],
571+
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
567572
)
568573
parts_schema = tiledb.ArraySchema(
569574
domain=parts_array_dom,

apis/python/src/tiledb/vector_search/storage_formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@
4040
}
4141

4242
STORAGE_VERSION = "0.3"
43+
44+
def validate_storage_version(storage_version):
45+
if storage_version not in storage_formats:
46+
valid_versions = ', '.join(storage_formats.keys())
47+
raise ValueError(f"Invalid storage version: {storage_version}. Valid versions are: [{valid_versions}]")

apis/python/test/test_ingestion.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
from common import *
3-
from tiledb.cloud.dag import Mode
3+
import pytest
44

5+
from tiledb.cloud.dag import Mode
56
from tiledb.vector_search.flat_index import FlatIndex
67
from tiledb.vector_search.index import Index
78
from tiledb.vector_search.ingestion import ingest
@@ -416,6 +417,7 @@ def test_ivf_flat_ingestion_with_batch_updates(tmp_path):
416417
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
417418
assert accuracy(result, gt_i, updated_ids=updated_ids) > 0.99
418419

420+
419421
def test_ivf_flat_ingestion_with_updates_and_timetravel(tmp_path):
420422
dataset_dir = os.path.join(tmp_path, "dataset")
421423
index_uri = os.path.join(tmp_path, "array")
@@ -669,6 +671,73 @@ def test_ivf_flat_ingestion_with_additions_and_timetravel(tmp_path):
669671
_, result = index.query(query_vectors, k=k, nprobe=index.partitions)
670672
assert 0.45 < accuracy(result, gt_i) < 0.55
671673

674+
675+
def test_storage_versions(tmp_path):
676+
dataset_dir = os.path.join(tmp_path, "dataset")
677+
k = 10
678+
size = 1000
679+
partitions = 10
680+
dimensions = 128
681+
nqueries = 100
682+
data = create_random_dataset_u8(nb=size, d=dimensions, nq=nqueries, k=k, path=dataset_dir)
683+
source_uri = os.path.join(dataset_dir, "data.u8bin")
684+
685+
dtype = np.uint8
686+
query_vectors = get_queries(dataset_dir, dtype=dtype)
687+
gt_i, _ = get_groundtruth(dataset_dir, k)
688+
689+
indexes = ["FLAT", "IVF_FLAT"]
690+
index_classes = [FlatIndex, IVFFlatIndex]
691+
index_files = [tiledb.vector_search.flat_index, tiledb.vector_search.ivf_flat_index]
692+
for index_type, index_class, index_file in zip(indexes, index_classes, index_files):
693+
# First we test with an invalid storage version.
694+
with pytest.raises(ValueError) as error:
695+
index_uri = os.path.join(tmp_path, f"array_{index_type}_invalid")
696+
ingest(
697+
index_type=index_type,
698+
index_uri=index_uri,
699+
source_uri=source_uri,
700+
partitions=partitions,
701+
storage_version="Foo"
702+
)
703+
assert "Invalid storage version" in str(error.value)
704+
705+
with pytest.raises(ValueError) as error:
706+
index_file.create(uri=index_uri, dimensions=3, vector_type=np.dtype(dtype), storage_version="Foo")
707+
assert "Invalid storage version" in str(error.value)
708+
709+
# Then we test with valid storage versions.
710+
for storage_version, _ in tiledb.vector_search.storage_formats.items():
711+
index_uri = os.path.join(tmp_path, f"array_{index_type}_{storage_version}")
712+
index = ingest(
713+
index_type=index_type,
714+
index_uri=index_uri,
715+
source_uri=source_uri,
716+
partitions=partitions,
717+
storage_version=storage_version
718+
)
719+
_, result = index.query(query_vectors, k=k)
720+
assert accuracy(result, gt_i) >= MINIMUM_ACCURACY
721+
722+
update_ids_offset = MAX_UINT64 - size
723+
updated_ids = {}
724+
for i in range(10):
725+
index.delete(external_id=i)
726+
index.update(vector=data[i].astype(dtype), external_id=i + update_ids_offset)
727+
updated_ids[i] = i + update_ids_offset
728+
729+
_, result = index.query(query_vectors, k=k)
730+
assert accuracy(result, gt_i, updated_ids=updated_ids) >= MINIMUM_ACCURACY
731+
732+
index = index.consolidate_updates(partitions=20)
733+
_, result = index.query(query_vectors, k=k)
734+
assert accuracy(result, gt_i, updated_ids=updated_ids) >= MINIMUM_ACCURACY
735+
736+
index_ram = index_class(uri=index_uri)
737+
_, result = index_ram.query(query_vectors, k=k)
738+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
739+
740+
672741
def test_kmeans():
673742
k = 128
674743
d = 16

0 commit comments

Comments
 (0)