1+ import  multiprocessing 
12import  os 
23import  math 
34
@@ -74,10 +75,10 @@ def query(
7475
7576        assert  targets .dtype  ==  np .float32 
7677
77-         targets_m  =  array_to_matrix (targets )
78+         targets_m  =  array_to_matrix (np . transpose ( targets ) )
7879
7980        r  =  query_vq (self ._db , targets_m , k , nqueries , nthreads )
80-         return  np .array (r )
81+         return  np .transpose ( np . array (r ) )
8182
8283
8384class  IVFFlatIndex (Index ):
@@ -118,92 +119,124 @@ def __init__(
118119
119120    def  query (
120121        self ,
121-         targets : np .ndarray ,
122-         k = 10 ,
123-         nqueries = 10 ,
124-         nthreads = 8 ,
125-         nprobe = 1 ,
122+         queries : np .ndarray ,
123+         k : int  =  10 ,
124+         nprobe : int  =  10 ,
125+         nthreads : int  =  - 1 ,
126126        use_nuv_implementation : bool  =  False ,
127+         mode : Mode  =  None ,
128+         num_partitions : int  =  - 1 ,
129+         num_workers : int  =  - 1 ,
127130    ):
128131        """ 
129132        Query an IVF_FLAT index 
130133
131134        Parameters 
132135        ---------- 
133-         targets : numpy.ndarray 
134-             ND Array of query targets  
136+         queries : numpy.ndarray 
137+             ND Array of queries  
135138        k: int 
136139            Number of top results to return per target 
137-         nqueries: int 
138-             Number of queries 
139-         nthreads: int 
140-             Number of threads to use for query 
141140        nprobe: int 
142141            number of probes 
142+         nthreads: int 
143+             Number of threads to use for query 
143144        use_nuv_implementation: bool 
144145            wether to use the nuv query implementation. Default: False 
146+         mode: Mode 
147+             If provided the query will be executed using TileDB cloud taskgraphs. 
148+             For distributed execution you can use REALTIME or BATCH mode 
149+         num_partitions: int 
150+             Only relevant for taskgraph based execution. 
151+             If provided, we split the query execution in that many partitions. 
152+         num_workers: int 
153+             Only relevant for taskgraph based execution. 
154+             If provided, this is the number of workers to use for the query execution. 
155+ 
145156        """ 
146-         assert  targets .dtype  ==  np .float32 
157+         assert  queries .dtype  ==  np .float32 
158+         if  nthreads  ==  - 1 :
159+             nthreads  =  multiprocessing .cpu_count ()
160+         if  mode  is  None :
161+             queries_m  =  array_to_matrix (np .transpose (queries ))
162+             if  self .memory_budget  ==  - 1 :
163+                 r  =  ivf_query_ram (
164+                     self .dtype ,
165+                     self ._db ,
166+                     self ._centroids ,
167+                     queries_m ,
168+                     self ._index ,
169+                     self ._ids ,
170+                     nprobe = nprobe ,
171+                     k_nn = k ,
172+                     nth = True ,  # ?? 
173+                     nthreads = nthreads ,
174+                     ctx = self .ctx ,
175+                     use_nuv_implementation = use_nuv_implementation ,
176+                 )
177+             else :
178+                 r  =  ivf_query (
179+                     self .dtype ,
180+                     self .parts_db_uri ,
181+                     self ._centroids ,
182+                     queries_m ,
183+                     self ._index ,
184+                     self .ids_uri ,
185+                     nprobe = nprobe ,
186+                     k_nn = k ,
187+                     memory_budget = self .memory_budget ,
188+                     nth = True ,  # ?? 
189+                     nthreads = nthreads ,
190+                     ctx = self .ctx ,
191+                     use_nuv_implementation = use_nuv_implementation ,
192+                 )
147193
148-         targets_m  =  array_to_matrix (targets )
149-         if  self .memory_budget  ==  - 1 :
150-             r  =  ivf_query_ram (
151-                 self .dtype ,
152-                 self ._db ,
153-                 self ._centroids ,
154-                 targets_m ,
155-                 self ._index ,
156-                 self ._ids ,
157-                 nprobe = nprobe ,
158-                 k_nn = k ,
159-                 nth = True ,  # ?? 
160-                 nthreads = nthreads ,
161-                 ctx = self .ctx ,
162-                 use_nuv_implementation = use_nuv_implementation ,
163-             )
194+             return  np .transpose (np .array (r ))
164195        else :
165-             r  =  ivf_query (
166-                 self .dtype ,
167-                 self .parts_db_uri ,
168-                 self ._centroids ,
169-                 targets_m ,
170-                 self ._index ,
171-                 self .ids_uri ,
172-                 nprobe = nprobe ,
173-                 k_nn = k ,
174-                 memory_budget = self .memory_budget ,
175-                 nth = True ,  # ?? 
196+             return  self .taskgraph_query (
197+                 queries = queries ,
198+                 k = k ,
176199                nthreads = nthreads ,
177-                 ctx = self .ctx ,
178-                 use_nuv_implementation = use_nuv_implementation ,
200+                 nprobe = nprobe ,
201+                 mode = mode ,
202+                 num_partitions = num_partitions ,
203+                 num_workers = num_workers ,
179204            )
180205
181-         return  np .array (r )
182- 
183-     def  distributed_query (
206+     def  taskgraph_query (
184207        self ,
185-         targets : np .ndarray ,
186-         k = 10 ,
187-         nthreads = 8 ,
188-         nprobe = 1 ,
189-         num_nodes = 5 ,
190-         mode : Mode  =  Mode .REALTIME ,
208+         queries : np .ndarray ,
209+         k : int  =  10 ,
210+         nprobe : int  =  10 ,
211+         nthreads : int  =  - 1 ,
212+         mode : Mode  =  None ,
213+         num_partitions : int  =  - 1 ,
214+         num_workers : int  =  - 1 ,
191215    ):
192216        """ 
193-         Distributed  Query on top of  an IVF_FLAT index 
217+         Query an IVF_FLAT index using TileDB cloud taskgraphs  
194218
195219        Parameters 
196220        ---------- 
197-         targets : numpy.ndarray 
198-             ND Array of query targets  
221+         queries : numpy.ndarray 
222+             ND Array of queries  
199223        k: int 
200224            Number of top results to return per target 
201-         nqueries: int 
202-             Number of queries 
203-         nthreads: int 
204-             Number of threads to use for query 
205225        nprobe: int 
206226            number of probes 
227+         nthreads: int 
228+             Number of threads to use for query 
229+         use_nuv_implementation: bool 
230+             wether to use the nuv query implementation. Default: False 
231+         mode: Mode 
232+             If provided the query will be executed using TileDB cloud taskgraphs. 
233+             For distributed execution you can use REALTIME or BATCH mode 
234+         num_partitions: int 
235+             Only relevant for taskgraph based execution. 
236+             If provided, we split the query execution in that many partitions. 
237+         num_workers: int 
238+             Only relevant for taskgraph based execution. 
239+             If provided, this is the number of workers to use for the query execution. 
207240        """ 
208241        from  tiledb .cloud  import  dag 
209242        from  tiledb .cloud .dag  import  Mode 
@@ -226,12 +259,12 @@ def dist_qv_udf(
226259            indices : np .array ,
227260            k_nn : int ,
228261        ):
229-             targets_m  =  array_to_matrix (query_vectors )
262+             queries_m  =  array_to_matrix (np . transpose ( query_vectors ) )
230263            r  =  dist_qv (
231264                dtype = dtype ,
232265                parts_uri = parts_uri ,
233266                ids_uri = ids_uri ,
234-                 query_vectors = targets_m ,
267+                 query_vectors = queries_m ,
235268                active_partitions = active_partitions ,
236269                active_queries = active_queries ,
237270                indices = indices ,
@@ -245,18 +278,22 @@ def dist_qv_udf(
245278                results .append (tmp_results )
246279            return  results 
247280
248-         assert  targets .dtype  ==  self .dtype 
281+         assert  queries .dtype  ==  np .float32 
282+         if  num_partitions  ==  - 1 :
283+             num_partitions  =  5 
284+         if  num_workers  ==  - 1 :
285+             num_workers  =  num_partitions 
249286        if  mode  ==  Mode .BATCH :
250287            d  =  dag .DAG (
251288                name = "vector-query" ,
252289                mode = Mode .BATCH ,
253-                 max_workers = num_nodes ,
290+                 max_workers = num_workers ,
254291            )
255292        if  mode  ==  Mode .REALTIME :
256293            d  =  dag .DAG (
257294                name = "vector-query" ,
258295                mode = Mode .REALTIME ,
259-                 max_workers = num_nodes ,
296+                 max_workers = num_workers ,
260297            )
261298        else :
262299            d  =  dag .DAG (
@@ -269,13 +306,13 @@ def dist_qv_udf(
269306        if  mode  ==  Mode .BATCH  or  mode  ==  Mode .REALTIME :
270307            submit  =  d .submit 
271308
272-         targets_m  =  array_to_matrix (targets )
309+         queries_m  =  array_to_matrix (np . transpose ( queries ) )
273310        active_partitions , active_queries  =  partition_ivf_index (
274-             centroids = self ._centroids , query = targets_m , nprobe = nprobe , nthreads = nthreads 
311+             centroids = self ._centroids , query = queries_m , nprobe = nprobe , nthreads = nthreads 
275312        )
276313        num_parts  =  len (active_partitions )
277314
278-         parts_per_node  =  int (math .ceil (num_parts  /  num_nodes ))
315+         parts_per_node  =  int (math .ceil (num_parts  /  num_partitions ))
279316        nodes  =  []
280317        for  part  in  range (0 , num_parts , parts_per_node ):
281318            part_end  =  part  +  parts_per_node 
@@ -287,7 +324,7 @@ def dist_qv_udf(
287324                    dtype = self .dtype ,
288325                    parts_uri = self .parts_db_uri ,
289326                    ids_uri = self .ids_uri ,
290-                     query_vectors = targets ,
327+                     query_vectors = queries ,
291328                    active_partitions = np .array (active_partitions )[part :part_end ],
292329                    active_queries = np .array (
293330                        active_queries [part :part_end ], dtype = object 
@@ -307,7 +344,7 @@ def dist_qv_udf(
307344            results .append (res )
308345
309346        results_per_query  =  []
310-         for  q  in  range (targets .shape [1 ]):
347+         for  q  in  range (queries .shape [0 ]):
311348            tmp_results  =  []
312349            for  j  in  range (k ):
313350                for  r  in  results :
0 commit comments