22from functools import partial
33
44from tiledb .cloud .dag import Mode
5- from tiledb .vector_search .index import FlatIndex
6- from tiledb .vector_search .index import IVFFlatIndex
5+ from tiledb .vector_search .index import FlatIndex , IVFFlatIndex , Index
76
87
98def ingest (
@@ -23,7 +22,7 @@ def ingest(
2322 verbose : bool = False ,
2423 trace_id : Optional [str ] = None ,
2524 mode : Mode = Mode .LOCAL ,
26- ) -> FlatIndex :
25+ ) -> Index :
2726 """
2827 Ingest vectors into TileDB.
2928
@@ -271,9 +270,7 @@ def create_arrays(
271270 )
272271 logger .debug (centroids_schema )
273272 tiledb .Array .create (centroids_uri , centroids_schema )
274- group .add (
275- centroids_uri , name = CENTROIDS_ARRAY_NAME
276- )
273+ group .add (centroids_uri , name = CENTROIDS_ARRAY_NAME )
277274
278275 if not tiledb .array_exists (index_uri ):
279276 logger .debug ("Creating index array" )
@@ -354,7 +351,9 @@ def create_arrays(
354351 except tiledb .TileDBError as err :
355352 message = str (err )
356353 if "already exists" in message :
357- logger .debug (f"Group '{ partial_write_array_dir_uri } ' already exists" )
354+ logger .debug (
355+ f"Group '{ partial_write_array_dir_uri } ' already exists"
356+ )
358357 raise err
359358 partial_write_array_group = tiledb .Group (partial_write_array_dir_uri , "w" )
360359 group .add (partial_write_array_dir_uri , name = PARTIAL_WRITE_ARRAY_DIR )
@@ -364,10 +363,16 @@ def create_arrays(
364363 except tiledb .TileDBError as err :
365364 message = str (err )
366365 if "already exists" in message :
367- logger .debug (f"Group '{ partial_write_array_index_uri } ' already exists" )
366+ logger .debug (
367+ f"Group '{ partial_write_array_index_uri } ' already exists"
368+ )
368369 raise err
369- partial_write_array_group .add (partial_write_array_index_uri , name = INDEX_ARRAY_NAME )
370- partial_write_array_index_group = tiledb .Group (partial_write_array_index_uri , "w" )
370+ partial_write_array_group .add (
371+ partial_write_array_index_uri , name = INDEX_ARRAY_NAME
372+ )
373+ partial_write_array_index_group = tiledb .Group (
374+ partial_write_array_index_uri , "w"
375+ )
371376
372377 if not tiledb .array_exists (partial_write_array_ids_uri ):
373378 logger .debug ("Creating temp ids array" )
@@ -389,7 +394,9 @@ def create_arrays(
389394 )
390395 logger .debug (ids_schema )
391396 tiledb .Array .create (partial_write_array_ids_uri , ids_schema )
392- partial_write_array_group .add (partial_write_array_ids_uri , name = IDS_ARRAY_NAME )
397+ partial_write_array_group .add (
398+ partial_write_array_ids_uri , name = IDS_ARRAY_NAME
399+ )
393400
394401 if not tiledb .array_exists (partial_write_array_parts_uri ):
395402 logger .debug ("Creating temp parts array" )
@@ -420,10 +427,12 @@ def create_arrays(
420427 logger .debug (parts_schema )
421428 logger .debug (partial_write_array_parts_uri )
422429 tiledb .Array .create (partial_write_array_parts_uri , parts_schema )
423- partial_write_array_group .add (partial_write_array_parts_uri , name = PARTS_ARRAY_NAME )
430+ partial_write_array_group .add (
431+ partial_write_array_parts_uri , name = PARTS_ARRAY_NAME
432+ )
424433
425434 for part in range (input_vectors_work_tasks ):
426- part_index_uri = partial_write_array_index_uri + "/" + str (part )
435+ part_index_uri = partial_write_array_index_uri + "/" + str (part )
427436 if not tiledb .array_exists (part_index_uri ):
428437 logger .debug (f"Creating part array { part_index_uri } " )
429438 index_array_rows_dim = tiledb .Dim (
@@ -867,8 +876,12 @@ def ingest_vectors_udf(
867876 partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
868877 partial_write_array_ids_uri = partial_write_array_group [IDS_ARRAY_NAME ].uri
869878 partial_write_array_parts_uri = partial_write_array_group [PARTS_ARRAY_NAME ].uri
870- partial_write_array_index_dir_uri = partial_write_array_group [INDEX_ARRAY_NAME ].uri
871- partial_write_array_index_group = tiledb .Group (partial_write_array_index_dir_uri )
879+ partial_write_array_index_dir_uri = partial_write_array_group [
880+ INDEX_ARRAY_NAME
881+ ].uri
882+ partial_write_array_index_group = tiledb .Group (
883+ partial_write_array_index_dir_uri
884+ )
872885
873886 for part in range (start , end , batch ):
874887 part_end = part + batch
@@ -877,7 +890,9 @@ def ingest_vectors_udf(
877890
878891 part_name = str (part ) + "-" + str (part_end )
879892
880- partial_write_array_index_uri = partial_write_array_index_group [str (int (start / batch ))].uri
893+ partial_write_array_index_uri = partial_write_array_index_group [
894+ str (int (start / batch ))
895+ ].uri
881896 logger .debug ("Input vectors start_pos: %d, end_pos: %d" , part , part_end )
882897 if source_type == "TILEDB_ARRAY" :
883898 logger .debug ("Start indexing" )
@@ -932,8 +947,12 @@ def compute_partition_indexes_udf(
932947 index_array_uri = group [INDEX_ARRAY_NAME ].uri
933948 partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
934949 partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
935- partial_write_array_index_dir_uri = partial_write_array_group [INDEX_ARRAY_NAME ].uri
936- partial_write_array_index_group = tiledb .Group (partial_write_array_index_dir_uri )
950+ partial_write_array_index_dir_uri = partial_write_array_group [
951+ INDEX_ARRAY_NAME
952+ ].uri
953+ partial_write_array_index_group = tiledb .Group (
954+ partial_write_array_index_dir_uri
955+ )
937956 partition_sizes = np .zeros (partitions )
938957 indexes = np .zeros (partitions + 1 ).astype (np .uint64 )
939958 for part in partial_write_array_index_group :
@@ -978,9 +997,15 @@ def consolidate_partition_udf(
978997 partial_write_array_dir_uri = group [PARTIAL_WRITE_ARRAY_DIR ].uri
979998 partial_write_array_group = tiledb .Group (partial_write_array_dir_uri )
980999 partial_write_array_ids_uri = partial_write_array_group [IDS_ARRAY_NAME ].uri
981- partial_write_array_parts_uri = partial_write_array_group [PARTS_ARRAY_NAME ].uri
982- partial_write_array_index_dir_uri = partial_write_array_group [INDEX_ARRAY_NAME ].uri
983- partial_write_array_index_group = tiledb .Group (partial_write_array_index_dir_uri )
1000+ partial_write_array_parts_uri = partial_write_array_group [
1001+ PARTS_ARRAY_NAME
1002+ ].uri
1003+ partial_write_array_index_dir_uri = partial_write_array_group [
1004+ INDEX_ARRAY_NAME
1005+ ].uri
1006+ partial_write_array_index_group = tiledb .Group (
1007+ partial_write_array_index_dir_uri
1008+ )
9841009 index_array_uri = group [INDEX_ARRAY_NAME ].uri
9851010 ids_array_uri = group [IDS_ARRAY_NAME ].uri
9861011 parts_array_uri = group [PARTS_ARRAY_NAME ].uri
@@ -1342,7 +1367,7 @@ def consolidate_and_vacuum(
13421367 if vfs .is_dir (partial_write_array_dir_uri ):
13431368 vfs .remove_dir (partial_write_array_dir_uri )
13441369
1345- with tiledb .scope_ctx (ctx_or_config = config ) as ctx :
1370+ with tiledb .scope_ctx (ctx_or_config = config ):
13461371 logger = setup (config , verbose )
13471372 logger .debug ("Ingesting Vectors into %r" , array_uri )
13481373 try :
@@ -1462,6 +1487,8 @@ def consolidate_and_vacuum(
14621487 consolidate_and_vacuum (array_uri = array_uri , config = config )
14631488
14641489 if index_type == "FLAT" :
1465- return FlatIndex (uri = array_uri , dtype = vector_type , ctx = ctx )
1490+ return FlatIndex (uri = array_uri , dtype = vector_type , config = config )
14661491 elif index_type == "IVF_FLAT" :
1467- return IVFFlatIndex (uri = array_uri , dtype = vector_type , memory_budget = 1000000 , ctx = ctx )
1492+ return IVFFlatIndex (
1493+ uri = array_uri , dtype = vector_type , memory_budget = 1000000 , config = config
1494+ )
0 commit comments