Skip to content

Commit b9eb281

Browse files
Support updates for object API (#310)
This adds support for updates in the object API ingestion. In more detail this adds support for: - Initial ingestion that can support large inputs with distributed processing. This uses the `TILEDB_PARTITIONED_ARRAY` for input. - Appending to an existing index using the index batch update functionality. - Updating the reader of an existing index.
1 parent 0490b47 commit b9eb281

File tree

4 files changed

+347
-221
lines changed

4 files changed

+347
-221
lines changed

apis/python/src/tiledb/vector_search/object_api/embeddings_ingestion.py

Lines changed: 84 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def ingest_embeddings_with_driver(
88
object_index_uri: str,
9-
embeddings_uri: str,
9+
use_updates_array: bool,
1010
metadata_array_uri: str = None,
1111
index_timestamp: int = None,
1212
workers: int = -1,
@@ -30,7 +30,7 @@ def ingest_embeddings_with_driver(
3030
):
3131
def ingest_embeddings(
3232
object_index_uri: str,
33-
embeddings_uri: str,
33+
use_updates_array: bool,
3434
metadata_array_uri: str = None,
3535
index_timestamp: int = None,
3636
workers: int = -1,
@@ -81,6 +81,7 @@ def install_extra_worker_modules():
8181
from tiledb.vector_search import ingest
8282
from tiledb.vector_search.object_api import ObjectIndex
8383
from tiledb.vector_search.object_readers import ObjectPartition
84+
from tiledb.vector_search.storage_formats import storage_formats
8485

8586
MAX_TASKS_PER_STAGE = 100
8687
DEFAULT_IMG_NAME = "3.9-vectorsearch"
@@ -121,14 +122,9 @@ def setup(
121122
# UDFs
122123
# --------------------------------------------------------------------
123124
def compute_embeddings_udf(
124-
object_reader_source_code: str,
125-
object_reader_class_name: str,
126-
object_reader_kwargs: Dict,
127-
object_embedding_source_code: str,
128-
object_embedding_class_name: str,
129-
object_embedding_kwargs: Dict,
125+
object_index_uri: str,
130126
partition_dicts: List[Dict],
131-
embeddings_uri: str,
127+
use_updates_array: bool,
132128
metadata_array_uri: str = None,
133129
index_timestamp: int = None,
134130
verbose: bool = False,
@@ -182,16 +178,15 @@ def instantiate_object(code, class_name, **kwargs):
182178
return class_(**kwargs)
183179

184180
logger = setup(config, verbose)
185-
object_reader = instantiate_object(
186-
code=object_reader_source_code,
187-
class_name=object_reader_class_name,
188-
**object_reader_kwargs,
189-
)
190-
object_embedding = instantiate_object(
191-
code=object_embedding_source_code,
192-
class_name=object_embedding_class_name,
193-
**object_embedding_kwargs,
181+
obj_index = object_index.ObjectIndex(
182+
object_index_uri,
183+
config=config,
184+
environment_variables=environment_variables,
185+
load_embedding=False,
186+
load_metadata_in_memory=False,
194187
)
188+
object_reader = obj_index.object_reader
189+
object_embedding = obj_index.embedding
195190
for var, val in environment_variables.items():
196191
os.environ[var] = val
197192
with tiledb.scope_ctx(ctx_or_config=config):
@@ -201,17 +196,23 @@ def instantiate_object(code, class_name, **kwargs):
201196
object_embedding.dimensions()
202197
vector_type = object_embedding.vector_type()
203198

204-
logger.debug("embeddings_uri %s", embeddings_uri)
205-
embeddings_array = tiledb.open(
206-
embeddings_uri, "w", timestamp=index_timestamp
207-
)
199+
if not use_updates_array:
200+
embeddings_array_name = storage_formats[
201+
obj_index.index.storage_version
202+
]["INPUT_VECTORS_ARRAY_NAME"]
203+
embeddings_array_uri = f"{obj_index.uri}/{embeddings_array_name}"
204+
logger.debug("embeddings_uri %s", embeddings_array_uri)
205+
embeddings_array = tiledb.open(
206+
embeddings_array_uri, "w", timestamp=index_timestamp
207+
)
208+
208209
if metadata_array_uri is not None:
209210
metadata_array = tiledb.open(
210211
metadata_array_uri, "w", timestamp=index_timestamp
211212
)
212213
for partition_dict in partition_dicts:
213214
partition = instantiate_object(
214-
code=object_reader_source_code,
215+
code=obj_index.object_reader_source_code,
215216
class_name=object_reader.partition_class_name(),
216217
**partition_dict,
217218
)
@@ -224,22 +225,36 @@ def instantiate_object(code, class_name, **kwargs):
224225
embeddings = object_embedding.embed(objects, metadata)
225226

226227
logger.debug("Write embeddings partition_id: %d", partition_id)
227-
embeddings_flattened = np.empty(1, dtype="O")
228-
embeddings_flattened[0] = embeddings.astype(vector_type).flatten()
229-
embeddings_shape = np.empty(1, dtype="O")
230-
embeddings_shape[0] = np.array(embeddings.shape, dtype=np.uint32)
231-
external_ids = np.empty(1, dtype="O")
232-
external_ids[0] = objects["external_id"].astype(np.uint64)
233-
embeddings_array[partition_id] = {
234-
"vectors": embeddings_flattened,
235-
"vectors_shape": embeddings_shape,
236-
"external_ids": external_ids,
237-
}
228+
if use_updates_array:
229+
vectors = np.empty(embeddings.shape[0], dtype="O")
230+
for i in range(embeddings.shape[0]):
231+
vectors[i] = embeddings[i].astype(vector_type)
232+
obj_index.index.update_batch(
233+
vectors=vectors,
234+
external_ids=objects["external_id"].astype(np.uint64),
235+
)
236+
else:
237+
embeddings_flattened = np.empty(1, dtype="O")
238+
embeddings_flattened[0] = embeddings.astype(
239+
vector_type
240+
).flatten()
241+
embeddings_shape = np.empty(1, dtype="O")
242+
embeddings_shape[0] = np.array(
243+
embeddings.shape, dtype=np.uint32
244+
)
245+
external_ids = np.empty(1, dtype="O")
246+
external_ids[0] = objects["external_id"].astype(np.uint64)
247+
embeddings_array[partition_id] = {
248+
"vectors": embeddings_flattened,
249+
"vectors_shape": embeddings_shape,
250+
"external_ids": external_ids,
251+
}
238252
if metadata_array_uri is not None:
239253
external_ids = metadata.pop("external_id", None)
240254
metadata_array[external_ids] = metadata
241255

242-
embeddings_array.close()
256+
if not use_updates_array:
257+
embeddings_array.close()
243258
if metadata_array_uri is not None:
244259
metadata_array.close()
245260

@@ -254,8 +269,8 @@ def submit_local(d, func, *args, **kwargs):
254269
return d.submit_local(func, *args, **kwargs)
255270

256271
def create_dag(
257-
ob_index: ObjectIndex,
258-
embeddings_uri: str,
272+
obj_index: ObjectIndex,
273+
use_updates_array: bool,
259274
partitions: List[ObjectPartition],
260275
object_partitions_per_worker: int,
261276
object_work_tasks: int,
@@ -322,14 +337,9 @@ def create_dag(
322337
] = worker_access_credentials_name
323338
submit(
324339
compute_embeddings_udf,
325-
object_reader_source_code=ob_index.object_reader_source_code,
326-
object_reader_class_name=ob_index.object_reader_class_name,
327-
object_reader_kwargs=ob_index.object_reader_kwargs,
328-
object_embedding_source_code=ob_index.embedding_source_code,
329-
object_embedding_class_name=ob_index.embedding_class_name,
330-
object_embedding_kwargs=ob_index.embedding_kwargs,
340+
object_index_uri=obj_index.uri,
331341
partition_dicts=partition_dicts,
332-
embeddings_uri=embeddings_uri,
342+
use_updates_array=use_updates_array,
333343
metadata_array_uri=metadata_array_uri,
334344
index_timestamp=index_timestamp,
335345
verbose=verbose,
@@ -361,12 +371,14 @@ def create_dag(
361371

362372
from tiledb.vector_search.object_api import object_index
363373

364-
ob_index = object_index.ObjectIndex(
374+
obj_index = object_index.ObjectIndex(
365375
object_index_uri,
366376
config=config,
367377
environment_variables=environment_variables,
378+
load_embedding=False,
379+
load_metadata_in_memory=False,
368380
)
369-
partitions = ob_index.object_reader.get_partitions(**kwargs)
381+
partitions = obj_index.object_reader.get_partitions(**kwargs)
370382
object_partitions = len(partitions)
371383
object_partitions_per_worker = 1
372384
if max_tasks_per_stage == -1:
@@ -395,8 +407,8 @@ def create_dag(
395407

396408
logger.debug("Creating ingestion graph")
397409
d = create_dag(
398-
ob_index=ob_index,
399-
embeddings_uri=embeddings_uri,
410+
obj_index=obj_index,
411+
use_updates_array=use_updates_array,
400412
partitions=partitions,
401413
object_partitions_per_worker=object_partitions_per_worker,
402414
object_work_tasks=object_work_tasks,
@@ -419,19 +431,29 @@ def create_dag(
419431
logger.debug("Submitted ingestion graph")
420432
d.wait()
421433

422-
ob_index.index = ingest(
423-
index_type=ob_index.index_type,
424-
index_uri=ob_index.uri,
425-
source_uri=embeddings_uri,
426-
source_type="TILEDB_PARTITIONED_ARRAY",
427-
external_ids_uri=embeddings_uri,
428-
external_ids_type="TILEDB_PARTITIONED_ARRAY",
429-
index_timestamp=index_timestamp,
430-
storage_version=ob_index.index.storage_version,
431-
config=config,
432-
mode=vector_indexing_mode,
433-
**kwargs,
434-
)
434+
if use_updates_array:
435+
obj_index.index.consolidate_updates(
436+
mode=vector_indexing_mode,
437+
**kwargs,
438+
)
439+
else:
440+
embeddings_array_name = storage_formats[
441+
obj_index.index.storage_version
442+
]["INPUT_VECTORS_ARRAY_NAME"]
443+
embeddings_array_uri = f"{obj_index.uri}/{embeddings_array_name}"
444+
obj_index.index = ingest(
445+
index_type=obj_index.index_type,
446+
index_uri=obj_index.uri,
447+
source_uri=embeddings_array_uri,
448+
source_type="TILEDB_PARTITIONED_ARRAY",
449+
external_ids_uri=embeddings_array_uri,
450+
external_ids_type="TILEDB_PARTITIONED_ARRAY",
451+
index_timestamp=index_timestamp,
452+
storage_version=obj_index.index.storage_version,
453+
config=config,
454+
mode=vector_indexing_mode,
455+
**kwargs,
456+
)
435457

436458
def submit_local(d, func, *args, **kwargs):
437459
# Drop kwarg
@@ -468,7 +490,7 @@ def submit_local(d, func, *args, **kwargs):
468490
submit(
469491
ingest_embeddings,
470492
object_index_uri=object_index_uri,
471-
embeddings_uri=embeddings_uri,
493+
use_updates_array=use_updates_array,
472494
metadata_array_uri=metadata_array_uri,
473495
index_timestamp=index_timestamp,
474496
max_tasks_per_stage=max_tasks_per_stage,

0 commit comments

Comments
 (0)