Skip to content

Commit ae2fad2

Browse files
Merge pull request #90 from TileDB-Inc/npapa/distributed-query
Distributed query implementation
2 parents 1d6e444 + 1d91167 commit ae2fad2

File tree

6 files changed

+378
-109
lines changed

6 files changed

+378
-109
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 228 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
1+
import multiprocessing
12
import os
3+
import math
24

35
import numpy as np
46
from tiledb.vector_search.module import *
7+
from tiledb.cloud.dag import Mode
58

69
CENTROIDS_ARRAY_NAME = "centroids.tdb"
710
INDEX_ARRAY_NAME = "index.tdb"
811
IDS_ARRAY_NAME = "ids.tdb"
912
PARTS_ARRAY_NAME = "parts.tdb"
1013

1114

15+
def submit_local(d, func, *args, **kwargs):
16+
# Drop kwarg
17+
kwargs.pop("image_name", None)
18+
kwargs.pop("resources", None)
19+
return d.submit_local(func, *args, **kwargs)
20+
21+
1222
class Index:
1323
def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
1424
raise NotImplementedError
@@ -44,7 +54,7 @@ def query(
4454
nprobe: int = 1,
4555
):
4656
"""
47-
Open a flat index
57+
Query a flat index
4858
4959
Parameters
5060
----------
@@ -65,10 +75,10 @@ def query(
6575

6676
assert targets.dtype == np.float32
6777

68-
targets_m = array_to_matrix(targets)
78+
targets_m = array_to_matrix(np.transpose(targets))
6979

7080
r = query_vq(self._db, targets_m, k, nqueries, nthreads)
71-
return np.array(r)
81+
return np.transpose(np.array(r))
7282

7383

7484
class IVFFlatIndex(Index):
@@ -109,64 +119,237 @@ def __init__(
109119

110120
def query(
111121
self,
112-
targets: np.ndarray,
113-
k=10,
114-
nqueries=10,
115-
nthreads=8,
116-
nprobe=1,
122+
queries: np.ndarray,
123+
k: int = 10,
124+
nprobe: int = 10,
125+
nthreads: int = -1,
117126
use_nuv_implementation: bool = False,
127+
mode: Mode = None,
128+
num_partitions: int = -1,
129+
num_workers: int = -1,
118130
):
119131
"""
120-
Open a flat index
132+
Query an IVF_FLAT index
121133
122134
Parameters
123135
----------
124-
targets: numpy.ndarray
125-
ND Array of query targets
136+
queries: numpy.ndarray
137+
ND Array of queries
126138
k: int
127139
Number of top results to return per target
128-
nqueries: int
129-
Number of queries
130-
nthreads: int
131-
Number of threads to use for queyr
132140
nprobe: int
133141
number of probes
142+
nthreads: int
143+
Number of threads to use for query
134144
use_nuv_implementation: bool
135145
wether to use the nuv query implementation. Default: False
146+
mode: Mode
147+
If provided the query will be executed using TileDB cloud taskgraphs.
148+
For distributed execution you can use REALTIME or BATCH mode
149+
num_partitions: int
150+
Only relevant for taskgraph based execution.
151+
If provided, we split the query execution in that many partitions.
152+
num_workers: int
153+
Only relevant for taskgraph based execution.
154+
If provided, this is the number of workers to use for the query execution.
155+
136156
"""
137-
assert targets.dtype == np.float32
157+
assert queries.dtype == np.float32
158+
if nthreads == -1:
159+
nthreads = multiprocessing.cpu_count()
160+
if mode is None:
161+
queries_m = array_to_matrix(np.transpose(queries))
162+
if self.memory_budget == -1:
163+
r = ivf_query_ram(
164+
self.dtype,
165+
self._db,
166+
self._centroids,
167+
queries_m,
168+
self._index,
169+
self._ids,
170+
nprobe=nprobe,
171+
k_nn=k,
172+
nth=True, # ??
173+
nthreads=nthreads,
174+
ctx=self.ctx,
175+
use_nuv_implementation=use_nuv_implementation,
176+
)
177+
else:
178+
r = ivf_query(
179+
self.dtype,
180+
self.parts_db_uri,
181+
self._centroids,
182+
queries_m,
183+
self._index,
184+
self.ids_uri,
185+
nprobe=nprobe,
186+
k_nn=k,
187+
memory_budget=self.memory_budget,
188+
nth=True, # ??
189+
nthreads=nthreads,
190+
ctx=self.ctx,
191+
use_nuv_implementation=use_nuv_implementation,
192+
)
138193

139-
targets_m = array_to_matrix(targets)
140-
if self.memory_budget == -1:
141-
r = ivf_query_ram(
142-
self.dtype,
143-
self._db,
144-
self._centroids,
145-
targets_m,
146-
self._index,
147-
self._ids,
148-
nprobe=nprobe,
149-
k_nn=k,
150-
nth=True, # ??
194+
return np.transpose(np.array(r))
195+
else:
196+
return self.taskgraph_query(
197+
queries=queries,
198+
k=k,
151199
nthreads=nthreads,
152-
ctx=self.ctx,
153-
use_nuv_implementation=use_nuv_implementation,
200+
nprobe=nprobe,
201+
mode=mode,
202+
num_partitions=num_partitions,
203+
num_workers=num_workers,
204+
)
205+
206+
def taskgraph_query(
207+
self,
208+
queries: np.ndarray,
209+
k: int = 10,
210+
nprobe: int = 10,
211+
nthreads: int = -1,
212+
mode: Mode = None,
213+
num_partitions: int = -1,
214+
num_workers: int = -1,
215+
):
216+
"""
217+
Query an IVF_FLAT index using TileDB cloud taskgraphs
218+
219+
Parameters
220+
----------
221+
queries: numpy.ndarray
222+
ND Array of queries
223+
k: int
224+
Number of top results to return per target
225+
nprobe: int
226+
number of probes
227+
nthreads: int
228+
Number of threads to use for query
229+
use_nuv_implementation: bool
230+
wether to use the nuv query implementation. Default: False
231+
mode: Mode
232+
If provided the query will be executed using TileDB cloud taskgraphs.
233+
For distributed execution you can use REALTIME or BATCH mode
234+
num_partitions: int
235+
Only relevant for taskgraph based execution.
236+
If provided, we split the query execution in that many partitions.
237+
num_workers: int
238+
Only relevant for taskgraph based execution.
239+
If provided, this is the number of workers to use for the query execution.
240+
"""
241+
from tiledb.cloud import dag
242+
from tiledb.cloud.dag import Mode
243+
from tiledb.vector_search.module import (
244+
array_to_matrix,
245+
partition_ivf_index,
246+
dist_qv,
247+
)
248+
import math
249+
import numpy as np
250+
from functools import partial
251+
252+
def dist_qv_udf(
253+
dtype: np.dtype,
254+
parts_uri: str,
255+
ids_uri: str,
256+
query_vectors: np.ndarray,
257+
active_partitions: np.array,
258+
active_queries: np.array,
259+
indices: np.array,
260+
k_nn: int,
261+
):
262+
queries_m = array_to_matrix(np.transpose(query_vectors))
263+
r = dist_qv(
264+
dtype=dtype,
265+
parts_uri=parts_uri,
266+
ids_uri=ids_uri,
267+
query_vectors=queries_m,
268+
active_partitions=active_partitions,
269+
active_queries=active_queries,
270+
indices=indices,
271+
k_nn=k_nn,
272+
)
273+
results = []
274+
for q in range(len(r)):
275+
tmp_results = []
276+
for j in range(len(r[q])):
277+
tmp_results.append(r[q][j])
278+
results.append(tmp_results)
279+
return results
280+
281+
assert queries.dtype == np.float32
282+
if num_partitions == -1:
283+
num_partitions = 5
284+
if num_workers == -1:
285+
num_workers = num_partitions
286+
if mode == Mode.BATCH:
287+
d = dag.DAG(
288+
name="vector-query",
289+
mode=Mode.BATCH,
290+
max_workers=num_workers,
291+
)
292+
if mode == Mode.REALTIME:
293+
d = dag.DAG(
294+
name="vector-query",
295+
mode=Mode.REALTIME,
296+
max_workers=num_workers,
154297
)
155298
else:
156-
r = ivf_query(
157-
self.dtype,
158-
self.parts_db_uri,
159-
self._centroids,
160-
targets_m,
161-
self._index,
162-
self.ids_uri,
163-
nprobe=nprobe,
164-
k_nn=k,
165-
memory_budget=self.memory_budget,
166-
nth=True, # ??
167-
nthreads=nthreads,
168-
ctx=self.ctx,
169-
use_nuv_implementation=use_nuv_implementation,
299+
d = dag.DAG(
300+
name="vector-query",
301+
mode=Mode.REALTIME,
302+
max_workers=1,
303+
namespace="default",
170304
)
305+
submit = partial(submit_local, d)
306+
if mode == Mode.BATCH or mode == Mode.REALTIME:
307+
submit = d.submit
308+
309+
queries_m = array_to_matrix(np.transpose(queries))
310+
active_partitions, active_queries = partition_ivf_index(
311+
centroids=self._centroids, query=queries_m, nprobe=nprobe, nthreads=nthreads
312+
)
313+
num_parts = len(active_partitions)
314+
315+
parts_per_node = int(math.ceil(num_parts / num_partitions))
316+
nodes = []
317+
for part in range(0, num_parts, parts_per_node):
318+
part_end = part + parts_per_node
319+
if part_end > num_parts:
320+
part_end = num_parts
321+
nodes.append(
322+
submit(
323+
dist_qv_udf,
324+
dtype=self.dtype,
325+
parts_uri=self.parts_db_uri,
326+
ids_uri=self.ids_uri,
327+
query_vectors=queries,
328+
active_partitions=np.array(active_partitions)[part:part_end],
329+
active_queries=np.array(
330+
active_queries[part:part_end], dtype=object
331+
),
332+
indices=np.array(self._index),
333+
k_nn=k,
334+
resource_class="large",
335+
image_name="3.9-vectorsearch",
336+
)
337+
)
338+
339+
d.compute()
340+
d.wait()
341+
results = []
342+
for node in nodes:
343+
res = node.result()
344+
results.append(res)
171345

172-
return np.array(r)
346+
results_per_query = []
347+
for q in range(queries.shape[0]):
348+
tmp_results = []
349+
for j in range(k):
350+
for r in results:
351+
if len(r[q]) > j:
352+
if r[q][j][0] > 0:
353+
tmp_results.append(r[q][j])
354+
results_per_query.append(sorted(tmp_results, key=lambda t: t[0])[0:k])
355+
return results_per_query

0 commit comments

Comments
 (0)