1+ import time
12import numpy as np
23from common import *
34import pytest
@@ -737,6 +738,50 @@ def test_storage_versions(tmp_path):
737738 _ , result = index_ram .query (query_vectors , k = k )
738739 assert accuracy (result , gt_i ) > MINIMUM_ACCURACY
739740
741+ def test_copy_centroids_uri (tmp_path ):
742+ dataset_dir = os .path .join (tmp_path , "dataset" )
743+ os .mkdir (dataset_dir )
744+
745+ # Create the index data.
746+ data = np .array ([[1 , 1 , 1 , 1 ], [1 , 1 , 1 , 1 ], [2 , 2 , 2 , 2 ], [2 , 2 , 2 , 2 ], [3 , 3 , 3 , 3 ]], dtype = np .float32 )
747+
748+ # Create the centroids - this is based on ivf_flat_index.py.
749+ centroids = np .array ([[1 , 1 , 1 , 1 ], [2 , 2 , 2 , 2 ]], dtype = np .float32 )
750+ centroids_in_size = centroids .shape [0 ]
751+ dimensions = centroids .shape [1 ]
752+ schema = tiledb .ArraySchema (
753+ domain = tiledb .Domain (
754+ * [
755+ tiledb .Dim (name = "rows" , domain = (0 , dimensions - 1 ), tile = dimensions , dtype = np .dtype (np .int32 )),
756+ tiledb .Dim (name = "cols" , domain = (0 , np .iinfo (np .dtype ("int32" )).max ), tile = 100000 , dtype = np .dtype (np .int32 )),
757+ ]
758+ ),
759+ sparse = False ,
760+ attrs = [tiledb .Attr (name = "centroids" , dtype = "float32" , filters = tiledb .FilterList ([tiledb .ZstdFilter ()]))],
761+ cell_order = "col-major" ,
762+ tile_order = "col-major" ,
763+ )
764+ centroids_uri = os .path .join (dataset_dir , "centroids.tdb" )
765+ tiledb .Array .create (centroids_uri , schema )
766+ index_timestamp = int (time .time () * 1000 )
767+ with tiledb .open (centroids_uri , mode = "w" , timestamp = index_timestamp ) as A :
768+ A [0 :dimensions , 0 :centroids_in_size ] = centroids .transpose ()
769+
770+ # Create the index.
771+ index_uri = os .path .join (tmp_path , "array" )
772+ index = ingest (
773+ index_type = "IVF_FLAT" ,
774+ index_uri = index_uri ,
775+ input_vectors = data ,
776+ copy_centroids_uri = centroids_uri
777+ )
778+
779+ # Query the index.
780+ query_vector_index = 4
781+ query_vectors = np .array ([data [query_vector_index ]], dtype = np .float32 )
782+ result_d , result_i = index .query (query_vectors , k = 1 )
783+ check_equals (result_d = result_d , result_i = result_i , expected_result_d = [[0 ]], expected_result_i = [[query_vector_index ]])
784+
740785
741786def test_kmeans ():
742787 k = 128
@@ -827,7 +872,7 @@ def test_ingest_with_training_source_uri_f32(tmp_path):
827872 result_d , result_i = index .query (query_vectors , k = 1 )
828873 check_equals (result_d = result_d , result_i = result_i , expected_result_d = [[0 ]], expected_result_i = [[query_vector_index ]])
829874
830- index_ram = FlatIndex (uri = index_uri )
875+ index_ram = IVFFlatIndex (uri = index_uri )
831876 result_d , result_i = index_ram .query (query_vectors , k = 1 )
832877 check_equals (result_d = result_d , result_i = result_i , expected_result_d = [[0 ]], expected_result_i = [[query_vector_index ]])
833878
@@ -837,7 +882,7 @@ def test_ingest_with_training_source_uri_f32(tmp_path):
837882 index_uri = os .path .join (tmp_path , "array_2" ),
838883 source_uri = os .path .join (dataset_dir , "data.f32bin" ),
839884 training_source_uri = os .path .join (dataset_dir , "training_data.f32bin" ),
840- training_source_type = "FVEC "
885+ training_source_type = "F32BIN "
841886 )
842887
843888def test_ingest_with_training_source_uri_tdb (tmp_path ):
0 commit comments