Skip to content

Commit 9edde59

Browse files
author
Nikos Papailiou
committed
Add default image and fix udf resources
1 parent d2032d2 commit 9edde59

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Optional, Tuple
2+
from functools import partial
23

34
from tiledb.cloud.dag import Mode
45
from tiledb.vector_search.index import FlatIndex
@@ -49,7 +50,7 @@ def ingest(
4950
copy_centroids_uri: str
5051
TileDB array URI to copy centroids from,
5152
if not provided, centroids are build running kmeans
52-
training_sample_size: int = 1
53+
training_sample_size: int = -1
5354
vector sample size to train centroids with,
5455
if not provided, is auto-configured based on the dataset size
5556
workers: int = -1
@@ -86,9 +87,10 @@ def ingest(
8687
IDS_ARRAY_NAME = "ids.tdb"
8788
PARTS_ARRAY_NAME = "parts.tdb"
8889
PARTIAL_WRITE_ARRAY_DIR = "write_temp"
89-
VECTORS_PER_WORK_ITEM = 1000000
90+
VECTORS_PER_WORK_ITEM = 10000000
9091
MAX_TASKS_PER_STAGE = 100
9192
CENTRALISED_KMEANS_MAX_SAMPLE_SIZE = 1000000
93+
DEFAULT_IMG_NAME = "3.9-vectorsearch"
9294

9395
class SourceType(enum.Enum):
9496
"""SourceType of input vectors"""
@@ -959,6 +961,11 @@ def consolidate_partition_udf(
959961
# --------------------------------------------------------------------
960962
# DAG
961963
# --------------------------------------------------------------------
964+
def submit_local(d, func, *args, **kwargs):
965+
# Drop kwarg
966+
kwargs.pop("image_name", None)
967+
kwargs.pop("resources", None)
968+
return d.submit_local(func, *args, **kwargs)
962969

963970
def create_ingestion_dag(
964971
index_type: str,
@@ -1001,7 +1008,7 @@ def create_ingestion_dag(
10011008
)
10021009
threads = multiprocessing.cpu_count()
10031010

1004-
submit = d.submit_local
1011+
submit = partial(submit_local, d)
10051012
if mode == Mode.BATCH or mode == Mode.REALTIME:
10061013
submit = d.submit
10071014

@@ -1029,7 +1036,8 @@ def create_ingestion_dag(
10291036
verbose=verbose,
10301037
trace_id=trace_id,
10311038
name="ingest-" + str(task_id),
1032-
resources={"cpu": "6", "memory": "32Gi"},
1039+
resources={"cpu": str(threads), "memory": "8Gi"},
1040+
image_name=DEFAULT_IMG_NAME,
10331041
)
10341042
task_id += 1
10351043
return d
@@ -1044,6 +1052,7 @@ def create_ingestion_dag(
10441052
trace_id=trace_id,
10451053
name="copy-centroids",
10461054
resources={"cpu": "1", "memory": "2Gi"},
1055+
image_name=DEFAULT_IMG_NAME,
10471056
)
10481057
else:
10491058
if training_sample_size <= CENTRALISED_KMEANS_MAX_SAMPLE_SIZE:
@@ -1062,6 +1071,7 @@ def create_ingestion_dag(
10621071
trace_id=trace_id,
10631072
name="kmeans",
10641073
resources={"cpu": "8", "memory": "32Gi"},
1074+
image_name=DEFAULT_IMG_NAME,
10651075
)
10661076
else:
10671077
internal_centroids_node = submit(
@@ -1076,6 +1086,7 @@ def create_ingestion_dag(
10761086
trace_id=trace_id,
10771087
name="init-centroids",
10781088
resources={"cpu": "1", "memory": "1Gi"},
1089+
image_name=DEFAULT_IMG_NAME,
10791090
)
10801091

10811092
for it in range(5):
@@ -1105,6 +1116,7 @@ def create_ingestion_dag(
11051116
trace_id=trace_id,
11061117
name="k-means-part-" + str(task_id),
11071118
resources={"cpu": str(threads), "memory": "12Gi"},
1119+
image_name=DEFAULT_IMG_NAME,
11081120
)
11091121
)
11101122
task_id += 1
@@ -1116,13 +1128,15 @@ def create_ingestion_dag(
11161128
*kmeans_workers[i: i + 10],
11171129
name="update-centroids-" + str(i),
11181130
resources={"cpu": "1", "memory": "8Gi"},
1131+
image_name=DEFAULT_IMG_NAME,
11191132
)
11201133
)
11211134
internal_centroids_node = submit(
11221135
compute_new_centroids,
11231136
*reducers,
11241137
name="update-centroids",
11251138
resources={"cpu": "1", "memory": "8Gi"},
1139+
image_name=DEFAULT_IMG_NAME,
11261140
)
11271141
centroids_node = submit(
11281142
write_centroids,
@@ -1135,6 +1149,7 @@ def create_ingestion_dag(
11351149
trace_id=trace_id,
11361150
name="write-centroids",
11371151
resources={"cpu": "1", "memory": "2Gi"},
1152+
image_name=DEFAULT_IMG_NAME,
11381153
)
11391154

11401155
compute_indexes_node = submit(
@@ -1146,6 +1161,7 @@ def create_ingestion_dag(
11461161
trace_id=trace_id,
11471162
name="compute-indexes",
11481163
resources={"cpu": "1", "memory": "2Gi"},
1164+
image_name=DEFAULT_IMG_NAME,
11491165
)
11501166

11511167
task_id = 0
@@ -1170,7 +1186,8 @@ def create_ingestion_dag(
11701186
verbose=verbose,
11711187
trace_id=trace_id,
11721188
name="ingest-" + str(task_id),
1173-
resources={"cpu": str(threads), "memory": "32Gi"},
1189+
resources={"cpu": str(threads), "memory": "8Gi"},
1190+
image_name=DEFAULT_IMG_NAME,
11741191
)
11751192
ingest_node.depends_on(centroids_node)
11761193
compute_indexes_node.depends_on(ingest_node)
@@ -1196,7 +1213,8 @@ def create_ingestion_dag(
11961213
verbose=verbose,
11971214
trace_id=trace_id,
11981215
name="consolidate-partition-" + str(task_id),
1199-
resources={"cpu": "2", "memory": "24Gi"},
1216+
resources={"cpu": str(threads), "memory": "8Gi"},
1217+
image_name=DEFAULT_IMG_NAME,
12001218
)
12011219
consolidate_partition_node.depends_on(compute_indexes_node)
12021220
task_id += 1
@@ -1258,11 +1276,16 @@ def consolidate_and_vacuum(
12581276
logger.info(f"Vector dimension type {vector_type}")
12591277
if partitions == -1:
12601278
partitions = int(math.sqrt(size))
1279+
if training_sample_size == -1:
1280+
training_sample_size = min(size, 100 * partitions)
12611281
if mode == Mode.BATCH:
12621282
if workers == -1:
12631283
workers = 10
12641284
else:
12651285
workers = 1
1286+
logger.info(f"Partitions {partitions}")
1287+
logger.info(f"Training sample size {training_sample_size}")
1288+
logger.info(f"Number of workers {workers}")
12661289

12671290
if input_vectors_per_work_item == -1:
12681291
input_vectors_per_work_item = VECTORS_PER_WORK_ITEM

apis/python/test/test_ingestion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from tiledb.vector_search.ingestion import ingest
44

5+
from tiledb.cloud.dag import Mode
6+
57

68
def test_flat_ingestion_u8(tmp_path):
79
dataset_dir = os.path.join(tmp_path, "dataset")

0 commit comments

Comments
 (0)