Skip to content

Commit a6bfc30

Browse files
Merge pull request #98 from TileDB-Inc/npapa/ingestion
Make vector ingestion work with tiledb:// uris
2 parents a930d41 + 2989635 commit a6bfc30

File tree

6 files changed

+153
-78
lines changed

6 files changed

+153
-78
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from tiledb.vector_search.module import *
88
from tiledb.cloud.dag import Mode
9+
from typing import Any, Mapping
910

1011
CENTROIDS_ARRAY_NAME = "centroids.tdb"
1112
INDEX_ARRAY_NAME = "index.tdb"
@@ -39,12 +40,19 @@ class FlatIndex(Index):
3940
Optional name of partitions
4041
"""
4142

42-
def __init__(self, uri: str, dtype: np.dtype, parts_name: str = "parts.tdb"):
43+
def __init__(
44+
self,
45+
uri: str,
46+
dtype: np.dtype,
47+
parts_name: str = "parts.tdb",
48+
config: Optional[Mapping[str, Any]] = None,
49+
):
4350
self.uri = uri
4451
self.dtype = dtype
4552
self._index = None
53+
self.ctx = Ctx(config)
4654

47-
self._db = load_as_matrix(os.path.join(uri, parts_name))
55+
self._db = load_as_matrix(os.path.join(uri, parts_name), ctx=self.ctx)
4856

4957
def query(
5058
self,
@@ -103,7 +111,11 @@ class IVFFlatIndex(Index):
103111
"""
104112

105113
def __init__(
106-
self, uri, dtype: np.dtype, memory_budget: int = -1, ctx: "Ctx" = None
114+
self,
115+
uri,
116+
dtype: np.dtype,
117+
memory_budget: int = -1,
118+
config: Optional[Mapping[str, Any]] = None,
107119
):
108120
group = tiledb.Group(uri)
109121
self.parts_db_uri = group[PARTS_ARRAY_NAME].uri
@@ -112,16 +124,14 @@ def __init__(
112124
self.ids_uri = group[IDS_ARRAY_NAME].uri
113125
self.dtype = dtype
114126
self.memory_budget = memory_budget
115-
self.ctx = ctx
116-
if ctx is None:
117-
self.ctx = Ctx({})
127+
self.ctx = Ctx(config)
118128

119129
# TODO pass in a context
120130
if self.memory_budget == -1:
121-
self._db = load_as_matrix(self.parts_db_uri)
131+
self._db = load_as_matrix(self.parts_db_uri, ctx=self.ctx)
122132
self._ids = read_vector_u64(self.ctx, self.ids_uri)
123133

124-
self._centroids = load_as_matrix(self.centroids_uri)
134+
self._centroids = load_as_matrix(self.centroids_uri, ctx=self.ctx)
125135
self._index = read_vector_u64(self.ctx, self.index_uri)
126136

127137
def query(

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 117 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from functools import partial
33

44
from tiledb.cloud.dag import Mode
5-
from tiledb.vector_search.index import FlatIndex
6-
from tiledb.vector_search.index import IVFFlatIndex
5+
from tiledb.vector_search.index import FlatIndex, IVFFlatIndex, Index
76

87

98
def ingest(
@@ -23,7 +22,7 @@ def ingest(
2322
verbose: bool = False,
2423
trace_id: Optional[str] = None,
2524
mode: Mode = Mode.LOCAL,
26-
) -> FlatIndex:
25+
) -> Index:
2726
"""
2827
Ingest vectors into TileDB.
2928
@@ -189,7 +188,7 @@ def create_arrays(
189188
size: int,
190189
dimensions: int,
191190
partitions: int,
192-
input_vectors_work_items: int,
191+
input_vectors_work_tasks: int,
193192
vector_type: np.dtype,
194193
logger: logging.Logger,
195194
) -> None:
@@ -223,7 +222,7 @@ def create_arrays(
223222
)
224223
logger.debug(parts_schema)
225224
tiledb.Array.create(parts_uri, parts_schema)
226-
group.add(PARTS_ARRAY_NAME, name=PARTS_ARRAY_NAME, relative=True)
225+
group.add(parts_uri, name=PARTS_ARRAY_NAME)
227226

228227
elif index_type == "IVF_FLAT":
229228
centroids_uri = f"{group.uri}/{CENTROIDS_ARRAY_NAME}"
@@ -271,9 +270,7 @@ def create_arrays(
271270
)
272271
logger.debug(centroids_schema)
273272
tiledb.Array.create(centroids_uri, centroids_schema)
274-
group.add(
275-
CENTROIDS_ARRAY_NAME, name=CENTROIDS_ARRAY_NAME, relative=True
276-
)
273+
group.add(centroids_uri, name=CENTROIDS_ARRAY_NAME)
277274

278275
if not tiledb.array_exists(index_uri):
279276
logger.debug("Creating index array")
@@ -295,7 +292,7 @@ def create_arrays(
295292
)
296293
logger.debug(index_schema)
297294
tiledb.Array.create(index_uri, index_schema)
298-
group.add(INDEX_ARRAY_NAME, name=INDEX_ARRAY_NAME, relative=True)
295+
group.add(index_uri, name=INDEX_ARRAY_NAME)
299296

300297
if not tiledb.array_exists(ids_uri):
301298
logger.debug("Creating ids array")
@@ -317,7 +314,7 @@ def create_arrays(
317314
)
318315
logger.debug(ids_schema)
319316
tiledb.Array.create(ids_uri, ids_schema)
320-
group.add(IDS_ARRAY_NAME, name=IDS_ARRAY_NAME, relative=True)
317+
group.add(ids_uri, name=IDS_ARRAY_NAME)
321318

322319
if not tiledb.array_exists(parts_uri):
323320
logger.debug("Creating parts array")
@@ -347,15 +344,35 @@ def create_arrays(
347344
)
348345
logger.debug(parts_schema)
349346
tiledb.Array.create(parts_uri, parts_schema)
350-
group.add(PARTS_ARRAY_NAME, name=PARTS_ARRAY_NAME, relative=True)
351-
352-
vfs = tiledb.VFS()
353-
if vfs.is_dir(partial_write_array_dir_uri):
354-
vfs.remove_dir(partial_write_array_dir_uri)
355-
vfs.create_dir(partial_write_array_dir_uri)
356-
if vfs.is_dir(partial_write_array_index_uri):
357-
vfs.remove_dir(partial_write_array_index_uri)
358-
vfs.create_dir(partial_write_array_index_uri)
347+
group.add(parts_uri, name=PARTS_ARRAY_NAME)
348+
349+
try:
350+
tiledb.group_create(partial_write_array_dir_uri)
351+
except tiledb.TileDBError as err:
352+
message = str(err)
353+
if "already exists" in message:
354+
logger.debug(
355+
f"Group '{partial_write_array_dir_uri}' already exists"
356+
)
357+
raise err
358+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri, "w")
359+
group.add(partial_write_array_dir_uri, name=PARTIAL_WRITE_ARRAY_DIR)
360+
361+
try:
362+
tiledb.group_create(partial_write_array_index_uri)
363+
except tiledb.TileDBError as err:
364+
message = str(err)
365+
if "already exists" in message:
366+
logger.debug(
367+
f"Group '{partial_write_array_index_uri}' already exists"
368+
)
369+
raise err
370+
partial_write_array_group.add(
371+
partial_write_array_index_uri, name=INDEX_ARRAY_NAME
372+
)
373+
partial_write_array_index_group = tiledb.Group(
374+
partial_write_array_index_uri, "w"
375+
)
359376

360377
if not tiledb.array_exists(partial_write_array_ids_uri):
361378
logger.debug("Creating temp ids array")
@@ -377,6 +394,9 @@ def create_arrays(
377394
)
378395
logger.debug(ids_schema)
379396
tiledb.Array.create(partial_write_array_ids_uri, ids_schema)
397+
partial_write_array_group.add(
398+
partial_write_array_ids_uri, name=IDS_ARRAY_NAME
399+
)
380400

381401
if not tiledb.array_exists(partial_write_array_parts_uri):
382402
logger.debug("Creating temp parts array")
@@ -407,6 +427,36 @@ def create_arrays(
407427
logger.debug(parts_schema)
408428
logger.debug(partial_write_array_parts_uri)
409429
tiledb.Array.create(partial_write_array_parts_uri, parts_schema)
430+
partial_write_array_group.add(
431+
partial_write_array_parts_uri, name=PARTS_ARRAY_NAME
432+
)
433+
434+
for part in range(input_vectors_work_tasks):
435+
part_index_uri = partial_write_array_index_uri + "/" + str(part)
436+
if not tiledb.array_exists(part_index_uri):
437+
logger.debug(f"Creating part array {part_index_uri}")
438+
index_array_rows_dim = tiledb.Dim(
439+
name="rows",
440+
domain=(0, partitions),
441+
tile=partitions,
442+
dtype=np.dtype(np.int32),
443+
)
444+
index_array_dom = tiledb.Domain(index_array_rows_dim)
445+
index_attr = tiledb.Attr(name="values", dtype=np.dtype(np.uint64))
446+
index_schema = tiledb.ArraySchema(
447+
domain=index_array_dom,
448+
sparse=False,
449+
attrs=[index_attr],
450+
capacity=partitions,
451+
cell_order="col-major",
452+
tile_order="col-major",
453+
)
454+
logger.debug(index_schema)
455+
tiledb.Array.create(part_index_uri, index_schema)
456+
partial_write_array_index_group.add(part_index_uri, name=str(part))
457+
partial_write_array_group.close()
458+
partial_write_array_index_group.close()
459+
410460
else:
411461
raise ValueError(f"Not supported index_type {index_type}")
412462

@@ -822,10 +872,15 @@ def ingest_vectors_udf(
822872
logger = setup(config, verbose)
823873
group = tiledb.Group(array_uri)
824874
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
825-
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
826-
partial_write_array_ids_uri = partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
827-
partial_write_array_parts_uri = (
828-
partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
875+
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
876+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
877+
partial_write_array_ids_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
878+
partial_write_array_parts_uri = partial_write_array_group[PARTS_ARRAY_NAME].uri
879+
partial_write_array_index_dir_uri = partial_write_array_group[
880+
INDEX_ARRAY_NAME
881+
].uri
882+
partial_write_array_index_group = tiledb.Group(
883+
partial_write_array_index_dir_uri
829884
)
830885

831886
for part in range(start, end, batch):
@@ -834,9 +889,10 @@ def ingest_vectors_udf(
834889
part_end = end
835890

836891
part_name = str(part) + "-" + str(part_end)
837-
partial_write_array_index_uri = (
838-
partial_write_array_dir_uri + "/" + INDEX_ARRAY_NAME + "/" + part_name
839-
)
892+
893+
partial_write_array_index_uri = partial_write_array_index_group[
894+
str(int(start / batch))
895+
].uri
840896
logger.debug("Input vectors start_pos: %d, end_pos: %d", part, part_end)
841897
if source_type == "TILEDB_ARRAY":
842898
logger.debug("Start indexing")
@@ -889,12 +945,18 @@ def compute_partition_indexes_udf(
889945
with tiledb.scope_ctx(ctx_or_config=config):
890946
group = tiledb.Group(array_uri)
891947
index_array_uri = group[INDEX_ARRAY_NAME].uri
892-
vfs = tiledb.VFS()
948+
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
949+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
950+
partial_write_array_index_dir_uri = partial_write_array_group[
951+
INDEX_ARRAY_NAME
952+
].uri
953+
partial_write_array_index_group = tiledb.Group(
954+
partial_write_array_index_dir_uri
955+
)
893956
partition_sizes = np.zeros(partitions)
894957
indexes = np.zeros(partitions + 1).astype(np.uint64)
895-
for partial_index_array_uri in vfs.ls(
896-
array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
897-
):
958+
for part in partial_write_array_index_group:
959+
partial_index_array_uri = part.uri
898960
if tiledb.array_exists(partial_index_array_uri):
899961
partial_index_array = tiledb.open(partial_index_array_uri, mode="r")
900962
partial_indexes = partial_index_array[:]["values"]
@@ -912,7 +974,7 @@ def compute_partition_indexes_udf(
912974
_sum += partition_size
913975
i += 1
914976
indexes[i] = _sum
915-
logger.debug("Partition indexes: %d", indexes)
977+
logger.debug(f"Partition indexes: {indexes}")
916978
index_array = tiledb.open(index_array_uri, mode="w")
917979
index_array[:] = indexes
918980

@@ -932,12 +994,17 @@ def consolidate_partition_udf(
932994
"Consolidating partitions %d-%d", partition_id_start, partition_id_end
933995
)
934996
group = tiledb.Group(array_uri)
935-
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
936-
partial_write_array_ids_uri = (
937-
partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
938-
)
939-
partial_write_array_parts_uri = (
940-
partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
997+
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
998+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
999+
partial_write_array_ids_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
1000+
partial_write_array_parts_uri = partial_write_array_group[
1001+
PARTS_ARRAY_NAME
1002+
].uri
1003+
partial_write_array_index_dir_uri = partial_write_array_group[
1004+
INDEX_ARRAY_NAME
1005+
].uri
1006+
partial_write_array_index_group = tiledb.Group(
1007+
partial_write_array_index_dir_uri
9411008
)
9421009
index_array_uri = group[INDEX_ARRAY_NAME].uri
9431010
ids_array_uri = group[IDS_ARRAY_NAME].uri
@@ -946,9 +1013,8 @@ def consolidate_partition_udf(
9461013
partition_slices = []
9471014
for i in range(partitions):
9481015
partition_slices.append([])
949-
for partial_index_array_uri in vfs.ls(
950-
array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
951-
):
1016+
for part in partial_write_array_index_group:
1017+
partial_index_array_uri = part.uri
9521018
if tiledb.array_exists(partial_index_array_uri):
9531019
partial_index_array = tiledb.open(partial_index_array_uri, mode="r")
9541020
partial_indexes = partial_index_array[:]["values"]
@@ -1294,10 +1360,12 @@ def consolidate_and_vacuum(
12941360
tiledb.consolidate(group[IDS_ARRAY_NAME].uri, config=conf)
12951361
tiledb.vacuum(group[IDS_ARRAY_NAME].uri, config=conf)
12961362

1297-
vfs = tiledb.VFS(config)
1298-
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1299-
if vfs.is_dir(partial_write_array_dir_uri):
1300-
vfs.remove_dir(partial_write_array_dir_uri)
1363+
# TODO remove temp data for tiledb URIs
1364+
if not array_uri.startswith("tiledb://"):
1365+
vfs = tiledb.VFS(config)
1366+
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1367+
if vfs.is_dir(partial_write_array_dir_uri):
1368+
vfs.remove_dir(partial_write_array_dir_uri)
13011369

13021370
with tiledb.scope_ctx(ctx_or_config=config):
13031371
logger = setup(config, verbose)
@@ -1384,7 +1452,7 @@ def consolidate_and_vacuum(
13841452
size=size,
13851453
dimensions=dimensions,
13861454
partitions=partitions,
1387-
input_vectors_work_items=input_vectors_work_items,
1455+
input_vectors_work_tasks=input_vectors_work_tasks,
13881456
vector_type=vector_type,
13891457
logger=logger,
13901458
)
@@ -1419,6 +1487,8 @@ def consolidate_and_vacuum(
14191487
consolidate_and_vacuum(array_uri=array_uri, config=config)
14201488

14211489
if index_type == "FLAT":
1422-
return FlatIndex(uri=array_uri, dtype=vector_type)
1490+
return FlatIndex(uri=array_uri, dtype=vector_type, config=config)
14231491
elif index_type == "IVF_FLAT":
1424-
return IVFFlatIndex(uri=array_uri, dtype=vector_type, memory_budget=1000000)
1492+
return IVFFlatIndex(
1493+
uri=array_uri, dtype=vector_type, memory_budget=1000000, config=config
1494+
)

0 commit comments

Comments
 (0)