Skip to content

Commit 7b5f96e

Browse files
authored
Add a type-erased Vamana index (which only supports creation) (#286)
1 parent 19e2fd8 commit 7b5f96e

File tree

11 files changed

+220
-32
lines changed

11 files changed

+220
-32
lines changed

apis/python/src/tiledb/vector_search/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"Index",
3636
"FlatIndex",
3737
"IVFFlatIndex",
38+
"VamanaIndex",
3839
"Mode",
3940
"load_as_array",
4041
"load_as_matrix",

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,11 @@ void init_type_erased_module(py::module_& m) {
166166

167167
py::class_<FeatureVectorArray>(m, "FeatureVectorArray", py::buffer_protocol())
168168
.def(py::init<const tiledb::Context&, const std::string&>())
169-
// .def(py::init<size_t, size_t, const std::string&>())
170-
// .def(py::init<size_t, size_t void*, const std::string&>())
169+
.def(py::init<
170+
const tiledb::Context&,
171+
const std::string&,
172+
const std::string&>())
173+
.def(py::init<size_t, size_t, const std::string&, const std::string&>())
171174
.def("dimension", &FeatureVectorArray::dimension)
172175
.def("num_vectors", &FeatureVectorArray::num_vectors)
173176
.def("feature_type", &FeatureVectorArray::feature_type)
@@ -277,6 +280,15 @@ void init_type_erased_module(py::module_& m) {
277280
py::arg("vectors"),
278281
py::arg("top_k"),
279282
py::arg("opt_l"))
283+
.def(
284+
"write_index",
285+
[](IndexVamana& index,
286+
const tiledb::Context& ctx,
287+
const std::string& group_uri,
288+
bool overwrite) { index.write_index(ctx, group_uri, overwrite); },
289+
py::arg("ctx"),
290+
py::arg("group_uri"),
291+
py::arg_v("overwrite", true))
280292
.def("feature_type_string", &IndexVamana::feature_type_string)
281293
.def("id_type_string", &IndexVamana::id_type_string)
282294
.def(
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import json
2+
import multiprocessing
3+
from typing import Any, Mapping
4+
5+
import numpy as np
6+
from tiledb.cloud.dag import Mode
7+
8+
from tiledb.vector_search import index
9+
from tiledb.vector_search.module import *
10+
from tiledb.vector_search.storage_formats import (STORAGE_VERSION,
11+
storage_formats,
12+
validate_storage_version)
13+
from tiledb.vector_search.utils import add_to_group
14+
from tiledb.vector_search import _tiledbvspy as vspy
15+
16+
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
17+
INDEX_TYPE = "VAMANA"
18+
19+
class VamanaIndex(index.Index):
20+
"""
21+
Open a Vamana index
22+
23+
Parameters
24+
----------
25+
uri: str
26+
URI of the index
27+
config: Optional[Mapping[str, Any]]
28+
config dictionary, defaults to None
29+
"""
30+
31+
def __init__(
32+
self,
33+
uri: str,
34+
config: Optional[Mapping[str, Any]] = None,
35+
timestamp=None,
36+
**kwargs,
37+
):
38+
super().__init__(uri=uri, config=config, timestamp=timestamp)
39+
self.index_type = INDEX_TYPE
40+
self.index = vspy.IndexVamana(vspy.Ctx(config), uri)
41+
self.db_uri = self.group[storage_formats[self.storage_version]["PARTS_ARRAY_NAME"]].uri
42+
self.ids_uri = self.group[storage_formats[self.storage_version]["IDS_ARRAY_NAME"]].uri
43+
44+
schema = tiledb.ArraySchema.load(self.db_uri, ctx=tiledb.Ctx(self.config))
45+
self.dimensions = self.index.dimension()
46+
47+
self.dtype = np.dtype(self.group.meta.get("dtype", None))
48+
if self.dtype is None:
49+
self.dtype = np.dtype(schema.attr("values").dtype)
50+
else:
51+
self.dtype = np.dtype(self.dtype)
52+
53+
if self.base_size == -1:
54+
self.size = schema.domain.dim(1).domain[1] + 1
55+
else:
56+
self.size = self.base_size
57+
58+
def get_dimensions(self):
59+
return self.dimensions
60+
61+
def query_internal(
62+
self,
63+
queries: np.ndarray,
64+
k: int = 10,
65+
):
66+
"""
67+
Query an VAMANA index
68+
69+
Parameters
70+
----------
71+
queries: numpy.ndarray
72+
ND Array of queries
73+
k: int
74+
Number of top results to return per query
75+
"""
76+
if self.size == 0:
77+
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
78+
(queries.shape[0], k), index.MAX_UINT64
79+
)
80+
81+
assert queries.dtype == np.float32
82+
83+
if queries.ndim == 1:
84+
queries = np.array([queries])
85+
86+
# TODO(paris): Actually run the query.
87+
return [], []
88+
89+
# TODO(paris): Pass more arguments to C++, i.e. storage_version.
90+
def create(
91+
uri: str,
92+
dimensions: int,
93+
vector_type: np.dtype,
94+
id_type: np.dtype = np.uint32,
95+
adjacency_row_index_type: np.dtype = np.uint32,
96+
group_exists: bool = False,
97+
config: Optional[Mapping[str, Any]] = None,
98+
storage_version: str = STORAGE_VERSION,
99+
**kwargs,
100+
) -> VamanaIndex:
101+
if not group_exists:
102+
ctx = vspy.Ctx(config)
103+
index = vspy.IndexVamana(
104+
feature_type=np.dtype(vector_type).name,
105+
id_type=np.dtype(id_type).name,
106+
adjacency_row_index_type=np.dtype(adjacency_row_index_type).name,
107+
dimension=dimensions,
108+
)
109+
# TODO(paris): Run all of this with a single C++ call.
110+
empty_vector = vspy.FeatureVectorArray(
111+
dimensions,
112+
0,
113+
np.dtype(vector_type).name,
114+
np.dtype(id_type).name
115+
)
116+
index.train(empty_vector)
117+
index.add(empty_vector)
118+
index.write_index(ctx, uri)
119+
return VamanaIndex(uri=uri, config=config, memory_budget=1000000)

apis/python/test/test_index.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from tiledb.vector_search import Index
99
from tiledb.vector_search import flat_index
1010
from tiledb.vector_search import ivf_flat_index
11+
from tiledb.vector_search import vamana_index
1112
from tiledb.vector_search.index import create_metadata
1213
from tiledb.vector_search.index import DATASET_TYPE
1314
from tiledb.vector_search.flat_index import FlatIndex
1415
from tiledb.vector_search.ingestion import ingest
1516
from tiledb.vector_search.ivf_flat_index import IVFFlatIndex
17+
from tiledb.vector_search.vamana_index import VamanaIndex
1618
from tiledb.vector_search.utils import load_fvecs
1719

1820

@@ -159,6 +161,30 @@ def test_ivf_flat_index(tmp_path):
159161
index, np.array([[2, 2, 2]], dtype=np.float32), 3, {0, 2, 4}, nprobe=partitions
160162
)
161163

164+
def test_vamana_index(tmp_path):
165+
uri = os.path.join(tmp_path, "array")
166+
dimensions = 3
167+
vector_type = np.dtype(np.uint8)
168+
169+
# Create the index.
170+
index = vamana_index.create(uri=uri, dimensions=dimensions, vector_type=vector_type, id_type=np.dtype(np.uint32))
171+
assert index.get_dimensions() == dimensions
172+
query_and_check(
173+
index,
174+
np.array([[2, 2, 2]], dtype=np.float32),
175+
3,
176+
{ind.MAX_UINT64}
177+
)
178+
179+
# Open the index.
180+
index = VamanaIndex(uri=uri)
181+
assert index.get_dimensions() == dimensions
182+
query_and_check(
183+
index,
184+
np.array([[2, 2, 2]], dtype=np.float32),
185+
3,
186+
{ind.MAX_UINT64}
187+
)
162188

163189
def test_delete_invalid_index(tmp_path):
164190
# We don't throw with an invalid uri.
@@ -179,7 +205,7 @@ def test_delete_index(tmp_path):
179205

180206

181207
def test_index_with_incorrect_dimensions(tmp_path):
182-
indexes = [flat_index, ivf_flat_index]
208+
indexes = [flat_index, ivf_flat_index, vamana_index]
183209
for index_type in indexes:
184210
uri = os.path.join(tmp_path, f"array_{index_type.__name__}")
185211
index = index_type.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))

src/include/detail/linalg/tdb_io.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ void write_matrix(
212212
create_matrix<T, LayoutPolicy, I>(ctx, A, uri);
213213
}
214214

215+
if (A.num_rows() == 0 || A.num_cols() == 0) {
216+
return;
217+
}
218+
215219
std::vector<int32_t> subarray_vals{
216220
0,
217221
(int)A.num_rows() - 1,
@@ -324,6 +328,11 @@ void write_vector(
324328
if (create) {
325329
create_vector(ctx, v, uri);
326330
}
331+
332+
if (size(v) == 0) {
333+
return;
334+
}
335+
327336
// Set the subarray to write into
328337
std::vector<int32_t> subarray_vals{
329338
(int)start_pos, (int)start_pos + (int)size(v) - 1};

src/include/index/index_defs.h

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,11 @@
4949
enum class IndexKind {
5050
FlatL2,
5151
IVFFlat,
52-
FlatPQ,
53-
IVFPQ,
5452
Vamana,
55-
VamanaPQ,
56-
NNDescent,
57-
Last
5853
};
5954

6055
[[maybe_unused]] static std::vector<std::string> index_kind_strings{
61-
"FlatL2",
62-
"IVFFlat",
63-
"FlatPQ",
64-
"IVFPQ",
65-
"Vamana",
66-
"VamanaPQ",
67-
"NNDescent",
68-
"Last"};
56+
"FLAT", "IVF_FLAT", "VAMANA"};
6957

7058
[[maybe_unused]] static inline auto str(IndexKind kind) {
7159
return index_kind_strings[static_cast<int>(kind)];

src/include/index/index_metadata.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
* "base_sizes", // (json) list
3434
* "dataset_type", // "vector_search"
3535
* "dtype", // "float32", etc (Python dtype names)
36-
* "index_type", // "FLAT", "IVF_FLAT", "Vamana"
36+
* "index_type", // "FLAT", "IVF_FLAT", "VAMANA"
3737
* "ingestion_timestamps", // (json) list
3838
* "storage_version", // "0.3"
3939
* "temp_size", // TILEDB_INT64 or TILEDB_FLOAT64

src/include/index/vamana_group.h

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
[[maybe_unused]] static StorageFormat vamana_storage_formats = {
4949
{"0.3",
5050
{
51-
{"feature_vectors_array_name", "feature_vectors"},
5251
{"adjacency_scores_array_name", "adjacency_scores"},
5352
{"adjacency_ids_array_name", "adjacency_ids"},
5453
{"adjacency_row_index_array_name", "adjacency_row_index"},
@@ -169,7 +168,10 @@ class vamana_index_group : public base_index_group<vamana_index_group<Index>> {
169168
}
170169

171170
[[nodiscard]] auto feature_vectors_uri() const {
172-
return this->array_key_to_uri("feature_vectors_array_name");
171+
return this->array_key_to_uri("parts_array_name");
172+
}
173+
[[nodiscard]] auto feature_vector_ids_uri() const {
174+
return this->array_key_to_uri("ids_array_name");
173175
}
174176
[[nodiscard]] auto adjacency_scores_uri() const {
175177
return this->array_key_to_uri("adjacency_scores_array_name");
@@ -181,7 +183,10 @@ class vamana_index_group : public base_index_group<vamana_index_group<Index>> {
181183
return this->array_key_to_uri("adjacency_row_index_array_name");
182184
}
183185
[[nodiscard]] auto feature_vectors_array_name() const {
184-
return this->array_key_to_array_name("feature_vectors_array_name");
186+
return this->array_key_to_array_name("parts_array_name");
187+
}
188+
[[nodiscard]] auto feature_vector_ids_name() const {
189+
return this->array_key_to_array_name("ids_array_name");
185190
}
186191
[[nodiscard]] auto adjacency_scores_array_name() const {
187192
return this->array_key_to_array_name("adjacency_scores_array_name");
@@ -248,8 +253,9 @@ class vamana_index_group : public base_index_group<vamana_index_group<Index>> {
248253
metadata_.dimension_ = this->get_dimension();
249254

250255
/**
251-
* Create the arrays: feature_vectors (matrix), adjacency_scores (vector),
252-
* adjacency_ids (vector), adjacency_row_index (vector).
256+
* Create the arrays: feature_vectors (matrix), feature_vectors_ids
257+
* (vector), adjacency_scores (vector), adjacency_ids (vector),
258+
* adjacency_row_index (vector).
253259
*/
254260
create_empty_for_matrix<
255261
typename index_type::feature_type,
@@ -264,6 +270,15 @@ class vamana_index_group : public base_index_group<vamana_index_group<Index>> {
264270
write_group.add_member(
265271
feature_vectors_array_name(), true, feature_vectors_array_name());
266272

273+
create_empty_for_vector<typename index_type::id_type>(
274+
cached_ctx_,
275+
feature_vector_ids_uri(),
276+
default_domain,
277+
tile_size,
278+
default_compression);
279+
write_group.add_member(
280+
feature_vector_ids_name(), true, feature_vector_ids_name());
281+
267282
create_empty_for_vector<typename index_type::score_type>(
268283
cached_ctx_,
269284
adjacency_scores_uri(),

0 commit comments

Comments
 (0)