Skip to content

Commit d0f8405

Browse files
author
Nikos Papailiou
committed
WIP
1 parent d56902f commit d0f8405

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def ingest(
6969
import logging
7070
import math
7171
from typing import Any, Mapping, Optional
72+
import multiprocessing
7273

7374
import numpy as np
7475

@@ -832,11 +833,11 @@ def consolidate_partition_udf(
832833
):
833834
logger = setup(config, verbose)
834835
with tiledb.scope_ctx(ctx_or_config=config):
836+
logger.info(f"Consolidating partitions {partition_id_start}-{partition_id_end}")
835837
group = tiledb.Group(array_uri)
836838
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
837839
partial_write_array_ids_uri = partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
838840
partial_write_array_parts_uri = partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
839-
logger.info("Consolidating array")
840841
index_array_uri = group[INDEX_ARRAY_NAME].uri
841842
ids_array_uri = group[IDS_ARRAY_NAME].uri
842843
parts_array_uri = group[PARTS_ARRAY_NAME].uri
@@ -862,26 +863,36 @@ def consolidate_partition_udf(
862863
index_array = tiledb.open(index_array_uri, mode="r")
863864
ids_array = tiledb.open(ids_array_uri, mode="w")
864865
parts_array = tiledb.open(parts_array_uri, mode="w")
865-
read_slices = []
866-
for part in range(partition_id_start, partition_id_end):
867-
for partition_slice in partition_slices[part]:
868-
read_slices.append(partition_slice)
869-
870-
logger.debug(f"Read slices: {read_slices}")
871-
ids = partial_write_array_ids_array.multi_index[read_slices]["values"]
872-
vectors = partial_write_array_parts_array.multi_index[:, read_slices]["values"]
873-
start_pos = int(index_array[partition_id_start]["values"])
874-
end_pos = int(index_array[partition_id_end]["values"])
875-
876-
logger.debug(
877-
f"Ids shape {ids.shape}, expected size: {end_pos - start_pos} expected range:({start_pos},{end_pos})")
878-
if ids.shape[0] != end_pos - start_pos:
879-
raise ValueError("Incorrect partition size.")
880-
881-
logger.info(f"Writing data to array: {parts_array_uri}")
882-
parts_array[:, start_pos:end_pos] = vectors
883-
logger.info(f"Writing data to array: {ids_array_uri}")
884-
ids_array[start_pos:end_pos] = ids
866+
logger.info(
867+
f"Partitions start: {partition_id_start} end: {partition_id_end}"
868+
)
869+
for part in range(partition_id_start, partition_id_end, batch):
870+
part_end = part + batch
871+
if part_end > partition_id_end:
872+
part_end = partition_id_end
873+
logger.info(f"Consolidating partitions start: {part} end: {part_end}")
874+
read_slices = []
875+
for p in range(part, part_end):
876+
for partition_slice in partition_slices[p]:
877+
read_slices.append(partition_slice)
878+
879+
logger.debug(f"Read slices: {read_slices}")
880+
ids = partial_write_array_ids_array.multi_index[read_slices]["values"]
881+
vectors = partial_write_array_parts_array.multi_index[:, read_slices]["values"]
882+
start_pos = int(index_array[part]["values"])
883+
end_pos = int(index_array[part_end]["values"])
884+
885+
logger.debug(
886+
f"Ids shape {ids.shape}, expected size: {end_pos - start_pos} expected range:({start_pos},{end_pos})")
887+
if ids.shape[0] != end_pos - start_pos:
888+
raise ValueError("Incorrect partition size.")
889+
890+
logger.info(f"Writing data to array: {parts_array_uri}")
891+
parts_array[:, start_pos:end_pos] = vectors
892+
logger.info(f"Writing data to array: {ids_array_uri}")
893+
ids_array[start_pos:end_pos] = ids
894+
parts_array.close()
895+
ids_array.close()
885896

886897
# --------------------------------------------------------------------
887898
# DAG
@@ -918,13 +929,15 @@ def create_ingestion_dag(
918929
retry_policy="Always",
919930
),
920931
)
932+
threads = 8
921933
else:
922934
d = dag.DAG(
923935
name="vector-ingestion",
924936
mode=Mode.REALTIME,
925937
max_workers=workers,
926938
namespace="default",
927939
)
940+
threads = multiprocessing.cpu_count()
928941

929942
submit = d.submit_local
930943
if mode == Mode.BATCH or mode == Mode.REALTIME:
@@ -1024,12 +1037,12 @@ def create_ingestion_dag(
10241037
dimensions=dimensions,
10251038
vector_start_pos=start,
10261039
vector_end_pos=end,
1027-
threads=8,
1040+
threads=threads,
10281041
config=config,
10291042
verbose=verbose,
10301043
trace_id=trace_id,
10311044
name="k-means-part-" + str(task_id),
1032-
resources={"cpu": "8", "memory": "12Gi"},
1045+
resources={"cpu": str(threads), "memory": "12Gi"},
10331046
)
10341047
)
10351048
task_id += 1
@@ -1090,12 +1103,12 @@ def create_ingestion_dag(
10901103
start=start,
10911104
end=end,
10921105
batch=input_vectors_per_work_item,
1093-
threads=6,
1106+
threads=threads,
10941107
config=config,
10951108
verbose=verbose,
10961109
trace_id=trace_id,
10971110
name="ingest-" + str(task_id),
1098-
resources={"cpu": "6", "memory": "32Gi"},
1111+
resources={"cpu": str(threads), "memory": "32Gi"},
10991112
)
11001113
ingest_node.depends_on(centroids_node)
11011114
compute_indexes_node.depends_on(ingest_node)

0 commit comments

Comments
 (0)