@@ -69,6 +69,7 @@ def ingest(
6969 import logging
7070 import math
7171 from typing import Any , Mapping , Optional
72+ import multiprocessing
7273
7374 import numpy as np
7475
@@ -832,11 +833,11 @@ def consolidate_partition_udf(
832833 ):
833834 logger = setup (config , verbose )
834835 with tiledb .scope_ctx (ctx_or_config = config ):
836+ logger .info (f"Consolidating partitions { partition_id_start } -{ partition_id_end } " )
835837 group = tiledb .Group (array_uri )
836838 partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
837839 partial_write_array_ids_uri = partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
838840 partial_write_array_parts_uri = partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
839- logger .info ("Consolidating array" )
840841 index_array_uri = group [INDEX_ARRAY_NAME ].uri
841842 ids_array_uri = group [IDS_ARRAY_NAME ].uri
842843 parts_array_uri = group [PARTS_ARRAY_NAME ].uri
@@ -862,26 +863,36 @@ def consolidate_partition_udf(
862863 index_array = tiledb .open (index_array_uri , mode = "r" )
863864 ids_array = tiledb .open (ids_array_uri , mode = "w" )
864865 parts_array = tiledb .open (parts_array_uri , mode = "w" )
865- read_slices = []
866- for part in range (partition_id_start , partition_id_end ):
867- for partition_slice in partition_slices [part ]:
868- read_slices .append (partition_slice )
869-
870- logger .debug (f"Read slices: { read_slices } " )
871- ids = partial_write_array_ids_array .multi_index [read_slices ]["values" ]
872- vectors = partial_write_array_parts_array .multi_index [:, read_slices ]["values" ]
873- start_pos = int (index_array [partition_id_start ]["values" ])
874- end_pos = int (index_array [partition_id_end ]["values" ])
875-
876- logger .debug (
877- f"Ids shape { ids .shape } , expected size: { end_pos - start_pos } expected range:({ start_pos } ,{ end_pos } )" )
878- if ids .shape [0 ] != end_pos - start_pos :
879- raise ValueError ("Incorrect partition size." )
880-
881- logger .info (f"Writing data to array: { parts_array_uri } " )
882- parts_array [:, start_pos :end_pos ] = vectors
883- logger .info (f"Writing data to array: { ids_array_uri } " )
884- ids_array [start_pos :end_pos ] = ids
866+ logger .info (
867+ f"Partitions start: { partition_id_start } end: { partition_id_end } "
868+ )
869+ for part in range (partition_id_start , partition_id_end , batch ):
870+ part_end = part + batch
871+ if part_end > partition_id_end :
872+ part_end = partition_id_end
873+ logger .info (f"Consolidating partitions start: { part } end: { part_end } " )
874+ read_slices = []
875+ for p in range (part , part_end ):
876+ for partition_slice in partition_slices [p ]:
877+ read_slices .append (partition_slice )
878+
879+ logger .debug (f"Read slices: { read_slices } " )
880+ ids = partial_write_array_ids_array .multi_index [read_slices ]["values" ]
881+ vectors = partial_write_array_parts_array .multi_index [:, read_slices ]["values" ]
882+ start_pos = int (index_array [part ]["values" ])
883+ end_pos = int (index_array [part_end ]["values" ])
884+
885+ logger .debug (
886+ f"Ids shape { ids .shape } , expected size: { end_pos - start_pos } expected range:({ start_pos } ,{ end_pos } )" )
887+ if ids .shape [0 ] != end_pos - start_pos :
888+ raise ValueError ("Incorrect partition size." )
889+
890+ logger .info (f"Writing data to array: { parts_array_uri } " )
891+ parts_array [:, start_pos :end_pos ] = vectors
892+ logger .info (f"Writing data to array: { ids_array_uri } " )
893+ ids_array [start_pos :end_pos ] = ids
894+ parts_array .close ()
895+ ids_array .close ()
885896
886897 # --------------------------------------------------------------------
887898 # DAG
@@ -918,13 +929,15 @@ def create_ingestion_dag(
918929 retry_policy = "Always" ,
919930 ),
920931 )
932+ threads = 8
921933 else :
922934 d = dag .DAG (
923935 name = "vector-ingestion" ,
924936 mode = Mode .REALTIME ,
925937 max_workers = workers ,
926938 namespace = "default" ,
927939 )
940+ threads = multiprocessing .cpu_count ()
928941
929942 submit = d .submit_local
930943 if mode == Mode .BATCH or mode == Mode .REALTIME :
@@ -1024,12 +1037,12 @@ def create_ingestion_dag(
10241037 dimensions = dimensions ,
10251038 vector_start_pos = start ,
10261039 vector_end_pos = end ,
1027- threads = 8 ,
1040+ threads = threads ,
10281041 config = config ,
10291042 verbose = verbose ,
10301043 trace_id = trace_id ,
10311044 name = "k-means-part-" + str (task_id ),
1032- resources = {"cpu" : "8" , "memory" : "12Gi" },
1045+ resources = {"cpu" : str ( threads ) , "memory" : "12Gi" },
10331046 )
10341047 )
10351048 task_id += 1
@@ -1090,12 +1103,12 @@ def create_ingestion_dag(
10901103 start = start ,
10911104 end = end ,
10921105 batch = input_vectors_per_work_item ,
1093- threads = 6 ,
1106+ threads = threads ,
10941107 config = config ,
10951108 verbose = verbose ,
10961109 trace_id = trace_id ,
10971110 name = "ingest-" + str (task_id ),
1098- resources = {"cpu" : "6" , "memory" : "32Gi" },
1111+ resources = {"cpu" : str ( threads ) , "memory" : "32Gi" },
10991112 )
11001113 ingest_node .depends_on (centroids_node )
11011114 compute_indexes_node .depends_on (ingest_node )
0 commit comments