@@ -21,6 +21,7 @@ class Index:
2121    config: Optional[Mapping[str, Any]] 
2222        config dictionary, defaults to None 
2323    """ 
24+ 
2425    def  __init__ (
2526        self ,
2627        uri : str ,
@@ -43,20 +44,26 @@ def query(self, queries: np.ndarray, k, **kwargs):
4344            return  self .query_internal (queries , k , ** kwargs )
4445
4546        # Query with updates 
46-         updated_ids  =  set (self .read_updated_ids ())
47-         retrieval_k  =  k 
48-         if  len (updated_ids ) >  0 :
49-             retrieval_k  =  2 * k 
50- 
5147        # Perform the queries in parallel 
52-         kwargs ["nthreads" ] =  int (mp .cpu_count ()/ 2 )
48+         retrieval_k  =  2  *  k 
49+         kwargs ["nthreads" ] =  int (mp .cpu_count () /  2 )
5350        parent_conn , child_conn  =  mp .Pipe ()
5451        p  =  mp .Process (
5552            target = Index .query_additions ,
56-             args = (child_conn , queries , k , self .dtype , self .update_arrays_uri , int (mp .cpu_count ()/ 2 )))
53+             args = (
54+                 child_conn ,
55+                 queries ,
56+                 k ,
57+                 self .dtype ,
58+                 self .update_arrays_uri ,
59+                 int (mp .cpu_count () /  2 ),
60+             ),
61+         )
5762        p .start ()
58-         internal_results_d , internal_results_i  =  self .query_internal (queries , retrieval_k , ** kwargs )
59-         addition_results_d , addition_results_i  =  parent_conn .recv ()
63+         internal_results_d , internal_results_i  =  self .query_internal (
64+             queries , retrieval_k , ** kwargs 
65+         )
66+         addition_results_d , addition_results_i , updated_ids  =  parent_conn .recv ()
6067        p .join ()
6168
6269        # Filter updated vectors 
@@ -81,13 +88,15 @@ def query(self, queries: np.ndarray, k, **kwargs):
8188        for  query  in  addition_results_d :
8289            res_id  =  0 
8390            for  res  in  query :
84-                 if  addition_results_d [query_id , res_id ] ==  0  and  addition_results_i [query_id , res_id ] ==  0 :
91+                 if  (
92+                     addition_results_d [query_id , res_id ] ==  0 
93+                     and  addition_results_i [query_id , res_id ] ==  0 
94+                 ):
8595                    addition_results_d [query_id , res_id ] =  MAX_FLOAT_32 
8696                    addition_results_i [query_id , res_id ] =  MAX_UINT64 
8797                res_id  +=  1 
8898            query_id  +=  1 
8999
90- 
91100        results_d  =  np .hstack ((internal_results_d , addition_results_d ))
92101        results_i  =  np .hstack ((internal_results_i , addition_results_i ))
93102        sort_index  =  np .argsort (results_d , axis = 1 )
@@ -96,95 +105,103 @@ def query(self, queries: np.ndarray, k, **kwargs):
96105        return  results_d [:, 0 :k ], results_i [:, 0 :k ]
97106
98107    @staticmethod  
99-     def  query_additions (conn , queries : np .ndarray , k , dtype , update_arrays_uri , nthreads = 8 ):
108+     def  query_additions (
109+         conn , queries : np .ndarray , k , dtype , update_arrays_uri , nthreads = 8 
110+     ):
100111        assert  queries .dtype  ==  np .float32 
101-         additions_vectors , additions_external_ids  =  Index .read_additions (update_arrays_uri )
112+         additions_vectors , additions_external_ids , updated_ids  =  Index .read_additions (
113+             update_arrays_uri 
114+         )
102115        if  additions_vectors  is  None :
103-             return  None , None 
116+             conn .send (None , None , updated_ids )
117+             conn .close ()
118+             return 
119+ 
104120        queries_m  =  array_to_matrix (np .transpose (queries ))
105121        d , i  =  query_vq_heap_pyarray (
106122            array_to_matrix (np .transpose (additions_vectors ).astype (dtype )),
107123            queries_m ,
108124            StdVector_u64 (additions_external_ids ),
109125            k ,
110-             nthreads )
111-         conn .send ((np .transpose (np .array (d )), np .transpose (np .array (i ))))
126+             nthreads ,
127+         )
128+         conn .send ((np .transpose (np .array (d )), np .transpose (np .array (i )), updated_ids ))
112129        conn .close ()
113130
114131    @staticmethod  
115132    def  read_additions (update_arrays_uri ) ->  (np .ndarray , np .array ):
116133        if  update_arrays_uri  is  None :
117-             return  None , None 
134+             return  None , None ,  np . array ([],  np . uint64 ) 
118135        updates_array  =  tiledb .open (update_arrays_uri , mode = "r" )
119-         q  =  updates_array .query (attrs = (' vector'  ,), coords = True )
136+         q  =  updates_array .query (attrs = (" vector"  ,), coords = True )
120137        data  =  q [:]
121138        updates_array .close ()
139+         updated_ids  =  data ["external_id" ]
122140        additions_filter  =  [len (item ) >  0  for  item  in  data ["vector" ]]
123141        if  len (data ["external_id" ][additions_filter ]) >  0 :
124-             return  np .vstack (data ["vector" ][additions_filter ]), data ["external_id" ][additions_filter ]
142+             return  (
143+                 np .vstack (data ["vector" ][additions_filter ]),
144+                 data ["external_id" ][additions_filter ],
145+                 updated_ids 
146+             )
125147        else :
126-             return  None , None 
148+             return  None , None ,  updated_ids 
127149
128150    def  query_internal (self , queries : np .ndarray , k , ** kwargs ):
129151        raise  NotImplementedError 
130152
131153    def  update (self , vector : np .array , external_id : np .uint64 ):
132154        updates_array  =  self .open_updates_array ()
133-         vectors  =  np .empty ((1 ), dtype = 'O' )
155+         vectors  =  np .empty ((1 ), dtype = "O" )
134156        vectors [0 ] =  vector 
135-         updates_array [external_id ] =  {' vector'  : vectors }
157+         updates_array [external_id ] =  {" vector"  : vectors }
136158        updates_array .close ()
137159        self .consolidate_update_fragments ()
138160
139161    def  update_batch (self , vectors : np .ndarray , external_ids : np .array ):
140162        updates_array  =  self .open_updates_array ()
141-         updates_array [external_ids ] =  {' vector'  : vectors }
163+         updates_array [external_ids ] =  {" vector"  : vectors }
142164        updates_array .close ()
143165        self .consolidate_update_fragments ()
144166
145167    def  delete (self , external_id : np .uint64 ):
146168        updates_array  =  self .open_updates_array ()
147-         deletes  =  np .empty ((1 ), dtype = 'O' )
169+         deletes  =  np .empty ((1 ), dtype = "O" )
148170        deletes [0 ] =  np .array ([], dtype = self .dtype )
149-         updates_array [external_id ] =   { ' vector'  : deletes }
171+         updates_array [external_id ] =  { " vector"  : deletes }
150172        updates_array .close ()
151173        self .consolidate_update_fragments ()
152174
153175    def  delete_batch (self , external_ids : np .array ):
154176        updates_array  =  self .open_updates_array ()
155-         deletes  =  np .empty ((len (external_ids )), dtype = 'O' )
177+         deletes  =  np .empty ((len (external_ids )), dtype = "O" )
156178        for  i  in  range (len (external_ids )):
157179            deletes [i ] =  np .array ([], dtype = self .dtype )
158-         updates_array [external_ids ] =  {' vector'  : deletes }
180+         updates_array [external_ids ] =  {" vector"  : deletes }
159181        updates_array .close ()
160182        self .consolidate_update_fragments ()
161183
162184    def  consolidate_update_fragments (self ):
163185        fragments_info  =  tiledb .array_fragments (self .update_arrays_uri )
164-         if ( len (fragments_info ) >  10 ) :
186+         if   len (fragments_info ) >  10 :
165187            tiledb .consolidate (self .update_arrays_uri )
166188            tiledb .vacuum (self .update_arrays_uri )
167189
168190    def  get_updates_uri (self ):
169191        return  self .update_arrays_uri 
170192
171-     def  read_updated_ids (self ) ->  np .array :
172-         if  self .update_arrays_uri  is  None :
173-             return  np .array ([], np .uint64 )
174-         updates_array  =  tiledb .open (self .update_arrays_uri , mode = "r" )
175-         q  =  updates_array .query (attrs = ('vector' ,), coords = True )
176-         data  =  q [:]
177-         updates_array .close ()
178-         return  data ["external_id" ]
179- 
180193    def  open_updates_array (self ):
181194        if  self .update_arrays_uri  is  None :
182-             updates_array_name  =  storage_formats [self .storage_version ]["UPDATES_ARRAY_NAME" ]
195+             updates_array_name  =  storage_formats [self .storage_version ][
196+                 "UPDATES_ARRAY_NAME" 
197+             ]
183198            updates_array_uri  =  f"{ self .group .uri }  /{ updates_array_name }  " 
184199            if  tiledb .array_exists (updates_array_uri ):
185200                raise  RuntimeError (f"Array { updates_array_uri }   already exists." )
186201            external_id_dim  =  tiledb .Dim (
187-                 name = "external_id" , domain = (0 , MAX_UINT64 - 1 ), dtype = np .dtype (np .uint64 )
202+                 name = "external_id" ,
203+                 domain = (0 , MAX_UINT64  -  1 ),
204+                 dtype = np .dtype (np .uint64 ),
188205            )
189206            dom  =  tiledb .Domain (external_id_dim )
190207            vector_attr  =  tiledb .Attr (name = "vector" , dtype = self .dtype , var = True )
@@ -205,13 +222,14 @@ def open_updates_array(self):
205222
206223    def  consolidate_updates (self ):
207224        from  tiledb .vector_search .ingestion  import  ingest 
225+ 
208226        new_index  =  ingest (
209227            index_type = self .index_type ,
210228            index_uri = self .uri ,
211229            size = self .size ,
212230            source_uri = self .db_uri ,
213231            external_ids_uri = self .ids_uri ,
214-             updates_uri = self .update_arrays_uri 
232+             updates_uri = self .update_arrays_uri , 
215233        )
216234        tiledb .Array .delete_array (self .update_arrays_uri )
217235        self .group .close ()
0 commit comments