Skip to content

Commit 41d36ee

Browse files
author
Nikos Papailiou
committed
Merge branch 'main' into npapa/numpy-ingestion
2 parents 5497a47 + 1756645 commit 41d36ee

File tree

9 files changed

+101
-124
lines changed

9 files changed

+101
-124
lines changed

.github/workflows/build_wheels.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ on:
44
push:
55
branches:
66
- release-*
7-
- refs/tags/*
7+
tags:
8+
- '*'
89
pull_request:
910
branches:
1011
- '*wheel*' # must quote since "*" is a YAML reserved character; we want a string

apis/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "tiledb-vector-search"
3-
version = "0.0.9"
3+
version = "0.0.10"
44
#dynamic = ["version"]
55
description = "TileDB Vector Search Python client"
66
license = { text = "MIT" }

apis/python/src/tiledb/vector_search/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
self.centroids_uri = group[
135135
storage_formats[self.storage_version]["CENTROIDS_ARRAY_NAME"]
136136
].uri
137-
self.index_uri = group[
137+
self.index_array_uri = group[
138138
storage_formats[self.storage_version]["INDEX_ARRAY_NAME"]
139139
].uri
140140
self.ids_uri = group[
@@ -145,7 +145,7 @@ def __init__(
145145
self._centroids = load_as_matrix(
146146
self.centroids_uri, ctx=self.ctx, config=config
147147
)
148-
self._index = read_vector_u64(self.ctx, self.index_uri)
148+
self._index = read_vector_u64(self.ctx, self.index_array_uri)
149149

150150
# TODO pass in a context
151151
if self.memory_budget == -1:

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
def ingest(
1010
index_type: str,
11-
array_uri: str,
11+
index_uri: str,
1212
*,
1313
input_vectors: np.ndarray = None,
1414
source_uri: str = None,
@@ -32,8 +32,8 @@ def ingest(
3232
----------
3333
index_type: str
3434
Type of vector index (FLAT, IVF_FLAT)
35-
array_uri: str
36-
Vector array URI
35+
index_uri: str
36+
Vector index URI (stored as TileDB group)
3737
input_vectors: numpy Array
3838
Input vectors, if this is provided it takes precedence over source_uri and source_type.
3939
source_uri: str
@@ -85,6 +85,9 @@ def ingest(
8585
from tiledb.cloud.utilities import set_aws_context
8686
from tiledb.vector_search.storage_formats import storage_formats, STORAGE_VERSION
8787

88+
# use index_group_uri for internal clarity
89+
index_group_uri = index_uri
90+
8891
CENTROIDS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["CENTROIDS_ARRAY_NAME"]
8992
INDEX_ARRAY_NAME = storage_formats[STORAGE_VERSION]["INDEX_ARRAY_NAME"]
9093
IDS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
@@ -300,7 +303,7 @@ def create_arrays(
300303

301304
elif index_type == "IVF_FLAT":
302305
centroids_uri = f"{group.uri}/{CENTROIDS_ARRAY_NAME}"
303-
index_uri = f"{group.uri}/{INDEX_ARRAY_NAME}"
306+
index_array_uri = f"{group.uri}/{INDEX_ARRAY_NAME}"
304307
ids_uri = f"{group.uri}/{IDS_ARRAY_NAME}"
305308
parts_uri = f"{group.uri}/{PARTS_ARRAY_NAME}"
306309
partial_write_array_dir_uri = f"{group.uri}/{PARTIAL_WRITE_ARRAY_DIR}"
@@ -348,7 +351,7 @@ def create_arrays(
348351
tiledb.Array.create(centroids_uri, centroids_schema)
349352
group.add(centroids_uri, name=CENTROIDS_ARRAY_NAME)
350353

351-
if not tiledb.array_exists(index_uri):
354+
if not tiledb.array_exists(index_array_uri):
352355
logger.debug("Creating index array")
353356
index_array_rows_dim = tiledb.Dim(
354357
name="rows",
@@ -371,8 +374,8 @@ def create_arrays(
371374
tile_order="col-major",
372375
)
373376
logger.debug(index_schema)
374-
tiledb.Array.create(index_uri, index_schema)
375-
group.add(index_uri, name=INDEX_ARRAY_NAME)
377+
tiledb.Array.create(index_array_uri, index_schema)
378+
group.add(index_array_uri, name=INDEX_ARRAY_NAME)
376379

377380
if not tiledb.array_exists(ids_uri):
378381
logger.debug("Creating ids array")
@@ -650,14 +653,14 @@ def read_input_vectors(
650653
# --------------------------------------------------------------------
651654

652655
def copy_centroids(
653-
array_uri: str,
656+
index_group_uri: str,
654657
copy_centroids_uri: str,
655658
config: Optional[Mapping[str, Any]] = None,
656659
verbose: bool = False,
657660
trace_id: Optional[str] = None,
658661
):
659662
logger = setup(config, verbose)
660-
group = tiledb.Group(array_uri)
663+
group = tiledb.Group(index_group_uri)
661664
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
662665
logger.debug(
663666
"Copying centroids from: %s, to: %s", copy_centroids_uri, centroids_uri
@@ -672,7 +675,7 @@ def copy_centroids(
672675
# centralised kmeans UDFs
673676
# --------------------------------------------------------------------
674677
def centralised_kmeans(
675-
array_uri: str,
678+
index_group_uri: str,
676679
source_uri: str,
677680
source_type: str,
678681
vector_type: np.dtype,
@@ -691,7 +694,7 @@ def centralised_kmeans(
691694

692695
with tiledb.scope_ctx(ctx_or_config=config):
693696
logger = setup(config, verbose)
694-
group = tiledb.Group(array_uri)
697+
group = tiledb.Group(index_group_uri)
695698
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
696699
sample_vectors = read_input_vectors(
697700
source_uri=source_uri,
@@ -879,7 +882,7 @@ def compute_new_centroids(*argv):
879882
return np.mean(argv, axis=0).astype(np.float32)
880883

881884
def ingest_flat(
882-
array_uri: str,
885+
index_group_uri: str,
883886
source_uri: str,
884887
source_type: str,
885888
vector_type: np.dtype,
@@ -897,7 +900,7 @@ def ingest_flat(
897900

898901
logger = setup(config, verbose)
899902
with tiledb.scope_ctx(ctx_or_config=config):
900-
group = tiledb.Group(array_uri)
903+
group = tiledb.Group(index_group_uri)
901904
parts_array_uri = group[PARTS_ARRAY_NAME].uri
902905
target = tiledb.open(parts_array_uri, mode="w")
903906
logger.debug("Input vectors start_pos: %d, end_pos: %d", start, end)
@@ -925,7 +928,7 @@ def ingest_flat(
925928

926929
def write_centroids(
927930
centroids: np.ndarray,
928-
array_uri: str,
931+
index_group_uri: str,
929932
partitions: int,
930933
dimensions: int,
931934
config: Optional[Mapping[str, Any]] = None,
@@ -934,7 +937,7 @@ def write_centroids(
934937
):
935938
with tiledb.scope_ctx(ctx_or_config=config):
936939
logger = setup(config, verbose)
937-
group = tiledb.Group(array_uri)
940+
group = tiledb.Group(index_group_uri)
938941
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
939942
logger.debug("Writing centroids to array %s", centroids_uri)
940943
with tiledb.open(centroids_uri, mode="w") as A:
@@ -944,7 +947,7 @@ def write_centroids(
944947
# vector ingestion UDFs
945948
# --------------------------------------------------------------------
946949
def ingest_vectors_udf(
947-
array_uri: str,
950+
index_group_uri: str,
948951
source_uri: str,
949952
source_type: str,
950953
vector_type: np.dtype,
@@ -966,7 +969,7 @@ def ingest_vectors_udf(
966969
)
967970

968971
logger = setup(config, verbose)
969-
group = tiledb.Group(array_uri)
972+
group = tiledb.Group(index_uri)
970973
centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
971974
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
972975
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
@@ -997,7 +1000,7 @@ def ingest_vectors_udf(
9971000
db_uri=source_uri,
9981001
centroids_uri=centroids_uri,
9991002
parts_uri=partial_write_array_parts_uri,
1000-
index_uri=partial_write_array_index_uri,
1003+
index_array_uri=partial_write_array_index_uri,
10011004
id_uri=partial_write_array_ids_uri,
10021005
start=part,
10031006
end=part_end,
@@ -1022,7 +1025,7 @@ def ingest_vectors_udf(
10221025
db=array_to_matrix(np.transpose(in_vectors).astype(vector_type)),
10231026
centroids_uri=centroids_uri,
10241027
parts_uri=partial_write_array_parts_uri,
1025-
index_uri=partial_write_array_index_uri,
1028+
index_array_uri=partial_write_array_index_uri,
10261029
id_uri=partial_write_array_ids_uri,
10271030
start=part,
10281031
end=part_end,
@@ -1031,15 +1034,15 @@ def ingest_vectors_udf(
10311034
)
10321035

10331036
def compute_partition_indexes_udf(
1034-
array_uri: str,
1037+
index_group_uri: str,
10351038
partitions: int,
10361039
config: Optional[Mapping[str, Any]] = None,
10371040
verbose: bool = False,
10381041
trace_id: Optional[str] = None,
10391042
):
10401043
logger = setup(config, verbose)
10411044
with tiledb.scope_ctx(ctx_or_config=config):
1042-
group = tiledb.Group(array_uri)
1045+
group = tiledb.Group(index_group_uri)
10431046
index_array_uri = group[INDEX_ARRAY_NAME].uri
10441047
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
10451048
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
@@ -1075,7 +1078,7 @@ def compute_partition_indexes_udf(
10751078
index_array[:] = indexes
10761079

10771080
def consolidate_partition_udf(
1078-
array_uri: str,
1081+
index_group_uri: str,
10791082
partition_id_start: int,
10801083
partition_id_end: int,
10811084
batch: int,
@@ -1089,7 +1092,7 @@ def consolidate_partition_udf(
10891092
logger.debug(
10901093
"Consolidating partitions %d-%d", partition_id_start, partition_id_end
10911094
)
1092-
group = tiledb.Group(array_uri)
1095+
group = tiledb.Group(index_group_uri)
10931096
partial_write_array_dir_uri = group[PARTIAL_WRITE_ARRAY_DIR].uri
10941097
partial_write_array_group = tiledb.Group(partial_write_array_dir_uri)
10951098
partial_write_array_ids_uri = partial_write_array_group[IDS_ARRAY_NAME].uri
@@ -1187,7 +1190,7 @@ def submit_local(d, func, *args, **kwargs):
11871190

11881191
def create_ingestion_dag(
11891192
index_type: str,
1190-
array_uri: str,
1193+
index_group_uri: str,
11911194
source_uri: str,
11921195
source_type: str,
11931196
vector_type: np.dtype,
@@ -1242,7 +1245,7 @@ def create_ingestion_dag(
12421245
end = size
12431246
ingest_node = submit(
12441247
ingest_flat,
1245-
array_uri=array_uri,
1248+
index_group_uri=index_group_uri,
12461249
source_uri=source_uri,
12471250
source_type=source_type,
12481251
vector_type=vector_type,
@@ -1263,7 +1266,7 @@ def create_ingestion_dag(
12631266
if copy_centroids_uri is not None:
12641267
centroids_node = submit(
12651268
copy_centroids,
1266-
array_uri=array_uri,
1269+
index_group_uri=index_group_uri,
12671270
copy_centroids_uri=copy_centroids_uri,
12681271
config=config,
12691272
verbose=verbose,
@@ -1276,7 +1279,7 @@ def create_ingestion_dag(
12761279
if training_sample_size <= CENTRALISED_KMEANS_MAX_SAMPLE_SIZE:
12771280
centroids_node = submit(
12781281
centralised_kmeans,
1279-
array_uri=array_uri,
1282+
index_group_uri=index_group_uri,
12801283
source_uri=source_uri,
12811284
source_type=source_type,
12821285
vector_type=vector_type,
@@ -1359,7 +1362,7 @@ def create_ingestion_dag(
13591362
centroids_node = submit(
13601363
write_centroids,
13611364
centroids=internal_centroids_node,
1362-
array_uri=array_uri,
1365+
index_group_uri=index_group_uri,
13631366
partitions=partitions,
13641367
dimensions=dimensions,
13651368
config=config,
@@ -1372,7 +1375,7 @@ def create_ingestion_dag(
13721375

13731376
compute_indexes_node = submit(
13741377
compute_partition_indexes_udf,
1375-
array_uri=array_uri,
1378+
index_group_uri=index_group_uri,
13761379
partitions=partitions,
13771380
config=config,
13781381
verbose=verbose,
@@ -1390,7 +1393,7 @@ def create_ingestion_dag(
13901393
end = size
13911394
ingest_node = submit(
13921395
ingest_vectors_udf,
1393-
array_uri=array_uri,
1396+
index_group_uri=index_group_uri,
13941397
source_uri=source_uri,
13951398
source_type=source_type,
13961399
vector_type=vector_type,
@@ -1422,7 +1425,7 @@ def create_ingestion_dag(
14221425
end = partitions
14231426
consolidate_partition_node = submit(
14241427
consolidate_partition_udf,
1425-
array_uri=array_uri,
1428+
index_group_uri=index_group_uri,
14261429
partition_id_start=start,
14271430
partition_id_end=end,
14281431
batch=table_partitions_per_work_item,
@@ -1441,10 +1444,10 @@ def create_ingestion_dag(
14411444
raise ValueError(f"Not supported index_type {index_type}")
14421445

14431446
def consolidate_and_vacuum(
1444-
array_uri: str,
1447+
index_group_uri: str,
14451448
config: Optional[Mapping[str, Any]] = None,
14461449
):
1447-
group = tiledb.Group(array_uri, config=config)
1450+
group = tiledb.Group(index_group_uri, config=config)
14481451
if INPUT_VECTORS_ARRAY_NAME in group:
14491452
tiledb.Array.delete_array(group[INPUT_VECTORS_ARRAY_NAME].uri)
14501453
modes = ["fragment_meta", "commits", "array_meta"]
@@ -1459,23 +1462,29 @@ def consolidate_and_vacuum(
14591462
tiledb.vacuum(group[IDS_ARRAY_NAME].uri, config=conf)
14601463

14611464
# TODO remove temp data for tiledb URIs
1462-
if not array_uri.startswith("tiledb://"):
1465+
if not index_group_uri.startswith("tiledb://"):
14631466
vfs = tiledb.VFS(config)
1464-
partial_write_array_dir_uri = array_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
1467+
partial_write_array_dir_uri = index_group_uri + "/" + PARTIAL_WRITE_ARRAY_DIR
14651468
if vfs.is_dir(partial_write_array_dir_uri):
14661469
vfs.remove_dir(partial_write_array_dir_uri)
14671470

1471+
1472+
# --------------------------------------------------------------------
1473+
# End internal function definitions
1474+
# --------------------------------------------------------------------
1475+
1476+
14681477
with tiledb.scope_ctx(ctx_or_config=config):
14691478
logger = setup(config, verbose)
1470-
logger.debug("Ingesting Vectors into %r", array_uri)
1479+
logger.debug("Ingesting Vectors into %r", index_uri)
14711480
try:
1472-
tiledb.group_create(array_uri)
1481+
tiledb.group_create(index_group_uri)
14731482
except tiledb.TileDBError as err:
14741483
message = str(err)
14751484
if "already exists" in message:
1476-
logger.debug(f"Group '{array_uri}' already exists")
1485+
logger.debug(f"Group '{index_group_uri}' already exists")
14771486
raise err
1478-
group = tiledb.Group(array_uri, "w")
1487+
group = tiledb.Group(index_group_uri, "w")
14791488

14801489
if input_vectors is not None:
14811490
in_size = input_vectors.shape[0]
@@ -1577,7 +1586,7 @@ def consolidate_and_vacuum(
15771586
logger.debug("Creating ingestion graph")
15781587
d = create_ingestion_dag(
15791588
index_type=index_type,
1580-
array_uri=array_uri,
1589+
index_group_uri=index_group_uri,
15811590
source_uri=source_uri,
15821591
source_type=source_type,
15831592
vector_type=vector_type,
@@ -1600,9 +1609,9 @@ def consolidate_and_vacuum(
16001609
d.compute()
16011610
logger.debug("Submitted ingestion graph")
16021611
d.wait()
1603-
consolidate_and_vacuum(array_uri=array_uri, config=config)
1612+
consolidate_and_vacuum(index_group_uri=index_group_uri, config=config)
16041613

16051614
if index_type == "FLAT":
1606-
return FlatIndex(uri=array_uri, config=config)
1615+
return FlatIndex(uri=index_group_uri, config=config)
16071616
elif index_type == "IVF_FLAT":
1608-
return IVFFlatIndex(uri=array_uri, memory_budget=1000000, config=config)
1617+
return IVFFlatIndex(uri=index_group_uri, memory_budget=1000000, config=config)

0 commit comments

Comments
 (0)