@@ -223,7 +223,7 @@ def create_arrays(
223223 )
224224 logger .debug (parts_schema )
225225 tiledb .Array .create (parts_uri , parts_schema )
226- group .add (PARTS_ARRAY_NAME , name = PARTS_ARRAY_NAME , relative = True )
226+ group .add (parts_uri , name = PARTS_ARRAY_NAME )
227227
228228 elif index_type == "IVF_FLAT" :
229229 centroids_uri = f"{ group .uri } /{ CENTROIDS_ARRAY_NAME } "
@@ -272,7 +272,7 @@ def create_arrays(
272272 logger .debug (centroids_schema )
273273 tiledb .Array .create (centroids_uri , centroids_schema )
274274 group .add (
275- CENTROIDS_ARRAY_NAME , name = CENTROIDS_ARRAY_NAME , relative = True
275+ centroids_uri , name = CENTROIDS_ARRAY_NAME
276276 )
277277
278278 if not tiledb .array_exists (index_uri ):
@@ -295,7 +295,7 @@ def create_arrays(
295295 )
296296 logger .debug (index_schema )
297297 tiledb .Array .create (index_uri , index_schema )
298- group .add (INDEX_ARRAY_NAME , name = INDEX_ARRAY_NAME , relative = True )
298+ group .add (index_uri , name = INDEX_ARRAY_NAME )
299299
300300 if not tiledb .array_exists (ids_uri ):
301301 logger .debug ("Creating ids array" )
@@ -317,7 +317,7 @@ def create_arrays(
317317 )
318318 logger .debug (ids_schema )
319319 tiledb .Array .create (ids_uri , ids_schema )
320- group .add (IDS_ARRAY_NAME , name = IDS_ARRAY_NAME , relative = True )
320+ group .add (ids_uri , name = IDS_ARRAY_NAME )
321321
322322 if not tiledb .array_exists (parts_uri ):
323323 logger .debug ("Creating parts array" )
@@ -347,15 +347,27 @@ def create_arrays(
347347 )
348348 logger .debug (parts_schema )
349349 tiledb .Array .create (parts_uri , parts_schema )
350- group .add (PARTS_ARRAY_NAME , name = PARTS_ARRAY_NAME , relative = True )
351-
352- vfs = tiledb .VFS ()
353- if vfs .is_dir (partial_write_array_dir_uri ):
354- vfs .remove_dir (partial_write_array_dir_uri )
355- vfs .create_dir (partial_write_array_dir_uri )
356- if vfs .is_dir (partial_write_array_index_uri ):
357- vfs .remove_dir (partial_write_array_index_uri )
358- vfs .create_dir (partial_write_array_index_uri )
350+ group .add (parts_uri , name = PARTS_ARRAY_NAME )
351+
352+ try :
353+ tiledb .group_create (partial_write_array_dir_uri )
354+ except tiledb .TileDBError as err :
355+ message = str (err )
356+ if "already exists" in message :
357+ logger .debug (f"Group '{ partial_write_array_dir_uri } ' already exists" )
358+ raise err
359+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri , "w" )
360+ group .add (partial_write_array_dir_uri , name = PARTIAL_WRITE_ARRAY_DIR )
361+
362+ try :
363+ tiledb .group_create (partial_write_array_index_uri )
364+ except tiledb .TileDBError as err :
365+ message = str (err )
366+ if "already exists" in message :
367+ logger .debug (f"Group '{ partial_write_array_index_uri } ' already exists" )
368+ raise err
369+ partial_write_array_group .add (partial_write_array_index_uri , name = INDEX_ARRAY_NAME )
370+ partial_write_array_index_group = tiledb .Group (partial_write_array_index_uri , "w" )
359371
360372 if not tiledb .array_exists (partial_write_array_ids_uri ):
361373 logger .debug ("Creating temp ids array" )
@@ -377,6 +389,7 @@ def create_arrays(
377389 )
378390 logger .debug (ids_schema )
379391 tiledb .Array .create (partial_write_array_ids_uri , ids_schema )
392+ partial_write_array_group .add (partial_write_array_ids_uri , name = IDS_ARRAY_NAME )
380393
381394 if not tiledb .array_exists (partial_write_array_parts_uri ):
382395 logger .debug ("Creating temp parts array" )
@@ -407,6 +420,34 @@ def create_arrays(
407420 logger .debug (parts_schema )
408421 logger .debug (partial_write_array_parts_uri )
409422 tiledb .Array .create (partial_write_array_parts_uri , parts_schema )
423+ partial_write_array_group .add (partial_write_array_parts_uri , name = PARTS_ARRAY_NAME )
424+
425+ for part in range (1 ):
426+ part_index_uri = partial_write_array_index_uri + "/" + str (part )
427+ if not tiledb .array_exists (part_index_uri ):
428+ logger .debug (f"Creating part array { part_index_uri } " )
429+ index_array_rows_dim = tiledb .Dim (
430+ name = "rows" ,
431+ domain = (0 , partitions ),
432+ tile = partitions ,
433+ dtype = np .dtype (np .int32 ),
434+ )
435+ index_array_dom = tiledb .Domain (index_array_rows_dim )
436+ index_attr = tiledb .Attr (name = "values" , dtype = np .dtype (np .uint64 ))
437+ index_schema = tiledb .ArraySchema (
438+ domain = index_array_dom ,
439+ sparse = False ,
440+ attrs = [index_attr ],
441+ capacity = partitions ,
442+ cell_order = "col-major" ,
443+ tile_order = "col-major" ,
444+ )
445+ logger .debug (index_schema )
446+ tiledb .Array .create (part_index_uri , index_schema )
447+ partial_write_array_index_group .add (part_index_uri , name = str (part ))
448+ partial_write_array_group .close ()
449+ partial_write_array_index_group .close ()
450+
410451 else :
411452 raise ValueError (f"Not supported index_type { index_type } " )
412453
@@ -822,21 +863,21 @@ def ingest_vectors_udf(
822863 logger = setup (config , verbose )
823864 group = tiledb .Group (array_uri )
824865 centroids_uri = group [CENTROIDS_ARRAY_NAME ].uri
825- partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
826- partial_write_array_ids_uri = partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
827- partial_write_array_parts_uri = (
828- partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
829- )
866+ partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
867+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
868+ partial_write_array_ids_uri = partial_write_array_group [IDS_ARRAY_NAME ].uri
869+ partial_write_array_parts_uri = partial_write_array_group [PARTS_ARRAY_NAME ].uri
870+ partial_write_array_index_dir_uri = partial_write_array_group [INDEX_ARRAY_NAME ].uri
871+ partial_write_array_index_group = tiledb .Group (partial_write_array_index_dir_uri )
830872
831873 for part in range (start , end , batch ):
832874 part_end = part + batch
833875 if part_end > end :
834876 part_end = end
835877
836878 part_name = str (part ) + "-" + str (part_end )
837- partial_write_array_index_uri = (
838- partial_write_array_dir_uri + "/" + INDEX_ARRAY_NAME + "/" + part_name
839- )
879+
880+ partial_write_array_index_uri = partial_write_array_index_group [str (int (start / batch ))].uri
840881 logger .debug ("Input vectors start_pos: %d, end_pos: %d" , part , part_end )
841882 if source_type == "TILEDB_ARRAY" :
842883 logger .debug ("Start indexing" )
@@ -889,12 +930,14 @@ def compute_partition_indexes_udf(
889930 with tiledb .scope_ctx (ctx_or_config = config ):
890931 group = tiledb .Group (array_uri )
891932 index_array_uri = group [INDEX_ARRAY_NAME ].uri
892- vfs = tiledb .VFS ()
933+ partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
934+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
935+ partial_write_array_index_dir_uri = partial_write_array_group [INDEX_ARRAY_NAME ].uri
936+ partial_write_array_index_group = tiledb .Group (partial_write_array_index_dir_uri )
893937 partition_sizes = np .zeros (partitions )
894938 indexes = np .zeros (partitions + 1 ).astype (np .uint64 )
895- for partial_index_array_uri in vfs .ls (
896- array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
897- ):
939+ for part in partial_write_array_index_group :
940+ partial_index_array_uri = part .uri
898941 if tiledb .array_exists (partial_index_array_uri ):
899942 partial_index_array = tiledb .open (partial_index_array_uri , mode = "r" )
900943 partial_indexes = partial_index_array [:]["values" ]
@@ -912,7 +955,7 @@ def compute_partition_indexes_udf(
912955 _sum += partition_size
913956 i += 1
914957 indexes [i ] = _sum
915- logger .debug ("Partition indexes: %d" , indexes )
958+ logger .debug (f "Partition indexes: { indexes } " )
916959 index_array = tiledb .open (index_array_uri , mode = "w" )
917960 index_array [:] = indexes
918961
@@ -932,23 +975,21 @@ def consolidate_partition_udf(
932975 "Consolidating partitions %d-%d" , partition_id_start , partition_id_end
933976 )
934977 group = tiledb .Group (array_uri )
935- partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
936- partial_write_array_ids_uri = (
937- partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
938- )
939- partial_write_array_parts_uri = (
940- partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
941- )
978+ partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
979+ partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
980+ partial_write_array_ids_uri = partial_write_array_group [IDS_ARRAY_NAME ].uri
981+ partial_write_array_parts_uri = partial_write_array_group [PARTS_ARRAY_NAME ].uri
982+ partial_write_array_index_dir_uri = partial_write_array_group [INDEX_ARRAY_NAME ].uri
983+ partial_write_array_index_group = tiledb .Group (partial_write_array_index_dir_uri )
942984 index_array_uri = group [INDEX_ARRAY_NAME ].uri
943985 ids_array_uri = group [IDS_ARRAY_NAME ].uri
944986 parts_array_uri = group [PARTS_ARRAY_NAME ].uri
945987 vfs = tiledb .VFS ()
946988 partition_slices = []
947989 for i in range (partitions ):
948990 partition_slices .append ([])
949- for partial_index_array_uri in vfs .ls (
950- array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
951- ):
991+ for part in partial_write_array_index_group :
992+ partial_index_array_uri = part .uri
952993 if tiledb .array_exists (partial_index_array_uri ):
953994 partial_index_array = tiledb .open (partial_index_array_uri , mode = "r" )
954995 partial_indexes = partial_index_array [:]["values" ]
@@ -1294,12 +1335,14 @@ def consolidate_and_vacuum(
12941335 tiledb .consolidate (group [IDS_ARRAY_NAME ].uri , config = conf )
12951336 tiledb .vacuum (group [IDS_ARRAY_NAME ].uri , config = conf )
12961337
1297- vfs = tiledb .VFS (config )
1298- partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1299- if vfs .is_dir (partial_write_array_dir_uri ):
1300- vfs .remove_dir (partial_write_array_dir_uri )
1338+ # TODO remove temp data for tiledb URIs
1339+ if not array_uri .startswith ("tiledb://" ):
1340+ vfs = tiledb .VFS (config )
1341+ partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1342+ if vfs .is_dir (partial_write_array_dir_uri ):
1343+ vfs .remove_dir (partial_write_array_dir_uri )
13011344
1302- with tiledb .scope_ctx (ctx_or_config = config ):
1345+ with tiledb .scope_ctx (ctx_or_config = config ) as ctx :
13031346 logger = setup (config , verbose )
13041347 logger .debug ("Ingesting Vectors into %r" , array_uri )
13051348 try :
@@ -1419,6 +1462,6 @@ def consolidate_and_vacuum(
14191462 consolidate_and_vacuum (array_uri = array_uri , config = config )
14201463
14211464 if index_type == "FLAT" :
1422- return FlatIndex (uri = array_uri , dtype = vector_type )
1465+ return FlatIndex (uri = array_uri , dtype = vector_type , ctx = ctx )
14231466 elif index_type == "IVF_FLAT" :
1424- return IVFFlatIndex (uri = array_uri , dtype = vector_type , memory_budget = 1000000 )
1467+ return IVFFlatIndex (uri = array_uri , dtype = vector_type , memory_budget = 1000000 , ctx = ctx )
0 commit comments