Skip to content

Commit bb433db

Browse files
authored
Have consolidate_updates() reuse existing centroids by default with an option to re-compute them (#178)
1 parent 9c029e9 commit bb433db

File tree

3 files changed

+111
-8
lines changed

3 files changed

+111
-8
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,22 @@ def open_updates_array(self, timestamp: int = None):
354354
timestamp = int(time.time() * 1000)
355355
return tiledb.open(self.updates_array_uri, mode="w", timestamp=timestamp)
356356

357-
def consolidate_updates(self, **kwargs):
357+
def consolidate_updates(
358+
self,
359+
retrain_index: bool = False,
360+
**kwargs
361+
):
362+
"""
363+
Parameters
364+
----------
365+
retrain_index: bool
366+
If true, retrain the index. If false, reuse data from the previous index.
367+
For IVF_FLAT retraining means we will recompute the centroids - when doing so you can
368+
pass any ingest() arguments used to configure computing centroids and we will use them
369+
when recomputing the centroids. Otherwise, if false, we will reuse the centroids from
370+
the previous index.
371+
"""
358372
from tiledb.vector_search.ingestion import ingest
359-
360373
fragments_info = tiledb.array_fragments(
361374
self.updates_array_uri, ctx=tiledb.Ctx(self.config)
362375
)
@@ -371,6 +384,15 @@ def consolidate_updates(self, **kwargs):
371384
tiledb.consolidate(self.updates_array_uri, config=conf)
372385
tiledb.vacuum(self.updates_array_uri, config=conf)
373386

387+
# We don't copy the centroids if self.partitions=0 because this means our index was previously empty.
388+
should_pass_copy_centroids_uri = self.index_type == "IVF_FLAT" and not retrain_index and self.partitions > 0
389+
if should_pass_copy_centroids_uri:
390+
# Make sure the user didn't pass an incorrect number of partitions.
391+
if 'partitions' in kwargs and self.partitions != kwargs['partitions']:
392+
raise ValueError(f"The passed partitions={kwargs['partitions']} is different than the number of partitions ({self.partitions}) from when the index was created - this is an issue because with retrain_index=True, the partitions from the previous index will be used; to fix, set retrain_index=False, don't pass partitions, or pass the correct number of partitions.")
393+
# We pass partitions through kwargs so that we don't pass it twice.
394+
kwargs['partitions'] = self.partitions
395+
374396
new_index = ingest(
375397
index_type=self.index_type,
376398
index_uri=self.uri,
@@ -381,6 +403,7 @@ def consolidate_updates(self, **kwargs):
381403
updates_uri=self.updates_array_uri,
382404
index_timestamp=max_timestamp,
383405
storage_version=self.storage_version,
406+
copy_centroids_uri=self.centroids_uri if should_pass_copy_centroids_uri else None,
384407
config=self.config,
385408
**kwargs,
386409
)

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def ingest(
159159
if training_sample_size < -1:
160160
raise ValueError("training_sample_size should either be positive or -1 (to auto-configure based on the dataset sizes)")
161161

162+
if copy_centroids_uri is not None and training_sample_size != -1:
163+
raise ValueError("training_sample_size should not be provided alongside copy_centroids_uri")
164+
if copy_centroids_uri is not None and partitions == -1:
165+
raise ValueError("partitions should be provided if copy_centroids_uri is provided (set partitions to the number of centroids in copy_centroids_uri)")
166+
162167
if index_type != "IVF_FLAT" and training_sample_size != -1:
163168
raise ValueError("training_sample_size should only be provided with index_type IVF_FLAT")
164169
for variable in ["copy_centroids_uri", "training_input_vectors", "training_source_uri", "training_source_type"]:

apis/python/test/test_ingestion.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def test_ivf_flat_ingestion_with_updates(tmp_path):
365365
_, result = index.query(queries, k=k, nprobe=nprobe)
366366
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
367367

368-
index = index.consolidate_updates(partitions=20)
368+
index = index.consolidate_updates(retrain_index=True, partitions=20)
369369
_, result = index.query(queries, k=k, nprobe=20)
370370
assert accuracy(result, gt_i, updated_ids=updated_ids) == 1.0
371371

@@ -733,7 +733,7 @@ def test_storage_versions(tmp_path):
733733
_, result = index.query(queries, k=k)
734734
assert accuracy(result, gt_i, updated_ids=updated_ids) >= MINIMUM_ACCURACY
735735

736-
index = index.consolidate_updates(partitions=20)
736+
index = index.consolidate_updates(retrain_index=True, partitions=20)
737737
_, result = index.query(queries, k=k)
738738
assert accuracy(result, gt_i, updated_ids=updated_ids) >= MINIMUM_ACCURACY
739739

@@ -776,7 +776,8 @@ def test_copy_centroids_uri(tmp_path):
776776
index_type="IVF_FLAT",
777777
index_uri=index_uri,
778778
input_vectors=data,
779-
copy_centroids_uri=centroids_uri
779+
copy_centroids_uri=centroids_uri,
780+
partitions=centroids_in_size
780781
)
781782

782783
# Query the index.
@@ -931,8 +932,48 @@ def test_ingest_with_training_source_uri_tdb(tmp_path):
931932
queries = np.array([data.transpose()[1]], dtype=np.float32)
932933
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1]])
933934

935+
update_vectors = np.empty([3], dtype=object)
936+
update_vectors[0] = np.array([6.0, 6.1, 6.2, 6.3], dtype=np.dtype(np.float32))
937+
update_vectors[1] = np.array([7.0, 7.1, 7.2, 7.3], dtype=np.dtype(np.float32))
938+
update_vectors[2] = np.array([8.0, 8.1, 8.2, 8.3], dtype=np.dtype(np.float32))
939+
index.update_batch(vectors=update_vectors, external_ids=np.array([1000, 1001, 1002]))
940+
941+
index = index.consolidate_updates()
942+
943+
queries = np.array([update_vectors[2]], dtype=np.float32)
944+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1002]])
945+
946+
################################################################################################
947+
# Test we can load the index again and query, update, and consolidate.
948+
################################################################################################
949+
# Load the index again and query.
934950
index = IVFFlatIndex(uri=index_uri)
935-
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1]])
951+
952+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1002]])
953+
954+
# Update the index and query.
955+
update_vectors = np.empty([2], dtype=object)
956+
update_vectors[0] = np.array([9.0, 9.1, 9.2, 9.3], dtype=np.dtype(np.float32))
957+
update_vectors[1] = np.array([10.0, 10.1, 10.2, 10.3], dtype=np.dtype(np.float32))
958+
index.update_batch(vectors=update_vectors, external_ids=np.array([1003, 1004]))
959+
index = index.consolidate_updates()
960+
961+
queries = np.array([update_vectors[0]], dtype=np.float32)
962+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1003]])
963+
964+
# Clear the index history, load, update, and query.
965+
Index.clear_history(uri=index_uri, timestamp=index.latest_ingestion_timestamp - 1)
966+
967+
index = IVFFlatIndex(uri=index_uri)
968+
969+
update_vectors = np.empty([2], dtype=object)
970+
update_vectors[0] = np.array([11.0, 11.1, 11.2, 11.3], dtype=np.dtype(np.float32))
971+
update_vectors[1] = np.array([12.0, 12.1, 12.2, 12.3], dtype=np.dtype(np.float32))
972+
index.update_batch(vectors=update_vectors, external_ids=np.array([1003, 1004]))
973+
index = index.consolidate_updates()
974+
975+
queries = np.array([update_vectors[0]], dtype=np.float32)
976+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1003]])
936977

937978
###############################################################################################
938979
# Also test that we can ingest with training_source_type.
@@ -984,5 +1025,39 @@ def test_ingest_with_training_source_uri_numpy(tmp_path):
9841025
queries = np.array([data[1]], dtype=np.float32)
9851026
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1]])
9861027

987-
index = IVFFlatIndex(uri=index_uri)
988-
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1]])
1028+
update_vectors = np.empty([3], dtype=object)
1029+
update_vectors[0] = np.array([6.0, 6.1, 6.2, 6.3], dtype=np.dtype(np.float32))
1030+
update_vectors[1] = np.array([7.0, 7.1, 7.2, 7.3], dtype=np.dtype(np.float32))
1031+
update_vectors[2] = np.array([8.0, 8.1, 8.2, 8.3], dtype=np.dtype(np.float32))
1032+
index.update_batch(vectors=update_vectors, external_ids=np.array([1000, 1001, 1002]))
1033+
1034+
index = index.consolidate_updates()
1035+
1036+
queries = np.array([update_vectors[2]], dtype=np.float32)
1037+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1002]])
1038+
1039+
################################################################################################
1040+
# Test we can load the index again and query, update, and consolidate.
1041+
################################################################################################
1042+
index_ram = IVFFlatIndex(uri=index_uri)
1043+
1044+
queries = np.array([data[1]], dtype=np.float32)
1045+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1]])
1046+
1047+
update_vectors = np.empty([2], dtype=object)
1048+
update_vectors[0] = np.array([9.0, 9.1, 9.2, 9.3], dtype=np.dtype(np.float32))
1049+
update_vectors[1] = np.array([10.0, 10.1, 10.2, 10.3], dtype=np.dtype(np.float32))
1050+
index.update_batch(vectors=update_vectors, external_ids=np.array([1003, 1004]))
1051+
index_ram = index_ram.consolidate_updates()
1052+
1053+
queries = np.array([update_vectors[0]], dtype=np.float32)
1054+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1003]])
1055+
1056+
update_vectors = np.empty([2], dtype=object)
1057+
update_vectors[0] = np.array([11.0, 11.1, 11.2, 11.3], dtype=np.dtype(np.float32))
1058+
update_vectors[1] = np.array([12.0, 12.1, 12.2, 12.3], dtype=np.dtype(np.float32))
1059+
index.update_batch(vectors=update_vectors, external_ids=np.array([1003, 1004]))
1060+
index_ram = index_ram.consolidate_updates(retrain_index=True, training_sample_size=3)
1061+
1062+
queries = np.array([update_vectors[0]], dtype=np.float32)
1063+
query_and_check_equals(index=index, queries=queries, expected_result_d=[[0]], expected_result_i=[[1003]])

0 commit comments

Comments
 (0)