Skip to content

Commit 1d91167

Browse files
author
Nikos Papailiou
committed
Address review comments
1 parent 334618c commit 1d91167

File tree

3 files changed

+141
-115
lines changed

3 files changed

+141
-115
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 106 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import multiprocessing
12
import os
23
import math
34

@@ -74,10 +75,10 @@ def query(
7475

7576
assert targets.dtype == np.float32
7677

77-
targets_m = array_to_matrix(targets)
78+
targets_m = array_to_matrix(np.transpose(targets))
7879

7980
r = query_vq(self._db, targets_m, k, nqueries, nthreads)
80-
return np.array(r)
81+
return np.transpose(np.array(r))
8182

8283

8384
class IVFFlatIndex(Index):
@@ -118,92 +119,124 @@ def __init__(
118119

119120
def query(
120121
self,
121-
targets: np.ndarray,
122-
k=10,
123-
nqueries=10,
124-
nthreads=8,
125-
nprobe=1,
122+
queries: np.ndarray,
123+
k: int = 10,
124+
nprobe: int = 10,
125+
nthreads: int = -1,
126126
use_nuv_implementation: bool = False,
127+
mode: Mode = None,
128+
num_partitions: int = -1,
129+
num_workers: int = -1,
127130
):
128131
"""
129132
Query an IVF_FLAT index
130133
131134
Parameters
132135
----------
133-
targets: numpy.ndarray
134-
ND Array of query targets
136+
queries: numpy.ndarray
137+
ND Array of queries
135138
k: int
136139
Number of top results to return per target
137-
nqueries: int
138-
Number of queries
139-
nthreads: int
140-
Number of threads to use for query
141140
nprobe: int
142141
number of probes
142+
nthreads: int
143+
Number of threads to use for query
143144
use_nuv_implementation: bool
144145
wether to use the nuv query implementation. Default: False
146+
mode: Mode
147+
If provided the query will be executed using TileDB cloud taskgraphs.
148+
For distributed execution you can use REALTIME or BATCH mode
149+
num_partitions: int
150+
Only relevant for taskgraph based execution.
151+
If provided, we split the query execution in that many partitions.
152+
num_workers: int
153+
Only relevant for taskgraph based execution.
154+
If provided, this is the number of workers to use for the query execution.
155+
145156
"""
146-
assert targets.dtype == np.float32
157+
assert queries.dtype == np.float32
158+
if nthreads == -1:
159+
nthreads = multiprocessing.cpu_count()
160+
if mode is None:
161+
queries_m = array_to_matrix(np.transpose(queries))
162+
if self.memory_budget == -1:
163+
r = ivf_query_ram(
164+
self.dtype,
165+
self._db,
166+
self._centroids,
167+
queries_m,
168+
self._index,
169+
self._ids,
170+
nprobe=nprobe,
171+
k_nn=k,
172+
nth=True, # ??
173+
nthreads=nthreads,
174+
ctx=self.ctx,
175+
use_nuv_implementation=use_nuv_implementation,
176+
)
177+
else:
178+
r = ivf_query(
179+
self.dtype,
180+
self.parts_db_uri,
181+
self._centroids,
182+
queries_m,
183+
self._index,
184+
self.ids_uri,
185+
nprobe=nprobe,
186+
k_nn=k,
187+
memory_budget=self.memory_budget,
188+
nth=True, # ??
189+
nthreads=nthreads,
190+
ctx=self.ctx,
191+
use_nuv_implementation=use_nuv_implementation,
192+
)
147193

148-
targets_m = array_to_matrix(targets)
149-
if self.memory_budget == -1:
150-
r = ivf_query_ram(
151-
self.dtype,
152-
self._db,
153-
self._centroids,
154-
targets_m,
155-
self._index,
156-
self._ids,
157-
nprobe=nprobe,
158-
k_nn=k,
159-
nth=True, # ??
160-
nthreads=nthreads,
161-
ctx=self.ctx,
162-
use_nuv_implementation=use_nuv_implementation,
163-
)
194+
return np.transpose(np.array(r))
164195
else:
165-
r = ivf_query(
166-
self.dtype,
167-
self.parts_db_uri,
168-
self._centroids,
169-
targets_m,
170-
self._index,
171-
self.ids_uri,
172-
nprobe=nprobe,
173-
k_nn=k,
174-
memory_budget=self.memory_budget,
175-
nth=True, # ??
196+
return self.taskgraph_query(
197+
queries=queries,
198+
k=k,
176199
nthreads=nthreads,
177-
ctx=self.ctx,
178-
use_nuv_implementation=use_nuv_implementation,
200+
nprobe=nprobe,
201+
mode=mode,
202+
num_partitions=num_partitions,
203+
num_workers=num_workers,
179204
)
180205

181-
return np.array(r)
182-
183-
def distributed_query(
206+
def taskgraph_query(
184207
self,
185-
targets: np.ndarray,
186-
k=10,
187-
nthreads=8,
188-
nprobe=1,
189-
num_nodes=5,
190-
mode: Mode = Mode.REALTIME,
208+
queries: np.ndarray,
209+
k: int = 10,
210+
nprobe: int = 10,
211+
nthreads: int = -1,
212+
mode: Mode = None,
213+
num_partitions: int = -1,
214+
num_workers: int = -1,
191215
):
192216
"""
193-
Distributed Query on top of an IVF_FLAT index
217+
Query an IVF_FLAT index using TileDB cloud taskgraphs
194218
195219
Parameters
196220
----------
197-
targets: numpy.ndarray
198-
ND Array of query targets
221+
queries: numpy.ndarray
222+
ND Array of queries
199223
k: int
200224
Number of top results to return per target
201-
nqueries: int
202-
Number of queries
203-
nthreads: int
204-
Number of threads to use for query
205225
nprobe: int
206226
number of probes
227+
nthreads: int
228+
Number of threads to use for query
229+
use_nuv_implementation: bool
230+
wether to use the nuv query implementation. Default: False
231+
mode: Mode
232+
If provided the query will be executed using TileDB cloud taskgraphs.
233+
For distributed execution you can use REALTIME or BATCH mode
234+
num_partitions: int
235+
Only relevant for taskgraph based execution.
236+
If provided, we split the query execution in that many partitions.
237+
num_workers: int
238+
Only relevant for taskgraph based execution.
239+
If provided, this is the number of workers to use for the query execution.
207240
"""
208241
from tiledb.cloud import dag
209242
from tiledb.cloud.dag import Mode
@@ -226,12 +259,12 @@ def dist_qv_udf(
226259
indices: np.array,
227260
k_nn: int,
228261
):
229-
targets_m = array_to_matrix(query_vectors)
262+
queries_m = array_to_matrix(np.transpose(query_vectors))
230263
r = dist_qv(
231264
dtype=dtype,
232265
parts_uri=parts_uri,
233266
ids_uri=ids_uri,
234-
query_vectors=targets_m,
267+
query_vectors=queries_m,
235268
active_partitions=active_partitions,
236269
active_queries=active_queries,
237270
indices=indices,
@@ -245,18 +278,22 @@ def dist_qv_udf(
245278
results.append(tmp_results)
246279
return results
247280

248-
assert targets.dtype == self.dtype
281+
assert queries.dtype == np.float32
282+
if num_partitions == -1:
283+
num_partitions = 5
284+
if num_workers == -1:
285+
num_workers = num_partitions
249286
if mode == Mode.BATCH:
250287
d = dag.DAG(
251288
name="vector-query",
252289
mode=Mode.BATCH,
253-
max_workers=num_nodes,
290+
max_workers=num_workers,
254291
)
255292
if mode == Mode.REALTIME:
256293
d = dag.DAG(
257294
name="vector-query",
258295
mode=Mode.REALTIME,
259-
max_workers=num_nodes,
296+
max_workers=num_workers,
260297
)
261298
else:
262299
d = dag.DAG(
@@ -269,13 +306,13 @@ def dist_qv_udf(
269306
if mode == Mode.BATCH or mode == Mode.REALTIME:
270307
submit = d.submit
271308

272-
targets_m = array_to_matrix(targets)
309+
queries_m = array_to_matrix(np.transpose(queries))
273310
active_partitions, active_queries = partition_ivf_index(
274-
centroids=self._centroids, query=targets_m, nprobe=nprobe, nthreads=nthreads
311+
centroids=self._centroids, query=queries_m, nprobe=nprobe, nthreads=nthreads
275312
)
276313
num_parts = len(active_partitions)
277314

278-
parts_per_node = int(math.ceil(num_parts / num_nodes))
315+
parts_per_node = int(math.ceil(num_parts / num_partitions))
279316
nodes = []
280317
for part in range(0, num_parts, parts_per_node):
281318
part_end = part + parts_per_node
@@ -287,7 +324,7 @@ def dist_qv_udf(
287324
dtype=self.dtype,
288325
parts_uri=self.parts_db_uri,
289326
ids_uri=self.ids_uri,
290-
query_vectors=targets,
327+
query_vectors=queries,
291328
active_partitions=np.array(active_partitions)[part:part_end],
292329
active_queries=np.array(
293330
active_queries[part:part_end], dtype=object
@@ -307,7 +344,7 @@ def dist_qv_udf(
307344
results.append(res)
308345

309346
results_per_query = []
310-
for q in range(targets.shape[1]):
347+
for q in range(queries.shape[0]):
311348
tmp_results = []
312349
for j in range(k):
313350
for r in results:

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,20 +357,20 @@ void declarePartitionIvfIndex(py::module& m, const std::string& suffix) {
357357
);
358358
}
359359

360-
template <typename query_type, typename shuffled_ids_type = uint64_t>
360+
template <typename T, typename shuffled_ids_type = uint64_t>
361361
static void declare_dist_qv(py::module& m, const std::string& suffix) {
362362
m.def(("dist_qv_" + suffix).c_str(),
363363
[](tiledb::Context& ctx,
364364
const std::string& part_uri,
365365
std::vector<int>& active_partitions,
366-
ColMajorMatrix<query_type>& query,
366+
ColMajorMatrix<float>& query,
367367
std::vector<std::vector<int>>& active_queries,
368368
std::vector<shuffled_ids_type>& indices,
369369
const std::string& id_uri,
370370
size_t k_nn
371371
/* size_t nthreads TODO: optional arg w/ fallback to C++ default arg */
372372
) { /* TODO return type */
373-
return detail::ivf::dist_qv_finite_ram_part<query_type, shuffled_ids_type>(
373+
return detail::ivf::dist_qv_finite_ram_part<T, shuffled_ids_type>(
374374
ctx,
375375
part_uri,
376376
active_partitions,

0 commit comments

Comments
 (0)