Skip to content

Commit 94a0b8a

Browse files
author
Nikos Papailiou
committed
Add util functions for checking updates existence
1 parent f35750a commit 94a0b8a

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def query(self, queries: np.ndarray, k, **kwargs):
136136
raise TypeError(f"A query in queries has {query_dimensions} dimensions, but the indexed data had {self.dimensions} dimensions")
137137

138138
with tiledb.scope_ctx(ctx_or_config=self.config):
139-
if not self.group.meta["has_updates"]:
139+
if not self.has_updates():
140140
if self.query_base_array:
141141
return self.query_internal(queries, k, **kwargs)
142142
else:
@@ -268,13 +268,19 @@ def get_dimensions(self):
268268
def query_internal(self, queries: np.ndarray, k, **kwargs):
269269
raise NotImplementedError
270270

271-
def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None):
271+
def has_updates(self):
272+
return self.group.meta["has_updates"]
273+
274+
def set_has_updates(self, has_updates: bool = True):
272275
if not self.group.meta["has_updates"]:
273276
self.group.close()
274277
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
275-
self.group.meta["has_updates"] = True
278+
self.group.meta["has_updates"] = has_updates
276279
self.group.close()
277280
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
281+
282+
def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None):
283+
self.set_has_updates()
278284
updates_array = self.open_updates_array(timestamp=timestamp)
279285
vectors = np.empty((1), dtype="O")
280286
vectors[0] = vector
@@ -285,24 +291,14 @@ def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None
285291
def update_batch(
286292
self, vectors: np.ndarray, external_ids: np.array, timestamp: int = None
287293
):
288-
if not self.group.meta["has_updates"]:
289-
self.group.close()
290-
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
291-
self.group.meta["has_updates"] = True
292-
self.group.close()
293-
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
294+
self.set_has_updates()
294295
updates_array = self.open_updates_array(timestamp=timestamp)
295296
updates_array[external_ids] = {"vector": vectors}
296297
updates_array.close()
297298
self.consolidate_update_fragments()
298299

299300
def delete(self, external_id: np.uint64, timestamp: int = None):
300-
if not self.group.meta["has_updates"]:
301-
self.group.close()
302-
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
303-
self.group.meta["has_updates"] = True
304-
self.group.close()
305-
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
301+
self.set_has_updates()
306302
updates_array = self.open_updates_array(timestamp=timestamp)
307303
deletes = np.empty((1), dtype="O")
308304
deletes[0] = np.array([], dtype=self.dtype)
@@ -311,12 +307,7 @@ def delete(self, external_id: np.uint64, timestamp: int = None):
311307
self.consolidate_update_fragments()
312308

313309
def delete_batch(self, external_ids: np.array, timestamp: int = None):
314-
if not self.group.meta["has_updates"]:
315-
self.group.close()
316-
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
317-
self.group.meta["has_updates"] = True
318-
self.group.close()
319-
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
310+
self.set_has_updates()
320311
updates_array = self.open_updates_array(timestamp=timestamp)
321312
deletes = np.empty((len(external_ids)), dtype="O")
322313
for i in range(len(external_ids)):

0 commit comments

Comments
 (0)