22from functools import partial
33
44from tiledb .cloud .dag import Mode
5- from tiledb .vector_search .index import FlatIndex
6- from tiledb .vector_search .index import IVFFlatIndex
5+ from tiledb .vector_search .index import FlatIndex , IVFFlatIndex , Index
76
87
98def ingest (
@@ -23,7 +22,7 @@ def ingest(
2322 verbose : bool = False ,
2423 trace_id : Optional [str ] = None ,
2524 mode : Mode = Mode .LOCAL ,
26- ) -> FlatIndex :
25+ ) -> Index :
2726 """
2827 Ingest vectors into TileDB.
2928
@@ -189,7 +188,7 @@ def create_arrays(
189188 size : int ,
190189 dimensions : int ,
191190 partitions : int ,
192- input_vectors_work_items : int ,
191+ input_vectors_work_tasks : int ,
193192 vector_type : np .dtype ,
194193 logger : logging .Logger ,
195194 ) -> None :
@@ -223,7 +222,7 @@ def create_arrays(
223222 )
224223 logger .debug (parts_schema )
225224 tiledb .Array .create (parts_uri , parts_schema )
226- group .add (PARTS_ARRAY_NAME , name = PARTS_ARRAY_NAME , relative = True )
225+ group .add (parts_uri , name = PARTS_ARRAY_NAME )
227226
228227 elif index_type == "IVF_FLAT" :
229228 centroids_uri = f"{ group .uri } /{ CENTROIDS_ARRAY_NAME } "
@@ -271,9 +270,7 @@ def create_arrays(
271270 )
272271 logger .debug (centroids_schema )
273272 tiledb .Array .create (centroids_uri , centroids_schema )
274- group .add (
275- CENTROIDS_ARRAY_NAME , name = CENTROIDS_ARRAY_NAME , relative = True
276- )
273+ group .add (centroids_uri , name = CENTROIDS_ARRAY_NAME )
277274
278275 if not tiledb .array_exists (index_uri ):
279276 logger .debug ("Creating index array" )
@@ -295,7 +292,7 @@ def create_arrays(
295292 )
296293 logger .debug (index_schema )
297294 tiledb .Array .create (index_uri , index_schema )
298- group .add (INDEX_ARRAY_NAME , name = INDEX_ARRAY_NAME , relative = True )
295+ group .add (index_uri , name = INDEX_ARRAY_NAME )
299296
300297 if not tiledb .array_exists (ids_uri ):
301298 logger .debug ("Creating ids array" )
@@ -317,7 +314,7 @@ def create_arrays(
317314 )
318315 logger .debug (ids_schema )
319316 tiledb .Array .create (ids_uri , ids_schema )
320- group .add (IDS_ARRAY_NAME , name = IDS_ARRAY_NAME , relative = True )
317+ group .add (ids_uri , name = IDS_ARRAY_NAME )
321318
322319 if not tiledb .array_exists (parts_uri ):
323320 logger .debug ("Creating parts array" )
@@ -347,15 +344,35 @@ def create_arrays(
347344 )
348345 logger .debug (parts_schema )
349346 tiledb .Array .create (parts_uri , parts_schema )
350- group .add (PARTS_ARRAY_NAME , name = PARTS_ARRAY_NAME , relative = True )
351-
352- vfs = tiledb .VFS ()
353- if vfs .is_dir (partial_write_array_dir_uri ):
354- vfs .remove_dir (partial_write_array_dir_uri )
355- vfs .create_dir (partial_write_array_dir_uri )
356- if vfs .is_dir (partial_write_array_index_uri ):
357- vfs .remove_dir (partial_write_array_index_uri )
358- vfs .create_dir (partial_write_array_index_uri )
347+ group .add (parts_uri , name = PARTS_ARRAY_NAME )
348+
349+ try :
350+ tiledb .group_create (partial_write_array_dir_uri )
351+ except tiledb .TileDBError as err :
352+ message = str (err )
353+ if "already exists" in message :
354+ logger .debug (
355+ f"Group '{ partial_write_array_dir_uri } ' already exists"
356+ )
357+ raise err
358+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri , "w" )
359+ group .add (partial_write_array_dir_uri , name = PARTIAL_WRITE_ARRAY_DIR )
360+
361+ try :
362+ tiledb .group_create (partial_write_array_index_uri )
363+ except tiledb .TileDBError as err :
364+ message = str (err )
365+ if "already exists" in message :
366+ logger .debug (
367+ f"Group '{ partial_write_array_index_uri } ' already exists"
368+ )
369+ raise err
370+ partial_write_array_group .add (
371+ partial_write_array_index_uri , name = INDEX_ARRAY_NAME
372+ )
373+ partial_write_array_index_group = tiledb .Group (
374+ partial_write_array_index_uri , "w"
375+ )
359376
360377 if not tiledb .array_exists (partial_write_array_ids_uri ):
361378 logger .debug ("Creating temp ids array" )
@@ -377,6 +394,9 @@ def create_arrays(
377394 )
378395 logger .debug (ids_schema )
379396 tiledb .Array .create (partial_write_array_ids_uri , ids_schema )
397+ partial_write_array_group .add (
398+ partial_write_array_ids_uri , name = IDS_ARRAY_NAME
399+ )
380400
381401 if not tiledb .array_exists (partial_write_array_parts_uri ):
382402 logger .debug ("Creating temp parts array" )
@@ -407,6 +427,36 @@ def create_arrays(
407427 logger .debug (parts_schema )
408428 logger .debug (partial_write_array_parts_uri )
409429 tiledb .Array .create (partial_write_array_parts_uri , parts_schema )
430+ partial_write_array_group .add (
431+ partial_write_array_parts_uri , name = PARTS_ARRAY_NAME
432+ )
433+
434+ for part in range (input_vectors_work_tasks ):
435+ part_index_uri = partial_write_array_index_uri + "/" + str (part )
436+ if not tiledb .array_exists (part_index_uri ):
437+ logger .debug (f"Creating part array { part_index_uri } " )
438+ index_array_rows_dim = tiledb .Dim (
439+ name = "rows" ,
440+ domain = (0 , partitions ),
441+ tile = partitions ,
442+ dtype = np .dtype (np .int32 ),
443+ )
444+ index_array_dom = tiledb .Domain (index_array_rows_dim )
445+ index_attr = tiledb .Attr (name = "values" , dtype = np .dtype (np .uint64 ))
446+ index_schema = tiledb .ArraySchema (
447+ domain = index_array_dom ,
448+ sparse = False ,
449+ attrs = [index_attr ],
450+ capacity = partitions ,
451+ cell_order = "col-major" ,
452+ tile_order = "col-major" ,
453+ )
454+ logger .debug (index_schema )
455+ tiledb .Array .create (part_index_uri , index_schema )
456+ partial_write_array_index_group .add (part_index_uri , name = str (part ))
457+ partial_write_array_group .close ()
458+ partial_write_array_index_group .close ()
459+
410460 else :
411461 raise ValueError (f"Not supported index_type { index_type } " )
412462
@@ -822,10 +872,15 @@ def ingest_vectors_udf(
822872 logger = setup (config , verbose )
823873 group = tiledb .Group (array_uri )
824874 centroids_uri = group [CENTROIDS_ARRAY_NAME ].uri
825- partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
826- partial_write_array_ids_uri = partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
827- partial_write_array_parts_uri = (
828- partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
875+ partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
876+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
877+ partial_write_array_ids_uri = partial_write_array_group [IDS_ARRAY_NAME ].uri
878+ partial_write_array_parts_uri = partial_write_array_group [PARTS_ARRAY_NAME ].uri
879+ partial_write_array_index_dir_uri = partial_write_array_group [
880+ INDEX_ARRAY_NAME
881+ ].uri
882+ partial_write_array_index_group = tiledb .Group (
883+ partial_write_array_index_dir_uri
829884 )
830885
831886 for part in range (start , end , batch ):
@@ -834,9 +889,10 @@ def ingest_vectors_udf(
834889 part_end = end
835890
836891 part_name = str (part ) + "-" + str (part_end )
837- partial_write_array_index_uri = (
838- partial_write_array_dir_uri + "/" + INDEX_ARRAY_NAME + "/" + part_name
839- )
892+
893+ partial_write_array_index_uri = partial_write_array_index_group [
894+ str (int (start / batch ))
895+ ].uri
840896 logger .debug ("Input vectors start_pos: %d, end_pos: %d" , part , part_end )
841897 if source_type == "TILEDB_ARRAY" :
842898 logger .debug ("Start indexing" )
@@ -889,12 +945,18 @@ def compute_partition_indexes_udf(
889945 with tiledb .scope_ctx (ctx_or_config = config ):
890946 group = tiledb .Group (array_uri )
891947 index_array_uri = group [INDEX_ARRAY_NAME ].uri
892- vfs = tiledb .VFS ()
948+ partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
949+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
950+ partial_write_array_index_dir_uri = partial_write_array_group [
951+ INDEX_ARRAY_NAME
952+ ].uri
953+ partial_write_array_index_group = tiledb .Group (
954+ partial_write_array_index_dir_uri
955+ )
893956 partition_sizes = np .zeros (partitions )
894957 indexes = np .zeros (partitions + 1 ).astype (np .uint64 )
895- for partial_index_array_uri in vfs .ls (
896- array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
897- ):
958+ for part in partial_write_array_index_group :
959+ partial_index_array_uri = part .uri
898960 if tiledb .array_exists (partial_index_array_uri ):
899961 partial_index_array = tiledb .open (partial_index_array_uri , mode = "r" )
900962 partial_indexes = partial_index_array [:]["values" ]
@@ -912,7 +974,7 @@ def compute_partition_indexes_udf(
912974 _sum += partition_size
913975 i += 1
914976 indexes [i ] = _sum
915- logger .debug ("Partition indexes: %d" , indexes )
977+ logger .debug (f "Partition indexes: { indexes } " )
916978 index_array = tiledb .open (index_array_uri , mode = "w" )
917979 index_array [:] = indexes
918980
@@ -932,12 +994,17 @@ def consolidate_partition_udf(
932994 "Consolidating partitions %d-%d" , partition_id_start , partition_id_end
933995 )
934996 group = tiledb .Group (array_uri )
935- partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
936- partial_write_array_ids_uri = (
937- partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
938- )
939- partial_write_array_parts_uri = (
940- partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
997+ partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
998+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
999+ partial_write_array_ids_uri = partial_write_array_group [IDS_ARRAY_NAME ].uri
1000+ partial_write_array_parts_uri = partial_write_array_group [
1001+ PARTS_ARRAY_NAME
1002+ ].uri
1003+ partial_write_array_index_dir_uri = partial_write_array_group [
1004+ INDEX_ARRAY_NAME
1005+ ].uri
1006+ partial_write_array_index_group = tiledb .Group (
1007+ partial_write_array_index_dir_uri
9411008 )
9421009 index_array_uri = group [INDEX_ARRAY_NAME ].uri
9431010 ids_array_uri = group [IDS_ARRAY_NAME ].uri
@@ -946,9 +1013,8 @@ def consolidate_partition_udf(
9461013 partition_slices = []
9471014 for i in range (partitions ):
9481015 partition_slices .append ([])
949- for partial_index_array_uri in vfs .ls (
950- array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
951- ):
1016+ for part in partial_write_array_index_group :
1017+ partial_index_array_uri = part .uri
9521018 if tiledb .array_exists (partial_index_array_uri ):
9531019 partial_index_array = tiledb .open (partial_index_array_uri , mode = "r" )
9541020 partial_indexes = partial_index_array [:]["values" ]
@@ -1294,10 +1360,12 @@ def consolidate_and_vacuum(
12941360 tiledb .consolidate (group [IDS_ARRAY_NAME ].uri , config = conf )
12951361 tiledb .vacuum (group [IDS_ARRAY_NAME ].uri , config = conf )
12961362
1297- vfs = tiledb .VFS (config )
1298- partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1299- if vfs .is_dir (partial_write_array_dir_uri ):
1300- vfs .remove_dir (partial_write_array_dir_uri )
1363+ # TODO remove temp data for tiledb URIs
1364+ if not array_uri .startswith ("tiledb://" ):
1365+ vfs = tiledb .VFS (config )
1366+ partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1367+ if vfs .is_dir (partial_write_array_dir_uri ):
1368+ vfs .remove_dir (partial_write_array_dir_uri )
13011369
13021370 with tiledb .scope_ctx (ctx_or_config = config ):
13031371 logger = setup (config , verbose )
@@ -1384,7 +1452,7 @@ def consolidate_and_vacuum(
13841452 size = size ,
13851453 dimensions = dimensions ,
13861454 partitions = partitions ,
1387- input_vectors_work_items = input_vectors_work_items ,
1455+ input_vectors_work_tasks = input_vectors_work_tasks ,
13881456 vector_type = vector_type ,
13891457 logger = logger ,
13901458 )
@@ -1419,6 +1487,8 @@ def consolidate_and_vacuum(
14191487 consolidate_and_vacuum (array_uri = array_uri , config = config )
14201488
14211489 if index_type == "FLAT" :
1422- return FlatIndex (uri = array_uri , dtype = vector_type )
1490+ return FlatIndex (uri = array_uri , dtype = vector_type , config = config )
14231491 elif index_type == "IVF_FLAT" :
1424- return IVFFlatIndex (uri = array_uri , dtype = vector_type , memory_budget = 1000000 )
1492+ return IVFFlatIndex (
1493+ uri = array_uri , dtype = vector_type , memory_budget = 1000000 , config = config
1494+ )
0 commit comments