Skip to content

Commit 2ad1ae4

Browse files
Merge pull request #156 from TileDB-Inc/npapa/fix-ingestion
Fix corner cases for ingestion
2 parents bcfdaa1 + a331d3b commit 2ad1ae4

File tree

2 files changed

+102
-22
lines changed

2 files changed

+102
-22
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def ingest(
2929
training_sample_size: int = -1,
3030
workers: int = -1,
3131
input_vectors_per_work_item: int = -1,
32+
max_tasks_per_stage: int= -1,
3233
storage_version: str = STORAGE_VERSION,
3334
verbose: bool = False,
3435
trace_id: Optional[str] = None,
@@ -83,6 +84,9 @@ def ingest(
8384
input_vectors_per_work_item: int = -1
8485
number of vectors per ingestion work item,
8586
if not provided, is auto-configured
87+
max_tasks_per_stage: int = -1
88+
Max number of tasks per execution stage of ingestion,
89+
if not provided, is auto-configured
8690
storage_version: str
8791
Vector index storage format version.
8892
verbose: bool
@@ -348,7 +352,7 @@ def create_arrays(
348352
index_type: str,
349353
size: int,
350354
dimensions: int,
351-
input_vectors_work_tasks: int,
355+
input_vectors_work_items: int,
352356
vector_type: np.dtype,
353357
logger: logging.Logger,
354358
) -> None:
@@ -475,10 +479,7 @@ def create_arrays(
475479
partial_write_array_group.add(
476480
partial_write_array_parts_uri, name=PARTS_ARRAY_NAME
477481
)
478-
partial_write_arrays = input_vectors_work_tasks
479-
if updates_uri is not None:
480-
partial_write_arrays += 1
481-
for part in range(partial_write_arrays):
482+
for part in range(input_vectors_work_items):
482483
part_index_uri = partial_write_array_index_uri + "/" + str(part)
483484
if not tiledb.array_exists(part_index_uri):
484485
logger.debug(f"Creating part array {part_index_uri}")
@@ -505,6 +506,33 @@ def create_arrays(
505506
logger.debug(index_schema)
506507
tiledb.Array.create(part_index_uri, index_schema)
507508
partial_write_array_index_group.add(part_index_uri, name=str(part))
509+
if updates_uri is not None:
510+
part_index_uri = partial_write_array_index_uri + "/additions"
511+
if not tiledb.array_exists(part_index_uri):
512+
logger.debug(f"Creating part array {part_index_uri}")
513+
index_array_rows_dim = tiledb.Dim(
514+
name="rows",
515+
domain=(0, partitions),
516+
tile=partitions,
517+
dtype=np.dtype(np.int32),
518+
)
519+
index_array_dom = tiledb.Domain(index_array_rows_dim)
520+
index_attr = tiledb.Attr(
521+
name="values",
522+
dtype=np.dtype(np.uint64),
523+
filters=DEFAULT_ATTR_FILTERS,
524+
)
525+
index_schema = tiledb.ArraySchema(
526+
domain=index_array_dom,
527+
sparse=False,
528+
attrs=[index_attr],
529+
capacity=partitions,
530+
cell_order="col-major",
531+
tile_order="col-major",
532+
)
533+
logger.debug(index_schema)
534+
tiledb.Array.create(part_index_uri, index_schema)
535+
partial_write_array_index_group.add(part_index_uri, name="additions")
508536
partial_write_array_group.close()
509537
partial_write_array_index_group.close()
510538

@@ -1098,7 +1126,7 @@ def ingest_vectors_udf(
10981126
part_name = str(part) + "-" + str(part_end)
10991127

11001128
partial_write_array_index_uri = partial_write_array_index_group[
1101-
str(int(start / batch))
1129+
str(int(part / batch))
11021130
].uri
11031131
logger.debug("Input vectors start_pos: %d, end_pos: %d", part, part_end)
11041132
updated_ids = read_updated_ids(
@@ -1203,7 +1231,6 @@ def ingest_additions_udf(
12031231
updates_uri: str,
12041232
vector_type: np.dtype,
12051233
write_offset: int,
1206-
task_id: int,
12071234
threads: int,
12081235
config: Optional[Mapping[str, Any]] = None,
12091236
verbose: bool = False,
@@ -1228,7 +1255,7 @@ def ingest_additions_udf(
12281255
partial_write_array_index_dir_uri
12291256
)
12301257
partial_write_array_index_uri = partial_write_array_index_group[
1231-
str(task_id)
1258+
"additions"
12321259
].uri
12331260
additions_vectors, additions_external_ids = read_additions(
12341261
updates_uri=updates_uri,
@@ -1678,7 +1705,6 @@ def create_ingestion_dag(
16781705
updates_uri=updates_uri,
16791706
vector_type=vector_type,
16801707
write_offset=size,
1681-
task_id=task_id,
16821708
threads=threads,
16831709
config=config,
16841710
verbose=verbose,
@@ -1744,15 +1770,22 @@ def consolidate_and_vacuum(
17441770
tiledb.vacuum(parts_uri, config=conf)
17451771
tiledb.consolidate(ids_uri, config=conf)
17461772
tiledb.vacuum(ids_uri, config=conf)
1773+
group.close()
17471774

17481775
# TODO remove temp data for tiledb URIs
17491776
if not index_group_uri.startswith("tiledb://"):
1750-
vfs = tiledb.VFS(config)
1751-
partial_write_array_dir_uri = (
1752-
index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1753-
)
1754-
if vfs.is_dir(partial_write_array_dir_uri):
1755-
vfs.remove_dir(partial_write_array_dir_uri)
1777+
group = tiledb.Group(index_group_uri, "r")
1778+
if PARTIAL_WRITE_ARRAY_DIR in group:
1779+
group.close()
1780+
group = tiledb.Group(index_group_uri, "w")
1781+
group.remove(PARTIAL_WRITE_ARRAY_DIR)
1782+
vfs = tiledb.VFS(config)
1783+
partial_write_array_dir_uri = (
1784+
index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1785+
)
1786+
if vfs.is_dir(partial_write_array_dir_uri):
1787+
vfs.remove_dir(partial_write_array_dir_uri)
1788+
group.close()
17561789

17571790
# --------------------------------------------------------------------
17581791
# End internal function definitions
@@ -1852,11 +1885,13 @@ def consolidate_and_vacuum(
18521885
input_vectors_work_items = int(math.ceil(size / input_vectors_per_work_item))
18531886
input_vectors_work_tasks = input_vectors_work_items
18541887
input_vectors_work_items_per_worker = 1
1855-
if input_vectors_work_tasks > MAX_TASKS_PER_STAGE:
1888+
if max_tasks_per_stage == -1:
1889+
max_tasks_per_stage = MAX_TASKS_PER_STAGE
1890+
if input_vectors_work_tasks > max_tasks_per_stage:
18561891
input_vectors_work_items_per_worker = int(
1857-
math.ceil(input_vectors_work_items / MAX_TASKS_PER_STAGE)
1892+
math.ceil(input_vectors_work_items / max_tasks_per_stage)
18581893
)
1859-
input_vectors_work_tasks = MAX_TASKS_PER_STAGE
1894+
input_vectors_work_tasks = max_tasks_per_stage
18601895
logger.debug("input_vectors_per_work_item %d", input_vectors_per_work_item)
18611896
logger.debug("input_vectors_work_items %d", input_vectors_work_items)
18621897
logger.debug("input_vectors_work_tasks %d", input_vectors_work_tasks)
@@ -1875,11 +1910,11 @@ def consolidate_and_vacuum(
18751910
)
18761911
table_partitions_work_tasks = table_partitions_work_items
18771912
table_partitions_work_items_per_worker = 1
1878-
if table_partitions_work_tasks > MAX_TASKS_PER_STAGE:
1913+
if table_partitions_work_tasks > max_tasks_per_stage:
18791914
table_partitions_work_items_per_worker = int(
1880-
math.ceil(table_partitions_work_items / MAX_TASKS_PER_STAGE)
1915+
math.ceil(table_partitions_work_items / max_tasks_per_stage)
18811916
)
1882-
table_partitions_work_tasks = MAX_TASKS_PER_STAGE
1917+
table_partitions_work_tasks = max_tasks_per_stage
18831918
logger.debug(
18841919
"table_partitions_per_work_item %d", table_partitions_per_work_item
18851920
)
@@ -1897,7 +1932,7 @@ def consolidate_and_vacuum(
18971932
index_type=index_type,
18981933
size=size,
18991934
dimensions=dimensions,
1900-
input_vectors_work_tasks=input_vectors_work_tasks,
1935+
input_vectors_work_items=input_vectors_work_items,
19011936
vector_type=vector_type,
19021937
logger=logger,
19031938
)

apis/python/test/test_ingestion.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,51 @@ def test_ivf_flat_ingestion_numpy(tmp_path):
262262
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
263263

264264

265+
def test_ivf_flat_ingestion_multiple_workers(tmp_path):
266+
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
267+
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"
268+
gt_uri = "test/data/siftsmall/siftsmall_groundtruth.ivecs"
269+
index_uri = os.path.join(tmp_path, "array")
270+
k = 100
271+
partitions = 100
272+
nqueries = 100
273+
nprobe = 20
274+
275+
query_vectors = load_fvecs(queries_uri)
276+
gt_i, gt_d = get_groundtruth_ivec(gt_uri, k=k, nqueries=nqueries)
277+
278+
index = ingest(
279+
index_type="IVF_FLAT",
280+
index_uri=index_uri,
281+
source_uri=source_uri,
282+
partitions=partitions,
283+
input_vectors_per_work_item=421,
284+
max_tasks_per_stage=4,
285+
)
286+
_, result = index.query(query_vectors, k=k, nprobe=nprobe)
287+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
288+
289+
# Test single query vector handling
290+
_, result1 = index.query(query_vectors[10], k=k, nprobe=nprobe)
291+
assert accuracy(result1, np.array([gt_i[10]])) > MINIMUM_ACCURACY
292+
293+
index_ram = IVFFlatIndex(uri=index_uri)
294+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe)
295+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
296+
297+
_, result = index_ram.query(
298+
query_vectors,
299+
k=k,
300+
nprobe=nprobe,
301+
use_nuv_implementation=True,
302+
)
303+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
304+
305+
# NB: local mode currently does not return distances
306+
_, result = index_ram.query(query_vectors, k=k, nprobe=nprobe, mode=Mode.LOCAL)
307+
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
308+
309+
265310
def test_ivf_flat_ingestion_external_ids_numpy(tmp_path):
266311
source_uri = "test/data/siftsmall/siftsmall_base.fvecs"
267312
queries_uri = "test/data/siftsmall/siftsmall_query.fvecs"

0 commit comments

Comments
 (0)