1+ import multiprocessing as mp
12import numpy as np
23import sys
34
@@ -37,15 +38,26 @@ def __init__(
3738 self .update_arrays_uri = None
3839 self .index_version = self .group .meta .get ("index_version" , "" )
3940
40-
4141 def query (self , queries : np .ndarray , k , ** kwargs ):
42+ if self .update_arrays_uri is None :
43+ return self .query_internal (queries , k , ** kwargs )
44+
45+ # Query with updates
4246 updated_ids = set (self .read_updated_ids ())
4347 retrieval_k = k
4448 if len (updated_ids ) > 0 :
4549 retrieval_k = 2 * k
50+
51+ # Perform the queries in parallel
52+ kwargs ["nthreads" ] = int (mp .cpu_count ()/ 2 )
53+ parent_conn , child_conn = mp .Pipe ()
54+ p = mp .Process (
55+ target = Index .query_additions ,
56+ args = (child_conn , queries , k , self .dtype , self .update_arrays_uri , int (mp .cpu_count ()/ 2 )))
57+ p .start ()
4658 internal_results_d , internal_results_i = self .query_internal (queries , retrieval_k , ** kwargs )
47- if self . update_arrays_uri is None :
48- return internal_results_d [:, 0 : k ], internal_results_i [:, 0 : k ]
59+ addition_results_d , addition_results_i = parent_conn . recv ()
60+ p . join ()
4961
5062 # Filter updated vectors
5163 query_id = 0
@@ -62,7 +74,6 @@ def query(self, queries: np.ndarray, k, **kwargs):
6274 internal_results_i = np .take_along_axis (internal_results_i , sort_index , axis = 1 )
6375
6476 # Merge update results
65- addition_results_d , addition_results_i = self .query_additions (queries , k )
6677 if addition_results_d is None :
6778 return internal_results_d [:, 0 :k ], internal_results_i [:, 0 :k ]
6879
@@ -84,22 +95,38 @@ def query(self, queries: np.ndarray, k, **kwargs):
8495 results_i = np .take_along_axis (results_i , sort_index , axis = 1 )
8596 return results_d [:, 0 :k ], results_i [:, 0 :k ]
8697
87- def query_internal (self , queries : np .ndarray , k , ** kwargs ):
88- raise NotImplementedError
89-
90- def query_additions (self , queries : np .ndarray , k ):
98+ @staticmethod
99+ def query_additions (conn , queries : np .ndarray , k , dtype , update_arrays_uri , nthreads = 8 ):
91100 assert queries .dtype == np .float32
92- additions_vectors , additions_external_ids = self .read_additions ()
101+ additions_vectors , additions_external_ids = Index .read_additions (update_arrays_uri )
93102 if additions_vectors is None :
94103 return None , None
95104 queries_m = array_to_matrix (np .transpose (queries ))
96105 d , i = query_vq_heap_pyarray (
97- array_to_matrix (np .transpose (additions_vectors ).astype (self . dtype )),
106+ array_to_matrix (np .transpose (additions_vectors ).astype (dtype )),
98107 queries_m ,
99108 StdVector_u64 (additions_external_ids ),
100109 k ,
101- 8 )
102- return np .transpose (np .array (d )), np .transpose (np .array (i ))
110+ nthreads )
111+ conn .send ((np .transpose (np .array (d )), np .transpose (np .array (i ))))
112+ conn .close ()
113+
114+ @staticmethod
115+ def read_additions (update_arrays_uri ) -> (np .ndarray , np .array ):
116+ if update_arrays_uri is None :
117+ return None , None
118+ updates_array = tiledb .open (update_arrays_uri , mode = "r" )
119+ q = updates_array .query (attrs = ('vector' ,), coords = True )
120+ data = q [:]
121+ updates_array .close ()
122+ additions_filter = [len (item ) > 0 for item in data ["vector" ]]
123+ if len (data ["external_id" ][additions_filter ]) > 0 :
124+ return np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
125+ else :
126+ return None , None
127+
128+ def query_internal (self , queries : np .ndarray , k , ** kwargs ):
129+ raise NotImplementedError
103130
104131 def update (self , vector : np .array , external_id : np .uint64 ):
105132 updates_array = self .open_updates_array ()
@@ -141,23 +168,13 @@ def consolidate_update_fragments(self):
141168 def get_updates_uri (self ):
142169 return self .update_arrays_uri
143170
144- def read_additions (self ) -> (np .ndarray , np .array ):
145- if self .update_arrays_uri is None :
146- return None , None
147- updates_array = tiledb .open (self .update_arrays_uri , mode = "r" )
148- q = updates_array .query (attrs = ('vector' ,), coords = True )
149- data = q [:]
150- additions_filter = [len (item ) > 0 for item in data ["vector" ]]
151- if len (data ["external_id" ][additions_filter ]) > 0 :
152- return np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
153- else :
154- return None , None
155171 def read_updated_ids (self ) -> np .array :
156172 if self .update_arrays_uri is None :
157173 return np .array ([], np .uint64 )
158174 updates_array = tiledb .open (self .update_arrays_uri , mode = "r" )
159175 q = updates_array .query (attrs = ('vector' ,), coords = True )
160176 data = q [:]
177+ updates_array .close ()
161178 return data ["external_id" ]
162179
163180 def open_updates_array (self ):
@@ -197,4 +214,8 @@ def consolidate_updates(self):
197214 updates_uri = self .update_arrays_uri
198215 )
199216 tiledb .Array .delete_array (self .update_arrays_uri )
217+ self .group .close ()
218+ self .group = tiledb .Group (self .uri , "w" , ctx = tiledb .Ctx (self .config ))
219+ self .group .remove (self .update_arrays_uri )
220+ self .group .close ()
200221 return new_index
0 commit comments