Skip to content

Commit f35750a

Browse files
author
Nikos Papailiou
committed
Merge branch 'main' into npapa/fix-update-uri-bug
2 parents ae44c50 + f7b8c0b commit f35750a

File tree

12 files changed

+610
-174
lines changed

12 files changed

+610
-174
lines changed

.github/workflows/build_wheels.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name: Build wheels
22

33
on:
4+
workflow_dispatch:
45
push:
56
branches:
67
- release-*

apis/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ classifiers = [
1919

2020
dependencies = [
2121
"tiledb-cloud>=0.11",
22-
"tiledb>=0.23.1",
22+
"tiledb>=0.25.0",
2323
"typing-extensions", # for tiledb-cloud indirect, x-ref https://github.com/TileDB-Inc/TileDB-Cloud-Py/pull/428
2424
"scikit-learn",
2525
]

apis/python/requirements-py.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
numpy==1.24.3
22
tiledb-cloud==0.10.24
3-
tiledb==0.23.1
3+
tiledb==0.25.0

apis/python/src/tiledb/vector_search/index.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def query(self, queries: np.ndarray, k, **kwargs):
136136
raise TypeError(f"A query in queries has {query_dimensions} dimensions, but the indexed data had {self.dimensions} dimensions")
137137

138138
with tiledb.scope_ctx(ctx_or_config=self.config):
139-
if not tiledb.array_exists(self.updates_array_uri):
139+
if not self.group.meta["has_updates"]:
140140
if self.query_base_array:
141141
return self.query_internal(queries, k, **kwargs)
142142
else:
@@ -269,6 +269,12 @@ def query_internal(self, queries: np.ndarray, k, **kwargs):
269269
raise NotImplementedError
270270

271271
def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None):
272+
if not self.group.meta["has_updates"]:
273+
self.group.close()
274+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
275+
self.group.meta["has_updates"] = True
276+
self.group.close()
277+
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
272278
updates_array = self.open_updates_array(timestamp=timestamp)
273279
vectors = np.empty((1), dtype="O")
274280
vectors[0] = vector
@@ -279,12 +285,24 @@ def update(self, vector: np.array, external_id: np.uint64, timestamp: int = None
279285
def update_batch(
280286
self, vectors: np.ndarray, external_ids: np.array, timestamp: int = None
281287
):
288+
if not self.group.meta["has_updates"]:
289+
self.group.close()
290+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
291+
self.group.meta["has_updates"] = True
292+
self.group.close()
293+
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
282294
updates_array = self.open_updates_array(timestamp=timestamp)
283295
updates_array[external_ids] = {"vector": vectors}
284296
updates_array.close()
285297
self.consolidate_update_fragments()
286298

287299
def delete(self, external_id: np.uint64, timestamp: int = None):
300+
if not self.group.meta["has_updates"]:
301+
self.group.close()
302+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
303+
self.group.meta["has_updates"] = True
304+
self.group.close()
305+
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
288306
updates_array = self.open_updates_array(timestamp=timestamp)
289307
deletes = np.empty((1), dtype="O")
290308
deletes[0] = np.array([], dtype=self.dtype)
@@ -293,6 +311,12 @@ def delete(self, external_id: np.uint64, timestamp: int = None):
293311
self.consolidate_update_fragments()
294312

295313
def delete_batch(self, external_ids: np.array, timestamp: int = None):
314+
if not self.group.meta["has_updates"]:
315+
self.group.close()
316+
self.group = tiledb.Group(self.uri, "w", ctx=tiledb.Ctx(self.config))
317+
self.group.meta["has_updates"] = True
318+
self.group.close()
319+
self.group = tiledb.Group(self.uri, "r", ctx=tiledb.Ctx(self.config))
296320
updates_array = self.open_updates_array(timestamp=timestamp)
297321
deletes = np.empty((len(external_ids)), dtype="O")
298322
for i in range(len(external_ids)):
@@ -356,9 +380,22 @@ def open_updates_array(self, timestamp: int = None):
356380
timestamp = int(time.time() * 1000)
357381
return tiledb.open(self.updates_array_uri, mode="w", timestamp=timestamp)
358382

359-
def consolidate_updates(self, **kwargs):
383+
def consolidate_updates(
384+
self,
385+
retrain_index: bool = False,
386+
**kwargs
387+
):
388+
"""
389+
Parameters
390+
----------
391+
retrain_index: bool
392+
If true, retrain the index. If false, reuse data from the previous index.
393+
For IVF_FLAT retraining means we will recompute the centroids - when doing so you can
394+
pass any ingest() arguments used to configure computing centroids and we will use them
395+
when recomputing the centroids. Otherwise, if false, we will reuse the centroids from
396+
the previous index.
397+
"""
360398
from tiledb.vector_search.ingestion import ingest
361-
362399
fragments_info = tiledb.array_fragments(
363400
self.updates_array_uri, ctx=tiledb.Ctx(self.config)
364401
)
@@ -373,6 +410,15 @@ def consolidate_updates(self, **kwargs):
373410
tiledb.consolidate(self.updates_array_uri, config=conf)
374411
tiledb.vacuum(self.updates_array_uri, config=conf)
375412

413+
# We don't copy the centroids if self.partitions=0 because this means our index was previously empty.
414+
should_pass_copy_centroids_uri = self.index_type == "IVF_FLAT" and not retrain_index and self.partitions > 0
415+
if should_pass_copy_centroids_uri:
416+
# Make sure the user didn't pass an incorrect number of partitions.
417+
if 'partitions' in kwargs and self.partitions != kwargs['partitions']:
418+
raise ValueError(f"The passed partitions={kwargs['partitions']} is different than the number of partitions ({self.partitions}) from when the index was created - this is an issue because with retrain_index=True, the partitions from the previous index will be used; to fix, set retrain_index=False, don't pass partitions, or pass the correct number of partitions.")
419+
# We pass partitions through kwargs so that we don't pass it twice.
420+
kwargs['partitions'] = self.partitions
421+
376422
new_index = ingest(
377423
index_type=self.index_type,
378424
index_uri=self.uri,
@@ -383,6 +429,7 @@ def consolidate_updates(self, **kwargs):
383429
updates_uri=self.updates_array_uri,
384430
index_timestamp=max_timestamp,
385431
storage_version=self.storage_version,
432+
copy_centroids_uri=self.centroids_uri if should_pass_copy_centroids_uri else None,
386433
config=self.config,
387434
**kwargs,
388435
)
@@ -508,4 +555,5 @@ def create_metadata(
508555
group.meta["index_type"] = index_type
509556
group.meta["base_sizes"] = json.dumps([0])
510557
group.meta["ingestion_timestamps"] = json.dumps([0])
558+
group.meta["has_updates"] = False
511559
group.close()

0 commit comments

Comments
 (0)