Skip to content

Commit bb20254

Browse files
authored
Fix crash from passing copy_centroids_uri to ingest() (#172)
1 parent 3e5489c commit bb20254

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,8 @@ def read_input_vectors(
765765
def copy_centroids(
766766
index_group_uri: str,
767767
copy_centroids_uri: str,
768+
partitions: int,
769+
dimensions: int,
768770
config: Optional[Mapping[str, Any]] = None,
769771
verbose: bool = False,
770772
trace_id: Optional[str] = None,
@@ -777,8 +779,8 @@ def copy_centroids(
777779
)
778780
src = tiledb.open(copy_centroids_uri, mode="r")
779781
dest = tiledb.open(centroids_uri, mode="w", timestamp=index_timestamp)
780-
src_centroids = src[:, :]
781-
dest[:, :] = src_centroids
782+
src_centroids = src[0:dimensions, 0:partitions]
783+
dest[0:dimensions, 0:partitions] = src_centroids
782784
logger.debug(src_centroids)
783785

784786
# --------------------------------------------------------------------
@@ -1613,6 +1615,8 @@ def create_ingestion_dag(
16131615
copy_centroids,
16141616
index_group_uri=index_group_uri,
16151617
copy_centroids_uri=copy_centroids_uri,
1618+
partitions=partitions,
1619+
dimensions=dimensions,
16161620
config=config,
16171621
verbose=verbose,
16181622
trace_id=trace_id,

apis/python/test/test_ingestion.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
import numpy as np
23
from common import *
34
import pytest
@@ -737,6 +738,50 @@ def test_storage_versions(tmp_path):
737738
_, result = index_ram.query(query_vectors, k=k)
738739
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
739740

741+
def test_copy_centroids_uri(tmp_path):
742+
dataset_dir = os.path.join(tmp_path, "dataset")
743+
os.mkdir(dataset_dir)
744+
745+
# Create the index data.
746+
data = np.array([[1, 1, 1, 1], [1, 1, 1, 1], [2, 2, 2, 2], [2, 2, 2, 2], [3, 3, 3, 3]], dtype=np.float32)
747+
748+
# Create the centroids - this is based on ivf_flat_index.py.
749+
centroids = np.array([[1, 1, 1, 1], [2, 2, 2, 2]], dtype=np.float32)
750+
centroids_in_size = centroids.shape[0]
751+
dimensions = centroids.shape[1]
752+
schema = tiledb.ArraySchema(
753+
domain=tiledb.Domain(
754+
*[
755+
tiledb.Dim(name="rows", domain=(0, dimensions - 1), tile=dimensions, dtype=np.dtype(np.int32)),
756+
tiledb.Dim(name="cols", domain=(0, np.iinfo(np.dtype("int32")).max), tile=100000, dtype=np.dtype(np.int32)),
757+
]
758+
),
759+
sparse=False,
760+
attrs=[tiledb.Attr(name="centroids", dtype="float32", filters=tiledb.FilterList([tiledb.ZstdFilter()]))],
761+
cell_order="col-major",
762+
tile_order="col-major",
763+
)
764+
centroids_uri = os.path.join(dataset_dir, "centroids.tdb")
765+
tiledb.Array.create(centroids_uri, schema)
766+
index_timestamp = int(time.time() * 1000)
767+
with tiledb.open(centroids_uri, mode="w", timestamp=index_timestamp) as A:
768+
A[0:dimensions, 0:centroids_in_size] = centroids.transpose()
769+
770+
# Create the index.
771+
index_uri = os.path.join(tmp_path, "array")
772+
index = ingest(
773+
index_type="IVF_FLAT",
774+
index_uri=index_uri,
775+
input_vectors=data,
776+
copy_centroids_uri=centroids_uri
777+
)
778+
779+
# Query the index.
780+
query_vector_index = 4
781+
query_vectors = np.array([data[query_vector_index]], dtype=np.float32)
782+
result_d, result_i = index.query(query_vectors, k=1)
783+
check_equals(result_d=result_d, result_i=result_i, expected_result_d=[[0]], expected_result_i=[[query_vector_index]])
784+
740785

741786
def test_kmeans():
742787
k = 128
@@ -827,7 +872,7 @@ def test_ingest_with_training_source_uri_f32(tmp_path):
827872
result_d, result_i = index.query(query_vectors, k=1)
828873
check_equals(result_d=result_d, result_i=result_i, expected_result_d=[[0]], expected_result_i=[[query_vector_index]])
829874

830-
index_ram = FlatIndex(uri=index_uri)
875+
index_ram = IVFFlatIndex(uri=index_uri)
831876
result_d, result_i = index_ram.query(query_vectors, k=1)
832877
check_equals(result_d=result_d, result_i=result_i, expected_result_d=[[0]], expected_result_i=[[query_vector_index]])
833878

@@ -837,7 +882,7 @@ def test_ingest_with_training_source_uri_f32(tmp_path):
837882
index_uri=os.path.join(tmp_path, "array_2"),
838883
source_uri=os.path.join(dataset_dir, "data.f32bin"),
839884
training_source_uri=os.path.join(dataset_dir, "training_data.f32bin"),
840-
training_source_type="FVEC"
885+
training_source_type="F32BIN"
841886
)
842887

843888
def test_ingest_with_training_source_uri_tdb(tmp_path):

0 commit comments

Comments
 (0)