Skip to content

Commit f4d08e3

Browse files
authored
Support Vamana in the ObjectIndex (#366)
1 parent 32d17ee commit f4d08e3

File tree

5 files changed

+85
-67
lines changed

5 files changed

+85
-67
lines changed

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .module import validate_top_k
2020
from .storage_formats import STORAGE_VERSION
2121
from .storage_formats import storage_formats
22+
from .vamana_index import VamanaIndex
2223

2324
try:
2425
from tiledb.vector_search.version import version as __version__

apis/python/src/tiledb/vector_search/object_api/object_index.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from tiledb.cloud.dag import Mode
1212
from tiledb.vector_search import FlatIndex
1313
from tiledb.vector_search import IVFFlatIndex
14+
from tiledb.vector_search import VamanaIndex
1415
from tiledb.vector_search import flat_index
1516
from tiledb.vector_search import ivf_flat_index
17+
from tiledb.vector_search import vamana_index
1618
from tiledb.vector_search.embeddings import ObjectEmbedding
1719
from tiledb.vector_search.object_readers import ObjectReader
1820
from tiledb.vector_search.storage_formats import STORAGE_VERSION
@@ -53,6 +55,12 @@ def __init__(
5355
self.index = IVFFlatIndex(
5456
uri=uri, config=config, timestamp=timestamp, **kwargs
5557
)
58+
elif self.index_type == "VAMANA":
59+
self.index = VamanaIndex(
60+
uri=uri, config=config, timestamp=timestamp, **kwargs
61+
)
62+
else:
63+
raise ValueError(f"Unsupported index type {self.index_type}")
5664

5765
self.object_reader_source_code = self.index.group.meta[
5866
"object_reader_source_code"
@@ -428,6 +436,16 @@ def create(
428436
config=config,
429437
storage_version=storage_version,
430438
)
439+
elif index_type == "VAMANA":
440+
index = vamana_index.create(
441+
uri=uri,
442+
dimensions=dimensions,
443+
vector_type=vector_type,
444+
config=config,
445+
storage_version=storage_version,
446+
)
447+
else:
448+
raise ValueError(f"Unsupported index type {index_type}")
431449

432450
group = tiledb.Group(uri, "w")
433451
group.meta["object_reader_source_code"] = get_source_code(object_reader)

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def query_internal(
9292

9393
assert queries.dtype == np.float32
9494
if opt_l < k:
95-
raise ValueError(f"opt_l ({opt_l}) should be >= k ({k})")
95+
warnings.warn(f"opt_l ({opt_l}) should be >= k ({k}), setting to k")
96+
opt_l = k
9697

9798
if queries.ndim == 1:
9899
queries = np.array([queries])

apis/python/test/test_index.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,6 @@ def test_vamana_index(tmp_path):
279279
and ingestion_timestamps[0] < timestamp_5_minutes_from_now
280280
)
281281

282-
# Check that we throw if we query with an invalid opt_l.
283-
with pytest.raises(ValueError):
284-
index.query(queries, k=3, opt_l=2)
285-
286282
# Test that we can query with multiple query vectors.
287283
for i in range(5):
288284
query_and_check_distances(

apis/python/test/test_object_index.py

Lines changed: 64 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Dict, List, OrderedDict, Tuple
23

34
import numpy as np
@@ -10,6 +11,7 @@
1011
from tiledb.vector_search.object_readers import ObjectReader
1112

1213
EMBED_DIM = 4
14+
INDEXES = ["FLAT", "IVF_FLAT", "VAMANA"]
1315

1416

1517
# TestEmbedding with vectors of EMBED_DIM size with all values being the id of the object
@@ -233,73 +235,73 @@ def df_filter(row):
233235
)
234236

235237

236-
def test_object_index_ivf_flat(tmp_path):
237-
reader = TestReader(
238-
object_id_start=0,
239-
object_id_end=1000,
240-
vector_dim_offset=0,
241-
)
242-
embedding = TestEmbedding()
243-
244-
index_uri = f"{tmp_path}/index"
238+
def test_object_index(tmp_path):
239+
for index_type in INDEXES:
240+
index_uri = os.path.join(tmp_path, f"object_index_{index_type}")
241+
reader = TestReader(
242+
object_id_start=0,
243+
object_id_end=1000,
244+
vector_dim_offset=0,
245+
)
246+
embedding = TestEmbedding()
245247

246-
index = object_index.create(
247-
uri=index_uri,
248-
index_type="IVF_FLAT",
249-
object_reader=reader,
250-
embedding=embedding,
251-
)
248+
index = object_index.create(
249+
uri=index_uri,
250+
index_type=index_type,
251+
object_reader=reader,
252+
embedding=embedding,
253+
)
252254

253-
# Check initial ingestion
254-
index.update_index(partitions=10)
255-
evaluate_query(
256-
index_uri=index_uri,
257-
query_kwargs={"nprobe": 10},
258-
dim_id=42,
259-
vector_dim_offset=0,
260-
)
255+
# Check initial ingestion
256+
index.update_index(partitions=10)
257+
evaluate_query(
258+
index_uri=index_uri,
259+
query_kwargs={"nprobe": 10, "opt_l": 250},
260+
dim_id=42,
261+
vector_dim_offset=0,
262+
)
261263

262-
# Check that updating the same data doesn't create duplicates
263-
index = object_index.ObjectIndex(uri=index_uri)
264-
index.update_index(partitions=10)
265-
evaluate_query(
266-
index_uri=index_uri,
267-
query_kwargs={"nprobe": 10},
268-
dim_id=42,
269-
vector_dim_offset=0,
270-
)
264+
# Check that updating the same data doesn't create duplicates
265+
index = object_index.ObjectIndex(uri=index_uri)
266+
index.update_index(partitions=10)
267+
evaluate_query(
268+
index_uri=index_uri,
269+
query_kwargs={"nprobe": 10, "opt_l": 500},
270+
dim_id=42,
271+
vector_dim_offset=0,
272+
)
271273

272-
# Add new data with a new reader
273-
reader = TestReader(
274-
object_id_start=1000,
275-
object_id_end=2000,
276-
vector_dim_offset=0,
277-
)
278-
index = object_index.ObjectIndex(uri=index_uri)
279-
index.update_object_reader(reader)
280-
index.update_index(partitions=10)
281-
evaluate_query(
282-
index_uri=index_uri,
283-
query_kwargs={"nprobe": 10},
284-
dim_id=1042,
285-
vector_dim_offset=0,
286-
)
274+
# Add new data with a new reader
275+
reader = TestReader(
276+
object_id_start=1000,
277+
object_id_end=2000,
278+
vector_dim_offset=0,
279+
)
280+
index = object_index.ObjectIndex(uri=index_uri)
281+
index.update_object_reader(reader)
282+
index.update_index(partitions=10)
283+
evaluate_query(
284+
index_uri=index_uri,
285+
query_kwargs={"nprobe": 10, "opt_l": 500},
286+
dim_id=1042,
287+
vector_dim_offset=0,
288+
)
287289

288-
# Check overwritting existing data
289-
reader = TestReader(
290-
object_id_start=1000,
291-
object_id_end=2000,
292-
vector_dim_offset=1000,
293-
)
294-
index = object_index.ObjectIndex(uri=index_uri)
295-
index.update_object_reader(reader)
296-
index.update_index(partitions=10)
297-
evaluate_query(
298-
index_uri=index_uri,
299-
query_kwargs={"nprobe": 10},
300-
dim_id=2042,
301-
vector_dim_offset=1000,
302-
)
290+
# Check overwritting existing data
291+
reader = TestReader(
292+
object_id_start=1000,
293+
object_id_end=2000,
294+
vector_dim_offset=1000,
295+
)
296+
index = object_index.ObjectIndex(uri=index_uri)
297+
index.update_object_reader(reader)
298+
index.update_index(partitions=10)
299+
evaluate_query(
300+
index_uri=index_uri,
301+
query_kwargs={"nprobe": 10, "opt_l": 500},
302+
dim_id=2042,
303+
vector_dim_offset=1000,
304+
)
303305

304306

305307
def test_object_index_ivf_flat_cloud(tmp_path):

0 commit comments

Comments
 (0)