Skip to content

Commit 8e4f847

Browse files
Merge pull request #126 from TileDB-Inc/npapa/support-external-ids
Add support for client provided IDs for vectors
2 parents d563aa8 + f9ee6eb commit 8e4f847

File tree

11 files changed

+325
-60
lines changed

11 files changed

+325
-60
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def submit_local(d, func, *args, **kwargs):
1818

1919

2020
class Index:
21-
def query(self, targets: np.ndarray, k):
21+
def query(self, targets: np.ndarray, k, **kwargs):
2222
raise NotImplementedError
2323

2424

@@ -49,11 +49,23 @@ def __init__(
4949
self.config = config
5050
group = tiledb.Group(uri, ctx=tiledb.Ctx(config))
5151
self.storage_version = group.meta.get("storage_version", "0.1")
52+
self.index_uri = group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]].uri
5253
self._db = load_as_matrix(
53-
group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]].uri,
54+
self.index_uri,
5455
ctx=self.ctx,
5556
config=config,
5657
)
58+
self.ids_uri = group[
59+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"]
60+
].uri
61+
if tiledb.array_exists(self.ids_uri, self.ctx):
62+
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
63+
else:
64+
schema = tiledb.ArraySchema.load(
65+
self.index_uri, ctx=tiledb.Ctx(self.config)
66+
)
67+
self.size = schema.domain.dim(1).domain[1]
68+
self._ids = StdVector_u64(np.arange(self.size).astype(np.uint64))
5769

5870
dtype = group.meta.get("dtype", None)
5971
if dtype is None:
@@ -66,7 +78,6 @@ def query(
6678
targets: np.ndarray,
6779
k: int = 10,
6880
nthreads: int = 8,
69-
query_type="heap",
7081
):
7182
"""
7283
Query a flat index
@@ -89,13 +100,7 @@ def query(
89100
assert targets.dtype == np.float32
90101

91102
targets_m = array_to_matrix(np.transpose(targets))
92-
93-
if query_type == "heap":
94-
r = query_vq_heap(self._db, targets_m, k, nthreads)
95-
elif query_type == "nth":
96-
r = query_vq_nth(self._db, targets_m, k, nthreads)
97-
else:
98-
raise Exception("Unknown query type!")
103+
r = query_vq_heap(self._db, targets_m, self._ids, k, nthreads)
99104

100105
return np.transpose(np.array(r))
101106

@@ -145,12 +150,12 @@ def __init__(
145150
self._centroids = load_as_matrix(
146151
self.centroids_uri, ctx=self.ctx, config=config
147152
)
148-
self._index = read_vector_u64(self.ctx, self.index_array_uri)
153+
self._index = read_vector_u64(self.ctx, self.index_array_uri, 0, 0)
149154

150155
# TODO pass in a context
151156
if self.memory_budget == -1:
152157
self._db = load_as_matrix(self.parts_db_uri, ctx=self.ctx, config=config)
153-
self._ids = read_vector_u64(self.ctx, self.ids_uri)
158+
self._ids = read_vector_u64(self.ctx, self.ids_uri, 0, 0)
154159

155160
dtype = group.meta.get("dtype", None)
156161
if dtype is None:

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 170 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from tiledb.cloud.dag import Mode
55
from tiledb.vector_search.index import FlatIndex, IVFFlatIndex, Index
6+
from tiledb.vector_search._tiledbvspy import *
67
import numpy as np
78

89

@@ -13,6 +14,9 @@ def ingest(
1314
input_vectors: np.ndarray = None,
1415
source_uri: str = None,
1516
source_type: str = None,
17+
external_ids: np.array = None,
18+
external_ids_uri: str = "",
19+
external_ids_type: str = None,
1620
config=None,
1721
namespace: Optional[str] = None,
1822
size: int = -1,
@@ -40,6 +44,12 @@ def ingest(
4044
Data source URI
4145
source_type: str
4246
Type of the source data. If left empty it is auto-detected from the suffix of source_uri
47+
external_ids: numpy Array
48+
Input vector external_ids, if this is provided it takes precedence over external_ids_uri and external_ids_type
49+
external_ids_uri: str
50+
Source URI for external_ids
51+
external_ids_type: str
52+
File type of external_ids_uri. If left empty it is auto-detected from the suffix of external_ids_uri
4353
config: None
4454
config dictionary, defaults to None
4555
namespace: str
@@ -95,6 +105,9 @@ def ingest(
95105
INPUT_VECTORS_ARRAY_NAME = storage_formats[STORAGE_VERSION][
96106
"INPUT_VECTORS_ARRAY_NAME"
97107
]
108+
EXTERNAL_IDS_ARRAY_NAME = storage_formats[STORAGE_VERSION][
109+
"EXTERNAL_IDS_ARRAY_NAME"
110+
]
98111
PARTIAL_WRITE_ARRAY_DIR = storage_formats[STORAGE_VERSION][
99112
"PARTIAL_WRITE_ARRAY_DIR"
100113
]
@@ -257,6 +270,47 @@ def write_input_vectors(
257270

258271
return input_vectors_array_uri
259272

273+
def write_external_ids(
274+
group: tiledb.Group,
275+
external_ids: np.array,
276+
size: int,
277+
partitions: int,
278+
) -> str:
279+
external_ids_array_uri = f"{group.uri}/{EXTERNAL_IDS_ARRAY_NAME}"
280+
if tiledb.array_exists(external_ids_array_uri):
281+
raise ValueError(f"Array exists {external_ids_array_uri}")
282+
283+
logger.debug("Creating external IDs array")
284+
ids_array_rows_dim = tiledb.Dim(
285+
name="rows",
286+
domain=(0, size - 1),
287+
tile=int(size / partitions),
288+
dtype=np.dtype(np.int32),
289+
)
290+
ids_array_dom = tiledb.Domain(ids_array_rows_dim)
291+
ids_attr = tiledb.Attr(
292+
name="values",
293+
dtype=np.dtype(np.uint64),
294+
filters=DEFAULT_ATTR_FILTERS,
295+
)
296+
ids_schema = tiledb.ArraySchema(
297+
domain=ids_array_dom,
298+
sparse=False,
299+
attrs=[ids_attr],
300+
capacity=int(size / partitions),
301+
cell_order="col-major",
302+
tile_order="col-major",
303+
)
304+
logger.debug(ids_schema)
305+
tiledb.Array.create(external_ids_array_uri, ids_schema)
306+
group.add(external_ids_array_uri, name=IDS_ARRAY_NAME)
307+
308+
external_ids_array = tiledb.open(external_ids_array_uri, "w")
309+
external_ids_array[:] = external_ids
310+
external_ids_array.close()
311+
312+
return external_ids_array_uri
313+
260314
def create_arrays(
261315
group: tiledb.Group,
262316
index_type: str,
@@ -268,7 +322,34 @@ def create_arrays(
268322
logger: logging.Logger,
269323
) -> None:
270324
if index_type == "FLAT":
325+
ids_uri = f"{group.uri}/{IDS_ARRAY_NAME}"
271326
parts_uri = f"{group.uri}/{PARTS_ARRAY_NAME}"
327+
if not tiledb.array_exists(ids_uri):
328+
logger.debug("Creating ids array")
329+
ids_array_rows_dim = tiledb.Dim(
330+
name="rows",
331+
domain=(0, size - 1),
332+
tile=int(size / partitions),
333+
dtype=np.dtype(np.int32),
334+
)
335+
ids_array_dom = tiledb.Domain(ids_array_rows_dim)
336+
ids_attr = tiledb.Attr(
337+
name="values",
338+
dtype=np.dtype(np.uint64),
339+
filters=DEFAULT_ATTR_FILTERS,
340+
)
341+
ids_schema = tiledb.ArraySchema(
342+
domain=ids_array_dom,
343+
sparse=False,
344+
attrs=[ids_attr],
345+
capacity=int(size / partitions),
346+
cell_order="col-major",
347+
tile_order="col-major",
348+
)
349+
logger.debug(ids_schema)
350+
tiledb.Array.create(ids_uri, ids_schema)
351+
group.add(ids_uri, name=IDS_ARRAY_NAME)
352+
272353
if not tiledb.array_exists(parts_uri):
273354
logger.debug("Creating parts array")
274355
parts_array_rows_dim = tiledb.Dim(
@@ -559,6 +640,39 @@ def create_arrays(
559640
else:
560641
raise ValueError(f"Not supported index_type {index_type}")
561642

643+
def read_external_ids(
644+
external_ids_uri: str,
645+
external_ids_type: str,
646+
start_pos: int,
647+
end_pos: int,
648+
config: Optional[Mapping[str, Any]] = None,
649+
verbose: bool = False,
650+
trace_id: Optional[str] = None,
651+
) -> np.array:
652+
logger = setup(config, verbose)
653+
logger.debug(
654+
"Reading external_ids start_pos: %i, end_pos: %i", start_pos, end_pos
655+
)
656+
if external_ids_uri == "":
657+
return np.arange(start_pos, end_pos).astype(np.uint64)
658+
if external_ids_type == "TILEDB_ARRAY":
659+
with tiledb.open(external_ids_uri, mode="r") as external_ids_array:
660+
return external_ids_array[start_pos:end_pos]["values"]
661+
elif external_ids_type == "U64BIN":
662+
vfs = tiledb.VFS()
663+
read_size = end_pos - start_pos
664+
read_offset = start_pos + 8
665+
with vfs.open(external_ids_uri, "rb") as f:
666+
f.seek(read_offset)
667+
return np.reshape(
668+
np.frombuffer(
669+
f.read(read_size),
670+
count=read_size,
671+
dtype=np.uint64,
672+
).astype(np.uint64),
673+
(read_size),
674+
)
675+
562676
def read_input_vectors(
563677
source_uri: str,
564678
source_type: str,
@@ -886,6 +1000,8 @@ def ingest_flat(
8861000
source_uri: str,
8871001
source_type: str,
8881002
vector_type: np.dtype,
1003+
external_ids_uri: str,
1004+
external_ids_type: str,
8891005
dimensions: int,
8901006
start: int,
8911007
end: int,
@@ -902,7 +1018,9 @@ def ingest_flat(
9021018
with tiledb.scope_ctx(ctx_or_config=config):
9031019
group = tiledb.Group(index_group_uri)
9041020
parts_array_uri = group[PARTS_ARRAY_NAME].uri
905-
target = tiledb.open(parts_array_uri, mode="w")
1021+
ids_array_uri = group[IDS_ARRAY_NAME].uri
1022+
parts_array = tiledb.open(parts_array_uri, mode="w")
1023+
ids_array = tiledb.open(ids_array_uri, mode="w")
9061024
logger.debug("Input vectors start_pos: %d, end_pos: %d", start, end)
9071025

9081026
for part in range(start, end, batch):
@@ -923,8 +1041,22 @@ def ingest_flat(
9231041

9241042
logger.debug("Vector read: %d", len(in_vectors))
9251043
logger.debug("Writing data to array %s", parts_array_uri)
926-
target[0:dimensions, start:end] = np.transpose(in_vectors)
927-
target.close()
1044+
parts_array[0:dimensions, start:end] = np.transpose(in_vectors)
1045+
1046+
external_ids = read_external_ids(
1047+
external_ids_uri=external_ids_uri,
1048+
external_ids_type=external_ids_type,
1049+
start_pos=part,
1050+
end_pos=part_end,
1051+
config=config,
1052+
verbose=verbose,
1053+
trace_id=trace_id,
1054+
)
1055+
logger.debug("External IDs read: %d", len(external_ids))
1056+
logger.debug("Writing data to array %s", ids_array_uri)
1057+
ids_array[start:end] = external_ids
1058+
parts_array.close()
1059+
ids_array.close()
9281060

9291061
def write_centroids(
9301062
centroids: np.ndarray,
@@ -951,6 +1083,8 @@ def ingest_vectors_udf(
9511083
source_uri: str,
9521084
source_type: str,
9531085
vector_type: np.dtype,
1086+
external_ids_uri: str,
1087+
external_ids_type: str,
9541088
partitions: int,
9551089
dimensions: int,
9561090
start: int,
@@ -998,6 +1132,7 @@ def ingest_vectors_udf(
9981132
ivf_index_tdb(
9991133
dtype=vector_type,
10001134
db_uri=source_uri,
1135+
external_ids_uri=external_ids_uri,
10011136
centroids_uri=centroids_uri,
10021137
parts_uri=partial_write_array_parts_uri,
10031138
index_array_uri=partial_write_array_index_uri,
@@ -1019,10 +1154,20 @@ def ingest_vectors_udf(
10191154
verbose=verbose,
10201155
trace_id=trace_id,
10211156
)
1157+
external_ids = read_external_ids(
1158+
external_ids_uri=external_ids_uri,
1159+
external_ids_type=external_ids_type,
1160+
start_pos=part,
1161+
end_pos=part_end,
1162+
config=config,
1163+
verbose=verbose,
1164+
trace_id=trace_id,
1165+
)
10221166
logger.debug("Start indexing")
10231167
ivf_index(
10241168
dtype=vector_type,
10251169
db=array_to_matrix(np.transpose(in_vectors).astype(vector_type)),
1170+
external_ids=StdVector_u64(external_ids),
10261171
centroids_uri=centroids_uri,
10271172
parts_uri=partial_write_array_parts_uri,
10281173
index_array_uri=partial_write_array_index_uri,
@@ -1194,6 +1339,8 @@ def create_ingestion_dag(
11941339
source_uri: str,
11951340
source_type: str,
11961341
vector_type: np.dtype,
1342+
external_ids_uri: str,
1343+
external_ids_type: str,
11971344
size: int,
11981345
partitions: int,
11991346
dimensions: int,
@@ -1249,6 +1396,8 @@ def create_ingestion_dag(
12491396
source_uri=source_uri,
12501397
source_type=source_type,
12511398
vector_type=vector_type,
1399+
external_ids_uri=external_ids_uri,
1400+
external_ids_type=external_ids_type,
12521401
dimensions=dimensions,
12531402
start=start,
12541403
end=end,
@@ -1397,6 +1546,8 @@ def create_ingestion_dag(
13971546
source_uri=source_uri,
13981547
source_type=source_type,
13991548
vector_type=vector_type,
1549+
external_ids_uri=external_ids_uri,
1550+
external_ids_type=external_ids_type,
14001551
partitions=partitions,
14011552
dimensions=dimensions,
14021553
start=start,
@@ -1450,6 +1601,8 @@ def consolidate_and_vacuum(
14501601
group = tiledb.Group(index_group_uri, config=config)
14511602
if INPUT_VECTORS_ARRAY_NAME in group:
14521603
tiledb.Array.delete_array(group[INPUT_VECTORS_ARRAY_NAME].uri)
1604+
if EXTERNAL_IDS_ARRAY_NAME in group:
1605+
tiledb.Array.delete_array(group[EXTERNAL_IDS_ARRAY_NAME].uri)
14531606
modes = ["fragment_meta", "commits", "array_meta"]
14541607
for mode in modes:
14551608
conf = tiledb.Config(config)
@@ -1528,6 +1681,18 @@ def consolidate_and_vacuum(
15281681
group.meta["partitions"] = partitions
15291682
group.meta["storage_version"] = STORAGE_VERSION
15301683

1684+
if external_ids is not None:
1685+
external_ids_uri = write_external_ids(
1686+
group=group,
1687+
external_ids=external_ids,
1688+
size=in_size,
1689+
partitions=partitions
1690+
)
1691+
external_ids_type = "TILEDB_ARRAY"
1692+
else:
1693+
if external_ids_type is None:
1694+
external_ids_type = "U64BIN"
1695+
15311696
if input_vectors_per_work_item == -1:
15321697
input_vectors_per_work_item = VECTORS_PER_WORK_ITEM
15331698
input_vectors_work_items = int(math.ceil(size / input_vectors_per_work_item))
@@ -1590,6 +1755,8 @@ def consolidate_and_vacuum(
15901755
source_uri=source_uri,
15911756
source_type=source_type,
15921757
vector_type=vector_type,
1758+
external_ids_uri=external_ids_uri,
1759+
external_ids_type=external_ids_type,
15931760
size=size,
15941761
partitions=partitions,
15951762
dimensions=dimensions,

0 commit comments

Comments
 (0)