Skip to content

Commit 7894656

Browse files
Merge pull request #118 from TileDB-Inc/npapa/storage-format
Add storage format versions and store more metadata in group metadata
2 parents c66aadc + 199f201 commit 7894656

File tree

8 files changed

+133
-87
lines changed

8 files changed

+133
-87
lines changed

apis/python/setup.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ def get_cmake_overrides():
3535
conf.append("-DUSE_MKL_CBLAS={}".format(val))
3636

3737
try:
38-
# Make sure we use pybind11 from this python environment if available,
39-
# required for windows wheels due to:
40-
# https://github.com/pybind/pybind11/issues/3445
41-
import pybind11
42-
pb11_path = pybind11.get_cmake_dir()
43-
conf.append(f"-Dpybind11_DIR={pb11_path}")
38+
# Make sure we use pybind11 from this python environment if available,
39+
# required for windows wheels due to:
40+
# https://github.com/pybind/pybind11/issues/3445
41+
import pybind11
42+
43+
pb11_path = pybind11.get_cmake_dir()
44+
conf.append(f"-Dpybind11_DIR={pb11_path}")
4445
except ImportError:
45-
pass
46+
pass
4647

4748
return conf
4849

@@ -62,5 +63,5 @@ def get_cmake_overrides():
6263
cmake_args=cmake_args,
6364
cmake_install_target="install-libtiledbvectorsearch",
6465
cmake_install_dir="src/tiledb/vector_search",
65-
use_scm_version={"root": "../../", "relative_to": __file__},
66+
use_scm_version={"root": "../../", "relative_to": __file__},
6667
)

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from . import utils
22
from .index import FlatIndex, IVFFlatIndex
33
from .ingestion import ingest
4+
from .storage_formats import storage_formats, STORAGE_VERSION
45
from .module import load_as_array
56
from .module import load_as_matrix
67
from .module import (
@@ -34,5 +35,5 @@
3435
"ivf_index_tdb",
3536
"array_to_matrix",
3637
"partition_ivf_index",
37-
"utils"
38+
"utils",
3839
]

apis/python/src/tiledb/vector_search/index.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,10 @@
55

66
import numpy as np
77
from tiledb.vector_search.module import *
8+
from tiledb.vector_search.storage_formats import storage_formats
89
from tiledb.cloud.dag import Mode
910
from typing import Any, Mapping
1011

11-
CENTROIDS_ARRAY_NAME = "centroids.tdb"
12-
INDEX_ARRAY_NAME = "index.tdb"
13-
IDS_ARRAY_NAME = "ids.tdb"
14-
PARTS_ARRAY_NAME = "parts.tdb"
15-
1612

1713
def submit_local(d, func, *args, **kwargs):
1814
# Drop kwarg
@@ -22,7 +18,7 @@ def submit_local(d, func, *args, **kwargs):
2218

2319

2420
class Index:
25-
def query(self, targets: np.ndarray, k=10, nqueries=10, nthreads=8, nprobe=1):
21+
def query(self, targets: np.ndarray, k):
2622
raise NotImplementedError
2723

2824

@@ -36,40 +32,40 @@ class FlatIndex(Index):
3632
URI of datataset
3733
dtype: numpy.dtype
3834
datatype float32 or uint8
39-
parts_name: str
40-
Optional name of partitions
4135
"""
4236

4337
def __init__(
4438
self,
4539
uri: str,
46-
dtype: Optional[np.dtype] = None,
47-
parts_name: str = "parts.tdb",
4840
config: Optional[Mapping[str, Any]] = None,
4941
):
5042
# If the user passes a tiledb python Config object convert to a dictionary
5143
if isinstance(config, tiledb.Config):
5244
config = dict(config)
5345

5446
self.uri = uri
55-
self.dtype = dtype
5647
self._index = None
5748
self.ctx = Ctx(config)
5849
self.config = config
50+
group = tiledb.Group(uri, ctx=tiledb.Ctx(config))
51+
self.storage_version = group.meta.get("storage_version", "0.1")
52+
self._db = load_as_matrix(
53+
group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]].uri,
54+
ctx=self.ctx,
55+
config=config,
56+
)
5957

60-
self._db = load_as_matrix(os.path.join(uri, parts_name), ctx=self.ctx, config=config)
61-
58+
dtype = group.meta.get("dtype", None)
6259
if dtype is None:
6360
self.dtype = self._db.dtype
6461
else:
65-
self.dtype = dtype
62+
self.dtype = np.dtype(dtype)
6663

6764
def query(
6865
self,
6966
targets: np.ndarray,
7067
k: int = 10,
7168
nthreads: int = 8,
72-
nprobe: int = 1,
7369
query_type="heap",
7470
):
7571
"""
@@ -84,9 +80,7 @@ def query(
8480
nqueries: int
8581
Number of queries
8682
nthreads: int
87-
Number of threads to use for queyr
88-
nprobe: int
89-
number of probes
83+
Number of threads to use for query
9084
"""
9185
# TODO:
9286
# - typecheck targets
@@ -123,7 +117,6 @@ class IVFFlatIndex(Index):
123117
def __init__(
124118
self,
125119
uri,
126-
dtype: np.dtype = None,
127120
memory_budget: int = -1,
128121
config: Optional[Mapping[str, Any]] = None,
129122
):
@@ -134,31 +127,48 @@ def __init__(
134127
self.config = config
135128
self.ctx = Ctx(config)
136129
group = tiledb.Group(uri, ctx=tiledb.Ctx(config))
137-
self.parts_db_uri = group[PARTS_ARRAY_NAME].uri
138-
self.centroids_uri = group[CENTROIDS_ARRAY_NAME].uri
139-
self.index_uri = group[INDEX_ARRAY_NAME].uri
140-
self.ids_uri = group[IDS_ARRAY_NAME].uri
130+
self.storage_version = group.meta.get("storage_version", "0.1")
131+
self.parts_db_uri = group[
132+
storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]
133+
].uri
134+
self.centroids_uri = group[
135+
storage_formats[self.storage_version]["CENTROIDS_ARRAY_NAME"]
136+
].uri
137+
self.index_uri = group[
138+
storage_formats[self.storage_version]["INDEX_ARRAY_NAME"]
139+
].uri
140+
self.ids_uri = group[
141+
storage_formats[self.storage_version]["IDS_ARRAY_NAME"]
142+
].uri
141143
self.memory_budget = memory_budget
142144

145+
self._centroids = load_as_matrix(
146+
self.centroids_uri, ctx=self.ctx, config=config
147+
)
148+
self._index = read_vector_u64(self.ctx, self.index_uri)
149+
143150
# TODO pass in a context
144151
if self.memory_budget == -1:
145152
self._db = load_as_matrix(self.parts_db_uri, ctx=self.ctx, config=config)
146153
self._ids = read_vector_u64(self.ctx, self.ids_uri)
147154

148-
self._centroids = load_as_matrix(self.centroids_uri, ctx=self.ctx, config=config)
149-
150-
# TODO this should always be available
155+
dtype = group.meta.get("dtype", None)
151156
if dtype is None:
152-
self.dtype = self._centroids.dtype
157+
schema = tiledb.ArraySchema.load(self.parts_db_uri, ctx=tiledb.Ctx(self.config))
158+
self.dtype = np.dtype(schema.attr("values").dtype)
153159
else:
154-
self.dtype = dtype
155-
self._index = read_vector_u64(self.ctx, self.index_uri)
160+
self.dtype = np.dtype(dtype)
161+
162+
self.partitions = group.meta.get("partitions", -1)
163+
if self.partitions == -1:
164+
schema = tiledb.ArraySchema.load(self.centroids_uri, ctx=tiledb.Ctx(self.config))
165+
self.partitions = schema.domain.dim("cols").domain[1] + 1
156166

157167
def query(
158168
self,
159169
queries: np.ndarray,
160170
k: int = 10,
161-
nprobe: int = 10,
171+
nprobe: int = 1,
162172
nthreads: int = -1,
163173
use_nuv_implementation: bool = False,
164174
mode: Mode = None,
@@ -198,6 +208,8 @@ def query(
198208

199209
if nthreads == -1:
200210
nthreads = multiprocessing.cpu_count()
211+
212+
nprobe = min(nprobe, self.partitions)
201213
if mode is None:
202214
queries_m = array_to_matrix(np.transpose(queries))
203215
if self.memory_budget == -1:
@@ -313,7 +325,7 @@ def dist_qv_udf(
313325
active_queries=active_queries,
314326
indices=indices,
315327
k_nn=k_nn,
316-
ctx=Ctx(config)
328+
ctx=Ctx(config),
317329
)
318330
results = []
319331
for q in range(len(r)):
@@ -377,9 +389,7 @@ def dist_qv_udf(
377389
ids_uri=self.ids_uri,
378390
query_vectors=queries,
379391
active_partitions=np.array(active_partitions)[part:part_end],
380-
active_queries=np.array(
381-
aq, dtype=object
382-
),
392+
active_queries=np.array(aq, dtype=object),
383393
indices=np.array(self._index),
384394
k_nn=k,
385395
config=config,
@@ -406,5 +416,5 @@ def dist_qv_udf(
406416
tmp = sorted(tmp_results, key=lambda t: t[0])[0:k]
407417
for j in range(len(tmp), k):
408418
tmp.append((float(0.0), int(0)))
409-
results_per_query.append(np.array(tmp, dtype=np.dtype('float,int'))['f1'])
419+
results_per_query.append(np.array(tmp, dtype=np.dtype("float,int"))["f1"])
410420
return results_per_query

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,15 @@ def ingest(
7979
from tiledb.cloud.rest_api import models
8080
from tiledb.cloud.utilities import get_logger
8181
from tiledb.cloud.utilities import set_aws_context
82-
83-
CENTROIDS_ARRAY_NAME = "centroids.tdb"
84-
INDEX_ARRAY_NAME = "index.tdb"
85-
IDS_ARRAY_NAME = "ids.tdb"
86-
PARTS_ARRAY_NAME = "parts.tdb"
87-
PARTIAL_WRITE_ARRAY_DIR = "write_temp"
82+
from tiledb.vector_search.storage_formats import storage_formats, STORAGE_VERSION
83+
84+
CENTROIDS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["CENTROIDS_ARRAY_NAME"]
85+
INDEX_ARRAY_NAME = storage_formats[STORAGE_VERSION]["INDEX_ARRAY_NAME"]
86+
IDS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["IDS_ARRAY_NAME"]
87+
PARTS_ARRAY_NAME = storage_formats[STORAGE_VERSION]["PARTS_ARRAY_NAME"]
88+
PARTIAL_WRITE_ARRAY_DIR = storage_formats[STORAGE_VERSION][
89+
"PARTIAL_WRITE_ARRAY_DIR"
90+
]
8891
VECTORS_PER_WORK_ITEM = 20000000
8992
MAX_TASKS_PER_STAGE = 100
9093
CENTRALISED_KMEANS_MAX_SAMPLE_SIZE = 1000000
@@ -1378,7 +1381,6 @@ def consolidate_and_vacuum(
13781381
logger.debug(f"Group '{array_uri}' already exists")
13791382
raise err
13801383
group = tiledb.Group(array_uri, "w")
1381-
group.meta["dataset_type"] = "vector_search"
13821384

13831385
in_size, dimensions, vector_type = read_source_metadata(
13841386
source_uri=source_uri, source_type=source_type, logger=logger
@@ -1402,6 +1404,10 @@ def consolidate_and_vacuum(
14021404
logger.debug("Partitions %d", partitions)
14031405
logger.debug("Training sample size %d", training_sample_size)
14041406
logger.debug("Number of workers %d", workers)
1407+
group.meta["dataset_type"] = "vector_search"
1408+
group.meta["dtype"] = np.dtype(vector_type).name
1409+
group.meta["partitions"] = partitions
1410+
group.meta["storage_version"] = STORAGE_VERSION
14051411

14061412
if input_vectors_per_work_item == -1:
14071413
input_vectors_per_work_item = VECTORS_PER_WORK_ITEM
@@ -1487,8 +1493,6 @@ def consolidate_and_vacuum(
14871493
consolidate_and_vacuum(array_uri=array_uri, config=config)
14881494

14891495
if index_type == "FLAT":
1490-
return FlatIndex(uri=array_uri, dtype=vector_type, config=config)
1496+
return FlatIndex(uri=array_uri, config=config)
14911497
elif index_type == "IVF_FLAT":
1492-
return IVFFlatIndex(
1493-
uri=array_uri, dtype=vector_type, memory_budget=1000000, config=config
1494-
)
1498+
return IVFFlatIndex(uri=array_uri, memory_budget=1000000, config=config)

apis/python/src/tiledb/vector_search/module.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from typing import Optional, Mapping, Any
99

1010

11-
def load_as_matrix(path: str, nqueries: int = 0, ctx: "Ctx" = None, config: Optional[Mapping[str, Any]] = None):
11+
def load_as_matrix(
12+
path: str,
13+
nqueries: int = 0,
14+
ctx: "Ctx" = None,
15+
config: Optional[Mapping[str, Any]] = None,
16+
):
1217
"""
1318
Load array as Matrix class
1419
@@ -48,7 +53,12 @@ def load_as_matrix(path: str, nqueries: int = 0, ctx: "Ctx" = None, config: Opti
4853
return m
4954

5055

51-
def load_as_array(path, return_matrix: bool = False, ctx: "Ctx" = None, config: Optional[Mapping[str, Any]] = None):
56+
def load_as_array(
57+
path,
58+
return_matrix: bool = False,
59+
ctx: "Ctx" = None,
60+
config: Optional[Mapping[str, Any]] = None,
61+
):
5262
"""
5363
Load array as array class
5464
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
storage_formats = {
2+
"0.1": {
3+
"CENTROIDS_ARRAY_NAME": "centroids.tdb",
4+
"INDEX_ARRAY_NAME": "index.tdb",
5+
"IDS_ARRAY_NAME": "ids.tdb",
6+
"PARTS_ARRAY_NAME": "parts.tdb",
7+
"PARTIAL_WRITE_ARRAY_DIR": "write_temp",
8+
},
9+
"0.2": {
10+
"CENTROIDS_ARRAY_NAME": "partition_centroids",
11+
"INDEX_ARRAY_NAME": "partition_indexes",
12+
"IDS_ARRAY_NAME": "shuffled_vector_ids",
13+
"PARTS_ARRAY_NAME": "shuffled_vectors",
14+
"PARTIAL_WRITE_ARRAY_DIR": "temp_data",
15+
},
16+
}
17+
18+
STORAGE_VERSION = "0.2"

apis/python/src/tiledb/vector_search/utils.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,39 @@
22
import numpy as np
33
import io
44

5+
56
def _load_vecs_t(uri, dtype, ctx_or_config=None):
6-
with tiledb.scope_ctx(ctx_or_config) as ctx:
7-
dtype = np.dtype(dtype)
8-
vfs = tiledb.VFS(ctx.config())
9-
with vfs.open(uri, "rb") as f:
10-
d = f.read(-1)
11-
raw = np.frombuffer(d, dtype=np.uint8)
12-
ndim = raw[:4].view(np.int32)[0]
13-
14-
elem_nbytes = int(4 + ndim * dtype.itemsize)
15-
if raw.size % elem_nbytes != 0:
16-
raise ValueError(
17-
f"Mismatched dims to bytes in file {uri}: {raw.size}, elem_nbytes"
18-
)
19-
# take a view on the whole array as
20-
# (ndim, sizeof(t)*ndim), and return the actual elements
21-
#return raw.view(np.uint8).reshape((elem_nbytes,-1))[4:,:].view(dtype).reshape((ndim,-1))
22-
23-
if dtype != np.uint8:
24-
return raw.view(np.int32).reshape((-1,ndim + 1))[:,1:].view(dtype)
25-
else:
26-
return raw.view(np.uint8).reshape((-1,ndim + 1))[:,1:].view(dtype)
27-
#return raw
7+
with tiledb.scope_ctx(ctx_or_config) as ctx:
8+
dtype = np.dtype(dtype)
9+
vfs = tiledb.VFS(ctx.config())
10+
with vfs.open(uri, "rb") as f:
11+
d = f.read(-1)
12+
raw = np.frombuffer(d, dtype=np.uint8)
13+
ndim = raw[:4].view(np.int32)[0]
14+
15+
elem_nbytes = int(4 + ndim * dtype.itemsize)
16+
if raw.size % elem_nbytes != 0:
17+
raise ValueError(
18+
f"Mismatched dims to bytes in file {uri}: {raw.size}, elem_nbytes"
19+
)
20+
# take a view on the whole array as
21+
# (ndim, sizeof(t)*ndim), and return the actual elements
22+
# return raw.view(np.uint8).reshape((elem_nbytes,-1))[4:,:].view(dtype).reshape((ndim,-1))
23+
24+
if dtype != np.uint8:
25+
return raw.view(np.int32).reshape((-1, ndim + 1))[:, 1:].view(dtype)
26+
else:
27+
return raw.view(np.uint8).reshape((-1, ndim + 1))[:, 1:].view(dtype)
28+
# return raw
29+
2830

2931
def load_ivecs(uri, ctx_or_config=None):
30-
return _load_vecs_t(uri, np.int32, ctx_or_config)
32+
return _load_vecs_t(uri, np.int32, ctx_or_config)
33+
3134

3235
def load_fvecs(uri, ctx_or_config=None):
33-
return _load_vecs_t(uri, np.float32, ctx_or_config)
36+
return _load_vecs_t(uri, np.float32, ctx_or_config)
37+
3438

3539
def load_bvecs(uri, ctx_or_config=None):
36-
return _load_vecs_t(uri, np.uint8, ctx_or_config)
40+
return _load_vecs_t(uri, np.uint8, ctx_or_config)

0 commit comments

Comments
 (0)