Skip to content

Commit 2138fa7

Browse files
authored
Add Python IVF PQ Index (#404)
1 parent e322d71 commit 2138fa7

20 files changed

+679
-123
lines changed

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .index import Index
77
from .ingestion import ingest
88
from .ivf_flat_index import IVFFlatIndex
9+
from .ivf_pq_index import IVFPQIndex
910
from .module import array_to_matrix
1011
from .module import ivf_index
1112
from .module import ivf_index_tdb
@@ -31,6 +32,7 @@
3132
"FlatIndex",
3233
"IVFFlatIndex",
3334
"VamanaIndex",
35+
"IVFPQIndex",
3436
"Mode",
3537
"load_as_array",
3638
"load_as_matrix",

apis/python/src/tiledb/vector_search/index.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tiledb.vector_search.utils import MAX_FLOAT32
1212
from tiledb.vector_search.utils import MAX_UINT64
1313
from tiledb.vector_search.utils import add_to_group
14+
from tiledb.vector_search.utils import is_type_erased_index
1415

1516
DATASET_TYPE = "vector_search"
1617

@@ -462,6 +463,10 @@ def consolidate_updates(self, retrain_index: bool = False, **kwargs):
462463
"""
463464
from tiledb.vector_search.ingestion import ingest
464465

466+
if self.index_type == "IVF_PQ":
467+
# TODO(SC-48888): Fix consolidation for IVF_PQ.
468+
raise ValueError("IVF_PQ indexes do not support consolidation yet.")
469+
465470
fragments_info = tiledb.array_fragments(
466471
self.updates_array_uri, ctx=tiledb.Ctx(self.config)
467472
)
@@ -566,14 +571,19 @@ def clear_history(
566571
f"Time traveling is not supported for index storage_version={storage_version}"
567572
)
568573

569-
if index_type == "VAMANA":
574+
if is_type_erased_index(index_type):
570575
if storage_formats[storage_version]["UPDATES_ARRAY_NAME"] in group:
571576
updates_array_uri = group[
572577
storage_formats[storage_version]["UPDATES_ARRAY_NAME"]
573578
].uri
574579
tiledb.Array.delete_fragments(updates_array_uri, 0, timestamp)
575580
ctx = vspy.Ctx(config)
576-
vspy.IndexVamana.clear_history(ctx, uri, timestamp)
581+
if index_type == "VAMANA":
582+
vspy.IndexVamana.clear_history(ctx, uri, timestamp)
583+
elif index_type == "IVF_PQ":
584+
vspy.IndexIVFPQ.clear_history(ctx, uri, timestamp)
585+
else:
586+
raise ValueError(f"Unsupported index_type: {index_type}")
577587
return
578588

579589
ingestion_timestamps = [

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def ingest(
5151
namespace: Optional[str] = None,
5252
size: int = -1,
5353
partitions: int = -1,
54+
num_subspaces: int = -1,
5455
training_sampling_policy: TrainingSamplingPolicy = TrainingSamplingPolicy.FIRST_N,
5556
copy_centroids_uri: str = None,
5657
training_sample_size: int = -1,
@@ -87,7 +88,7 @@ def ingest(
8788
Parameters
8889
----------
8990
index_type: str
90-
Type of vector index (FLAT, IVF_FLAT, VAMANA).
91+
Type of vector index (FLAT, IVF_FLAT, IVF_PQ, VAMANA).
9192
index_uri: str
9293
Vector index URI (stored as TileDB group).
9394
input_vectors: np.ndarray
@@ -114,7 +115,11 @@ def ingest(
114115
Number of input vectors, if not provided use the full size of the input dataset.
115116
If provided, we filter the first vectors from the input source.
116117
partitions: int
117-
Number of partitions to load the data with, if not provided, is auto-configured based on the dataset size.
118+
For IVF indexes, the number of partitions to load the data with, if not provided, is auto-configured based on the dataset size.
119+
num_subspaces: int
120+
For PQ encoded indexes, the number of subspaces to use in the PQ encoding. We will divide the dimensions into
121+
num_subspaces parts, and PQ encode each part separately. This means dimensions must
122+
be divisible by num_subspaces.
118123
copy_centroids_uri: str
119124
TileDB array URI to copy centroids from, if not provided, centroids are build running `k-means`.
120125
training_sample_size: int
@@ -199,6 +204,7 @@ def ingest(
199204
from tiledb.cloud.utilities import set_aws_context
200205
from tiledb.vector_search import flat_index
201206
from tiledb.vector_search import ivf_flat_index
207+
from tiledb.vector_search import ivf_pq_index
202208
from tiledb.vector_search import vamana_index
203209
from tiledb.vector_search.storage_formats import storage_formats
204210

@@ -1511,7 +1517,8 @@ def ingest_flat(
15111517
parts_array.close()
15121518
ids_array.close()
15131519

1514-
def ingest_vamana(
1520+
def ingest_type_erased(
1521+
index_type: str,
15151522
index_group_uri: str,
15161523
source_uri: str,
15171524
source_type: str,
@@ -1636,7 +1643,12 @@ def ingest_vamana(
16361643
from tiledb.vector_search import _tiledbvspy as vspy
16371644

16381645
ctx = vspy.Ctx(config)
1639-
index = vspy.IndexVamana(ctx, index_group_uri)
1646+
if index_type == "VAMANA":
1647+
index = vspy.IndexVamana(ctx, index_group_uri)
1648+
elif index_type == "IVF_PQ":
1649+
index = vspy.IndexIVFPQ(ctx, index_group_uri)
1650+
else:
1651+
raise ValueError(f"Unsupported index type: {index_type}")
16401652
data = vspy.FeatureVectorArray(
16411653
ctx, parts_array_uri, ids_array_uri, 0, to_temporal_policy(index_timestamp)
16421654
)
@@ -2191,9 +2203,10 @@ def create_ingestion_dag(
21912203
**kwargs,
21922204
)
21932205
return d
2194-
elif index_type == "VAMANA":
2206+
elif is_type_erased_index(index_type):
21952207
ingest_node = submit(
2196-
ingest_vamana,
2208+
ingest_type_erased,
2209+
index_type=index_type,
21972210
index_group_uri=index_group_uri,
21982211
source_uri=source_uri,
21992212
source_type=source_type,
@@ -2572,8 +2585,8 @@ def consolidate_and_vacuum(
25722585

25732586
logger.debug("Ingesting Vectors into %r", index_group_uri)
25742587
arrays_created = False
2575-
if index_type == "VAMANA":
2576-
# If we're using a type-erased index (i.e. Vamana), we create the group in C++.
2588+
if is_type_erased_index(index_type):
2589+
# If we're using a type-erased index, we create the group in C++.
25772590
try:
25782591
# Try opening the group to see if it exists.
25792592
group = tiledb.Group(index_group_uri, "r")
@@ -2583,13 +2596,26 @@ def consolidate_and_vacuum(
25832596
# If it does not then we can create it in C++.
25842597
message = str(err)
25852598
if "not exist" in message:
2586-
vamana_index.create(
2587-
uri=index_group_uri,
2588-
dimensions=dimensions,
2589-
vector_type=vector_type,
2590-
config=config,
2591-
storage_version=storage_version,
2592-
)
2599+
if index_type == "VAMANA":
2600+
vamana_index.create(
2601+
uri=index_group_uri,
2602+
dimensions=dimensions,
2603+
vector_type=vector_type,
2604+
config=config,
2605+
storage_version=storage_version,
2606+
)
2607+
elif index_type == "IVF_PQ":
2608+
ivf_pq_index.create(
2609+
uri=index_group_uri,
2610+
dimensions=dimensions,
2611+
vector_type=vector_type,
2612+
num_subspaces=num_subspaces,
2613+
partitions=partitions,
2614+
config=config,
2615+
storage_version=storage_version,
2616+
)
2617+
else:
2618+
raise ValueError(f"Unsupported index type {index_type}")
25932619
else:
25942620
raise err
25952621
else:
@@ -2860,5 +2886,7 @@ def consolidate_and_vacuum(
28602886
return ivf_flat_index.IVFFlatIndex(
28612887
uri=index_group_uri, memory_budget=1000000, config=config
28622888
)
2889+
elif index_type == "IVF_PQ":
2890+
return ivf_pq_index.IVFPQIndex(uri=index_group_uri, config=config)
28632891
else:
28642892
raise ValueError(f"Not supported index_type {index_type}")
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""
2+
IVFPQ Index implementation.
3+
"""
4+
import warnings
5+
from typing import Any, Mapping
6+
7+
import numpy as np
8+
9+
from tiledb.vector_search import _tiledbvspy as vspy
10+
from tiledb.vector_search import index
11+
from tiledb.vector_search.module import *
12+
from tiledb.vector_search.storage_formats import STORAGE_VERSION
13+
from tiledb.vector_search.storage_formats import storage_formats
14+
from tiledb.vector_search.storage_formats import validate_storage_version
15+
from tiledb.vector_search.utils import MAX_FLOAT32
16+
from tiledb.vector_search.utils import MAX_UINT64
17+
from tiledb.vector_search.utils import to_temporal_policy
18+
19+
INDEX_TYPE = "IVF_PQ"
20+
21+
22+
class IVFPQIndex(index.Index):
23+
"""
24+
Opens a `IVFPQIndex`.
25+
26+
Parameters
27+
----------
28+
uri: str
29+
URI of the index.
30+
config: Optional[Mapping[str, Any]]
31+
TileDB config dictionary.
32+
timestamp: int or tuple(int)
33+
If int, open the index at a given timestamp.
34+
If tuple, open at the given start and end timestamps.
35+
open_for_remote_query_execution: bool
36+
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`.
37+
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
38+
"""
39+
40+
def __init__(
41+
self,
42+
uri: str,
43+
config: Optional[Mapping[str, Any]] = None,
44+
timestamp=None,
45+
open_for_remote_query_execution: bool = False,
46+
**kwargs,
47+
):
48+
self.index_open_kwargs = {
49+
"uri": uri,
50+
"config": config,
51+
"timestamp": timestamp,
52+
}
53+
self.index_open_kwargs.update(kwargs)
54+
self.index_type = INDEX_TYPE
55+
super().__init__(
56+
uri=uri,
57+
config=config,
58+
timestamp=timestamp,
59+
open_for_remote_query_execution=open_for_remote_query_execution,
60+
)
61+
self.index = vspy.IndexIVFPQ(self.ctx, uri, to_temporal_policy(timestamp))
62+
# TODO(paris): This is incorrect - should be fixed when we fix consolidation.
63+
self.db_uri = self.group[
64+
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
65+
].uri
66+
self.ids_uri = self.group[
67+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"]
68+
].uri
69+
70+
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
71+
self.dimensions = self.index.dimensions()
72+
73+
self.dtype = np.dtype(self.group.meta.get("dtype", None))
74+
if self.dtype is None:
75+
self.dtype = np.dtype(schema.attr("values").dtype)
76+
else:
77+
self.dtype = np.dtype(self.dtype)
78+
79+
if self.base_size == -1:
80+
self.size = schema.domain.dim(1).domain[1] + 1
81+
else:
82+
self.size = self.base_size
83+
84+
def get_dimensions(self):
85+
"""
86+
Returns the dimension of the vectors in the index.
87+
"""
88+
return self.dimensions
89+
90+
def query_internal(
91+
self,
92+
queries: np.ndarray,
93+
k: int = 10,
94+
nprobe: Optional[int] = 100,
95+
**kwargs,
96+
):
97+
"""
98+
Queries a `IVFPQIndex`.
99+
100+
Parameters
101+
----------
102+
queries: np.ndarray
103+
2D array of query vectors. This can be used as a batch query interface by passing multiple queries in one call.
104+
k: int
105+
Number of results to return per query vector.
106+
nprobe: int
107+
Number of partitions to check per query.
108+
Use this parameter to trade-off accuracy for latency and cost.
109+
"""
110+
warnings.warn("The IVF PQ index is not yet supported, please use with caution.")
111+
if self.size == 0:
112+
return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
113+
(queries.shape[0], k), MAX_UINT64
114+
)
115+
116+
if queries.ndim == 1:
117+
queries = np.array([queries])
118+
queries = np.transpose(queries)
119+
if not queries.flags.f_contiguous:
120+
queries = queries.copy(order="F")
121+
queries_feature_vector_array = vspy.FeatureVectorArray(queries)
122+
123+
distances, ids = self.index.query(
124+
vspy.QueryType.InfiniteRAM, queries_feature_vector_array, k, nprobe
125+
)
126+
127+
return np.array(distances, copy=False), np.array(ids, copy=False)
128+
129+
130+
def create(
131+
uri: str,
132+
dimensions: int,
133+
vector_type: np.dtype,
134+
num_subspaces: int,
135+
config: Optional[Mapping[str, Any]] = None,
136+
storage_version: str = STORAGE_VERSION,
137+
partitions: Optional[int] = None,
138+
**kwargs,
139+
) -> IVFPQIndex:
140+
"""
141+
Creates an empty IVFPQIndex.
142+
Parameters
143+
----------
144+
uri: str
145+
URI of the index.
146+
dimensions: int
147+
Number of dimensions for the vectors to be stored in the index.
148+
vector_type: np.dtype
149+
Datatype of vectors.
150+
Supported values (uint8, int8, float32).
151+
num_subspaces: int
152+
Number of subspaces to use in the PQ encoding. We will divide the dimensions into
153+
num_subspaces parts, and PQ encode each part separately. This means dimensions must
154+
be divisible by num_subspaces.
155+
config: Optional[Mapping[str, Any]]
156+
TileDB config dictionary.
157+
storage_version: str
158+
The TileDB vector search storage version to use.
159+
If not provided, use the latest stable storage version.
160+
partitions: int
161+
Number of partitions to load the data with, if not provided, is auto-configured
162+
based on the dataset size.
163+
"""
164+
warnings.warn("The IVF PQ index is not yet supported, please use with caution.")
165+
validate_storage_version(storage_version)
166+
ctx = vspy.Ctx(config)
167+
if num_subspaces <= 0:
168+
raise ValueError(
169+
f"Number of num_subspaces ({num_subspaces}) must be greater than 0."
170+
)
171+
if dimensions % num_subspaces != 0:
172+
raise ValueError(
173+
f"Number of dimensions ({dimensions}) must be divisible by num_subspaces ({num_subspaces})."
174+
)
175+
index = vspy.IndexIVFPQ(
176+
feature_type=np.dtype(vector_type).name,
177+
id_type=np.dtype(np.uint64).name,
178+
partitioning_index_type=np.dtype(np.uint64).name,
179+
dimensions=dimensions,
180+
n_list=partitions if (partitions is not None and partitions is not -1) else 0,
181+
num_subspaces=num_subspaces,
182+
)
183+
# TODO(paris): Run all of this with a single C++ call.
184+
empty_vector = vspy.FeatureVectorArray(
185+
dimensions, 0, np.dtype(vector_type).name, np.dtype(np.uint64).name
186+
)
187+
index.train(empty_vector)
188+
index.add(empty_vector)
189+
index.write_index(ctx, uri, vspy.TemporalPolicy(0), storage_version)
190+
return IVFPQIndex(uri=uri, config=config)

0 commit comments

Comments
 (0)