@@ -27,6 +27,9 @@ def ingest(
2727 partitions : int = - 1 ,
2828 copy_centroids_uri : str = None ,
2929 training_sample_size : int = - 1 ,
30+ training_input_vectors : np .ndarray = None ,
31+ training_source_uri : str = None ,
32+ training_source_type : str = None ,
3033 workers : int = - 1 ,
3134 input_vectors_per_work_item : int = - 1 ,
3235 max_tasks_per_stage : int = - 1 ,
@@ -78,6 +81,18 @@ def ingest(
7881 training_sample_size: int = -1
7982 vector sample size to train centroids with,
8083 if not provided, is auto-configured based on the dataset sizes
84+ should not be provided if training_source_uri is provided
85+ training_input_vectors: numpy Array
86+ Training input vectors, if this is provided it takes precedence over training_source_uri and training_source_type
87+ should not be provided if training_sample_size or training_source_uri is provided
88+ training_source_uri: str = None
89+ The source URI to use for training centroids when building a IVF_FLAT vector index,
90+ if not provided, the first training_sample_size vectors from source_uri are used
91+ should not be provided if training_sample_size or training_input_vectors is provided
92+ training_source_type: str = None
93+ Type of the training source data in training_source_uri
94+ if left empty, is auto-detected from the suffix of training_source_type
95+ should only be provided when training_source_uri is provided
8196 workers: int = -1
8297 number of workers for vector ingestion,
8398 if not provided, is auto-configured based on the dataset size
@@ -121,6 +136,29 @@ def ingest(
121136
122137 validate_storage_version (storage_version )
123138
139+ if source_type and not source_uri :
140+ raise ValueError ("source_type should not be provided without source_uri" )
141+ if source_uri and input_vectors :
142+ raise ValueError ("source_uri should not be provided alongside input_vectors" )
143+ if source_type and input_vectors :
144+ raise ValueError ("source_type should not be provided alongside input_vectors" )
145+
146+ if training_source_uri and training_sample_size != - 1 :
147+ raise ValueError ("training_source_uri and training_sample_size should not both be provided" )
148+ if training_source_uri and training_input_vectors is not None :
149+ raise ValueError ("training_source_uri and training_input_vectors should not both be provided" )
150+
151+ if training_input_vectors is not None and training_sample_size != - 1 :
152+ raise ValueError ("training_input_vectors and training_sample_size should not both be provided" )
153+ if training_input_vectors is not None and training_source_type :
154+ raise ValueError ("training_input_vectors and training_source_type should not both be provided" )
155+
156+ if training_source_type and not training_source_uri :
157+ raise ValueError ("training_source_type should not be provided without training_source_uri" )
158+
159+ if training_sample_size < - 1 :
160+ raise ValueError ("training_sample_size should either be positive or -1 to auto-configure based on the dataset sizes" )
161+
124162 # use index_group_uri for internal clarity
125163 index_group_uri = index_uri
126164
@@ -131,6 +169,9 @@ def ingest(
131169 INPUT_VECTORS_ARRAY_NAME = storage_formats [storage_version ][
132170 "INPUT_VECTORS_ARRAY_NAME"
133171 ]
172+ TRAINING_INPUT_VECTORS_ARRAY_NAME = storage_formats [storage_version ][
173+ "TRAINING_INPUT_VECTORS_ARRAY_NAME"
174+ ]
134175 EXTERNAL_IDS_ARRAY_NAME = storage_formats [storage_version ][
135176 "EXTERNAL_IDS_ARRAY_NAME"
136177 ]
@@ -248,16 +289,17 @@ def read_source_metadata(
248289 size = int (file_size / vector_size )
249290 return size , dimensions , np .uint8
250291 else :
251- raise ValueError (f"Not supported source_type { source_type } " )
292+ raise ValueError (f"Not supported source_type { source_type } - valid types are [TILEDB_ARRAY, U8BIN, F32BIN, FVEC, IVEC, BVEC] " )
252293
253294 def write_input_vectors (
254295 group : tiledb .Group ,
255296 input_vectors : np .ndarray ,
256297 size : int ,
257298 dimensions : int ,
258299 vector_type : np .dtype ,
300+ array_name : str
259301 ) -> str :
260- input_vectors_array_uri = f"{ group .uri } /{ INPUT_VECTORS_ARRAY_NAME } "
302+ input_vectors_array_uri = f"{ group .uri } /{ array_name } "
261303 if tiledb .array_exists (input_vectors_array_uri ):
262304 raise ValueError (f"Array exists { input_vectors_array_uri } " )
263305 tile_size = min (
@@ -295,7 +337,7 @@ def write_input_vectors(
295337 )
296338 logger .debug (input_vectors_array_schema )
297339 tiledb .Array .create (input_vectors_array_uri , input_vectors_array_schema )
298- group .add (input_vectors_array_uri , name = INPUT_VECTORS_ARRAY_NAME )
340+ group .add (input_vectors_array_uri , name = array_name )
299341
300342 input_vectors_array = tiledb .open (
301343 input_vectors_array_uri , "w" , timestamp = index_timestamp
@@ -749,8 +791,9 @@ def centralised_kmeans(
749791 vector_type : np .dtype ,
750792 partitions : int ,
751793 dimensions : int ,
752- sample_start_pos : int ,
753- sample_end_pos : int ,
794+ training_sample_size : int ,
795+ training_source_uri : Optional [str ],
796+ training_source_type : Optional [str ],
754797 init : str = "random" ,
755798 max_iter : int = 10 ,
756799 n_init : int = 1 ,
@@ -765,45 +808,61 @@ def centralised_kmeans(
765808 array_to_matrix ,
766809 kmeans_fit ,
767810 )
811+
768812 with tiledb .scope_ctx (ctx_or_config = config ):
769813 logger = setup (config , verbose )
770814 group = tiledb .Group (index_group_uri )
771815 centroids_uri = group [CENTROIDS_ARRAY_NAME ].uri
772- verb = 0
773- if verbose :
774- verb = 3
775-
776- if sample_end_pos - sample_start_pos >= partitions :
777- sample_vectors = read_input_vectors (
778- source_uri = source_uri ,
779- source_type = source_type ,
780- vector_type = vector_type ,
781- dimensions = dimensions ,
782- start_pos = sample_start_pos ,
783- end_pos = sample_end_pos ,
784- config = config ,
785- verbose = verbose ,
786- trace_id = trace_id ,
787- ).astype (np .float32 )
816+ if training_sample_size >= partitions :
817+ if training_source_uri :
818+ if training_source_type is None :
819+ training_source_type = autodetect_source_type (source_uri = training_source_uri )
820+ training_in_size , training_dimensions , training_vector_type = read_source_metadata (source_uri = training_source_uri , source_type = training_source_type )
821+ dimensions = training_dimensions
822+ sample_vectors = read_input_vectors (
823+ source_uri = training_source_uri ,
824+ source_type = training_source_type ,
825+ vector_type = training_vector_type ,
826+ dimensions = training_dimensions ,
827+ start_pos = 0 ,
828+ end_pos = training_in_size ,
829+ config = config ,
830+ verbose = verbose ,
831+ trace_id = trace_id ,
832+ ).astype (np .float32 )
833+ else :
834+ sample_vectors = read_input_vectors (
835+ source_uri = source_uri ,
836+ source_type = source_type ,
837+ vector_type = vector_type ,
838+ dimensions = dimensions ,
839+ start_pos = 0 ,
840+ end_pos = training_sample_size ,
841+ config = config ,
842+ verbose = verbose ,
843+ trace_id = trace_id ,
844+ ).astype (np .float32 )
788845
846+ logger .debug ("Start kmeans training" )
789847 if use_sklearn :
790848 km = KMeans (
791849 n_clusters = partitions ,
792850 init = init ,
793851 max_iter = max_iter ,
794- verbose = verb ,
852+ verbose = 3 if verbose else 0 ,
795853 n_init = n_init ,
854+ random_state = 0 ,
796855 )
797856 km .fit_predict (sample_vectors )
798857 centroids = np .transpose (np .array (km .cluster_centers_ ))
799858 else :
800859 centroids = kmeans_fit (partitions , init , max_iter , verbose , n_init , array_to_matrix (np .transpose (sample_vectors )))
801860 centroids = np .array (centroids ) # TODO: why is this here?
802861 else :
862+ # TODO(paris): Should we instead take the first training_sample_size vectors and then fill in random for the rest? Or raise an error like this:
863+ # raise ValueError(f"We have a training_sample_size of {training_sample_size} but {partitions} partitions - training_sample_size must be >= partitions")
803864 centroids = np .random .rand (dimensions , partitions )
804865
805- logger .debug ("Start kmeans training" )
806-
807866 logger .debug ("Writing centroids to array %s" , centroids_uri )
808867 with tiledb .open (centroids_uri , mode = "w" , timestamp = index_timestamp ) as A :
809868 A [0 :dimensions , 0 :partitions ] = centroids
@@ -1487,6 +1546,8 @@ def create_ingestion_dag(
14871546 dimensions : int ,
14881547 copy_centroids_uri : str ,
14891548 training_sample_size : int ,
1549+ training_source_uri : Optional [str ],
1550+ training_source_type : Optional [str ],
14901551 input_vectors_per_work_item : int ,
14911552 input_vectors_work_items_per_worker : int ,
14921553 table_partitions_per_work_item : int ,
@@ -1569,8 +1630,9 @@ def create_ingestion_dag(
15691630 vector_type = vector_type ,
15701631 partitions = partitions ,
15711632 dimensions = dimensions ,
1572- sample_start_pos = 0 ,
1573- sample_end_pos = training_sample_size ,
1633+ training_sample_size = training_sample_size ,
1634+ training_source_uri = training_source_uri ,
1635+ training_source_type = training_source_type ,
15741636 config = config ,
15751637 verbose = verbose ,
15761638 trace_id = trace_id ,
@@ -1835,6 +1897,17 @@ def consolidate_and_vacuum(
18351897 group .close ()
18361898 group = tiledb .Group (index_group_uri , "w" )
18371899
1900+ if training_input_vectors is not None :
1901+ training_source_uri = write_input_vectors (
1902+ group = group ,
1903+ input_vectors = training_input_vectors ,
1904+ size = training_input_vectors .shape [0 ],
1905+ dimensions = training_input_vectors .shape [1 ],
1906+ vector_type = training_input_vectors .dtype ,
1907+ array_name = TRAINING_INPUT_VECTORS_ARRAY_NAME
1908+ )
1909+ training_source_type = "TILEDB_ARRAY"
1910+
18381911 if input_vectors is not None :
18391912 in_size = input_vectors .shape [0 ]
18401913 dimensions = input_vectors .shape [1 ]
@@ -1845,6 +1918,7 @@ def consolidate_and_vacuum(
18451918 size = in_size ,
18461919 dimensions = dimensions ,
18471920 vector_type = vector_type ,
1921+ array_name = INPUT_VECTORS_ARRAY_NAME
18481922 )
18491923 source_type = "TILEDB_ARRAY"
18501924 else :
@@ -1871,6 +1945,7 @@ def consolidate_and_vacuum(
18711945 workers = 1
18721946 logger .debug ("Partitions %d" , partitions )
18731947 logger .debug ("Training sample size %d" , training_sample_size )
1948+ logger .debug ("Training source uri %s and type %s" , training_source_uri , training_source_type )
18741949 logger .debug ("Number of workers %d" , workers )
18751950
18761951 if external_ids is not None :
@@ -1959,6 +2034,8 @@ def consolidate_and_vacuum(
19592034 dimensions = dimensions ,
19602035 copy_centroids_uri = copy_centroids_uri ,
19612036 training_sample_size = training_sample_size ,
2037+ training_source_uri = training_source_uri ,
2038+ training_source_type = training_source_type ,
19622039 input_vectors_per_work_item = input_vectors_per_work_item ,
19632040 input_vectors_work_items_per_worker = input_vectors_work_items_per_worker ,
19642041 table_partitions_per_work_item = table_partitions_per_work_item ,
0 commit comments