1+ import concurrent .futures as futures
2+ import os
13import numpy as np
24import sys
35
@@ -20,6 +22,7 @@ class Index:
2022 config: Optional[Mapping[str, Any]]
2123 config dictionary, defaults to None
2224 """
25+
2326 def __init__ (
2427 self ,
2528 uri : str ,
@@ -36,16 +39,28 @@ def __init__(
3639 self .storage_version = self .group .meta .get ("storage_version" , "0.1" )
3740 self .update_arrays_uri = None
3841 self .index_version = self .group .meta .get ("index_version" , "" )
39-
42+ self . thread_executor = futures . ThreadPoolExecutor ()
4043
4144 def query (self , queries : np .ndarray , k , ** kwargs ):
42- updated_ids = set (self .read_updated_ids ())
43- retrieval_k = k
44- if len (updated_ids ) > 0 :
45- retrieval_k = 2 * k
46- internal_results_d , internal_results_i = self .query_internal (queries , retrieval_k , ** kwargs )
4745 if self .update_arrays_uri is None :
48- return internal_results_d [:, 0 :k ], internal_results_i [:, 0 :k ]
46+ return self .query_internal (queries , k , ** kwargs )
47+
48+ # Query with updates
49+ # Perform the queries in parallel
50+ retrieval_k = 2 * k
51+ kwargs ["nthreads" ] = int (os .cpu_count () / 2 )
52+ future = self .thread_executor .submit (
53+ Index .query_additions ,
54+ queries ,
55+ k ,
56+ self .dtype ,
57+ self .update_arrays_uri ,
58+ int (os .cpu_count () / 2 ),
59+ )
60+ internal_results_d , internal_results_i = self .query_internal (
61+ queries , retrieval_k , ** kwargs
62+ )
63+ addition_results_d , addition_results_i , updated_ids = future .result ()
4964
5065 # Filter updated vectors
5166 query_id = 0
@@ -62,112 +77,124 @@ def query(self, queries: np.ndarray, k, **kwargs):
6277 internal_results_i = np .take_along_axis (internal_results_i , sort_index , axis = 1 )
6378
6479 # Merge update results
65- addition_results_d , addition_results_i = self .query_additions (queries , k )
6680 if addition_results_d is None :
6781 return internal_results_d [:, 0 :k ], internal_results_i [:, 0 :k ]
6882
6983 query_id = 0
7084 for query in addition_results_d :
7185 res_id = 0
7286 for res in query :
73- if addition_results_d [query_id , res_id ] == 0 and addition_results_i [query_id , res_id ] == 0 :
87+ if (
88+ addition_results_d [query_id , res_id ] == 0
89+ and addition_results_i [query_id , res_id ] == 0
90+ ):
7491 addition_results_d [query_id , res_id ] = MAX_FLOAT_32
7592 addition_results_i [query_id , res_id ] = MAX_UINT64
7693 res_id += 1
7794 query_id += 1
7895
79-
8096 results_d = np .hstack ((internal_results_d , addition_results_d ))
8197 results_i = np .hstack ((internal_results_i , addition_results_i ))
8298 sort_index = np .argsort (results_d , axis = 1 )
8399 results_d = np .take_along_axis (results_d , sort_index , axis = 1 )
84100 results_i = np .take_along_axis (results_i , sort_index , axis = 1 )
85101 return results_d [:, 0 :k ], results_i [:, 0 :k ]
86102
87- def query_internal ( self , queries : np . ndarray , k , ** kwargs ):
88- raise NotImplementedError
89-
90- def query_additions ( self , queries : np . ndarray , k ):
103+ @ staticmethod
104+ def query_additions (
105+ queries : np . ndarray , k , dtype , update_arrays_uri , nthreads = 8
106+ ):
91107 assert queries .dtype == np .float32
92- additions_vectors , additions_external_ids = self .read_additions ()
108+ additions_vectors , additions_external_ids , updated_ids = Index .read_additions (
109+ update_arrays_uri
110+ )
93111 if additions_vectors is None :
94- return None , None
112+ return None , None , updated_ids
113+
95114 queries_m = array_to_matrix (np .transpose (queries ))
96115 d , i = query_vq_heap_pyarray (
97- array_to_matrix (np .transpose (additions_vectors ).astype (self . dtype )),
116+ array_to_matrix (np .transpose (additions_vectors ).astype (dtype )),
98117 queries_m ,
99118 StdVector_u64 (additions_external_ids ),
100119 k ,
101- 8 )
102- return np .transpose (np .array (d )), np .transpose (np .array (i ))
120+ nthreads ,
121+ )
122+ return np .transpose (np .array (d )), np .transpose (np .array (i )), updated_ids
123+
124+ @staticmethod
125+ def read_additions (update_arrays_uri ) -> (np .ndarray , np .array ):
126+ if update_arrays_uri is None :
127+ return None , None , np .array ([], np .uint64 )
128+ updates_array = tiledb .open (update_arrays_uri , mode = "r" )
129+ q = updates_array .query (attrs = ("vector" ,), coords = True )
130+ data = q [:]
131+ updates_array .close ()
132+ updated_ids = data ["external_id" ]
133+ additions_filter = [len (item ) > 0 for item in data ["vector" ]]
134+ if len (data ["external_id" ][additions_filter ]) > 0 :
135+ return (
136+ np .vstack (data ["vector" ][additions_filter ]),
137+ data ["external_id" ][additions_filter ],
138+ updated_ids
139+ )
140+ else :
141+ return None , None , updated_ids
142+
143+ def query_internal (self , queries : np .ndarray , k , ** kwargs ):
144+ raise NotImplementedError
103145
104146 def update (self , vector : np .array , external_id : np .uint64 ):
105147 updates_array = self .open_updates_array ()
106- vectors = np .empty ((1 ), dtype = 'O' )
148+ vectors = np .empty ((1 ), dtype = "O" )
107149 vectors [0 ] = vector
108- updates_array [external_id ] = {' vector' : vectors }
150+ updates_array [external_id ] = {" vector" : vectors }
109151 updates_array .close ()
110152 self .consolidate_update_fragments ()
111153
112154 def update_batch (self , vectors : np .ndarray , external_ids : np .array ):
113155 updates_array = self .open_updates_array ()
114- updates_array [external_ids ] = {' vector' : vectors }
156+ updates_array [external_ids ] = {" vector" : vectors }
115157 updates_array .close ()
116158 self .consolidate_update_fragments ()
117159
118160 def delete (self , external_id : np .uint64 ):
119161 updates_array = self .open_updates_array ()
120- deletes = np .empty ((1 ), dtype = 'O' )
162+ deletes = np .empty ((1 ), dtype = "O" )
121163 deletes [0 ] = np .array ([], dtype = self .dtype )
122- updates_array [external_id ] = { ' vector' : deletes }
164+ updates_array [external_id ] = { " vector" : deletes }
123165 updates_array .close ()
124166 self .consolidate_update_fragments ()
125167
126168 def delete_batch (self , external_ids : np .array ):
127169 updates_array = self .open_updates_array ()
128- deletes = np .empty ((len (external_ids )), dtype = 'O' )
170+ deletes = np .empty ((len (external_ids )), dtype = "O" )
129171 for i in range (len (external_ids )):
130172 deletes [i ] = np .array ([], dtype = self .dtype )
131- updates_array [external_ids ] = {' vector' : deletes }
173+ updates_array [external_ids ] = {" vector" : deletes }
132174 updates_array .close ()
133175 self .consolidate_update_fragments ()
134176
135177 def consolidate_update_fragments (self ):
136178 fragments_info = tiledb .array_fragments (self .update_arrays_uri )
137- if ( len (fragments_info ) > 10 ) :
179+ if len (fragments_info ) > 10 :
138180 tiledb .consolidate (self .update_arrays_uri )
139181 tiledb .vacuum (self .update_arrays_uri )
140182
141183 def get_updates_uri (self ):
142184 return self .update_arrays_uri
143185
144- def read_additions (self ) -> (np .ndarray , np .array ):
145- if self .update_arrays_uri is None :
146- return None , None
147- updates_array = tiledb .open (self .update_arrays_uri , mode = "r" )
148- q = updates_array .query (attrs = ('vector' ,), coords = True )
149- data = q [:]
150- additions_filter = [len (item ) > 0 for item in data ["vector" ]]
151- if len (data ["external_id" ][additions_filter ]) > 0 :
152- return np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
153- else :
154- return None , None
155- def read_updated_ids (self ) -> np .array :
156- if self .update_arrays_uri is None :
157- return np .array ([], np .uint64 )
158- updates_array = tiledb .open (self .update_arrays_uri , mode = "r" )
159- q = updates_array .query (attrs = ('vector' ,), coords = True )
160- data = q [:]
161- return data ["external_id" ]
162-
163186 def open_updates_array (self ):
164187 if self .update_arrays_uri is None :
165- updates_array_name = storage_formats [self .storage_version ]["UPDATES_ARRAY_NAME" ]
188+ updates_array_name = storage_formats [self .storage_version ][
189+ "UPDATES_ARRAY_NAME"
190+ ]
166191 updates_array_uri = f"{ self .group .uri } /{ updates_array_name } "
167192 if tiledb .array_exists (updates_array_uri ):
168193 raise RuntimeError (f"Array { updates_array_uri } already exists." )
169194 external_id_dim = tiledb .Dim (
170- name = "external_id" , domain = (0 , MAX_UINT64 - 1 ), dtype = np .dtype (np .uint64 )
195+ name = "external_id" ,
196+ domain = (0 , MAX_UINT64 - 1 ),
197+ dtype = np .dtype (np .uint64 ),
171198 )
172199 dom = tiledb .Domain (external_id_dim )
173200 vector_attr = tiledb .Attr (name = "vector" , dtype = self .dtype , var = True )
@@ -188,13 +215,18 @@ def open_updates_array(self):
188215
189216 def consolidate_updates (self ):
190217 from tiledb .vector_search .ingestion import ingest
218+
191219 new_index = ingest (
192220 index_type = self .index_type ,
193221 index_uri = self .uri ,
194222 size = self .size ,
195223 source_uri = self .db_uri ,
196224 external_ids_uri = self .ids_uri ,
197- updates_uri = self .update_arrays_uri
225+ updates_uri = self .update_arrays_uri ,
198226 )
199227 tiledb .Array .delete_array (self .update_arrays_uri )
228+ self .group .close ()
229+ self .group = tiledb .Group (self .uri , "w" , ctx = tiledb .Ctx (self .config ))
230+ self .group .remove (self .update_arrays_uri )
231+ self .group .close ()
200232 return new_index
0 commit comments