@@ -136,7 +136,7 @@ def query(self, queries: np.ndarray, k, **kwargs):
136136 raise TypeError (f"A query in queries has { query_dimensions } dimensions, but the indexed data had { self .dimensions } dimensions" )
137137
138138 with tiledb .scope_ctx (ctx_or_config = self .config ):
139- if not self .group . meta [ " has_updates" ] :
139+ if not self .has_updates () :
140140 if self .query_base_array :
141141 return self .query_internal (queries , k , ** kwargs )
142142 else :
@@ -268,13 +268,19 @@ def get_dimensions(self):
268268 def query_internal (self , queries : np .ndarray , k , ** kwargs ):
269269 raise NotImplementedError
270270
271- def update (self , vector : np .array , external_id : np .uint64 , timestamp : int = None ):
271+ def has_updates (self ):
272+ return self .group .meta ["has_updates" ]
273+
274+ def set_has_updates (self , has_updates : bool = True ):
272275 if not self .group .meta ["has_updates" ]:
273276 self .group .close ()
274277 self .group = tiledb .Group (self .uri , "w" , ctx = tiledb .Ctx (self .config ))
275- self .group .meta ["has_updates" ] = True
278+ self .group .meta ["has_updates" ] = has_updates
276279 self .group .close ()
277280 self .group = tiledb .Group (self .uri , "r" , ctx = tiledb .Ctx (self .config ))
281+
282+ def update (self , vector : np .array , external_id : np .uint64 , timestamp : int = None ):
283+ self .set_has_updates ()
278284 updates_array = self .open_updates_array (timestamp = timestamp )
279285 vectors = np .empty ((1 ), dtype = "O" )
280286 vectors [0 ] = vector
@@ -285,24 +291,14 @@ def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None
285291 def update_batch (
286292 self , vectors : np .ndarray , external_ids : np .array , timestamp : int = None
287293 ):
288- if not self .group .meta ["has_updates" ]:
289- self .group .close ()
290- self .group = tiledb .Group (self .uri , "w" , ctx = tiledb .Ctx (self .config ))
291- self .group .meta ["has_updates" ] = True
292- self .group .close ()
293- self .group = tiledb .Group (self .uri , "r" , ctx = tiledb .Ctx (self .config ))
294+ self .set_has_updates ()
294295 updates_array = self .open_updates_array (timestamp = timestamp )
295296 updates_array [external_ids ] = {"vector" : vectors }
296297 updates_array .close ()
297298 self .consolidate_update_fragments ()
298299
299300 def delete (self , external_id : np .uint64 , timestamp : int = None ):
300- if not self .group .meta ["has_updates" ]:
301- self .group .close ()
302- self .group = tiledb .Group (self .uri , "w" , ctx = tiledb .Ctx (self .config ))
303- self .group .meta ["has_updates" ] = True
304- self .group .close ()
305- self .group = tiledb .Group (self .uri , "r" , ctx = tiledb .Ctx (self .config ))
301+ self .set_has_updates ()
306302 updates_array = self .open_updates_array (timestamp = timestamp )
307303 deletes = np .empty ((1 ), dtype = "O" )
308304 deletes [0 ] = np .array ([], dtype = self .dtype )
@@ -311,12 +307,7 @@ def delete(self, external_id: np.uint64, timestamp: int = None):
311307 self .consolidate_update_fragments ()
312308
313309 def delete_batch (self , external_ids : np .array , timestamp : int = None ):
314- if not self .group .meta ["has_updates" ]:
315- self .group .close ()
316- self .group = tiledb .Group (self .uri , "w" , ctx = tiledb .Ctx (self .config ))
317- self .group .meta ["has_updates" ] = True
318- self .group .close ()
319- self .group = tiledb .Group (self .uri , "r" , ctx = tiledb .Ctx (self .config ))
310+ self .set_has_updates ()
320311 updates_array = self .open_updates_array (timestamp = timestamp )
321312 deletes = np .empty ((len (external_ids )), dtype = "O" )
322313 for i in range (len (external_ids )):
0 commit comments