Skip to content

Commit 5ee2aaa

Browse files
Add query with driver implementation (#398)
1 parent cbe5577 commit 5ee2aaa

File tree

6 files changed

+259
-151
lines changed

6 files changed

+259
-151
lines changed

apis/python/src/tiledb/vector_search/flat_index.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,32 @@ class FlatIndex(index.Index):
3535
timestamp: int or tuple(int)
3636
If int, open the index at a given timestamp.
3737
If tuple, open at the given start and end timestamps.
38+
open_for_remote_query_execution: bool
39+
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`.
40+
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
3841
"""
3942

4043
def __init__(
4144
self,
4245
uri: str,
4346
config: Optional[Mapping[str, Any]] = None,
4447
timestamp=None,
48+
open_for_remote_query_execution: bool = False,
4549
**kwargs,
4650
):
51+
self.index_open_kwargs = {
52+
"uri": uri,
53+
"config": config,
54+
"timestamp": timestamp,
55+
}
56+
self.index_open_kwargs.update(kwargs)
4757
self.index_type = INDEX_TYPE
48-
super().__init__(uri=uri, config=config, timestamp=timestamp)
58+
super().__init__(
59+
uri=uri,
60+
config=config,
61+
timestamp=timestamp,
62+
open_for_remote_query_execution=open_for_remote_query_execution,
63+
)
4964
self._index = None
5065
self.db_uri = self.group[
5166
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
@@ -69,7 +84,7 @@ def __init__(
6984
].uri
7085
else:
7186
self.ids_uri = ""
72-
if self.size > 0:
87+
if self.size > 0 and not open_for_remote_query_execution:
7388
self._db = load_as_matrix(
7489
self.db_uri,
7590
ctx=self.ctx,
@@ -121,8 +136,6 @@ def query_internal(
121136
(queries.shape[0], k), MAX_UINT64
122137
)
123138

124-
assert queries.dtype == np.float32
125-
126139
queries_m = array_to_matrix(np.transpose(queries))
127140
d, i = query_vq_heap(self._db, queries_m, self._ids, k, nthreads)
128141

apis/python/src/tiledb/vector_search/index.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from typing import Any, Mapping, Optional
66

7+
from tiledb.cloud.dag import Mode
78
from tiledb.vector_search import _tiledbvspy as vspy
89
from tiledb.vector_search.module import *
910
from tiledb.vector_search.storage_formats import storage_formats
@@ -35,11 +36,15 @@ class Index:
3536
timestamp: int or tuple(int)
3637
If int, open the index at a given timestamp.
3738
If tuple, open at the given start and end timestamps.
39+
open_for_remote_query_execution: bool
40+
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`.
41+
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
3842
"""
3943

4044
def __init__(
4145
self,
4246
uri: str,
47+
open_for_remote_query_execution: bool,
4348
config: Optional[Mapping[str, Any]] = None,
4449
timestamp=None,
4550
):
@@ -48,6 +53,7 @@ def __init__(
4853
config = dict(config)
4954

5055
self.uri = uri
56+
self.open_for_remote_query_execution = open_for_remote_query_execution
5157
self.config = config
5258
self.ctx = vspy.Ctx(config)
5359
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(config))
@@ -154,7 +160,66 @@ def __init__(
154160
self.thread_executor = futures.ThreadPoolExecutor()
155161
self.has_updates = self._check_has_updates()
156162

157-
def query(self, queries: np.ndarray, k: int, **kwargs):
163+
def _query_with_driver(
164+
self,
165+
queries: np.ndarray,
166+
k: int,
167+
driver_mode=None,
168+
driver_resources=None,
169+
driver_access_credentials_name=None,
170+
**kwargs,
171+
):
172+
from tiledb.cloud import dag
173+
174+
def query_udf(index_type, index_open_kwargs, query_kwargs):
175+
from tiledb.vector_search.flat_index import FlatIndex
176+
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
177+
from tiledb.vector_search.vamana_index import VamanaIndex
178+
179+
# Open index
180+
if index_type == "FLAT":
181+
index = FlatIndex(**index_open_kwargs)
182+
elif index_type == "IVF_FLAT":
183+
index = IVFFlatIndex(**index_open_kwargs)
184+
elif index_type == "VAMANA":
185+
index = VamanaIndex(**index_open_kwargs)
186+
187+
# Query index
188+
return index.query(**query_kwargs)
189+
190+
d = dag.DAG(
191+
name="vector-query",
192+
mode=driver_mode,
193+
max_workers=1,
194+
)
195+
query_kwargs = {
196+
"queries": queries,
197+
"k": k,
198+
}
199+
query_kwargs.update(kwargs)
200+
node = d.submit(
201+
query_udf,
202+
self.index_type,
203+
self.index_open_kwargs,
204+
query_kwargs,
205+
name="vector-query-driver",
206+
resources=driver_resources,
207+
image_name="vectorsearch",
208+
access_credentials_name=driver_access_credentials_name,
209+
)
210+
d.compute()
211+
d.wait()
212+
return node.result()
213+
214+
def query(
215+
self,
216+
queries: np.ndarray,
217+
k: int,
218+
driver_mode: Mode = None,
219+
driver_resources: Optional[str] = None,
220+
driver_access_credentials_name: Optional[str] = None,
221+
**kwargs,
222+
):
158223
"""
159224
Queries an index with a set of query vectors, retrieving the `k` most similar vectors for each query.
160225
@@ -164,12 +229,23 @@ def query(self, queries: np.ndarray, k: int, **kwargs):
164229
- Calls the algorithm specific implementation of `query_internal` to query the base data.
165230
- Merges the results applying the updated data.
166231
232+
You can control where the query is executed by setting the `driver_mode` parameter:
233+
- With `driver_mode = None`, the driver logic for the query will be executed locally.
234+
- If `driver_mode` is not `None`, we will use a TileDB cloud taskgraph to re-open the index and run the query.
235+
With both options, certain implementations, i.e. IVF Flat, may let you create further TileDB taskgraphs as defined in the implementation specific `query_internal` methods.
236+
167237
Parameters
168238
----------
169239
queries: np.ndarray
170240
2D array of query vectors. This can be used as a batch query interface by passing multiple queries in one call.
171241
k: int
172242
Number of results to return per query vector.
243+
driver_mode: Mode
244+
If not `None`, the query will be executed in a TileDB cloud taskgraph using the driver mode specified.
245+
driver_resources: Optional[str]
246+
If `driver_mode` was not `None`, the resources to use for the driver execution.
247+
driver_access_credentials_name: Optional[str]
248+
If `driver_mode` was not `None`, the access credentials name to use for the driver execution.
173249
**kwargs
174250
Extra kwargs passed here are passed to the `query_internal` implementation of the concrete index class.
175251
"""
@@ -184,6 +260,32 @@ def query(self, queries: np.ndarray, k: int, **kwargs):
184260
f"A query in queries has {query_dimensions} dimensions, but the indexed data had {self.dimensions} dimensions"
185261
)
186262

263+
if queries.dtype != np.float32:
264+
raise TypeError(
265+
f"Expected queries to have dtype np.float32, but it had dtype {queries.dtype}"
266+
)
267+
268+
if driver_mode == Mode.LOCAL:
269+
# @todo: Fix bug with driver_mode=Mode.LOCAL and remove this check.
270+
raise TypeError(
271+
"Cannot pass driver_mode=Mode.LOCAL to query() - use driver_mode=None to query locally."
272+
)
273+
274+
if driver_mode is not None:
275+
return self._query_with_driver(
276+
queries,
277+
k,
278+
driver_mode,
279+
driver_resources,
280+
driver_access_credentials_name,
281+
**kwargs,
282+
)
283+
284+
if self.open_for_remote_query_execution:
285+
raise ValueError(
286+
"Cannot query an index with driver_mode=None without loading the index data in main memory. Set open_for_remote_query_execution=False when creating the index to load the index data before query."
287+
)
288+
187289
with tiledb.scope_ctx(ctx_or_config=self.config):
188290
if not self.has_updates:
189291
if self.query_base_array:
@@ -575,7 +677,6 @@ def _query_additions(
575677
timestamp=None,
576678
config=None,
577679
):
578-
assert queries.dtype == np.float32
579680
additions_vectors, additions_external_ids, updated_ids = Index._read_additions(
580681
updates_array_uri, timestamp, config
581682
)

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2855,9 +2855,7 @@ def consolidate_and_vacuum(
28552855
if index_type == "FLAT":
28562856
return flat_index.FlatIndex(uri=index_group_uri, config=config)
28572857
elif index_type == "VAMANA":
2858-
return vamana_index.VamanaIndex(
2859-
uri=index_group_uri, config=config, debug=True
2860-
)
2858+
return vamana_index.VamanaIndex(uri=index_group_uri, config=config)
28612859
elif index_type == "IVF_FLAT":
28622860
return ivf_flat_index.IVFFlatIndex(
28632861
uri=index_group_uri, memory_budget=1000000, config=config

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class IVFFlatIndex(index.Index):
7070
If not provided, all index data are loaded in main memory.
7171
Otherwise, no index data are loaded in main memory and this memory budget is
7272
applied during queries.
73+
open_for_remote_query_execution: bool
74+
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`. We then load index data in the taskgraph based on `memory_budget`.
75+
If `False`, load index data in main memory locally according to `memory_budget`. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph..
7376
"""
7477

7578
def __init__(
@@ -78,10 +81,23 @@ def __init__(
7881
config: Optional[Mapping[str, Any]] = None,
7982
timestamp=None,
8083
memory_budget: int = -1,
84+
open_for_remote_query_execution: bool = False,
8185
**kwargs,
8286
):
87+
self.index_open_kwargs = {
88+
"uri": uri,
89+
"config": config,
90+
"timestamp": timestamp,
91+
"memory_budget": memory_budget,
92+
}
93+
self.index_open_kwargs.update(kwargs)
8394
self.index_type = INDEX_TYPE
84-
super().__init__(uri=uri, config=config, timestamp=timestamp)
95+
super().__init__(
96+
uri=uri,
97+
config=config,
98+
timestamp=timestamp,
99+
open_for_remote_query_execution=open_for_remote_query_execution,
100+
)
85101
self.db_uri = self.group[
86102
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
87103
+ self.index_version
@@ -125,28 +141,29 @@ def __init__(
125141
else:
126142
self.partitions = self.partition_history[self.history_index]
127143

128-
self._centroids = load_as_matrix(
129-
self.centroids_uri,
130-
ctx=self.ctx,
131-
size=self.partitions,
132-
config=config,
133-
timestamp=self.base_array_timestamp,
134-
)
135-
self._index = read_vector_u64(
136-
self.ctx,
137-
self.index_array_uri,
138-
0,
139-
self.partitions + 1,
140-
self.base_array_timestamp,
141-
)
144+
if not open_for_remote_query_execution:
145+
self._centroids = load_as_matrix(
146+
self.centroids_uri,
147+
ctx=self.ctx,
148+
size=self.partitions,
149+
config=config,
150+
timestamp=self.base_array_timestamp,
151+
)
152+
self._index = read_vector_u64(
153+
self.ctx,
154+
self.index_array_uri,
155+
0,
156+
self.partitions + 1,
157+
self.base_array_timestamp,
158+
)
142159

143160
if self.base_size == -1:
144161
self.size = self._index[self.partitions]
145162
else:
146163
self.size = self.base_size
147164

148165
# TODO pass in a context
149-
if self.memory_budget == -1:
166+
if not open_for_remote_query_execution and self.memory_budget == -1:
150167
self._db = load_as_matrix(
151168
self.db_uri,
152169
ctx=self.ctx,
@@ -225,8 +242,6 @@ def query_internal(
225242
if (mode != Mode.REALTIME and mode != Mode.BATCH) and resource_class:
226243
raise TypeError("Can only pass resource_class in REALTIME or BATCH mode")
227244

228-
assert queries.dtype == np.float32
229-
230245
if queries.ndim == 1:
231246
queries = np.array([queries])
232247

@@ -391,7 +406,6 @@ def dist_qv_udf(
391406
results.append(tmp_results)
392407
return results
393408

394-
assert queries.dtype == np.float32
395409
if num_partitions == -1:
396410
num_partitions = 5
397411
if num_workers == -1:

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,33 @@ class VamanaIndex(index.Index):
3838
URI of the index.
3939
config: Optional[Mapping[str, Any]]
4040
TileDB config dictionary.
41+
open_for_remote_query_execution: bool
42+
If `True`, do not load any index data in main memory locally, and instead load index data in the TileDB Cloud taskgraph created when a non-`None` `driver_mode` is passed to `query()`.
43+
If `False`, load index data in main memory locally. Note that you can still use a taskgraph for query execution, you'll just end up loading the data both on your local machine and in the cloud taskgraph.
4144
"""
4245

4346
def __init__(
4447
self,
4548
uri: str,
4649
config: Optional[Mapping[str, Any]] = None,
4750
timestamp=None,
51+
open_for_remote_query_execution: bool = False,
4852
**kwargs,
4953
):
50-
super().__init__(uri=uri, config=config, timestamp=timestamp)
54+
self.index_open_kwargs = {
55+
"uri": uri,
56+
"config": config,
57+
"timestamp": timestamp,
58+
}
59+
self.index_open_kwargs.update(kwargs)
60+
super().__init__(
61+
uri=uri,
62+
config=config,
63+
timestamp=timestamp,
64+
open_for_remote_query_execution=open_for_remote_query_execution,
65+
)
5166
self.index_type = INDEX_TYPE
67+
# TODO(SC-48710): Add support for `open_for_remote_query_execution`. We don't leave `self.index`` as `None` because we need to be able to call index.dimensions().
5268
self.index = vspy.IndexVamana(self.ctx, uri, to_temporal_policy(timestamp))
5369
self.db_uri = self.group[
5470
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
@@ -96,13 +112,11 @@ def query_internal(
96112
opt_l: int
97113
How deep to search. Should be >= k, and if it's not, we will set it to k.
98114
"""
99-
warnings.warn("The Vamana index is not yet supported, please use with caution.")
100115
if self.size == 0:
101116
return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
102117
(queries.shape[0], k), MAX_UINT64
103118
)
104119

105-
assert queries.dtype == np.float32
106120
if opt_l < k:
107121
warnings.warn(f"opt_l ({opt_l}) should be >= k ({k}), setting to k")
108122
opt_l = k
@@ -144,7 +158,6 @@ def create(
144158
The TileDB vector search storage version to use.
145159
If not provided, use the latest stable storage version.
146160
"""
147-
warnings.warn("The Vamana index is not yet supported, please use with caution.")
148161
validate_storage_version(storage_version)
149162
ctx = vspy.Ctx(config)
150163
index = vspy.IndexVamana(
@@ -160,4 +173,4 @@ def create(
160173
index.train(empty_vector)
161174
index.add(empty_vector)
162175
index.write_index(ctx, uri, vspy.TemporalPolicy(0), storage_version)
163-
return VamanaIndex(uri=uri, config=config, memory_budget=1000000)
176+
return VamanaIndex(uri=uri, config=config)

0 commit comments

Comments
 (0)