Skip to content

Commit b679c08

Browse files
Merge pull request #81 from TileDB-Inc/npapa/fix-groups
Use groups instead of os.path.join and register arrays with relative …
2 parents cae0cab + fd7b205 commit b679c08

File tree

2 files changed

+63
-54
lines changed

2 files changed

+63
-54
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
import numpy as np
44
from tiledb.vector_search.module import *
55

6+
CENTROIDS_ARRAY_NAME = "centroids.tdb"
7+
INDEX_ARRAY_NAME = "index.tdb"
8+
IDS_ARRAY_NAME = "ids.tdb"
9+
PARTS_ARRAY_NAME = "parts.tdb"
10+
611

712
class Index:
813
def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
@@ -83,10 +88,11 @@ class IVFFlatIndex(Index):
8388
def __init__(
8489
self, uri, dtype: np.dtype, memory_budget: int = -1, ctx: "Ctx" = None
8590
):
86-
self.parts_db_uri = os.path.join(uri, "parts.tdb")
87-
self.centroids_uri = os.path.join(uri, "centroids.tdb")
88-
self.index_uri = os.path.join(uri, "index.tdb")
89-
self.ids_uri = os.path.join(uri, "ids.tdb")
91+
group = tiledb.Group(uri)
92+
self.parts_db_uri = group[PARTS_ARRAY_NAME].uri
93+
self.centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
94+
self.index_uri = group[INDEX_ARRAY_NAME].uri
95+
self.ids_uri = group[IDS_ARRAY_NAME].uri
9096
self.dtype = dtype
9197
self.memory_budget = memory_budget
9298
self.ctx = ctx

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def create_arrays(
197197
if index_type == "FLAT":
198198
parts_uri = f"{group.uri}/{PARTS_ARRAY_NAME}"
199199
if not tiledb.array_exists(parts_uri):
200-
logger.info("Creating parts array")
200+
logger.debug("Creating parts array")
201201
parts_array_rows_dim = tiledb.Dim(
202202
name="rows",
203203
domain=(0, dimensions - 1),
@@ -222,9 +222,9 @@ def create_arrays(
222222
cell_order="col-major",
223223
tile_order="col-major",
224224
)
225-
logger.info(parts_schema)
225+
logger.debug(parts_schema)
226226
tiledb.Array.create(parts_uri, parts_schema)
227-
group.add(parts_uri, name=PARTS_ARRAY_NAME)
227+
group.add(PARTS_ARRAY_NAME, name=PARTS_ARRAY_NAME, relative=True)
228228

229229
elif index_type == "IVF_FLAT":
230230
centroids_uri = f"{group.uri}/{CENTROIDS_ARRAY_NAME}"
@@ -243,7 +243,7 @@ def create_arrays(
243243
)
244244

245245
if not tiledb.array_exists(centroids_uri):
246-
logger.info("Creating centroids array")
246+
logger.debug("Creating centroids array")
247247
centroids_array_rows_dim = tiledb.Dim(
248248
name="rows",
249249
domain=(0, dimensions - 1),
@@ -270,12 +270,14 @@ def create_arrays(
270270
cell_order="col-major",
271271
tile_order="col-major",
272272
)
273-
logger.info(centroids_schema)
273+
logger.debug(centroids_schema)
274274
tiledb.Array.create(centroids_uri, centroids_schema)
275-
group.add(centroids_uri, name=CENTROIDS_ARRAY_NAME)
275+
group.add(
276+
CENTROIDS_ARRAY_NAME, name=CENTROIDS_ARRAY_NAME, relative=True
277+
)
276278

277279
if not tiledb.array_exists(index_uri):
278-
logger.info("Creating index array")
280+
logger.debug("Creating index array")
279281
index_array_rows_dim = tiledb.Dim(
280282
name="rows",
281283
domain=(0, partitions),
@@ -292,12 +294,12 @@ def create_arrays(
292294
cell_order="col-major",
293295
tile_order="col-major",
294296
)
295-
logger.info(index_schema)
297+
logger.debug(index_schema)
296298
tiledb.Array.create(index_uri, index_schema)
297-
group.add(index_uri, name=INDEX_ARRAY_NAME)
299+
group.add(INDEX_ARRAY_NAME, name=INDEX_ARRAY_NAME, relative=True)
298300

299301
if not tiledb.array_exists(ids_uri):
300-
logger.info("Creating ids array")
302+
logger.debug("Creating ids array")
301303
ids_array_rows_dim = tiledb.Dim(
302304
name="rows",
303305
domain=(0, size - 1),
@@ -314,12 +316,12 @@ def create_arrays(
314316
cell_order="col-major",
315317
tile_order="col-major",
316318
)
317-
logger.info(ids_schema)
319+
logger.debug(ids_schema)
318320
tiledb.Array.create(ids_uri, ids_schema)
319-
group.add(ids_uri, name=IDS_ARRAY_NAME)
321+
group.add(IDS_ARRAY_NAME, name=IDS_ARRAY_NAME, relative=True)
320322

321323
if not tiledb.array_exists(parts_uri):
322-
logger.info("Creating parts array")
324+
logger.debug("Creating parts array")
323325
parts_array_rows_dim = tiledb.Dim(
324326
name="rows",
325327
domain=(0, dimensions - 1),
@@ -344,9 +346,9 @@ def create_arrays(
344346
cell_order="col-major",
345347
tile_order="col-major",
346348
)
347-
logger.info(parts_schema)
349+
logger.debug(parts_schema)
348350
tiledb.Array.create(parts_uri, parts_schema)
349-
group.add(parts_uri, name=PARTS_ARRAY_NAME)
351+
group.add(PARTS_ARRAY_NAME, name=PARTS_ARRAY_NAME, relative=True)
350352

351353
vfs = tiledb.VFS()
352354
if vfs.is_dir(partial_write_array_dir_uri):
@@ -357,7 +359,7 @@ def create_arrays(
357359
vfs.create_dir(partial_write_array_index_uri)
358360

359361
if not tiledb.array_exists(partial_write_array_ids_uri):
360-
logger.info("Creating temp ids array")
362+
logger.debug("Creating temp ids array")
361363
ids_array_rows_dim = tiledb.Dim(
362364
name="rows",
363365
domain=(0, size - 1),
@@ -374,11 +376,11 @@ def create_arrays(
374376
cell_order="col-major",
375377
tile_order="col-major",
376378
)
377-
logger.info(ids_schema)
379+
logger.debug(ids_schema)
378380
tiledb.Array.create(partial_write_array_ids_uri, ids_schema)
379381

380382
if not tiledb.array_exists(partial_write_array_parts_uri):
381-
logger.info("Creating temp parts array")
383+
logger.debug("Creating temp parts array")
382384
parts_array_rows_dim = tiledb.Dim(
383385
name="rows",
384386
domain=(0, dimensions - 1),
@@ -403,8 +405,8 @@ def create_arrays(
403405
cell_order="col-major",
404406
tile_order="col-major",
405407
)
406-
logger.info(parts_schema)
407-
logger.info(partial_write_array_parts_uri)
408+
logger.debug(parts_schema)
409+
logger.debug(partial_write_array_parts_uri)
408410
tiledb.Array.create(partial_write_array_parts_uri, parts_schema)
409411
else:
410412
raise ValueError(f"Not supported index_type {index_type}")
@@ -517,7 +519,7 @@ def copy_centroids(
517519
dest = tiledb.open(centroids_uri, mode="w")
518520
src_centroids = src[:, :]
519521
dest[:, :] = src_centroids
520-
logger.info(src_centroids)
522+
logger.debug(src_centroids)
521523

522524
# --------------------------------------------------------------------
523525
# centralised kmeans UDFs
@@ -630,7 +632,7 @@ def generate_new_centroid_per_thread(
630632
new_centroid_count = np.ones(len(cents_t))
631633
for vector_id in range(start, end):
632634
if vector_id % 100000 == 0:
633-
logger.info(f"Vectors computed: {vector_id}")
635+
logger.debug(f"Vectors computed: {vector_id}")
634636
c_id = assignments_t[vector_id]
635637
if new_centroid_count[c_id] == 1:
636638
new_centroid_sums[c_id] = vectors_t[vector_id]
@@ -640,7 +642,7 @@ def generate_new_centroid_per_thread(
640642
new_centroid_count[c_id] += 1
641643
new_centroid_sums_queue.put(new_centroid_sums)
642644
new_centroid_counts_queue.put(new_centroid_count)
643-
logger.info(f"Finished thread: {thread_id}")
645+
logger.debug(f"Finished thread: {thread_id}")
644646

645647
def update_centroids():
646648
import multiprocessing as mp
@@ -713,15 +715,15 @@ def update_centroids():
713715
verbose=verbose,
714716
trace_id=trace_id,
715717
)
716-
logger.info(f"Input centroids: {centroids[0:5]}")
718+
logger.debug(f"Input centroids: {centroids[0:5]}")
717719
logger.info("Assigning vectors to centroids")
718720
km = KMeans()
719721
km._n_threads = threads
720722
km.cluster_centers_ = centroids
721723
assignments = km.predict(vectors)
722-
logger.info(f"Assignments: {assignments[0:100]}")
724+
logger.debug(f"Assignments: {assignments[0:100]}")
723725
partial_new_centroids = update_centroids()
724-
logger.info(f"New centroids: {partial_new_centroids[0:5]}")
726+
logger.debug(f"New centroids: {partial_new_centroids[0:5]}")
725727
return partial_new_centroids
726728

727729
def compute_new_centroids(*argv):
@@ -769,7 +771,7 @@ def ingest_flat(
769771
trace_id=trace_id,
770772
)
771773

772-
logger.info(f"Vector read:{len(in_vectors)}")
774+
logger.debug(f"Vector read:{len(in_vectors)}")
773775
logger.info(f"Writing data to array {parts_array_uri}")
774776
target[0:dimensions, start:end] = np.transpose(in_vectors)
775777
target.close()
@@ -836,7 +838,7 @@ def ingest_vectors_udf(
836838
)
837839
logger.info(f"Input vectors start_pos: {part}, end_pos: {part_end}")
838840
if source_type == "TILEDB_ARRAY":
839-
logger.info("Start indexing")
841+
logger.debug("Start indexing")
840842
ivf_index_tdb(
841843
dtype=vector_type,
842844
db_uri=source_uri,
@@ -861,7 +863,7 @@ def ingest_vectors_udf(
861863
verbose=verbose,
862864
trace_id=trace_id,
863865
)
864-
logger.info("Start indexing")
866+
logger.debug("Start indexing")
865867
ivf_index(
866868
dtype=vector_type,
867869
db=array_to_matrix(np.transpose(in_vectors).astype(vector_type)),
@@ -909,7 +911,7 @@ def compute_partition_indexes_udf(
909911
sum += partition_size
910912
i += 1
911913
indexes[i] = sum
912-
logger.info(f"Partition indexes: {indexes}")
914+
logger.debug(f"Partition indexes: {indexes}")
913915
index_array = tiledb.open(index_array_uri, mode="w")
914916
index_array[:] = indexes
915917

@@ -967,7 +969,7 @@ def consolidate_partition_udf(
967969
index_array = tiledb.open(index_array_uri, mode="r")
968970
ids_array = tiledb.open(ids_array_uri, mode="w")
969971
parts_array = tiledb.open(parts_array_uri, mode="w")
970-
logger.info(
972+
logger.debug(
971973
f"Partitions start: {partition_id_start} end: {partition_id_end}"
972974
)
973975
for part in range(partition_id_start, partition_id_end, batch):
@@ -1306,8 +1308,7 @@ def consolidate_and_vacuum(
13061308
message = str(err)
13071309
if "already exists" in message:
13081310
logger.info(f"Group '{array_uri}' already exists")
1309-
else:
1310-
raise err
1311+
raise err
13111312
group = tiledb.Group(array_uri, "w")
13121313
group.meta["dataset_type"] = "vector_search"
13131314

@@ -1318,9 +1319,9 @@ def consolidate_and_vacuum(
13181319
size = in_size
13191320
if size > in_size:
13201321
size = in_size
1321-
logger.info("Input dataset size %d", size)
1322-
logger.info("Input dataset dimensions %d", dimensions)
1323-
logger.info(f"Vector dimension type {vector_type}")
1322+
logger.debug("Input dataset size %d", size)
1323+
logger.debug("Input dataset dimensions %d", dimensions)
1324+
logger.debug(f"Vector dimension type {vector_type}")
13241325
if partitions == -1:
13251326
partitions = int(math.sqrt(size))
13261327
if training_sample_size == -1:
@@ -1330,9 +1331,9 @@ def consolidate_and_vacuum(
13301331
workers = 10
13311332
else:
13321333
workers = 1
1333-
logger.info(f"Partitions {partitions}")
1334-
logger.info(f"Training sample size {training_sample_size}")
1335-
logger.info(f"Number of workers {workers}")
1334+
logger.debug(f"Partitions {partitions}")
1335+
logger.debug(f"Training sample size {training_sample_size}")
1336+
logger.debug(f"Number of workers {workers}")
13361337

13371338
if input_vectors_per_work_item == -1:
13381339
input_vectors_per_work_item = VECTORS_PER_WORK_ITEM
@@ -1344,10 +1345,10 @@ def consolidate_and_vacuum(
13441345
math.ceil(input_vectors_work_items / MAX_TASKS_PER_STAGE)
13451346
)
13461347
input_vectors_work_tasks = MAX_TASKS_PER_STAGE
1347-
logger.info("input_vectors_per_work_item %d", input_vectors_per_work_item)
1348-
logger.info("input_vectors_work_items %d", input_vectors_work_items)
1349-
logger.info("input_vectors_work_tasks %d", input_vectors_work_tasks)
1350-
logger.info(
1348+
logger.debug("input_vectors_per_work_item %d", input_vectors_per_work_item)
1349+
logger.debug("input_vectors_work_items %d", input_vectors_work_items)
1350+
logger.debug("input_vectors_work_tasks %d", input_vectors_work_tasks)
1351+
logger.debug(
13511352
"input_vectors_work_items_per_worker %d",
13521353
input_vectors_work_items_per_worker,
13531354
)
@@ -1366,15 +1367,17 @@ def consolidate_and_vacuum(
13661367
math.ceil(table_partitions_work_items / MAX_TASKS_PER_STAGE)
13671368
)
13681369
table_partitions_work_tasks = MAX_TASKS_PER_STAGE
1369-
logger.info("table_partitions_per_work_item %d", table_partitions_per_work_item)
1370-
logger.info("table_partitions_work_items %d", table_partitions_work_items)
1371-
logger.info("table_partitions_work_tasks %d", table_partitions_work_tasks)
1372-
logger.info(
1370+
logger.debug(
1371+
"table_partitions_per_work_item %d", table_partitions_per_work_item
1372+
)
1373+
logger.debug("table_partitions_work_items %d", table_partitions_work_items)
1374+
logger.debug("table_partitions_work_tasks %d", table_partitions_work_tasks)
1375+
logger.debug(
13731376
"table_partitions_work_items_per_worker %d",
13741377
table_partitions_work_items_per_worker,
13751378
)
13761379

1377-
logger.info("Creating arrays")
1380+
logger.debug("Creating arrays")
13781381
create_arrays(
13791382
group=group,
13801383
index_type=index_type,
@@ -1387,7 +1390,7 @@ def consolidate_and_vacuum(
13871390
)
13881391
group.close()
13891392

1390-
logger.info("Creating ingestion graph")
1393+
logger.debug("Creating ingestion graph")
13911394
d = create_ingestion_dag(
13921395
index_type=index_type,
13931396
array_uri=array_uri,
@@ -1409,7 +1412,7 @@ def consolidate_and_vacuum(
14091412
trace_id=trace_id,
14101413
mode=mode,
14111414
)
1412-
logger.info("Submitting ingestion graph")
1415+
logger.debug("Submitting ingestion graph")
14131416
d.compute()
14141417
logger.info("Submitted ingestion graph")
14151418
d.wait()

0 commit comments

Comments
 (0)