Skip to content

Commit b70d4b4

Browse files
Add open method for Index class (#503)
This adds `open` method for `Index` class which can be used to open an index without knowing its `index_type`
1 parent 5dc0264 commit b70d4b4

File tree

9 files changed

+115
-34
lines changed

9 files changed

+115
-34
lines changed

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from . import utils
55
from .flat_index import FlatIndex
66
from .index import Index
7+
from .index import open
78
from .ingestion import ingest
89
from .ivf_flat_index import IVFFlatIndex
910
from .ivf_pq_index import IVFPQIndex
@@ -34,6 +35,7 @@
3435
"VamanaIndex",
3536
"IVFPQIndex",
3637
"Mode",
38+
"open",
3739
"load_as_array",
3840
"load_as_matrix",
3941
"ingest",

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
config: Optional[Mapping[str, Any]] = None,
4747
timestamp=None,
4848
open_for_remote_query_execution: bool = False,
49+
group: tiledb.Group = None,
4950
**kwargs,
5051
):
5152
self.index_open_kwargs = {
@@ -60,6 +61,7 @@ def __init__(
6061
config=config,
6162
timestamp=timestamp,
6263
open_for_remote_query_execution=open_for_remote_query_execution,
64+
group=group,
6365
)
6466
self._index = None
6567
self.db_uri = self.group[

apis/python/src/tiledb/vector_search/index.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
import os
44
import time
5+
from abc import ABCMeta
6+
from abc import abstractmethod
57
from typing import Any, Mapping, Optional
68

79
from tiledb.cloud.dag import Mode
@@ -16,7 +18,7 @@
1618
DATASET_TYPE = "vector_search"
1719

1820

19-
class Index:
21+
class Index(metaclass=ABCMeta):
2022
"""
2123
Abstract Vector Index class.
2224
@@ -42,22 +44,26 @@ class Index:
4244
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
4345
"""
4446

47+
@abstractmethod
4548
def __init__(
4649
self,
4750
uri: str,
48-
open_for_remote_query_execution: bool,
51+
open_for_remote_query_execution: bool = False,
4952
config: Optional[Mapping[str, Any]] = None,
5053
timestamp=None,
54+
group: tiledb.Group = None,
5155
):
5256
# If the user passes a tiledb python Config object convert to a dictionary
5357
if isinstance(config, tiledb.Config):
5458
config = dict(config)
55-
5659
self.uri = uri
5760
self.open_for_remote_query_execution = open_for_remote_query_execution
5861
self.config = config
5962
self.ctx = vspy.Ctx(config)
60-
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(config))
63+
if group is not None:
64+
self.group = group
65+
else:
66+
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(config))
6167
self.storage_version = self.group.meta.get("storage_version", "0.1")
6268
try:
6369
self.distance_metric = vspy.DistanceMetric(
@@ -688,6 +694,7 @@ def clear_history(
688694
raise ValueError(f"Unsupported index_type: {index_type}")
689695
group.close()
690696

697+
@abstractmethod
691698
def get_dimensions(self):
692699
"""
693700
Abstract method implemented by all Vector Index implementations.
@@ -696,6 +703,7 @@ def get_dimensions(self):
696703
"""
697704
raise NotImplementedError
698705

706+
@abstractmethod
699707
def query_internal(self, queries: np.ndarray, k: int, **kwargs):
700708
"""
701709
Abstract method implemented by all Vector Index implementations.
@@ -868,3 +876,80 @@ def create_metadata(
868876
group.meta["has_updates"] = False
869877
group.meta["distance_metric"] = int(distance_metric)
870878
group.close()
879+
880+
881+
"""
882+
Factory method that opens a vector index.
883+
884+
Retrieves the `index_type` from the index group metadata and instantiates the appropriate `Index` subclass.
885+
886+
Parameters
887+
----------
888+
uri: str
889+
URI of the index.
890+
config: Optional[Mapping[str, Any]]
891+
TileDB config dictionary.
892+
timestamp: int or tuple(int)
893+
If int, open the index at a given timestamp.
894+
If tuple, open at the given start and end timestamps.
895+
open_for_remote_query_execution: bool
896+
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`.
897+
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
898+
kwargs:
899+
Additional arguments to be passed to the `Index` subclass constructor.
900+
"""
901+
902+
903+
def open(
904+
uri: str,
905+
open_for_remote_query_execution: bool = False,
906+
config: Optional[Mapping[str, Any]] = None,
907+
timestamp=None,
908+
**kwargs,
909+
) -> Index:
910+
from tiledb.vector_search.flat_index import FlatIndex
911+
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
912+
from tiledb.vector_search.ivf_pq_index import IVFPQIndex
913+
from tiledb.vector_search.vamana_index import VamanaIndex
914+
915+
group = tiledb.Group(uri, "r")
916+
index_type = group.meta["index_type"]
917+
if index_type == "FLAT":
918+
return FlatIndex(
919+
uri=uri,
920+
open_for_remote_query_execution=open_for_remote_query_execution,
921+
config=config,
922+
timestamp=timestamp,
923+
group=group,
924+
**kwargs,
925+
)
926+
elif index_type == "IVF_FLAT":
927+
return IVFFlatIndex(
928+
uri=uri,
929+
open_for_remote_query_execution=open_for_remote_query_execution,
930+
config=config,
931+
timestamp=timestamp,
932+
group=group,
933+
**kwargs,
934+
)
935+
elif index_type == "VAMANA":
936+
return VamanaIndex(
937+
uri=uri,
938+
open_for_remote_query_execution=open_for_remote_query_execution,
939+
config=config,
940+
timestamp=timestamp,
941+
group=group,
942+
**kwargs,
943+
)
944+
elif index_type == "IVF_PQ":
945+
return IVFPQIndex(
946+
uri=uri,
947+
open_for_remote_query_execution=open_for_remote_query_execution,
948+
config=config,
949+
timestamp=timestamp,
950+
group=group,
951+
**kwargs,
952+
)
953+
else:
954+
group.close()
955+
raise ValueError(f"Unsupported index type {index_type}")

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def ingest(
218218
from tiledb.cloud.utilities import get_logger
219219
from tiledb.cloud.utilities import set_aws_context
220220
from tiledb.vector_search import flat_index
221+
from tiledb.vector_search import index
221222
from tiledb.vector_search import ivf_flat_index
222223
from tiledb.vector_search import ivf_pq_index
223224
from tiledb.vector_search import vamana_index
@@ -3061,16 +3062,4 @@ def consolidate_and_vacuum(
30613062
group.close()
30623063

30633064
consolidate_and_vacuum(index_group_uri=index_group_uri, config=config)
3064-
3065-
if index_type == "FLAT":
3066-
return flat_index.FlatIndex(uri=index_group_uri, config=config)
3067-
elif index_type == "VAMANA":
3068-
return vamana_index.VamanaIndex(uri=index_group_uri, config=config)
3069-
elif index_type == "IVF_FLAT":
3070-
return ivf_flat_index.IVFFlatIndex(
3071-
uri=index_group_uri, memory_budget=1000000, config=config
3072-
)
3073-
elif index_type == "IVF_PQ":
3074-
return ivf_pq_index.IVFPQIndex(uri=index_group_uri, config=config)
3075-
else:
3076-
raise ValueError(f"Not supported index_type {index_type}")
3065+
return index.open(uri=index_group_uri, memory_budget=1000000, config=config)

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
timestamp=None,
8585
memory_budget: int = -1,
8686
open_for_remote_query_execution: bool = False,
87+
group: tiledb.Group = None,
8788
**kwargs,
8889
):
8990
self.index_open_kwargs = {
@@ -99,6 +100,7 @@ def __init__(
99100
config=config,
100101
timestamp=timestamp,
101102
open_for_remote_query_execution=open_for_remote_query_execution,
103+
group=group,
102104
)
103105
self.db_uri = self.group[
104106
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]

apis/python/src/tiledb/vector_search/ivf_pq_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
timestamp=None,
5050
memory_budget: int = -1,
5151
open_for_remote_query_execution: bool = False,
52+
group: tiledb.Group = None,
5253
**kwargs,
5354
):
5455
self.index_open_kwargs = {
@@ -64,6 +65,7 @@ def __init__(
6465
config=config,
6566
timestamp=timestamp,
6667
open_for_remote_query_execution=open_for_remote_query_execution,
68+
group=group,
6769
)
6870
# TODO(SC-48710): Add support for `open_for_remote_query_execution`. We don't leave `self.index`` as `None` because we need to be able to call index.dimensions().
6971
self.index = vspy.IndexIVFPQ(self.ctx, uri, to_temporal_policy(timestamp))

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
config: Optional[Mapping[str, Any]] = None,
5454
timestamp=None,
5555
open_for_remote_query_execution: bool = False,
56+
group: tiledb.Group = None,
5657
**kwargs,
5758
):
5859
self.index_open_kwargs = {
@@ -67,6 +68,7 @@ def __init__(
6768
config=config,
6869
timestamp=timestamp,
6970
open_for_remote_query_execution=open_for_remote_query_execution,
71+
group=group,
7072
)
7173
# TODO(SC-48710): Add support for `open_for_remote_query_execution`. We don't leave `self.index`` as `None` because we need to be able to call index.dimensions().
7274
self.index = vspy.IndexVamana(self.ctx, uri, to_temporal_policy(timestamp))

apis/python/test/test_index.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@
1212
from tiledb.vector_search import flat_index
1313
from tiledb.vector_search import ivf_flat_index
1414
from tiledb.vector_search import ivf_pq_index
15+
from tiledb.vector_search import open
1516
from tiledb.vector_search import vamana_index
16-
from tiledb.vector_search.flat_index import FlatIndex
1717
from tiledb.vector_search.index import DATASET_TYPE
1818
from tiledb.vector_search.index import create_metadata
1919
from tiledb.vector_search.ingestion import ingest
20-
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
21-
from tiledb.vector_search.ivf_pq_index import IVFPQIndex
2220
from tiledb.vector_search.utils import MAX_FLOAT32
2321
from tiledb.vector_search.utils import MAX_UINT64
2422
from tiledb.vector_search.utils import is_type_erased_index
@@ -423,9 +421,8 @@ def test_delete_index(tmp_path):
423421
vfs = tiledb.VFS()
424422

425423
indexes = ["FLAT", "IVF_FLAT", "VAMANA", "IVF_PQ"]
426-
index_classes = [FlatIndex, IVFFlatIndex, VamanaIndex, IVFPQIndex]
427424
data = np.array([[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3]], dtype=np.float32)
428-
for index_type, index_class in zip(indexes, index_classes):
425+
for index_type in indexes:
429426
index_uri = os.path.join(tmp_path, f"array_{index_type}")
430427
ingest(
431428
index_type=index_type,
@@ -436,7 +433,7 @@ def test_delete_index(tmp_path):
436433
Index.delete_index(uri=index_uri, config={})
437434
assert vfs.dir_size(index_uri) == 0
438435
with pytest.raises(tiledb.TileDBError) as error:
439-
index_class(uri=index_uri)
436+
open(uri=index_uri)
440437
assert "does not exist" in str(error.value)
441438

442439

apis/python/test/test_ingestion.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
from tiledb.cloud.dag import Mode
1010
from tiledb.vector_search import _tiledbvspy as vspy
11+
from tiledb.vector_search import open
1112
from tiledb.vector_search.index import Index
1213
from tiledb.vector_search.ingestion import TrainingSamplingPolicy
1314
from tiledb.vector_search.ingestion import ingest
14-
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
1515
from tiledb.vector_search.module import array_to_matrix
1616
from tiledb.vector_search.module import kmeans_fit
1717
from tiledb.vector_search.module import kmeans_predict
@@ -72,7 +72,7 @@ def test_vamana_ingestion_u8(tmp_path):
7272
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
7373

7474
index_uri = move_local_index_to_new_location(index_uri)
75-
index_ram = VamanaIndex(uri=index_uri)
75+
index_ram = open(uri=index_uri)
7676
_, result = index_ram.query(queries, k=k)
7777
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
7878

@@ -100,7 +100,7 @@ def test_flat_ingestion_u8(tmp_path):
100100
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
101101

102102
index_uri = move_local_index_to_new_location(index_uri)
103-
index_ram = FlatIndex(uri=index_uri)
103+
index_ram = open(uri=index_uri)
104104
_, result = index_ram.query(queries, k=k)
105105
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
106106

@@ -124,7 +124,7 @@ def test_flat_ingestion_f32(tmp_path):
124124
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
125125

126126
index_uri = move_local_index_to_new_location(index_uri)
127-
index_ram = FlatIndex(uri=index_uri)
127+
index_ram = open(uri=index_uri)
128128
_, result = index_ram.query(queries, k=k)
129129
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
130130

@@ -157,7 +157,7 @@ def test_flat_ingestion_external_id_u8(tmp_path):
157157
)
158158

159159
index_uri = move_local_index_to_new_location(index_uri)
160-
index_ram = FlatIndex(uri=index_uri)
160+
index_ram = open(uri=index_uri)
161161
_, result = index_ram.query(queries, k=k)
162162
assert (
163163
accuracy(result, gt_i, external_ids_offset=external_ids_offset)
@@ -190,7 +190,7 @@ def test_ivf_flat_ingestion_u8(tmp_path):
190190
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
191191

192192
index_uri = move_local_index_to_new_location(index_uri)
193-
index_ram = IVFFlatIndex(uri=index_uri, memory_budget=int(size / 10))
193+
index_ram = open(uri=index_uri, memory_budget=int(size / 10))
194194
_, result = index_ram.query(queries, k=k, nprobe=nprobe)
195195
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
196196

@@ -237,7 +237,7 @@ def test_ivf_pq_ingestion_u8(tmp_path):
237237
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
238238

239239
index_uri = move_local_index_to_new_location(index_uri)
240-
index_ram = IVFPQIndex(uri=index_uri, memory_budget=int(size / 10))
240+
index_ram = open(uri=index_uri, memory_budget=int(size / 10))
241241
_, result = index_ram.query(queries, k=k, nprobe=nprobe)
242242
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
243243

@@ -1698,7 +1698,7 @@ def test_ivf_flat_ingestion_with_training_source_uri_f32(tmp_path):
16981698
)
16991699

17001700
index_uri = move_local_index_to_new_location(index_uri)
1701-
index = IVFFlatIndex(uri=index_uri)
1701+
index = open(uri=index_uri)
17021702
query_and_check_equals(
17031703
index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1]]
17041704
)
@@ -1796,7 +1796,7 @@ def test_ivf_flat_ingestion_with_training_source_uri_tdb(tmp_path):
17961796
index_uri = move_local_index_to_new_location(index_uri)
17971797

17981798
# Load the index again and query.
1799-
index = IVFFlatIndex(uri=index_uri)
1799+
index = open(uri=index_uri)
18001800

18011801
query_and_check_equals(
18021802
index=index,
@@ -1823,7 +1823,7 @@ def test_ivf_flat_ingestion_with_training_source_uri_tdb(tmp_path):
18231823
# Clear the index history, load, update, and query.
18241824
Index.clear_history(uri=index_uri, timestamp=index.latest_ingestion_timestamp - 1)
18251825

1826-
index = IVFFlatIndex(uri=index_uri)
1826+
index = open(uri=index_uri)
18271827

18281828
update_vectors = np.empty([2], dtype=object)
18291829
update_vectors[0] = np.array([11.0, 11.1, 11.2, 11.3], dtype=np.dtype(np.float32))
@@ -1918,7 +1918,7 @@ def test_ivf_flat_ingestion_with_training_source_uri_numpy(tmp_path):
19181918
# Test we can load the index again and query, update, and consolidate.
19191919
################################################################################################
19201920
index_uri = move_local_index_to_new_location(index_uri)
1921-
index = IVFFlatIndex(uri=index_uri)
1921+
index = open(uri=index_uri)
19221922

19231923
queries = np.array([data[1]], dtype=np.float32)
19241924
query_and_check_equals(

0 commit comments

Comments
 (0)