Skip to content

Commit d36e125

Browse files
author
Nikos Papailiou
committed
Make vector ingestion work with tiledb:// uris
1 parent 89c58f9 commit d36e125

File tree

4 files changed

+99
-52
lines changed

4 files changed

+99
-52
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,15 @@ class FlatIndex(Index):
3939
Optional name of partitions
4040
"""
4141

42-
def __init__(self, uri: str, dtype: np.dtype, parts_name: str = "parts.tdb"):
42+
def __init__(self, uri: str, dtype: np.dtype, parts_name: str = "parts.tdb", ctx: "Ctx" = None):
4343
self.uri = uri
4444
self.dtype = dtype
4545
self._index = None
46+
self.ctx = ctx
47+
if ctx is None:
48+
self.ctx = Ctx({})
4649

47-
self._db = load_as_matrix(os.path.join(uri, parts_name))
50+
self._db = load_as_matrix(os.path.join(uri, parts_name), ctx=self.ctx)
4851

4952
def query(
5053
self,
@@ -118,10 +121,10 @@ def __init__(
118121

119122
# TODO pass in a context
120123
if self.memory_budget == -1:
121-
self._db = load_as_matrix(self.parts_db_uri)
124+
self._db = load_as_matrix(self.parts_db_uri, ctx=self.ctx)
122125
self._ids = read_vector_u64(self.ctx, self.ids_uri)
123126

124-
self._centroids = load_as_matrix(self.centroids_uri)
127+
self._centroids = load_as_matrix(self.centroids_uri, ctx=self.ctx)
125128
self._index = read_vector_u64(self.ctx, self.index_uri)
126129

127130
def query(

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 86 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def create_arrays(
223223
)
224224
logger.debug(parts_schema)
225225
tiledb.Array.create(parts_uri, parts_schema)
226-
group.add(PARTS_ARRAY_NAME, name=PARTS_ARRAY_NAME, relative=True)
226+
group.add(parts_uri, name=PARTS_ARRAY_NAME)
227227

228228
elif index_type == "IVF_FLAT":
229229
centroids_uri = f"{group.uri}/{CENTROIDS_ARRAY_NAME}"
@@ -272,7 +272,7 @@ def create_arrays(
272272
logger.debug(centroids_schema)
273273
tiledb.Array.create(centroids_uri, centroids_schema)
274274
group.add(
275-
CENTROIDS_ARRAY_NAME, name=CENTROIDS_ARRAY_NAME, relative=True
275+
centroids_uri, name=CENTROIDS_ARRAY_NAME
276276
)
277277

278278
if not tiledb.array_exists(index_uri):
@@ -295,7 +295,7 @@ def create_arrays(
295295
)
296296
logger.debug(index_schema)
297297
tiledb.Array.create(index_uri, index_schema)
298-
group.add(INDEX_ARRAY_NAME, name=INDEX_ARRAY_NAME, relative=True)
298+
group.add(index_uri, name=INDEX_ARRAY_NAME)
299299

300300
if not tiledb.array_exists(ids_uri):
301301
logger.debug("Creating ids array")
@@ -317,7 +317,7 @@ def create_arrays(
317317
)
318318
logger.debug(ids_schema)
319319
tiledb.Array.create(ids_uri, ids_schema)
320-
group.add(IDS_ARRAY_NAME, name=IDS_ARRAY_NAME, relative=True)
320+
group.add(ids_uri, name=IDS_ARRAY_NAME)
321321

322322
if not tiledb.array_exists(parts_uri):
323323
logger.debug("Creating parts array")
@@ -347,15 +347,27 @@ def create_arrays(
347347
)
348348
logger.debug(parts_schema)
349349
tiledb.Array.create(parts_uri, parts_schema)
350-
group.add(PARTS_ARRAY_NAME, name=PARTS_ARRAY_NAME, relative=True)
351-
352-
vfs = tiledb.VFS()
353-
if vfs.is_dir(partial_write_array_dir_uri):
354-
vfs.remove_dir(partial_write_array_dir_uri)
355-
vfs.create_dir(partial_write_array_dir_uri)
356-
if vfs.is_dir(partial_write_array_index_uri):
357-
vfs.remove_dir(partial_write_array_index_uri)
358-
vfs.create_dir(partial_write_array_index_uri)
350+
group.add(parts_uri, name=PARTS_ARRAY_NAME)
351+
352+
try:
353+
tiledb.group_create(partial_write_array_dir_uri)
354+
except tiledb.TileDBError as err:
355+
message = str(err)
356+
if "already exists" in message:
357+
logger.debug(f"Group '{partial_write_array_dir_uri}' already exists")
358+
raise err
359+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri, "w")
360+
group.add(partial_write_array_dir_uri, name=PARTIAL_WRITE_ARRAY_DIR)
361+
362+
try:
363+
tiledb.group_create(partial_write_array_index_uri)
364+
except tiledb.TileDBError as err:
365+
message = str(err)
366+
if "already exists" in message:
367+
logger.debug(f"Group '{partial_write_array_index_uri}' already exists")
368+
raise err
369+
partial_write_array_group.add(partial_write_array_index_uri, name=INDEX_ARRAY_NAME)
370+
partial_write_array_index_group = tiledb.Group(partial_write_array_index_uri, "w")
359371

360372
if not tiledb.array_exists(partial_write_array_ids_uri):
361373
logger.debug("Creating temp ids array")
@@ -377,6 +389,7 @@ def create_arrays(
377389
)
378390
logger.debug(ids_schema)
379391
tiledb.Array.create(partial_write_array_ids_uri, ids_schema)
392+
partial_write_array_group.add(partial_write_array_ids_uri, name=IDS_ARRAY_NAME)
380393

381394
if not tiledb.array_exists(partial_write_array_parts_uri):
382395
logger.debug("Creating temp parts array")
@@ -407,6 +420,34 @@ def create_arrays(
407420
logger.debug(parts_schema)
408421
logger.debug(partial_write_array_parts_uri)
409422
tiledb.Array.create(partial_write_array_parts_uri, parts_schema)
423+
partial_write_array_group.add(partial_write_array_parts_uri, name=PARTS_ARRAY_NAME)
424+
425+
for part in range(1):
426+
part_index_uri = partial_write_array_index_uri+"/"+str(part)
427+
if not tiledb.array_exists(part_index_uri):
428+
logger.debug(f"Creating part array {part_index_uri}")
429+
index_array_rows_dim = tiledb.Dim(
430+
name="rows",
431+
domain=(0, partitions),
432+
tile=partitions,
433+
dtype=np.dtype(np.int32),
434+
)
435+
index_array_dom = tiledb.Domain(index_array_rows_dim)
436+
index_attr = tiledb.Attr(name="values", dtype=np.dtype(np.uint64))
437+
index_schema = tiledb.ArraySchema(
438+
domain=index_array_dom,
439+
sparse=False,
440+
attrs=[index_attr],
441+
capacity=partitions,
442+
cell_order="col-major",
443+
tile_order="col-major",
444+
)
445+
logger.debug(index_schema)
446+
tiledb.Array.create(part_index_uri, index_schema)
447+
partial_write_array_index_group.add(part_index_uri, name=str(part))
448+
partial_write_array_group.close()
449+
partial_write_array_index_group.close()
450+
410451
else:
411452
raise ValueError(f"Not supported index_type {index_type}")
412453

@@ -822,21 +863,21 @@ def ingest_vectors_udf(
822863
logger = setup(config, verbose)
823864
group = tiledb.Group(array_uri)
824865
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
825-
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
826-
partial_write_array_ids_uri = partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
827-
partial_write_array_parts_uri = (
828-
partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
829-
)
866+
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
867+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
868+
partial_write_array_ids_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
869+
partial_write_array_parts_uri = partial_write_array_group[PARTS_ARRAY_NAME].uri
870+
partial_write_array_index_dir_uri = partial_write_array_group[INDEX_ARRAY_NAME].uri
871+
partial_write_array_index_group = tiledb.Group(partial_write_array_index_dir_uri)
830872

831873
for part in range(start, end, batch):
832874
part_end = part + batch
833875
if part_end > end:
834876
part_end = end
835877

836878
part_name = str(part) + "-" + str(part_end)
837-
partial_write_array_index_uri = (
838-
partial_write_array_dir_uri + "/" + INDEX_ARRAY_NAME + "/" + part_name
839-
)
879+
880+
partial_write_array_index_uri = partial_write_array_index_group[str(int(start / batch))].uri
840881
logger.debug("Input vectors start_pos: %d, end_pos: %d", part, part_end)
841882
if source_type == "TILEDB_ARRAY":
842883
logger.debug("Start indexing")
@@ -889,12 +930,14 @@ def compute_partition_indexes_udf(
889930
with tiledb.scope_ctx(ctx_or_config=config):
890931
group = tiledb.Group(array_uri)
891932
index_array_uri = group[INDEX_ARRAY_NAME].uri
892-
vfs = tiledb.VFS()
933+
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
934+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
935+
partial_write_array_index_dir_uri = partial_write_array_group[INDEX_ARRAY_NAME].uri
936+
partial_write_array_index_group = tiledb.Group(partial_write_array_index_dir_uri)
893937
partition_sizes = np.zeros(partitions)
894938
indexes = np.zeros(partitions + 1).astype(np.uint64)
895-
for partial_index_array_uri in vfs.ls(
896-
array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
897-
):
939+
for part in partial_write_array_index_group:
940+
partial_index_array_uri = part.uri
898941
if tiledb.array_exists(partial_index_array_uri):
899942
partial_index_array = tiledb.open(partial_index_array_uri, mode="r")
900943
partial_indexes = partial_index_array[:]["values"]
@@ -912,7 +955,7 @@ def compute_partition_indexes_udf(
912955
_sum += partition_size
913956
i += 1
914957
indexes[i] = _sum
915-
logger.debug("Partition indexes: %d", indexes)
958+
logger.debug(f"Partition indexes: {indexes}")
916959
index_array = tiledb.open(index_array_uri, mode="w")
917960
index_array[:] = indexes
918961

@@ -932,23 +975,21 @@ def consolidate_partition_udf(
932975
"Consolidating partitions %d-%d", partition_id_start, partition_id_end
933976
)
934977
group = tiledb.Group(array_uri)
935-
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
936-
partial_write_array_ids_uri = (
937-
partial_write_array_dir_uri + "/" + IDS_ARRAY_NAME
938-
)
939-
partial_write_array_parts_uri = (
940-
partial_write_array_dir_uri + "/" + PARTS_ARRAY_NAME
941-
)
978+
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
979+
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
980+
partial_write_array_ids_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
981+
partial_write_array_parts_uri = partial_write_array_group[PARTS_ARRAY_NAME].uri
982+
partial_write_array_index_dir_uri = partial_write_array_group[INDEX_ARRAY_NAME].uri
983+
partial_write_array_index_group = tiledb.Group(partial_write_array_index_dir_uri)
942984
index_array_uri = group[INDEX_ARRAY_NAME].uri
943985
ids_array_uri = group[IDS_ARRAY_NAME].uri
944986
parts_array_uri = group[PARTS_ARRAY_NAME].uri
945987
vfs = tiledb.VFS()
946988
partition_slices = []
947989
for i in range(partitions):
948990
partition_slices.append([])
949-
for partial_index_array_uri in vfs.ls(
950-
array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR + "/" + INDEX_ARRAY_NAME
951-
):
991+
for part in partial_write_array_index_group:
992+
partial_index_array_uri = part.uri
952993
if tiledb.array_exists(partial_index_array_uri):
953994
partial_index_array = tiledb.open(partial_index_array_uri, mode="r")
954995
partial_indexes = partial_index_array[:]["values"]
@@ -1294,12 +1335,14 @@ def consolidate_and_vacuum(
12941335
tiledb.consolidate(group[IDS_ARRAY_NAME].uri, config=conf)
12951336
tiledb.vacuum(group[IDS_ARRAY_NAME].uri, config=conf)
12961337

1297-
vfs = tiledb.VFS(config)
1298-
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1299-
if vfs.is_dir(partial_write_array_dir_uri):
1300-
vfs.remove_dir(partial_write_array_dir_uri)
1338+
# TODO remove temp data for tiledb URIs
1339+
if not array_uri.startswith("tiledb://"):
1340+
vfs = tiledb.VFS(config)
1341+
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1342+
if vfs.is_dir(partial_write_array_dir_uri):
1343+
vfs.remove_dir(partial_write_array_dir_uri)
13011344

1302-
with tiledb.scope_ctx(ctx_or_config=config):
1345+
with tiledb.scope_ctx(ctx_or_config=config) as ctx:
13031346
logger = setup(config, verbose)
13041347
logger.debug("Ingesting Vectors into %r", array_uri)
13051348
try:
@@ -1419,6 +1462,6 @@ def consolidate_and_vacuum(
14191462
consolidate_and_vacuum(array_uri=array_uri, config=config)
14201463

14211464
if index_type == "FLAT":
1422-
return FlatIndex(uri=array_uri, dtype=vector_type)
1465+
return FlatIndex(uri=array_uri, dtype=vector_type, ctx=ctx)
14231466
elif index_type == "IVF_FLAT":
1424-
return IVFFlatIndex(uri=array_uri, dtype=vector_type, memory_budget=1000000)
1467+
return IVFFlatIndex(uri=array_uri, dtype=vector_type, memory_budget=1000000, ctx=ctx)

apis/python/src/tiledb/vector_search/module.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Optional
99

1010

11-
def load_as_matrix(path: str, nqueries: int = 0, config: Dict = {}):
11+
def load_as_matrix(path: str, nqueries: int = 0, ctx: "Ctx" = None):
1212
"""
1313
Load array as Matrix class
1414
@@ -21,7 +21,8 @@ def load_as_matrix(path: str, nqueries: int = 0, config: Dict = {}):
2121
config: Dict
2222
TileDB configuration parameters
2323
"""
24-
ctx = Ctx(config)
24+
if ctx is None:
25+
ctx = Ctx({})
2526

2627
a = tiledb.ArraySchema.load(path)
2728
dtype = a.attr(0).dtype
@@ -43,7 +44,7 @@ def load_as_matrix(path: str, nqueries: int = 0, config: Dict = {}):
4344
return m
4445

4546

46-
def load_as_array(path, return_matrix: bool = False, config: Dict = {}):
47+
def load_as_array(path, return_matrix: bool = False, ctx: "Ctx" = None):
4748
"""
4849
Load array as array class
4950
@@ -56,7 +57,7 @@ def load_as_array(path, return_matrix: bool = False, config: Dict = {}):
5657
config: Dict
5758
TileDB configuration parameters
5859
"""
59-
m = load_as_matrix(path, config=config)
60+
m = load_as_matrix(path, ctx=ctx)
6061
r = np.array(m, copy=False)
6162

6263
# hang on to a copy for testing purposes, for now

src/include/detail/ivf/index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ int ivf_index(
136136
ctx, shuffled_db, parts_uri, start_pos, false);
137137
}
138138
if (index_uri != "") {
139-
write_vector<ids_type>(ctx, indices, index_uri, 0, true);
139+
write_vector<ids_type>(ctx, indices, index_uri, 0, false);
140140
}
141141
if (id_uri != "") {
142142
write_vector<ids_type>(ctx, shuffled_ids, id_uri, start_pos, false);

0 commit comments

Comments
 (0)