1+ import multiprocessing
12import os
3+ import math
24
35import numpy as np
46from tiledb .vector_search .module import *
7+ from tiledb .cloud .dag import Mode
58
69CENTROIDS_ARRAY_NAME = "centroids.tdb"
710INDEX_ARRAY_NAME = "index.tdb"
811IDS_ARRAY_NAME = "ids.tdb"
912PARTS_ARRAY_NAME = "parts.tdb"
1013
1114
15+ def submit_local (d , func , * args , ** kwargs ):
16+ # Drop kwarg
17+ kwargs .pop ("image_name" , None )
18+ kwargs .pop ("resources" , None )
19+ return d .submit_local (func , * args , ** kwargs )
20+
21+
1222class Index :
1323 def query (self , targets : np .ndarray , k = 10 , nqueries = 10 , nthreads = 8 , nprobe = 1 ):
1424 raise NotImplementedError
@@ -44,7 +54,7 @@ def query(
4454 nprobe : int = 1 ,
4555 ):
4656 """
47- Open a flat index
57+ Query a flat index
4858
4959 Parameters
5060 ----------
@@ -65,10 +75,10 @@ def query(
6575
6676 assert targets .dtype == np .float32
6777
68- targets_m = array_to_matrix (targets )
78+ targets_m = array_to_matrix (np . transpose ( targets ) )
6979
7080 r = query_vq (self ._db , targets_m , k , nqueries , nthreads )
71- return np .array (r )
81+ return np .transpose ( np . array (r ) )
7282
7383
7484class IVFFlatIndex (Index ):
@@ -109,64 +119,237 @@ def __init__(
109119
110120 def query (
111121 self ,
112- targets : np .ndarray ,
113- k = 10 ,
114- nqueries = 10 ,
115- nthreads = 8 ,
116- nprobe = 1 ,
122+ queries : np .ndarray ,
123+ k : int = 10 ,
124+ nprobe : int = 10 ,
125+ nthreads : int = - 1 ,
117126 use_nuv_implementation : bool = False ,
127+ mode : Mode = None ,
128+ num_partitions : int = - 1 ,
129+ num_workers : int = - 1 ,
118130 ):
119131 """
120- Open a flat index
132+ Query an IVF_FLAT index
121133
122134 Parameters
123135 ----------
124- targets : numpy.ndarray
125- ND Array of query targets
136+ queries : numpy.ndarray
137+ ND Array of queries
126138 k: int
127139 Number of top results to return per target
128- nqueries: int
129- Number of queries
130- nthreads: int
131- Number of threads to use for queyr
132140 nprobe: int
133141 number of probes
142+ nthreads: int
143+ Number of threads to use for query
134144 use_nuv_implementation: bool
135145 wether to use the nuv query implementation. Default: False
146+ mode: Mode
147+ If provided the query will be executed using TileDB cloud taskgraphs.
148+ For distributed execution you can use REALTIME or BATCH mode
149+ num_partitions: int
150+ Only relevant for taskgraph based execution.
151+ If provided, we split the query execution in that many partitions.
152+ num_workers: int
153+ Only relevant for taskgraph based execution.
154+ If provided, this is the number of workers to use for the query execution.
155+
136156 """
137- assert targets .dtype == np .float32
157+ assert queries .dtype == np .float32
158+ if nthreads == - 1 :
159+ nthreads = multiprocessing .cpu_count ()
160+ if mode is None :
161+ queries_m = array_to_matrix (np .transpose (queries ))
162+ if self .memory_budget == - 1 :
163+ r = ivf_query_ram (
164+ self .dtype ,
165+ self ._db ,
166+ self ._centroids ,
167+ queries_m ,
168+ self ._index ,
169+ self ._ids ,
170+ nprobe = nprobe ,
171+ k_nn = k ,
172+ nth = True , # ??
173+ nthreads = nthreads ,
174+ ctx = self .ctx ,
175+ use_nuv_implementation = use_nuv_implementation ,
176+ )
177+ else :
178+ r = ivf_query (
179+ self .dtype ,
180+ self .parts_db_uri ,
181+ self ._centroids ,
182+ queries_m ,
183+ self ._index ,
184+ self .ids_uri ,
185+ nprobe = nprobe ,
186+ k_nn = k ,
187+ memory_budget = self .memory_budget ,
188+ nth = True , # ??
189+ nthreads = nthreads ,
190+ ctx = self .ctx ,
191+ use_nuv_implementation = use_nuv_implementation ,
192+ )
138193
139- targets_m = array_to_matrix (targets )
140- if self .memory_budget == - 1 :
141- r = ivf_query_ram (
142- self .dtype ,
143- self ._db ,
144- self ._centroids ,
145- targets_m ,
146- self ._index ,
147- self ._ids ,
148- nprobe = nprobe ,
149- k_nn = k ,
150- nth = True , # ??
194+ return np .transpose (np .array (r ))
195+ else :
196+ return self .taskgraph_query (
197+ queries = queries ,
198+ k = k ,
151199 nthreads = nthreads ,
152- ctx = self .ctx ,
153- use_nuv_implementation = use_nuv_implementation ,
200+ nprobe = nprobe ,
201+ mode = mode ,
202+ num_partitions = num_partitions ,
203+ num_workers = num_workers ,
204+ )
205+
206+ def taskgraph_query (
207+ self ,
208+ queries : np .ndarray ,
209+ k : int = 10 ,
210+ nprobe : int = 10 ,
211+ nthreads : int = - 1 ,
212+ mode : Mode = None ,
213+ num_partitions : int = - 1 ,
214+ num_workers : int = - 1 ,
215+ ):
216+ """
217+ Query an IVF_FLAT index using TileDB cloud taskgraphs
218+
219+ Parameters
220+ ----------
221+ queries: numpy.ndarray
222+ ND Array of queries
223+ k: int
224+ Number of top results to return per target
225+ nprobe: int
226+ number of probes
227+ nthreads: int
228+ Number of threads to use for query
229+ use_nuv_implementation: bool
230+ wether to use the nuv query implementation. Default: False
231+ mode: Mode
232+ If provided the query will be executed using TileDB cloud taskgraphs.
233+ For distributed execution you can use REALTIME or BATCH mode
234+ num_partitions: int
235+ Only relevant for taskgraph based execution.
236+ If provided, we split the query execution in that many partitions.
237+ num_workers: int
238+ Only relevant for taskgraph based execution.
239+ If provided, this is the number of workers to use for the query execution.
240+ """
241+ from tiledb .cloud import dag
242+ from tiledb .cloud .dag import Mode
243+ from tiledb .vector_search .module import (
244+ array_to_matrix ,
245+ partition_ivf_index ,
246+ dist_qv ,
247+ )
248+ import math
249+ import numpy as np
250+ from functools import partial
251+
252+ def dist_qv_udf (
253+ dtype : np .dtype ,
254+ parts_uri : str ,
255+ ids_uri : str ,
256+ query_vectors : np .ndarray ,
257+ active_partitions : np .array ,
258+ active_queries : np .array ,
259+ indices : np .array ,
260+ k_nn : int ,
261+ ):
262+ queries_m = array_to_matrix (np .transpose (query_vectors ))
263+ r = dist_qv (
264+ dtype = dtype ,
265+ parts_uri = parts_uri ,
266+ ids_uri = ids_uri ,
267+ query_vectors = queries_m ,
268+ active_partitions = active_partitions ,
269+ active_queries = active_queries ,
270+ indices = indices ,
271+ k_nn = k_nn ,
272+ )
273+ results = []
274+ for q in range (len (r )):
275+ tmp_results = []
276+ for j in range (len (r [q ])):
277+ tmp_results .append (r [q ][j ])
278+ results .append (tmp_results )
279+ return results
280+
281+ assert queries .dtype == np .float32
282+ if num_partitions == - 1 :
283+ num_partitions = 5
284+ if num_workers == - 1 :
285+ num_workers = num_partitions
286+ if mode == Mode .BATCH :
287+ d = dag .DAG (
288+ name = "vector-query" ,
289+ mode = Mode .BATCH ,
290+ max_workers = num_workers ,
291+ )
292+ if mode == Mode .REALTIME :
293+ d = dag .DAG (
294+ name = "vector-query" ,
295+ mode = Mode .REALTIME ,
296+ max_workers = num_workers ,
154297 )
155298 else :
156- r = ivf_query (
157- self .dtype ,
158- self .parts_db_uri ,
159- self ._centroids ,
160- targets_m ,
161- self ._index ,
162- self .ids_uri ,
163- nprobe = nprobe ,
164- k_nn = k ,
165- memory_budget = self .memory_budget ,
166- nth = True , # ??
167- nthreads = nthreads ,
168- ctx = self .ctx ,
169- use_nuv_implementation = use_nuv_implementation ,
299+ d = dag .DAG (
300+ name = "vector-query" ,
301+ mode = Mode .REALTIME ,
302+ max_workers = 1 ,
303+ namespace = "default" ,
170304 )
305+ submit = partial (submit_local , d )
306+ if mode == Mode .BATCH or mode == Mode .REALTIME :
307+ submit = d .submit
308+
309+ queries_m = array_to_matrix (np .transpose (queries ))
310+ active_partitions , active_queries = partition_ivf_index (
311+ centroids = self ._centroids , query = queries_m , nprobe = nprobe , nthreads = nthreads
312+ )
313+ num_parts = len (active_partitions )
314+
315+ parts_per_node = int (math .ceil (num_parts / num_partitions ))
316+ nodes = []
317+ for part in range (0 , num_parts , parts_per_node ):
318+ part_end = part + parts_per_node
319+ if part_end > num_parts :
320+ part_end = num_parts
321+ nodes .append (
322+ submit (
323+ dist_qv_udf ,
324+ dtype = self .dtype ,
325+ parts_uri = self .parts_db_uri ,
326+ ids_uri = self .ids_uri ,
327+ query_vectors = queries ,
328+ active_partitions = np .array (active_partitions )[part :part_end ],
329+ active_queries = np .array (
330+ active_queries [part :part_end ], dtype = object
331+ ),
332+ indices = np .array (self ._index ),
333+ k_nn = k ,
334+ resource_class = "large" ,
335+ image_name = "3.9-vectorsearch" ,
336+ )
337+ )
338+
339+ d .compute ()
340+ d .wait ()
341+ results = []
342+ for node in nodes :
343+ res = node .result ()
344+ results .append (res )
171345
172- return np .array (r )
346+ results_per_query = []
347+ for q in range (queries .shape [0 ]):
348+ tmp_results = []
349+ for j in range (k ):
350+ for r in results :
351+ if len (r [q ]) > j :
352+ if r [q ][j ][0 ] > 0 :
353+ tmp_results .append (r [q ][j ])
354+ results_per_query .append (sorted (tmp_results , key = lambda t : t [0 ])[0 :k ])
355+ return results_per_query
0 commit comments