Skip to content

Commit 3e3e7f8

Browse files
authored
Enable setting resource_class or resources when calling query() on IVFFlatIndex (#165)
1 parent 15f83d8 commit 3e3e7f8

File tree

2 files changed

+74
-5
lines changed

2 files changed

+74
-5
lines changed

apis/python/src/tiledb/vector_search/ivf_flat_index.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
def submit_local(d, func, *args, **kwargs):
1919
# Drop kwarg
2020
kwargs.pop("image_name", None)
21+
kwargs.pop("resource_class", None)
2122
kwargs.pop("resources", None)
2223
return d.submit_local(func, *args, **kwargs)
2324

@@ -133,6 +134,8 @@ def query_internal(
133134
nthreads: int = -1,
134135
use_nuv_implementation: bool = False,
135136
mode: Mode = None,
137+
resource_class: Optional[str] = None,
138+
resources: Optional[Mapping[str, Any]] = None,
136139
num_partitions: int = -1,
137140
num_workers: int = -1,
138141
):
@@ -153,7 +156,18 @@ def query_internal(
153156
wether to use the nuv query implementation. Default: False
154157
mode: Mode
155158
If provided the query will be executed using TileDB cloud taskgraphs.
156-
For distributed execution you can use REALTIME or BATCH mode
159+
For distributed execution you can use REALTIME or BATCH mode.
160+
For local execution you can use LOCAL mode.
161+
resource_class:
162+
The name of the resource class to use ("standard" or "large"). Resource classes define maximum
163+
limits for cpu and memory usage. Can only be used in REALTIME or BATCH mode.
164+
Cannot be used alongside resources.
165+
In REALTIME or BATCH mode if neither resource_class nor resources are provided,
166+
we default to the "large" resource class.
167+
resources:
168+
A specification for the amount of resources to use when executing using TileDB cloud
169+
taskgraphs, of the form: {"cpu": "6", "memory": "12Gi", "gpu": 1}. Can only be used
170+
in BATCH mode. Cannot be used alongside resource_class.
157171
num_partitions: int
158172
Only relevant for taskgraph based execution.
159173
If provided, we split the query execution in that many partitions.
@@ -167,6 +181,11 @@ def query_internal(
167181
(queries.shape[0], k), index.MAX_UINT64
168182
)
169183

184+
if mode != Mode.BATCH and resources:
185+
raise TypeError("Can only pass resources in BATCH mode")
186+
if (mode != Mode.REALTIME and mode != Mode.BATCH) and resource_class:
187+
raise TypeError("Can only pass resource_class in REALTIME or BATCH mode")
188+
170189
assert queries.dtype == np.float32
171190

172191
if queries.ndim == 1:
@@ -217,6 +236,8 @@ def query_internal(
217236
nthreads=nthreads,
218237
nprobe=nprobe,
219238
mode=mode,
239+
resource_class=resource_class,
240+
resources=resources,
220241
num_partitions=num_partitions,
221242
num_workers=num_workers,
222243
config=self.config,
@@ -229,6 +250,8 @@ def taskgraph_query(
229250
nprobe: int = 10,
230251
nthreads: int = -1,
231252
mode: Mode = None,
253+
resource_class: Optional[str] = None,
254+
resources: Optional[Mapping[str, Any]] = None,
232255
num_partitions: int = -1,
233256
num_workers: int = -1,
234257
config: Optional[Mapping[str, Any]] = None,
@@ -248,7 +271,18 @@ def taskgraph_query(
248271
Number of threads to use for query
249272
mode: Mode
250273
If provided the query will be executed using TileDB cloud taskgraphs.
251-
For distributed execution you can use REALTIME or BATCH mode
274+
For distributed execution you can use REALTIME or BATCH mode.
275+
For local execution you can use LOCAL mode.
276+
resource_class:
277+
The name of the resource class to use ("standard" or "large"). Resource classes define maximum
278+
limits for cpu and memory usage. Can only be used in REALTIME or BATCH mode.
279+
Cannot be used alongside resources.
280+
In REALTIME or BATCH mode if neither resource_class nor resources are provided,
281+
we default to the "large" resource class.
282+
resources:
283+
A specification for the amount of resources to use when executing using TileDB cloud
284+
taskgraphs, of the form: {"cpu": "6", "memory": "12Gi", "gpu": 1}. Can only be used
285+
in BATCH mode. Cannot be used alongside resource_class.
252286
num_partitions: int
253287
Only relevant for taskgraph based execution.
254288
If provided, we split the query execution in that many partitions.
@@ -268,6 +302,9 @@ def taskgraph_query(
268302
from tiledb.vector_search.module import (array_to_matrix, dist_qv,
269303
partition_ivf_index)
270304

305+
if resource_class and resources:
306+
raise TypeError("Cannot provide both resource_class and resources")
307+
271308
def dist_qv_udf(
272309
dtype: np.dtype,
273310
parts_uri: str,
@@ -373,7 +410,8 @@ def dist_qv_udf(
373410
k_nn=k,
374411
config=config,
375412
timestamp=self.base_array_timestamp,
376-
resource_class="large" if mode == Mode.REALTIME else None,
413+
resource_class="large" if (not resources and not resource_class) else resource_class,
414+
resources=resources,
377415
image_name="3.9-vectorsearch",
378416
)
379417
)

apis/python/test/test_cloud.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import unittest
33

44
from common import *
5-
from tiledb.cloud import groups
5+
from tiledb.cloud import groups, tiledb_cloud_error
66
from tiledb.cloud.dag import Mode
77

88
import tiledb.vector_search as vs
@@ -17,6 +17,8 @@ class CloudTests(unittest.TestCase):
1717

1818
@classmethod
1919
def setUpClass(cls):
20+
if not os.getenv("TILEDB_REST_TOKEN"):
21+
raise ValueError("TILEDB_REST_TOKEN not set")
2022
tiledb.cloud.login(token=os.getenv("TILEDB_REST_TOKEN"))
2123
namespace, storage_path, _ = groups._default_ns_path_cred()
2224
storage_path = storage_path.replace("//", "/").replace("/", "//", 1)
@@ -76,10 +78,39 @@ def test_cloud_ivf_flat(self):
7678
# UDF library releases.
7779
# mode=Mode.BATCH,
7880
)
81+
7982
_, result_i = index.query(query_vectors, k=k, nprobe=nprobe)
8083
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
8184

8285
_, result_i = index.query(
83-
query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, num_partitions=2
86+
query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, num_partitions=2, resource_class="standard"
87+
)
88+
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
89+
90+
_, result_i = index.query(
91+
query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, num_partitions=2
8492
)
8593
assert accuracy(result_i, gt_i) > MINIMUM_ACCURACY
94+
95+
# We now will test for invalid scenarios when setting the query() resources.
96+
resources = {"cpu": "9", "memory": "12Gi", "gpu": 0}
97+
98+
# Cannot pass resource_class or resources to LOCAL mode or to no mode.
99+
with self.assertRaises(TypeError):
100+
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, resource_class="large")
101+
with self.assertRaises(TypeError):
102+
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL, resources=resources)
103+
with self.assertRaises(TypeError):
104+
index.query(query_vectors, k=k, nprobe=nprobe, resource_class="large")
105+
with self.assertRaises(TypeError):
106+
index.query(query_vectors, k=k, nprobe=nprobe, resources=resources)
107+
108+
# Cannot pass resources to REALTIME.
109+
with self.assertRaises(TypeError):
110+
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, resources=resources)
111+
112+
# Cannot pass both resource_class and resources.
113+
with self.assertRaises(TypeError):
114+
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.REALTIME, resource_class="large", resources=resources)
115+
with self.assertRaises(TypeError):
116+
index.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.BATCH, resource_class="large", resources=resources)

0 commit comments

Comments
 (0)