Skip to content

Commit d563aa8

Browse files
Merge pull request #120 from TileDB-Inc/npapa/numpy-ingestion
Add support for ingesting from in-memory numpy arrays
2 parents 1756645 + 41d36ee commit d563aa8

File tree

5 files changed

+158
-58
lines changed

5 files changed

+158
-58
lines changed

apis/python/src/tiledb/vector_search/index.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,18 @@ def __init__(
154154

155155
dtype = group.meta.get("dtype", None)
156156
if dtype is None:
157-
schema = tiledb.ArraySchema.load(self.parts_db_uri, ctx=tiledb.Ctx(self.config))
157+
schema = tiledb.ArraySchema.load(
158+
self.parts_db_uri, ctx=tiledb.Ctx(self.config)
159+
)
158160
self.dtype = np.dtype(schema.attr("values").dtype)
159161
else:
160162
self.dtype = np.dtype(dtype)
161163

162164
self.partitions = group.meta.get("partitions", -1)
163165
if self.partitions == -1:
164-
schema = tiledb.ArraySchema.load(self.centroids_uri, ctx=tiledb.Ctx(self.config))
166+
schema = tiledb.ArraySchema.load(
167+
self.centroids_uri, ctx=tiledb.Ctx(self.config)
168+
)
165169
self.partitions = schema.domain.dim("cols").domain[1] + 1
166170

167171
def query(

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
from tiledb.cloud.dag import Mode
55
from tiledb.vector_search.index import FlatIndex, IVFFlatIndex, Index
6+
import numpy as np
67

78

89
def ingest(
910
index_type: str,
1011
index_uri: str,
11-
source_uri: str,
12-
source_type: str,
1312
*,
13+
input_vectors: np.ndarray = None,
14+
source_uri: str = None,
15+
source_type: str = None,
1416
config=None,
1517
namespace: Optional[str] = None,
1618
size: int = -1,
@@ -32,10 +34,12 @@ def ingest(
3234
Type of vector index (FLAT, IVF_FLAT)
3335
index_uri: str
3436
Vector index URI (stored as TileDB group)
37+
input_vectors: numpy Array
38+
Input vectors, if this is provided it takes precedence over source_uri and source_type.
3539
source_uri: str
3640
Data source URI
3741
source_type: str
38-
Type of the source data
42+
Type of the source data. If left empty it is auto-detected from the suffix of source_uri
3943
config: None
4044
config dictionary, defaults to None
4145
namespace: str
@@ -88,6 +92,9 @@ def ingest(
8892
INDEX_ARRAY_NAME = storage_formats[STORAGE_VERSION]["INDEX_ARRAY_NAME"]
8993
IDS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
9094
PARTS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["PARTS_ARRAY_NAME"]
95+
INPUT_VECTORS_ARRAY_NAME = storage_formats[STORAGE_VERSION][
96+
"INPUT_VECTORS_ARRAY_NAME"
97+
]
9198
PARTIAL_WRITE_ARRAY_DIR = storage_formats[STORAGE_VERSION][
9299
"PARTIAL_WRITE_ARRAY_DIR"
93100
]
@@ -139,8 +146,22 @@ def setup(
139146

140147
return logger
141148

149+
def autodetect_source_type(source_uri: str) -> str:
150+
if source_uri.endswith(".u8bin"):
151+
return "U8BIN"
152+
elif source_uri.endswith(".f32bin"):
153+
return "F32BIN"
154+
elif source_uri.endswith(".fvecs"):
155+
return "FVEC"
156+
elif source_uri.endswith(".ivecs"):
157+
return "IVEC"
158+
elif source_uri.endswith(".bvecs"):
159+
return "BVEC"
160+
else:
161+
return "TILEDB_ARRAY"
162+
142163
def read_source_metadata(
143-
source_uri: str, source_type: str, logger: logging.Logger
164+
source_uri: str, source_type: str = None
144165
) -> Tuple[int, int, np.dtype]:
145166
if source_type == "TILEDB_ARRAY":
146167
schema = tiledb.ArraySchema.load(source_uri)
@@ -189,6 +210,53 @@ def read_source_metadata(
189210
else:
190211
raise ValueError(f"Not supported source_type {source_type}")
191212

213+
def write_input_vectors(
214+
group: tiledb.Group,
215+
input_vectors: np.ndarray,
216+
size: int,
217+
dimensions: int,
218+
vector_type: np.dtype,
219+
) -> str:
220+
input_vectors_array_uri = f"{group.uri}/{INPUT_VECTORS_ARRAY_NAME}"
221+
if tiledb.array_exists(input_vectors_array_uri):
222+
raise ValueError(f"Array exists {input_vectors_array_uri}")
223+
224+
logger.debug("Creating input vectors array")
225+
input_vectors_array_rows_dim = tiledb.Dim(
226+
name="rows",
227+
domain=(0, dimensions - 1),
228+
tile=dimensions,
229+
dtype=np.dtype(np.int32),
230+
)
231+
input_vectors_array_cols_dim = tiledb.Dim(
232+
name="cols",
233+
domain=(0, size - 1),
234+
tile=int(size / partitions),
235+
dtype=np.dtype(np.int32),
236+
)
237+
input_vectors_array_dom = tiledb.Domain(
238+
input_vectors_array_rows_dim, input_vectors_array_cols_dim
239+
)
240+
input_vectors_array_attr = tiledb.Attr(
241+
name="values", dtype=vector_type, filters=DEFAULT_ATTR_FILTERS
242+
)
243+
input_vectors_array_schema = tiledb.ArraySchema(
244+
domain=input_vectors_array_dom,
245+
sparse=False,
246+
attrs=[input_vectors_array_attr],
247+
cell_order="col-major",
248+
tile_order="col-major",
249+
)
250+
logger.debug(input_vectors_array_schema)
251+
tiledb.Array.create(input_vectors_array_uri, input_vectors_array_schema)
252+
group.add(input_vectors_array_uri, name=INPUT_VECTORS_ARRAY_NAME)
253+
254+
input_vectors_array = tiledb.open(input_vectors_array_uri, "w")
255+
input_vectors_array[:, :] = np.transpose(input_vectors)
256+
input_vectors_array.close()
257+
258+
return input_vectors_array_uri
259+
192260
def create_arrays(
193261
group: tiledb.Group,
194262
index_type: str,
@@ -501,7 +569,7 @@ def read_input_vectors(
501569
config: Optional[Mapping[str, Any]] = None,
502570
verbose: bool = False,
503571
trace_id: Optional[str] = None,
504-
) -> np.array:
572+
) -> np.ndarray:
505573
logger = setup(config, verbose)
506574
logger.debug(
507575
"Reading input vectors start_pos: %i, end_pos: %i", start_pos, end_pos
@@ -669,7 +737,7 @@ def init_centroids(
669737
config: Optional[Mapping[str, Any]] = None,
670738
verbose: bool = False,
671739
trace_id: Optional[str] = None,
672-
) -> np.array:
740+
) -> np.ndarray:
673741
logger = setup(config, verbose)
674742
logger.debug(
675743
"Initialising centroids by reading the first vectors in the source data."
@@ -688,7 +756,7 @@ def init_centroids(
688756
)
689757

690758
def assign_points_and_partial_new_centroids(
691-
centroids: np.array,
759+
centroids: np.ndarray,
692760
source_uri: str,
693761
source_type: str,
694762
vector_type: np.dtype,
@@ -859,7 +927,7 @@ def ingest_flat(
859927
target.close()
860928

861929
def write_centroids(
862-
centroids: np.array,
930+
centroids: np.ndarray,
863931
index_group_uri: str,
864932
partitions: int,
865933
dimensions: int,
@@ -1379,12 +1447,14 @@ def consolidate_and_vacuum(
13791447
index_group_uri: str,
13801448
config: Optional[Mapping[str, Any]] = None,
13811449
):
1450+
group = tiledb.Group(index_group_uri, config=config)
1451+
if INPUT_VECTORS_ARRAY_NAME in group:
1452+
tiledb.Array.delete_array(group[INPUT_VECTORS_ARRAY_NAME].uri)
13821453
modes = ["fragment_meta", "commits", "array_meta"]
13831454
for mode in modes:
13841455
conf = tiledb.Config(config)
13851456
conf["sm.consolidation.mode"] = mode
13861457
conf["sm.vacuum.mode"] = mode
1387-
group = tiledb.Group(index_group_uri, config=conf)
13881458
tiledb.consolidate(group[PARTS_ARRAY_NAME].uri, config=conf)
13891459
tiledb.vacuum(group[PARTS_ARRAY_NAME].uri, config=conf)
13901460
if index_type == "IVF_FLAT":
@@ -1416,9 +1486,24 @@ def consolidate_and_vacuum(
14161486
raise err
14171487
group = tiledb.Group(index_group_uri, "w")
14181488

1419-
in_size, dimensions, vector_type = read_source_metadata(
1420-
source_uri=source_uri, source_type=source_type, logger=logger
1421-
)
1489+
if input_vectors is not None:
1490+
in_size = input_vectors.shape[0]
1491+
dimensions = input_vectors.shape[1]
1492+
vector_type = input_vectors.dtype
1493+
source_uri = write_input_vectors(
1494+
group=group,
1495+
input_vectors=input_vectors,
1496+
size=in_size,
1497+
dimensions=dimensions,
1498+
vector_type=vector_type,
1499+
)
1500+
source_type = "TILEDB_ARRAY"
1501+
else:
1502+
if source_type is None:
1503+
source_type = autodetect_source_type(source_uri=source_uri)
1504+
in_size, dimensions, vector_type = read_source_metadata(
1505+
source_uri=source_uri, source_type=source_type
1506+
)
14221507
if size == -1:
14231508
size = in_size
14241509
if size > in_size:

apis/python/src/tiledb/vector_search/storage_formats.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"INDEX_ARRAY_NAME": "index.tdb",
77
"IDS_ARRAY_NAME": "ids.tdb",
88
"PARTS_ARRAY_NAME": "parts.tdb",
9+
"INPUT_VECTORS_ARRAY_NAME": "input_vectors",
910
"PARTIAL_WRITE_ARRAY_DIR": "write_temp",
1011
"DEFAULT_ATTR_FILTERS": None,
1112
},
@@ -14,6 +15,7 @@
1415
"INDEX_ARRAY_NAME": "partition_indexes",
1516
"IDS_ARRAY_NAME": "shuffled_vector_ids",
1617
"PARTS_ARRAY_NAME": "shuffled_vectors",
18+
"INPUT_VECTORS_ARRAY_NAME": "input_vectors",
1719
"PARTIAL_WRITE_ARRAY_DIR": "temp_data",
1820
"DEFAULT_ATTR_FILTERS": tiledb.FilterList([tiledb.ZstdFilter()]),
1921
},

apis/python/test/common.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,6 @@ def xbin_mmap(fname, dtype):
1111
return np.memmap(fname, dtype=dtype, mode="r", offset=8, shape=(n, d))
1212

1313

14-
def get_queries_fvec(file, dimensions, nqueries=None):
15-
vfs = tiledb.VFS()
16-
vector_values = 1 + dimensions
17-
vector_size = vector_values * 4
18-
read_size = nqueries
19-
read_offset = 0
20-
with vfs.open(file, "rb") as f:
21-
f.seek(read_offset)
22-
return np.delete(
23-
np.reshape(
24-
np.frombuffer(
25-
f.read(read_size * vector_size),
26-
count=read_size * vector_values,
27-
dtype=np.float32,
28-
).astype(np.float32),
29-
(read_size, dimensions + 1),
30-
),
31-
0,
32-
axis=1,
33-
)
34-
35-
3614
def get_groundtruth_ivec(file, k=None, nqueries=None):
3715
vfs = tiledb.VFS()
3816
vector_values = 1 + k
@@ -104,7 +82,7 @@ def create_random_dataset_f32(nb, d, nq, k, path):
10482
X, test_size=nq, random_state=1
10583
)
10684

107-
with open(os.path.join(path, "data"), "wb") as f:
85+
with open(os.path.join(path, "data.f32bin"), "wb") as f:
10886
np.array([nb, d], dtype="uint32").tofile(f)
10987
data.astype("float32").tofile(f)
11088
with open(os.path.join(path, "queries"), "wb") as f:
@@ -138,7 +116,7 @@ def create_random_dataset_u8(nb, d, nq, k, path):
138116
data = data.astype("uint8")
139117
queries = queries.astype("uint8")
140118

141-
with open(os.path.join(path, "data"), "wb") as f:
119+
with open(os.path.join(path, "data.u8bin"), "wb") as f:
142120
np.array([nb, d], dtype="uint32").tofile(f)
143121
data.tofile(f)
144122
with open(os.path.join(path, "queries"), "wb") as f:

0 commit comments

Comments
 (0)