Skip to content

Commit b0f664c

Browse files
Use unique path for temp_data folder (#357)
Vector ingestion is using a `temp_data` group that can be used to store any temporary data that are needed during the ingestion processing. This PR adds a random string in the `temp_data` uri in order for different ingestion jobs to not have conflict with each other. Using one `temp_data` group for all ingestion jobs was causing errors in consecutive ingestion jobs when a previous ingestion job failed or didn't cleanup its temporary data.
1 parent 6936b77 commit b0f664c

File tree

5 files changed

+200
-178
lines changed

5 files changed

+200
-178
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 80 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def ingest(
174174
import math
175175
import multiprocessing
176176
import os
177+
import random
178+
import string
177179
import time
178180
from typing import Any, Mapping
179181

@@ -280,9 +282,11 @@ def ingest(
280282
EXTERNAL_IDS_ARRAY_NAME = storage_formats[storage_version][
281283
"EXTERNAL_IDS_ARRAY_NAME"
282284
]
283-
PARTIAL_WRITE_ARRAY_DIR = storage_formats[storage_version][
284-
"PARTIAL_WRITE_ARRAY_DIR"
285-
]
285+
PARTIAL_WRITE_ARRAY_DIR = (
286+
storage_formats[storage_version]["PARTIAL_WRITE_ARRAY_DIR"]
287+
+ "_"
288+
+ "".join(random.choices(string.ascii_letters, k=10))
289+
)
286290
DEFAULT_ATTR_FILTERS = storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"]
287291
VECTORS_PER_WORK_ITEM = 20000000
288292
VECTORS_PER_SAMPLE_WORK_ITEM = 1000000
@@ -535,36 +539,32 @@ def write_external_ids(
535539

536540
return external_ids_array_uri
537541

542+
def create_temp_data_group(
543+
group: tiledb.Group,
544+
) -> tiledb.Group:
545+
partial_write_array_dir_uri = f"{group.uri}/{PARTIAL_WRITE_ARRAY_DIR}"
546+
try:
547+
tiledb.group_create(partial_write_array_dir_uri)
548+
add_to_group(group, partial_write_array_dir_uri, PARTIAL_WRITE_ARRAY_DIR)
549+
except tiledb.TileDBError as err:
550+
message = str(err)
551+
if "already exists" not in message:
552+
raise err
553+
return tiledb.Group(partial_write_array_dir_uri, "w")
554+
538555
def create_partial_write_array_group(
539-
index_group_uri: str,
556+
temp_data_group: tiledb.Group,
540557
vector_type: np.dtype,
541558
dimensions: int,
542559
filters: Any,
543560
create_index_array: bool,
544-
) -> (tiledb.Group, str):
545-
group = tiledb.Group(index_group_uri, "w")
561+
) -> str:
546562
tile_size = int(
547563
ivf_flat_index.TILE_SIZE_BYTES / np.dtype(vector_type).itemsize / dimensions
548564
)
549-
partial_write_array_dir_uri = f"{group.uri}/{PARTIAL_WRITE_ARRAY_DIR}"
550-
partial_write_array_index_uri = (
551-
f"{partial_write_array_dir_uri}/{INDEX_ARRAY_NAME}"
552-
)
553-
partial_write_array_ids_uri = f"{partial_write_array_dir_uri}/{IDS_ARRAY_NAME}"
554-
partial_write_array_parts_uri = (
555-
f"{partial_write_array_dir_uri}/{PARTS_ARRAY_NAME}"
556-
)
557-
558-
try:
559-
tiledb.group_create(partial_write_array_dir_uri)
560-
except tiledb.TileDBError as err:
561-
message = str(err)
562-
if "already exists" in message:
563-
logger.debug(f"Group '{partial_write_array_dir_uri}' already exists")
564-
raise err
565-
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri, "w")
566-
add_to_group(group, partial_write_array_dir_uri, PARTIAL_WRITE_ARRAY_DIR)
567-
565+
partial_write_array_index_uri = f"{temp_data_group.uri}/{INDEX_ARRAY_NAME}"
566+
partial_write_array_ids_uri = f"{temp_data_group.uri}/{IDS_ARRAY_NAME}"
567+
partial_write_array_parts_uri = f"{temp_data_group.uri}/{PARTS_ARRAY_NAME}"
568568
if create_index_array:
569569
try:
570570
tiledb.group_create(partial_write_array_index_uri)
@@ -576,7 +576,7 @@ def create_partial_write_array_group(
576576
)
577577
raise err
578578
add_to_group(
579-
partial_write_array_group,
579+
temp_data_group,
580580
partial_write_array_index_uri,
581581
INDEX_ARRAY_NAME,
582582
)
@@ -606,7 +606,7 @@ def create_partial_write_array_group(
606606
logger.debug(ids_schema)
607607
tiledb.Array.create(partial_write_array_ids_uri, ids_schema)
608608
add_to_group(
609-
partial_write_array_group,
609+
temp_data_group,
610610
partial_write_array_ids_uri,
611611
IDS_ARRAY_NAME,
612612
)
@@ -638,18 +638,17 @@ def create_partial_write_array_group(
638638
logger.debug(partial_write_array_parts_uri)
639639
tiledb.Array.create(partial_write_array_parts_uri, parts_schema)
640640
add_to_group(
641-
partial_write_array_group,
641+
temp_data_group,
642642
partial_write_array_parts_uri,
643643
PARTS_ARRAY_NAME,
644644
)
645-
group.close()
646-
return partial_write_array_group, partial_write_array_index_uri
645+
return partial_write_array_index_uri
647646

648647
def create_arrays(
649648
group: tiledb.Group,
649+
temp_data_group: tiledb.Group,
650650
arrays_created: bool,
651651
index_type: str,
652-
size: int,
653652
dimensions: int,
654653
input_vectors_work_items: int,
655654
vector_type: np.dtype,
@@ -676,11 +675,8 @@ def create_arrays(
676675
config=config,
677676
storage_version=storage_version,
678677
)
679-
(
680-
partial_write_array_group,
681-
partial_write_array_index_uri,
682-
) = create_partial_write_array_group(
683-
index_group_uri=group.uri,
678+
partial_write_array_index_uri = create_partial_write_array_group(
679+
temp_data_group=temp_data_group,
684680
vector_type=vector_type,
685681
dimensions=dimensions,
686682
filters=DEFAULT_ATTR_FILTERS,
@@ -748,7 +744,6 @@ def create_arrays(
748744
add_to_group(
749745
partial_write_array_index_group, part_index_uri, "additions"
750746
)
751-
partial_write_array_group.close()
752747
partial_write_array_index_group.close()
753748

754749
# Note that we don't create type-erased indexes (i.e. Vamana) here. Instead we create them
@@ -1534,21 +1529,20 @@ def ingest_vamana(
15341529
trace_id=trace_id,
15351530
)
15361531

1537-
partial_write_array_group, _ = create_partial_write_array_group(
1538-
index_group_uri=index_group_uri,
1532+
temp_data_group_uri = f"{index_group_uri}/{PARTIAL_WRITE_ARRAY_DIR}"
1533+
temp_data_group = tiledb.Group(temp_data_group_uri, "w")
1534+
create_partial_write_array_group(
1535+
temp_data_group=temp_data_group,
15391536
vector_type=vector_type,
15401537
dimensions=dimensions,
15411538
filters=storage_formats[storage_version]["DEFAULT_ATTR_FILTERS"],
15421539
create_index_array=False,
15431540
)
1544-
partial_write_array_group.close()
1545-
1546-
group = tiledb.Group(index_group_uri, mode="r")
1547-
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
1548-
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
1549-
ids_array_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
1550-
parts_array_uri = partial_write_array_group[PARTS_ARRAY_NAME].uri
1551-
group.close()
1541+
temp_data_group.close()
1542+
temp_data_group = tiledb.Group(temp_data_group_uri)
1543+
ids_array_uri = temp_data_group[IDS_ARRAY_NAME].uri
1544+
parts_array_uri = temp_data_group[PARTS_ARRAY_NAME].uri
1545+
temp_data_group.close()
15521546

15531547
parts_array = tiledb.open(
15541548
parts_array_uri, mode="w", timestamp=index_timestamp
@@ -2612,6 +2606,8 @@ def consolidate_and_vacuum(
26122606
partitions = int(group.meta.get("partitions", "-1"))
26132607

26142608
previous_ingestion_timestamp = 0
2609+
if index_timestamp is None:
2610+
index_timestamp = int(time.time() * 1000)
26152611
if len(ingestion_timestamps) > 0:
26162612
previous_ingestion_timestamp = ingestion_timestamps[
26172613
len(ingestion_timestamps) - 1
@@ -2626,28 +2622,6 @@ def consolidate_and_vacuum(
26262622
)
26272623

26282624
group.close()
2629-
group = tiledb.Group(index_group_uri, "w")
2630-
2631-
if training_input_vectors is not None:
2632-
training_source_uri = write_input_vectors(
2633-
group=group,
2634-
input_vectors=training_input_vectors,
2635-
size=training_input_vectors.shape[0],
2636-
dimensions=training_input_vectors.shape[1],
2637-
vector_type=training_input_vectors.dtype,
2638-
array_name=TRAINING_INPUT_VECTORS_ARRAY_NAME,
2639-
)
2640-
training_source_type = "TILEDB_ARRAY"
2641-
2642-
if input_vectors is not None:
2643-
source_uri = write_input_vectors(
2644-
group=group,
2645-
input_vectors=input_vectors,
2646-
size=in_size,
2647-
dimensions=dimensions,
2648-
vector_type=vector_type,
2649-
array_name=INPUT_VECTORS_ARRAY_NAME,
2650-
)
26512625

26522626
if size == -1:
26532627
size = int(in_size)
@@ -2679,17 +2653,6 @@ def consolidate_and_vacuum(
26792653
)
26802654
logger.debug("Number of workers %d", workers)
26812655

2682-
if external_ids is not None:
2683-
external_ids_uri = write_external_ids(
2684-
group=group,
2685-
external_ids=external_ids,
2686-
size=size,
2687-
partitions=partitions,
2688-
)
2689-
external_ids_type = "TILEDB_ARRAY"
2690-
else:
2691-
if external_ids_type is None:
2692-
external_ids_type = "U64BIN"
26932656
# Compute task parameters for main ingestion.
26942657
if input_vectors_per_work_item == -1:
26952658
input_vectors_per_work_item = VECTORS_PER_WORK_ITEM
@@ -2769,17 +2732,53 @@ def consolidate_and_vacuum(
27692732
)
27702733

27712734
logger.debug("Creating arrays")
2735+
group = tiledb.Group(index_group_uri, "w")
2736+
temp_data_group = create_temp_data_group(group=group)
27722737
create_arrays(
27732738
group=group,
2739+
temp_data_group=temp_data_group,
27742740
arrays_created=arrays_created,
27752741
index_type=index_type,
2776-
size=size,
27772742
dimensions=dimensions,
27782743
input_vectors_work_items=input_vectors_work_items,
27792744
vector_type=vector_type,
27802745
logger=logger,
27812746
storage_version=storage_version,
27822747
)
2748+
2749+
if training_input_vectors is not None:
2750+
training_source_uri = write_input_vectors(
2751+
group=temp_data_group,
2752+
input_vectors=training_input_vectors,
2753+
size=training_input_vectors.shape[0],
2754+
dimensions=training_input_vectors.shape[1],
2755+
vector_type=training_input_vectors.dtype,
2756+
array_name=TRAINING_INPUT_VECTORS_ARRAY_NAME,
2757+
)
2758+
training_source_type = "TILEDB_ARRAY"
2759+
2760+
if input_vectors is not None:
2761+
source_uri = write_input_vectors(
2762+
group=temp_data_group,
2763+
input_vectors=input_vectors,
2764+
size=in_size,
2765+
dimensions=dimensions,
2766+
vector_type=vector_type,
2767+
array_name=INPUT_VECTORS_ARRAY_NAME,
2768+
)
2769+
2770+
if external_ids is not None:
2771+
external_ids_uri = write_external_ids(
2772+
group=temp_data_group,
2773+
external_ids=external_ids,
2774+
size=size,
2775+
partitions=partitions,
2776+
)
2777+
external_ids_type = "TILEDB_ARRAY"
2778+
else:
2779+
if external_ids_type is None:
2780+
external_ids_type = "U64BIN"
2781+
temp_data_group.close()
27832782
group.meta["temp_size"] = size
27842783
group.close()
27852784

@@ -2836,8 +2835,6 @@ def consolidate_and_vacuum(
28362835
# For type-erased indexes (i.e. Vamana), we update this metadata in the write_index()
28372836
# call during create_ingestion_dag(), so don't do it here.
28382837
group = tiledb.Group(index_group_uri, "w")
2839-
if index_timestamp is None:
2840-
index_timestamp = int(time.time() * 1000)
28412838
ingestion_timestamps.append(index_timestamp)
28422839
base_sizes.append(temp_size)
28432840
partition_history.append(partitions)

apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
def ingest_embeddings_with_driver(
88
object_index_uri: str,
99
use_updates_array: bool,
10+
embeddings_array_uri: str = None,
1011
metadata_array_uri: str = None,
1112
index_timestamp: int = None,
1213
workers: int = -1,
@@ -31,6 +32,7 @@ def ingest_embeddings_with_driver(
3132
def ingest_embeddings(
3233
object_index_uri: str,
3334
use_updates_array: bool,
35+
embeddings_array_uri: str = None,
3436
metadata_array_uri: str = None,
3537
index_timestamp: int = None,
3638
workers: int = -1,
@@ -81,7 +83,6 @@ def install_extra_worker_modules():
8183
from tiledb.vector_search import ingest
8284
from tiledb.vector_search.object_api import ObjectIndex
8385
from tiledb.vector_search.object_readers import ObjectPartition
84-
from tiledb.vector_search.storage_formats import storage_formats
8586

8687
MAX_TASKS_PER_STAGE = 100
8788
DEFAULT_IMG_NAME = "3.9-vectorsearch"
@@ -125,6 +126,7 @@ def compute_embeddings_udf(
125126
object_index_uri: str,
126127
partition_dicts: List[Dict],
127128
use_updates_array: bool,
129+
embeddings_array_uri: str = None,
128130
metadata_array_uri: str = None,
129131
index_timestamp: int = None,
130132
verbose: bool = False,
@@ -155,7 +157,6 @@ def install_extra_driver_modules():
155157

156158
import tiledb
157159
from tiledb.vector_search.object_api import ObjectIndex
158-
from tiledb.vector_search.storage_formats import storage_formats
159160

160161
def instantiate_object(code, class_name, **kwargs):
161162
import importlib.util
@@ -199,10 +200,6 @@ def instantiate_object(code, class_name, **kwargs):
199200
vector_type = object_embedding.vector_type()
200201

201202
if not use_updates_array:
202-
embeddings_array_name = storage_formats[
203-
obj_index.index.storage_version
204-
]["INPUT_VECTORS_ARRAY_NAME"]
205-
embeddings_array_uri = f"{obj_index.uri}/{embeddings_array_name}"
206203
logger.debug("embeddings_uri %s", embeddings_array_uri)
207204
embeddings_array = tiledb.open(
208205
embeddings_array_uri, "w", timestamp=index_timestamp
@@ -276,6 +273,7 @@ def create_dag(
276273
partitions: List[ObjectPartition],
277274
object_partitions_per_worker: int,
278275
object_work_tasks: int,
276+
embeddings_array_uri: str = None,
279277
metadata_array_uri: str = None,
280278
index_timestamp: int = None,
281279
workers: int = -1,
@@ -345,6 +343,7 @@ def create_dag(
345343
object_index_uri=obj_index.uri,
346344
partition_dicts=partition_dicts,
347345
use_updates_array=use_updates_array,
346+
embeddings_array_uri=embeddings_array_uri,
348347
metadata_array_uri=metadata_array_uri,
349348
index_timestamp=index_timestamp,
350349
verbose=verbose,
@@ -417,6 +416,7 @@ def create_dag(
417416
partitions=partitions,
418417
object_partitions_per_worker=object_partitions_per_worker,
419418
object_work_tasks=object_work_tasks,
419+
embeddings_array_uri=embeddings_array_uri,
420420
metadata_array_uri=metadata_array_uri,
421421
index_timestamp=index_timestamp,
422422
workers=workers,
@@ -442,10 +442,6 @@ def create_dag(
442442
**kwargs,
443443
)
444444
else:
445-
embeddings_array_name = storage_formats[
446-
obj_index.index.storage_version
447-
]["INPUT_VECTORS_ARRAY_NAME"]
448-
embeddings_array_uri = f"{obj_index.uri}/{embeddings_array_name}"
449445
obj_index.index = ingest(
450446
index_type=obj_index.index_type,
451447
index_uri=obj_index.uri,
@@ -500,6 +496,7 @@ def submit_local(d, func, *args, **kwargs):
500496
ingest_embeddings,
501497
object_index_uri=object_index_uri,
502498
use_updates_array=use_updates_array,
499+
embeddings_array_uri=embeddings_array_uri,
503500
metadata_array_uri=metadata_array_uri,
504501
index_timestamp=index_timestamp,
505502
max_tasks_per_stage=max_tasks_per_stage,

0 commit comments

Comments
 (0)