@@ -29,6 +29,7 @@ def ingest(
2929 training_sample_size : int = - 1 ,
3030 workers : int = - 1 ,
3131 input_vectors_per_work_item : int = - 1 ,
32+ max_tasks_per_stage : int = - 1 ,
3233 storage_version : str = STORAGE_VERSION ,
3334 verbose : bool = False ,
3435 trace_id : Optional [str ] = None ,
@@ -83,6 +84,9 @@ def ingest(
8384 input_vectors_per_work_item: int = -1
8485 number of vectors per ingestion work item,
8586 if not provided, is auto-configured
87+ max_tasks_per_stage: int = -1
88+ Max number of tasks per execution stage of ingestion,
89+ if not provided, is auto-configured
8690 storage_version: str
8791 Vector index storage format version.
8892 verbose: bool
@@ -348,7 +352,7 @@ def create_arrays(
348352 index_type : str ,
349353 size : int ,
350354 dimensions : int ,
351- input_vectors_work_tasks : int ,
355+ input_vectors_work_items : int ,
352356 vector_type : np .dtype ,
353357 logger : logging .Logger ,
354358 ) -> None :
@@ -475,10 +479,7 @@ def create_arrays(
475479 partial_write_array_group .add (
476480 partial_write_array_parts_uri , name = PARTS_ARRAY_NAME
477481 )
478- partial_write_arrays = input_vectors_work_tasks
479- if updates_uri is not None :
480- partial_write_arrays += 1
481- for part in range (partial_write_arrays ):
482+ for part in range (input_vectors_work_items ):
482483 part_index_uri = partial_write_array_index_uri + "/" + str (part )
483484 if not tiledb .array_exists (part_index_uri ):
484485 logger .debug (f"Creating part array { part_index_uri } " )
@@ -505,6 +506,33 @@ def create_arrays(
505506 logger .debug (index_schema )
506507 tiledb .Array .create (part_index_uri , index_schema )
507508 partial_write_array_index_group .add (part_index_uri , name = str (part ))
509+ if updates_uri is not None :
510+ part_index_uri = partial_write_array_index_uri + "/additions"
511+ if not tiledb .array_exists (part_index_uri ):
512+ logger .debug (f"Creating part array { part_index_uri } " )
513+ index_array_rows_dim = tiledb .Dim (
514+ name = "rows" ,
515+ domain = (0 , partitions ),
516+ tile = partitions ,
517+ dtype = np .dtype (np .int32 ),
518+ )
519+ index_array_dom = tiledb .Domain (index_array_rows_dim )
520+ index_attr = tiledb .Attr (
521+ name = "values" ,
522+ dtype = np .dtype (np .uint64 ),
523+ filters = DEFAULT_ATTR_FILTERS ,
524+ )
525+ index_schema = tiledb .ArraySchema (
526+ domain = index_array_dom ,
527+ sparse = False ,
528+ attrs = [index_attr ],
529+ capacity = partitions ,
530+ cell_order = "col-major" ,
531+ tile_order = "col-major" ,
532+ )
533+ logger .debug (index_schema )
534+ tiledb .Array .create (part_index_uri , index_schema )
535+ partial_write_array_index_group .add (part_index_uri , name = "additions" )
508536 partial_write_array_group .close ()
509537 partial_write_array_index_group .close ()
510538
@@ -1098,7 +1126,7 @@ def ingest_vectors_udf(
10981126 part_name = str (part ) + "-" + str (part_end )
10991127
11001128 partial_write_array_index_uri = partial_write_array_index_group [
1101- str (int (start / batch ))
1129+ str (int (part / batch ))
11021130 ].uri
11031131 logger .debug ("Input vectors start_pos: %d, end_pos: %d" , part , part_end )
11041132 updated_ids = read_updated_ids (
@@ -1203,7 +1231,6 @@ def ingest_additions_udf(
12031231 updates_uri : str ,
12041232 vector_type : np .dtype ,
12051233 write_offset : int ,
1206- task_id : int ,
12071234 threads : int ,
12081235 config : Optional [Mapping [str , Any ]] = None ,
12091236 verbose : bool = False ,
@@ -1228,7 +1255,7 @@ def ingest_additions_udf(
12281255 partial_write_array_index_dir_uri
12291256 )
12301257 partial_write_array_index_uri = partial_write_array_index_group [
1231- str ( task_id )
1258+ "additions"
12321259 ].uri
12331260 additions_vectors , additions_external_ids = read_additions (
12341261 updates_uri = updates_uri ,
@@ -1678,7 +1705,6 @@ def create_ingestion_dag(
16781705 updates_uri = updates_uri ,
16791706 vector_type = vector_type ,
16801707 write_offset = size ,
1681- task_id = task_id ,
16821708 threads = threads ,
16831709 config = config ,
16841710 verbose = verbose ,
@@ -1744,15 +1770,22 @@ def consolidate_and_vacuum(
17441770 tiledb .vacuum (parts_uri , config = conf )
17451771 tiledb .consolidate (ids_uri , config = conf )
17461772 tiledb .vacuum (ids_uri , config = conf )
1773+ group .close ()
17471774
17481775 # TODO remove temp data for tiledb URIs
17491776 if not index_group_uri .startswith ("tiledb://" ):
1750- vfs = tiledb .VFS (config )
1751- partial_write_array_dir_uri = (
1752- index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1753- )
1754- if vfs .is_dir (partial_write_array_dir_uri ):
1755- vfs .remove_dir (partial_write_array_dir_uri )
1777+ group = tiledb .Group (index_group_uri , "r" )
1778+ if PARTIAL_WRITE_ARRAY_DIR in group :
1779+ group .close ()
1780+ group = tiledb .Group (index_group_uri , "w" )
1781+ group .remove (PARTIAL_WRITE_ARRAY_DIR )
1782+ vfs = tiledb .VFS (config )
1783+ partial_write_array_dir_uri = (
1784+ index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1785+ )
1786+ if vfs .is_dir (partial_write_array_dir_uri ):
1787+ vfs .remove_dir (partial_write_array_dir_uri )
1788+ group .close ()
17561789
17571790 # --------------------------------------------------------------------
17581791 # End internal function definitions
@@ -1852,11 +1885,13 @@ def consolidate_and_vacuum(
18521885 input_vectors_work_items = int (math .ceil (size / input_vectors_per_work_item ))
18531886 input_vectors_work_tasks = input_vectors_work_items
18541887 input_vectors_work_items_per_worker = 1
1855- if input_vectors_work_tasks > MAX_TASKS_PER_STAGE :
1888+ if max_tasks_per_stage == - 1 :
1889+ max_tasks_per_stage = MAX_TASKS_PER_STAGE
1890+ if input_vectors_work_tasks > max_tasks_per_stage :
18561891 input_vectors_work_items_per_worker = int (
1857- math .ceil (input_vectors_work_items / MAX_TASKS_PER_STAGE )
1892+ math .ceil (input_vectors_work_items / max_tasks_per_stage )
18581893 )
1859- input_vectors_work_tasks = MAX_TASKS_PER_STAGE
1894+ input_vectors_work_tasks = max_tasks_per_stage
18601895 logger .debug ("input_vectors_per_work_item %d" , input_vectors_per_work_item )
18611896 logger .debug ("input_vectors_work_items %d" , input_vectors_work_items )
18621897 logger .debug ("input_vectors_work_tasks %d" , input_vectors_work_tasks )
@@ -1875,11 +1910,11 @@ def consolidate_and_vacuum(
18751910 )
18761911 table_partitions_work_tasks = table_partitions_work_items
18771912 table_partitions_work_items_per_worker = 1
1878- if table_partitions_work_tasks > MAX_TASKS_PER_STAGE :
1913+ if table_partitions_work_tasks > max_tasks_per_stage :
18791914 table_partitions_work_items_per_worker = int (
1880- math .ceil (table_partitions_work_items / MAX_TASKS_PER_STAGE )
1915+ math .ceil (table_partitions_work_items / max_tasks_per_stage )
18811916 )
1882- table_partitions_work_tasks = MAX_TASKS_PER_STAGE
1917+ table_partitions_work_tasks = max_tasks_per_stage
18831918 logger .debug (
18841919 "table_partitions_per_work_item %d" , table_partitions_per_work_item
18851920 )
@@ -1897,7 +1932,7 @@ def consolidate_and_vacuum(
18971932 index_type = index_type ,
18981933 size = size ,
18991934 dimensions = dimensions ,
1900- input_vectors_work_tasks = input_vectors_work_tasks ,
1935+ input_vectors_work_items = input_vectors_work_items ,
19011936 vector_type = vector_type ,
19021937 logger = logger ,
19031938 )
0 commit comments