11import json
22from functools import partial
33from typing import Any , Mapping , Optional , Tuple
4-
4+ import enum
5+ import random
56import numpy as np
6- from tiledb .cloud .dag import Mode
77
8+ from tiledb .cloud .dag import Mode
89from tiledb .vector_search ._tiledbvspy import *
910from tiledb .vector_search .storage_formats import STORAGE_VERSION , validate_storage_version
1011
1112
13+ class TrainingSamplingPolicy (enum .Enum ):
14+ FIRST_N = 1
15+ RANDOM = 2
16+
17+ def __str__ (self ):
18+ return self .name .replace ("_" , " " ).title ()
19+
1220def ingest (
1321 index_type : str ,
1422 index_uri : str ,
@@ -25,6 +33,7 @@ def ingest(
2533 namespace : Optional [str ] = None ,
2634 size : int = - 1 ,
2735 partitions : int = - 1 ,
36+ training_sampling_policy : TrainingSamplingPolicy = TrainingSamplingPolicy .FIRST_N ,
2837 copy_centroids_uri : str = None ,
2938 training_sample_size : int = - 1 ,
3039 training_input_vectors : np .ndarray = None ,
@@ -33,6 +42,8 @@ def ingest(
3342 workers : int = - 1 ,
3443 input_vectors_per_work_item : int = - 1 ,
3544 max_tasks_per_stage : int = - 1 ,
45+ input_vectors_per_work_item_during_sampling : int = - 1 ,
46+ max_sampling_tasks : int = - 1 ,
3647 storage_version : str = STORAGE_VERSION ,
3748 verbose : bool = False ,
3849 trace_id : Optional [str ] = None ,
@@ -102,6 +113,14 @@ def ingest(
102113 max_tasks_per_stage: int = -1
103114 Max number of tasks per execution stage of ingestion,
104115 if not provided, is auto-configured
116+ input_vectors_per_work_item_during_sampling: int = -1
117+ number of vectors per sample ingestion work item,
118+ if not provided, is auto-configured
119+ only valid with training_sampling_policy=TrainingSamplingPolicy.RANDOM
120+ max_sampling_tasks: int = -1
121+ Max number of tasks per execution stage of sampling,
122+ if not provided, is auto-configured
123+ only valid with training_sampling_policy=TrainingSamplingPolicy.RANDOM
105124 storage_version: str
106125 Vector index storage format version. If not provided, defaults to the latest version.
107126 verbose: bool
@@ -169,6 +188,10 @@ def ingest(
169188 for variable in ["copy_centroids_uri" , "training_input_vectors" , "training_source_uri" , "training_source_type" ]:
170189 if index_type != "IVF_FLAT" and locals ().get (variable ) is not None :
171190 raise ValueError (f"{ variable } should only be provided with index_type IVF_FLAT" )
191+
192+ for variable in ["copy_centroids_uri" , "training_input_vectors" , "training_source_uri" , "training_source_type" ]:
193+ if training_sampling_policy != TrainingSamplingPolicy .FIRST_N and locals ().get (variable ) is not None :
194+ raise ValueError (f"{ variable } should not provided alonside training_sampling_policy" )
172195
173196 # use index_group_uri for internal clarity
174197 index_group_uri = index_uri
@@ -191,6 +214,7 @@ def ingest(
191214 ]
192215 DEFAULT_ATTR_FILTERS = storage_formats [storage_version ]["DEFAULT_ATTR_FILTERS" ]
193216 VECTORS_PER_WORK_ITEM = 20000000
217+ VECTORS_PER_SAMPLE_WORK_ITEM = 1000000
194218 MAX_TASKS_PER_STAGE = 100
195219 CENTRALISED_KMEANS_MAX_SAMPLE_SIZE = 1000000
196220 DEFAULT_IMG_NAME = "3.9-vectorsearch"
@@ -302,9 +326,8 @@ def read_source_metadata(
302326 else :
303327 raise ValueError (f"Not supported source_type { source_type } - valid types are [TILEDB_ARRAY, U8BIN, F32BIN, FVEC, IVEC, BVEC]" )
304328
305- def write_input_vectors (
329+ def create_array (
306330 group : tiledb .Group ,
307- input_vectors : np .ndarray ,
308331 size : int ,
309332 dimensions : int ,
310333 vector_type : np .dtype ,
@@ -350,6 +373,18 @@ def write_input_vectors(
350373 tiledb .Array .create (input_vectors_array_uri , input_vectors_array_schema )
351374 group .add (input_vectors_array_uri , name = array_name )
352375
376+ return input_vectors_array_uri
377+
378+ def write_input_vectors (
379+ group : tiledb .Group ,
380+ input_vectors : np .ndarray ,
381+ size : int ,
382+ dimensions : int ,
383+ vector_type : np .dtype ,
384+ array_name : str
385+ ) -> str :
386+ input_vectors_array_uri = create_array (group = group , size = size , dimensions = dimensions , vector_type = vector_type , array_name = array_name )
387+
353388 input_vectors_array = tiledb .open (
354389 input_vectors_array_uri , "w" , timestamp = index_timestamp
355390 )
@@ -797,6 +832,91 @@ def copy_centroids(
797832 # --------------------------------------------------------------------
798833 # centralised kmeans UDFs
799834 # --------------------------------------------------------------------
835+ def random_sample_from_input_vectors (
836+ source_uri : str ,
837+ source_type : str ,
838+ vector_type : np .dtype ,
839+ dimensions : int ,
840+ source_start_pos : int ,
841+ source_end_pos : int ,
842+ batch : int ,
843+ random_sample_size : int ,
844+ output_source_uri : str ,
845+ output_start_pos : int ,
846+ config : Optional [Mapping [str , Any ]] = None ,
847+ verbose : bool = False ,
848+ ):
849+ '''
850+ Reads a random sample of vectors from the source data and appends them to the output array.
851+
852+ Parameters
853+ ----------
854+ source_uri: str
855+ Data source URI.
856+ source_type: str
857+ Type of the source data.
858+ vector_type: np.dtype
859+ Type of the vectors.
860+ dimensions: int
861+ Number of dimensions in a vector.
862+ vector_start_pos: int
863+ Start position of source_uri to read from.
864+ vector_end_pos: int
865+ End position of source_uri to read to.
866+ batch: int
867+ Read the source data in batches of this size.
868+ random_sample_size: int
869+ Number of vectors to randomly sample from the source data.
870+ output_source_uri: str
871+ URI of the output array.
872+ output_start_pos: int
873+ Start position of the output array to write to.
874+ '''
875+ if random_sample_size == 0 :
876+ return
877+
878+ with tiledb .scope_ctx (ctx_or_config = config ):
879+ source_size = source_end_pos - source_start_pos
880+ num_sampled = 0
881+ for start in range (source_start_pos , source_end_pos , batch ):
882+ # What vectors to read from the source_uri.
883+ end = start + batch
884+ if end > source_end_pos :
885+ end = source_end_pos
886+
887+ # How many vectors sample from the vectors read.
888+ percent_of_data_to_read = (end - start ) / source_size
889+ num_to_sample = math .ceil (random_sample_size * percent_of_data_to_read )
890+ if num_sampled + num_to_sample > random_sample_size :
891+ num_to_sample = random_sample_size - num_sampled
892+ if num_to_sample == 0 :
893+ continue
894+ num_sampled += num_to_sample
895+
896+ # Read from the source data.
897+ vectors = read_input_vectors (
898+ source_uri = source_uri ,
899+ source_type = source_type ,
900+ vector_type = vector_type ,
901+ dimensions = dimensions ,
902+ start_pos = start ,
903+ end_pos = end ,
904+ config = config ,
905+ verbose = verbose ,
906+ trace_id = trace_id ,
907+ )
908+
909+ # Randomly sample from the data we read.
910+ row_indices = np .random .choice (vectors .shape [0 ], size = num_to_sample , replace = False )
911+ sampled_vectors = vectors [row_indices ]
912+
913+ # Append to output array.
914+ with tiledb .open (output_source_uri , mode = "w" , timestamp = index_timestamp ) as A :
915+ A [0 :dimensions , output_start_pos :output_start_pos + num_to_sample ] = np .transpose (sampled_vectors )
916+
917+ if num_sampled != random_sample_size :
918+ raise ValueError (f"The random sampling within a batch ran into an issue: num_sampled ({ num_sampled } ) != random_sample_size ({ random_sample_size } )" )
919+
800920 def centralised_kmeans (
801921 index_group_uri : str ,
802922 source_uri : str ,
@@ -1564,6 +1684,8 @@ def create_ingestion_dag(
15641684 training_source_type : Optional [str ],
15651685 input_vectors_per_work_item : int ,
15661686 input_vectors_work_items_per_worker : int ,
1687+ input_vectors_per_work_item_during_sampling : int ,
1688+ input_vectors_work_items_per_worker_during_sampling : int ,
15671689 table_partitions_per_work_item : int ,
15681690 table_partitions_work_items_per_worker : int ,
15691691 workers : int ,
@@ -1600,6 +1722,15 @@ def create_ingestion_dag(
16001722 input_vectors_batch_size = (
16011723 input_vectors_per_work_item * input_vectors_work_items_per_worker
16021724 )
1725+
1726+ # The number of vectors each task will read.
1727+ input_vectors_batch_size_during_sampling = (
1728+ # The number of vectors to read into memory in one batch within a task.
1729+ input_vectors_per_work_item_during_sampling *
1730+ # The number of batches that a single task will need to run.
1731+ input_vectors_work_items_per_worker_during_sampling
1732+ )
1733+
16031734 if index_type == "FLAT" :
16041735 ingest_node = submit (
16051736 ingest_flat ,
@@ -1637,6 +1768,59 @@ def create_ingestion_dag(
16371768 image_name = DEFAULT_IMG_NAME ,
16381769 )
16391770 else :
1771+ random_sample_nodes = []
1772+ if training_sampling_policy == TrainingSamplingPolicy .RANDOM :
1773+ # Create an empty array to write the sampled vectors to.
1774+ group = tiledb .Group (index_group_uri , "w" )
1775+ training_source_uri = create_array (
1776+ group = group ,
1777+ size = training_sample_size ,
1778+ dimensions = dimensions ,
1779+ vector_type = vector_type ,
1780+ array_name = TRAINING_INPUT_VECTORS_ARRAY_NAME
1781+ )
1782+ training_source_type = "TILEDB_ARRAY"
1783+ group .close ()
1784+
1785+ idx = 0
1786+ num_sampled = 0
1787+ for start in range (0 , in_size , input_vectors_batch_size_during_sampling ):
1788+ # What vectors to read from the source_uri.
1789+ end = start + input_vectors_batch_size_during_sampling
1790+ if end > size :
1791+ end = size
1792+
1793+ # How many vectors to sample from the vectors read.
1794+ percent_of_data_to_read = (end - start ) / in_size
1795+ num_to_sample = math .ceil (training_sample_size * percent_of_data_to_read )
1796+ if num_sampled + num_to_sample > training_sample_size :
1797+ num_to_sample = training_sample_size - num_sampled
1798+ if num_to_sample == 0 :
1799+ continue
1800+
1801+ random_sample_nodes .append (submit (
1802+ random_sample_from_input_vectors ,
1803+ source_uri = source_uri ,
1804+ source_type = source_type ,
1805+ vector_type = vector_type ,
1806+ dimensions = dimensions ,
1807+ source_start_pos = start ,
1808+ source_end_pos = end ,
1809+ batch = input_vectors_per_work_item_during_sampling ,
1810+ random_sample_size = num_to_sample ,
1811+ output_source_uri = training_source_uri ,
1812+ output_start_pos = num_sampled ,
1813+ config = config ,
1814+ verbose = verbose ,
1815+ name = "read-random-sample-" + str (idx ),
1816+ resources = {"cpu" : str (threads ), "memory" : "1Gi" },
1817+ image_name = DEFAULT_IMG_NAME ,
1818+ ))
1819+ num_sampled += num_to_sample
1820+ idx += 1
1821+ if num_sampled != training_sample_size :
1822+ raise ValueError (f"The random sampling ran into an issue: num_sampled ({ num_sampled } ) != training_sample_size ({ training_sample_size } )" )
1823+
16401824 if training_sample_size <= CENTRALISED_KMEANS_MAX_SAMPLE_SIZE :
16411825 centroids_node = submit (
16421826 centralised_kmeans ,
@@ -1657,6 +1841,9 @@ def create_ingestion_dag(
16571841 resources = {"cpu" : "8" , "memory" : "32Gi" },
16581842 image_name = DEFAULT_IMG_NAME ,
16591843 )
1844+
1845+ for random_sample_node in random_sample_nodes :
1846+ centroids_node .depends_on (random_sample_node )
16601847 else :
16611848 internal_centroids_node = submit (
16621849 init_centroids ,
@@ -1950,6 +2137,9 @@ def consolidate_and_vacuum(
19502137 logger .debug ("Input dataset size %d" , size )
19512138 logger .debug ("Input dataset dimensions %d" , dimensions )
19522139 logger .debug ("Vector dimension type %s" , vector_type )
2140+ if training_sample_size > in_size :
2141+ raise ValueError (f"training_sample_size { training_sample_size } is larger than the input dataset size { in_size } " )
2142+
19532143 if partitions == - 1 :
19542144 partitions = max (1 , int (math .sqrt (size )))
19552145 if training_sample_size == - 1 :
@@ -1976,6 +2166,7 @@ def consolidate_and_vacuum(
19762166 if external_ids_type is None :
19772167 external_ids_type = "U64BIN"
19782168
2169+ # Compute task parameters for main ingestion.
19792170 if input_vectors_per_work_item == - 1 :
19802171 input_vectors_per_work_item = VECTORS_PER_WORK_ITEM
19812172 input_vectors_work_items = int (math .ceil (size / input_vectors_per_work_item ))
@@ -1996,6 +2187,27 @@ def consolidate_and_vacuum(
19962187 input_vectors_work_items_per_worker ,
19972188 )
19982189
2190+ # Compute task parameters for random sampling.
2191+ # How many input vectors to read into memory in one batch within a task.
2192+ if input_vectors_per_work_item_during_sampling == - 1 :
2193+ input_vectors_per_work_item_during_sampling = VECTORS_PER_SAMPLE_WORK_ITEM
2194+ # How many total batches we need to read all the data..
2195+ input_vectors_work_items_during_sampling = int (math .ceil (size / input_vectors_per_work_item_during_sampling ))
2196+ # The number of tasks to create, at max.
2197+ if max_sampling_tasks == - 1 :
2198+ max_sampling_tasks = MAX_TASKS_PER_STAGE
2199+ # The number of batches a single task will run. If there are more batches required than
2200+ # allowed tasks, each task will process mutiple batches.
2201+ input_vectors_work_items_per_worker_during_sampling = 1
2202+ if input_vectors_work_items_during_sampling > max_sampling_tasks :
2203+ input_vectors_work_items_per_worker_during_sampling = int (
2204+ math .ceil (input_vectors_work_items_during_sampling / max_sampling_tasks )
2205+ )
2206+ input_vectors_work_items_during_sampling = max_sampling_tasks
2207+ logger .debug ("input_vectors_per_work_item_during_sampling %d" , input_vectors_per_work_item_during_sampling )
2208+ logger .debug ("input_vectors_work_items_during_sampling %d" , input_vectors_work_items_during_sampling )
2209+ logger .debug ("input_vectors_work_items_per_worker_during_sampling %d" , input_vectors_work_items_per_worker_during_sampling )
2210+
19992211 vectors_per_table_partitions = max (1 , size / partitions )
20002212 table_partitions_per_work_item = max (
20012213 1 ,
@@ -2054,6 +2266,8 @@ def consolidate_and_vacuum(
20542266 training_source_type = training_source_type ,
20552267 input_vectors_per_work_item = input_vectors_per_work_item ,
20562268 input_vectors_work_items_per_worker = input_vectors_work_items_per_worker ,
2269+ input_vectors_per_work_item_during_sampling = input_vectors_per_work_item_during_sampling ,
2270+ input_vectors_work_items_per_worker_during_sampling = input_vectors_work_items_per_worker_during_sampling ,
20572271 table_partitions_per_work_item = table_partitions_per_work_item ,
20582272 table_partitions_work_items_per_worker = table_partitions_work_items_per_worker ,
20592273 workers = workers ,
0 commit comments