1+ import concurrent .futures as futures
2+ import os
13import numpy as np
24import sys
35
@@ -20,6 +22,7 @@ class Index:
2022 config: Optional[Mapping[str, Any]]
2123 config dictionary, defaults to None
2224 """
25+
2326 def __init__ (
2427 self ,
2528 uri : str ,
@@ -36,16 +39,28 @@ def __init__(
3639 self .storage_version = self .group .meta .get ("storage_version" , "0.1" )
3740 self .update_arrays_uri = None
3841 self .index_version = self .group .meta .get ("index_version" , "" )
39-
42+ self . thread_executor = futures . ThreadPoolExecutor ()
4043
4144 def query (self , queries : np .ndarray , k , ** kwargs ):
42- updated_ids = set (self .read_updated_ids ())
43- retrieval_k = k
44- if len (updated_ids ) > 0 :
45- retrieval_k = 2 * k
46- internal_results_d , internal_results_i = self .query_internal (queries , retrieval_k , ** kwargs )
4745 if self .update_arrays_uri is None :
48- return internal_results_d [:, 0 :k ], internal_results_i [:, 0 :k ]
46+ return self .query_internal (queries , k , ** kwargs )
47+
48+ # Query with updates
49+ # Perform the queries in parallel
50+ retrieval_k = 2 * k
51+ kwargs ["nthreads" ] = int (os .cpu_count () / 2 )
52+ future = self .thread_executor .submit (
53+ Index .query_additions ,
54+ queries ,
55+ k ,
56+ self .dtype ,
57+ self .update_arrays_uri ,
58+ int (os .cpu_count () / 2 ),
59+ )
60+ internal_results_d , internal_results_i = self .query_internal (
61+ queries , retrieval_k , ** kwargs
62+ )
63+ addition_results_d , addition_results_i , updated_ids = future .result ()
4964
5065 # Filter updated vectors
5166 query_id = 0
@@ -55,119 +70,137 @@ def query(self, queries: np.ndarray, k, **kwargs):
5570 if res in updated_ids :
5671 internal_results_d [query_id , res_id ] = MAX_FLOAT_32
5772 internal_results_i [query_id , res_id ] = MAX_UINT64
73+ if (
74+ internal_results_d [query_id , res_id ] == 0
75+ and internal_results_i [query_id , res_id ] == 0
76+ ):
77+ internal_results_d [query_id , res_id ] = MAX_FLOAT_32
78+ internal_results_i [query_id , res_id ] = MAX_UINT64
5879 res_id += 1
5980 query_id += 1
6081 sort_index = np .argsort (internal_results_d , axis = 1 )
6182 internal_results_d = np .take_along_axis (internal_results_d , sort_index , axis = 1 )
6283 internal_results_i = np .take_along_axis (internal_results_i , sort_index , axis = 1 )
6384
6485 # Merge update results
65- addition_results_d , addition_results_i = self .query_additions (queries , k )
6686 if addition_results_d is None :
6787 return internal_results_d [:, 0 :k ], internal_results_i [:, 0 :k ]
6888
6989 query_id = 0
7090 for query in addition_results_d :
7191 res_id = 0
7292 for res in query :
73- if addition_results_d [query_id , res_id ] == 0 and addition_results_i [query_id , res_id ] == 0 :
93+ if (
94+ addition_results_d [query_id , res_id ] == 0
95+ and addition_results_i [query_id , res_id ] == 0
96+ ):
7497 addition_results_d [query_id , res_id ] = MAX_FLOAT_32
7598 addition_results_i [query_id , res_id ] = MAX_UINT64
7699 res_id += 1
77100 query_id += 1
78101
79-
80102 results_d = np .hstack ((internal_results_d , addition_results_d ))
81103 results_i = np .hstack ((internal_results_i , addition_results_i ))
82104 sort_index = np .argsort (results_d , axis = 1 )
83105 results_d = np .take_along_axis (results_d , sort_index , axis = 1 )
84106 results_i = np .take_along_axis (results_i , sort_index , axis = 1 )
85107 return results_d [:, 0 :k ], results_i [:, 0 :k ]
86108
87- def query_internal ( self , queries : np . ndarray , k , ** kwargs ):
88- raise NotImplementedError
89-
90- def query_additions ( self , queries : np . ndarray , k ):
109+ @ staticmethod
110+ def query_additions (
111+ queries : np . ndarray , k , dtype , update_arrays_uri , nthreads = 8
112+ ):
91113 assert queries .dtype == np .float32
92- additions_vectors , additions_external_ids = self .read_additions ()
114+ additions_vectors , additions_external_ids , updated_ids = Index .read_additions (
115+ update_arrays_uri
116+ )
93117 if additions_vectors is None :
94- return None , None
118+ return None , None , updated_ids
119+
95120 queries_m = array_to_matrix (np .transpose (queries ))
96121 d , i = query_vq_heap_pyarray (
97- array_to_matrix (np .transpose (additions_vectors ).astype (self . dtype )),
122+ array_to_matrix (np .transpose (additions_vectors ).astype (dtype )),
98123 queries_m ,
99124 StdVector_u64 (additions_external_ids ),
100125 k ,
101- 8 )
102- return np .transpose (np .array (d )), np .transpose (np .array (i ))
126+ nthreads ,
127+ )
128+ return np .transpose (np .array (d )), np .transpose (np .array (i )), updated_ids
129+
130+ @staticmethod
131+ def read_additions (update_arrays_uri ) -> (np .ndarray , np .array ):
132+ if update_arrays_uri is None :
133+ return None , None , np .array ([], np .uint64 )
134+ updates_array = tiledb .open (update_arrays_uri , mode = "r" )
135+ q = updates_array .query (attrs = ("vector" ,), coords = True )
136+ data = q [:]
137+ updates_array .close ()
138+ updated_ids = data ["external_id" ]
139+ additions_filter = [len (item ) > 0 for item in data ["vector" ]]
140+ if len (data ["external_id" ][additions_filter ]) > 0 :
141+ return (
142+ np .vstack (data ["vector" ][additions_filter ]),
143+ data ["external_id" ][additions_filter ],
144+ updated_ids
145+ )
146+ else :
147+ return None , None , updated_ids
148+
149+ def query_internal (self , queries : np .ndarray , k , ** kwargs ):
150+ raise NotImplementedError
103151
104152 def update (self , vector : np .array , external_id : np .uint64 ):
105153 updates_array = self .open_updates_array ()
106- vectors = np .empty ((1 ), dtype = 'O' )
154+ vectors = np .empty ((1 ), dtype = "O" )
107155 vectors [0 ] = vector
108- updates_array [external_id ] = {' vector' : vectors }
156+ updates_array [external_id ] = {" vector" : vectors }
109157 updates_array .close ()
110158 self .consolidate_update_fragments ()
111159
112160 def update_batch (self , vectors : np .ndarray , external_ids : np .array ):
113161 updates_array = self .open_updates_array ()
114- updates_array [external_ids ] = {' vector' : vectors }
162+ updates_array [external_ids ] = {" vector" : vectors }
115163 updates_array .close ()
116164 self .consolidate_update_fragments ()
117165
118166 def delete (self , external_id : np .uint64 ):
119167 updates_array = self .open_updates_array ()
120- deletes = np .empty ((1 ), dtype = 'O' )
168+ deletes = np .empty ((1 ), dtype = "O" )
121169 deletes [0 ] = np .array ([], dtype = self .dtype )
122- updates_array [external_id ] = { ' vector' : deletes }
170+ updates_array [external_id ] = { " vector" : deletes }
123171 updates_array .close ()
124172 self .consolidate_update_fragments ()
125173
126174 def delete_batch (self , external_ids : np .array ):
127175 updates_array = self .open_updates_array ()
128- deletes = np .empty ((len (external_ids )), dtype = 'O' )
176+ deletes = np .empty ((len (external_ids )), dtype = "O" )
129177 for i in range (len (external_ids )):
130178 deletes [i ] = np .array ([], dtype = self .dtype )
131- updates_array [external_ids ] = {' vector' : deletes }
179+ updates_array [external_ids ] = {" vector" : deletes }
132180 updates_array .close ()
133181 self .consolidate_update_fragments ()
134182
135183 def consolidate_update_fragments (self ):
136184 fragments_info = tiledb .array_fragments (self .update_arrays_uri )
137- if ( len (fragments_info ) > 10 ) :
185+ if len (fragments_info ) > 10 :
138186 tiledb .consolidate (self .update_arrays_uri )
139187 tiledb .vacuum (self .update_arrays_uri )
140188
141189 def get_updates_uri (self ):
142190 return self .update_arrays_uri
143191
144- def read_additions (self ) -> (np .ndarray , np .array ):
145- if self .update_arrays_uri is None :
146- return None , None
147- updates_array = tiledb .open (self .update_arrays_uri , mode = "r" )
148- q = updates_array .query (attrs = ('vector' ,), coords = True )
149- data = q [:]
150- additions_filter = [len (item ) > 0 for item in data ["vector" ]]
151- if len (data ["external_id" ][additions_filter ]) > 0 :
152- return np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
153- else :
154- return None , None
155- def read_updated_ids (self ) -> np .array :
156- if self .update_arrays_uri is None :
157- return np .array ([], np .uint64 )
158- updates_array = tiledb .open (self .update_arrays_uri , mode = "r" )
159- q = updates_array .query (attrs = ('vector' ,), coords = True )
160- data = q [:]
161- return data ["external_id" ]
162-
163192 def open_updates_array (self ):
164193 if self .update_arrays_uri is None :
165- updates_array_name = storage_formats [self .storage_version ]["UPDATES_ARRAY_NAME" ]
194+ updates_array_name = storage_formats [self .storage_version ][
195+ "UPDATES_ARRAY_NAME"
196+ ]
166197 updates_array_uri = f"{ self .group .uri } /{ updates_array_name } "
167198 if tiledb .array_exists (updates_array_uri ):
168199 raise RuntimeError (f"Array { updates_array_uri } already exists." )
169200 external_id_dim = tiledb .Dim (
170- name = "external_id" , domain = (0 , MAX_UINT64 - 1 ), dtype = np .dtype (np .uint64 )
201+ name = "external_id" ,
202+ domain = (0 , MAX_UINT64 - 1 ),
203+ dtype = np .dtype (np .uint64 ),
171204 )
172205 dom = tiledb .Domain (external_id_dim )
173206 vector_attr = tiledb .Attr (name = "vector" , dtype = self .dtype , var = True )
@@ -188,13 +221,30 @@ def open_updates_array(self):
188221
189222 def consolidate_updates (self ):
190223 from tiledb .vector_search .ingestion import ingest
224+
191225 new_index = ingest (
192226 index_type = self .index_type ,
193227 index_uri = self .uri ,
194228 size = self .size ,
195229 source_uri = self .db_uri ,
196230 external_ids_uri = self .ids_uri ,
197- updates_uri = self .update_arrays_uri
231+ updates_uri = self .update_arrays_uri ,
198232 )
199233 tiledb .Array .delete_array (self .update_arrays_uri )
234+ self .group .close ()
235+ self .group = tiledb .Group (self .uri , "w" , ctx = tiledb .Ctx (self .config ))
236+ self .group .remove (self .update_arrays_uri )
237+ self .group .close ()
200238 return new_index
239+
240+ @staticmethod
241+ def delete_index (uri , config ):
242+ try :
243+ group = tiledb .Group (uri , "m" , config = config )
244+ except tiledb .TileDBError as err :
245+ message = str (err )
246+ if "group does not exist" in message :
247+ return
248+ else :
249+ raise err
250+ group .delete ()
0 commit comments