Skip to content

Commit 6ff2a86

Browse files
authored
Type erased Vamana index (#285)
1 parent b9eb281 commit 6ff2a86

File tree

10 files changed

+669
-178
lines changed

10 files changed

+669
-178
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 345 additions & 135 deletions
Large diffs are not rendered by default.

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Any, Mapping
23

34
import numpy as np
@@ -62,6 +63,8 @@ def query_internal(
6263
self,
6364
queries: np.ndarray,
6465
k: int = 10,
66+
opt_l: Optional[int] = 1,
67+
**kwargs,
6568
):
6669
"""
6770
Query an VAMANA index
@@ -72,7 +75,10 @@ def query_internal(
7275
ND Array of queries
7376
k: int
7477
Number of top results to return per query
78+
opt_l: int
79+
How deep to search
7580
"""
81+
warnings.warn("The Vamana index is not yet supported, please use with caution.")
7682
if self.size == 0:
7783
return np.full((queries.shape[0], k), index.MAX_FLOAT_32), np.full(
7884
(queries.shape[0], k), index.MAX_UINT64
@@ -83,8 +89,10 @@ def query_internal(
8389
if queries.ndim == 1:
8490
queries = np.array([queries])
8591

86-
# TODO(paris): Actually run the query.
87-
return [], []
92+
queries_feature_vector_array = vspy.FeatureVectorArray(np.transpose(queries))
93+
distances, ids = self.index.query(queries_feature_vector_array, k, opt_l)
94+
95+
return np.array(distances, copy=False), np.array(ids, copy=False)
8896

8997

9098
# TODO(paris): Pass more arguments to C++, i.e. storage_version.
@@ -94,24 +102,23 @@ def create(
94102
vector_type: np.dtype,
95103
id_type: np.dtype = np.uint32,
96104
adjacency_row_index_type: np.dtype = np.uint32,
97-
group_exists: bool = False,
98105
config: Optional[Mapping[str, Any]] = None,
99106
storage_version: str = STORAGE_VERSION,
100107
**kwargs,
101108
) -> VamanaIndex:
102-
if not group_exists:
103-
ctx = vspy.Ctx(config)
104-
index = vspy.IndexVamana(
105-
feature_type=np.dtype(vector_type).name,
106-
id_type=np.dtype(id_type).name,
107-
adjacency_row_index_type=np.dtype(adjacency_row_index_type).name,
108-
dimension=dimensions,
109-
)
110-
# TODO(paris): Run all of this with a single C++ call.
111-
empty_vector = vspy.FeatureVectorArray(
112-
dimensions, 0, np.dtype(vector_type).name, np.dtype(id_type).name
113-
)
114-
index.train(empty_vector)
115-
index.add(empty_vector)
116-
index.write_index(ctx, uri)
109+
warnings.warn("The Vamana index is not yet supported, please use with caution.")
110+
ctx = vspy.Ctx(config)
111+
index = vspy.IndexVamana(
112+
feature_type=np.dtype(vector_type).name,
113+
id_type=np.dtype(id_type).name,
114+
adjacency_row_index_type=np.dtype(adjacency_row_index_type).name,
115+
dimension=dimensions,
116+
)
117+
# TODO(paris): Run all of this with a single C++ call.
118+
empty_vector = vspy.FeatureVectorArray(
119+
dimensions, 0, np.dtype(vector_type).name, np.dtype(id_type).name
120+
)
121+
index.train(empty_vector)
122+
index.add(empty_vector)
123+
index.write_index(ctx, uri)
117124
return VamanaIndex(uri=uri, config=config, memory_budget=1000000)

apis/python/test/test_index.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
from tiledb.vector_search.vamana_index import VamanaIndex
2020

2121

22+
def query_and_check_distances(
23+
index, queries, k, expected_distances, expected_ids, **kwargs
24+
):
25+
for _ in range(1):
26+
distances, ids = index.query(queries, k=k, **kwargs)
27+
assert np.array_equal(ids, expected_ids)
28+
assert np.array_equal(distances, expected_distances)
29+
30+
2231
def query_and_check(index, queries, k, expected, **kwargs):
2332
for _ in range(3):
2433
result_d, result_i = index.query(queries, k=k, **kwargs)
@@ -167,7 +176,7 @@ def test_ivf_flat_index(tmp_path):
167176
)
168177

169178

170-
def test_vamana_index(tmp_path):
179+
def test_vamana_index_simple(tmp_path):
171180
uri = os.path.join(tmp_path, "array")
172181
dimensions = 3
173182
vector_type = np.dtype(np.uint8)
@@ -188,14 +197,68 @@ def test_vamana_index(tmp_path):
188197
query_and_check(index, np.array([[2, 2, 2]], dtype=np.float32), 3, {ind.MAX_UINT64})
189198

190199

200+
def test_vamana_index(tmp_path):
201+
uri = os.path.join(tmp_path, "array")
202+
if os.path.exists(uri):
203+
os.rmdir(uri)
204+
vector_type = np.float32
205+
206+
index = vamana_index.create(
207+
uri=uri,
208+
dimensions=3,
209+
vector_type=np.dtype(vector_type),
210+
id_type=np.dtype(np.uint32),
211+
)
212+
213+
queries = np.array([[2, 2, 2]], dtype=np.float32)
214+
distances, ids = index.query(queries, k=1)
215+
assert distances.shape == (1, 1)
216+
assert ids.shape == (1, 1)
217+
assert distances[0][0] == ind.MAX_FLOAT_32
218+
assert ids[0][0] == ind.MAX_UINT64
219+
query_and_check_distances(
220+
index, queries, 1, [[ind.MAX_FLOAT_32]], [[ind.MAX_UINT64]]
221+
)
222+
223+
update_vectors = np.empty([5], dtype=object)
224+
update_vectors[0] = np.array([0, 0, 0], dtype=np.dtype(np.float32))
225+
update_vectors[1] = np.array([1, 1, 1], dtype=np.dtype(np.float32))
226+
update_vectors[2] = np.array([2, 2, 2], dtype=np.dtype(np.float32))
227+
update_vectors[3] = np.array([3, 3, 3], dtype=np.dtype(np.float32))
228+
update_vectors[4] = np.array([4, 4, 4], dtype=np.dtype(np.float32))
229+
index.update_batch(
230+
vectors=update_vectors,
231+
external_ids=np.array([0, 1, 2, 3, 4], dtype=np.dtype(np.uint32)),
232+
)
233+
query_and_check_distances(
234+
index, np.array([[2, 2, 2]], dtype=np.float32), 2, [[0, 3]], [[2, 1]]
235+
)
236+
237+
index = index.consolidate_updates()
238+
239+
# TODO(paris): Does not work with k > 1 or with [0, 0, 0] as the query.
240+
query_and_check_distances(
241+
index, np.array([[1, 1, 1]], dtype=np.float32), 1, [[0]], [[1]]
242+
)
243+
query_and_check_distances(
244+
index, np.array([[2, 2, 2]], dtype=np.float32), 1, [[0]], [[2]]
245+
)
246+
query_and_check_distances(
247+
index, np.array([[3, 3, 3]], dtype=np.float32), 1, [[0]], [[3]]
248+
)
249+
query_and_check_distances(
250+
index, np.array([[4, 4, 4]], dtype=np.float32), 1, [[0]], [[4]]
251+
)
252+
253+
191254
def test_delete_invalid_index(tmp_path):
192255
# We don't throw with an invalid uri.
193256
Index.delete_index(uri="invalid_uri", config=tiledb.cloud.Config())
194257

195258

196259
def test_delete_index(tmp_path):
197-
indexes = ["FLAT", "IVF_FLAT"]
198-
index_classes = [FlatIndex, IVFFlatIndex]
260+
indexes = ["FLAT", "IVF_FLAT", "VAMANA"]
261+
index_classes = [FlatIndex, IVFFlatIndex, VamanaIndex]
199262
data = np.array([[1.0, 1.1, 1.2, 1.3], [2.0, 2.1, 2.2, 2.3]], dtype=np.float32)
200263
for index_type, index_class in zip(indexes, index_classes):
201264
index_uri = os.path.join(tmp_path, f"array_{index_type}")
@@ -229,7 +292,7 @@ def test_index_with_incorrect_dimensions(tmp_path):
229292
def test_index_with_incorrect_num_of_query_columns_simple(tmp_path):
230293
siftsmall_uri = siftsmall_inputs_file
231294
queries_uri = siftsmall_query_file
232-
indexes = ["FLAT", "IVF_FLAT"]
295+
indexes = ["FLAT", "IVF_FLAT", "VAMANA"]
233296
for index_type in indexes:
234297
index_uri = os.path.join(tmp_path, f"sift10k_flat_{index_type}")
235298
index = ingest(
@@ -253,7 +316,7 @@ def test_index_with_incorrect_num_of_query_columns_complex(tmp_path):
253316
# Tests that we raise a TypeError if the number of columns in the query is not the same as the
254317
# number of columns in the indexed data.
255318
size = 1000
256-
indexes = ["FLAT", "IVF_FLAT"]
319+
indexes = ["FLAT", "IVF_FLAT", "VAMANA"]
257320
num_columns_in_vector = [1, 2, 3, 4, 5, 10]
258321
for index_type in indexes:
259322
for num_columns in num_columns_in_vector:
@@ -298,7 +361,7 @@ def test_index_with_incorrect_num_of_query_columns_in_single_vector_query(tmp_pa
298361
# Tests that we raise a TypeError if the number of columns in the query is not the same as the
299362
# number of columns in the indexed data, specifically for a single vector query.
300363
# i.e. queries = [1, 2, 3] instead of queries = [[1, 2, 3], [4, 5, 6]].
301-
indexes = [flat_index, ivf_flat_index]
364+
indexes = [flat_index, ivf_flat_index, vamana_index]
302365
for index_type in indexes:
303366
uri = os.path.join(tmp_path, f"array_{index_type.__name__}")
304367
index = index_type.create(uri=uri, dimensions=3, vector_type=np.dtype(np.uint8))

apis/python/test/test_ingestion.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from tiledb.vector_search.module import kmeans_fit
1616
from tiledb.vector_search.module import kmeans_predict
1717
from tiledb.vector_search.utils import load_fvecs
18+
from tiledb.vector_search.vamana_index import VamanaIndex
1819

1920
MINIMUM_ACCURACY = 0.85
2021
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
@@ -30,6 +31,34 @@ def query_and_check_equals(index, queries, expected_result_d, expected_result_i)
3031
)
3132

3233

34+
def test_vamana_ingestion_u8(tmp_path):
35+
dataset_dir = os.path.join(tmp_path, "dataset")
36+
index_uri = os.path.join(tmp_path, "array")
37+
if os.path.exists(index_uri):
38+
shutil.rmtree(index_uri)
39+
create_random_dataset_u8(nb=10000, d=100, nq=100, k=10, path=dataset_dir)
40+
dtype = np.dtype(np.uint8)
41+
k = 10
42+
43+
queries = get_queries(dataset_dir, dtype=dtype)
44+
gt_i, gt_d = get_groundtruth(dataset_dir, k)
45+
46+
index = ingest(
47+
index_type="VAMANA",
48+
index_uri=index_uri,
49+
source_uri=os.path.join(dataset_dir, "data.u8bin"),
50+
)
51+
_, result = index.query(queries, k=k)
52+
# TODO(paris): Fix IDs and re-enable.
53+
# assert accuracy(result, gt_i) > MINIMUM_ACCURACY
54+
55+
index_uri = move_local_index_to_new_location(index_uri)
56+
index_ram = VamanaIndex(uri=index_uri)
57+
_, result = index_ram.query(queries, k=k)
58+
# TODO(paris): Fix IDs and re-enable.
59+
# assert accuracy(result, gt_i) > MINIMUM_ACCURACY
60+
61+
3362
def test_flat_ingestion_u8(tmp_path):
3463
dataset_dir = os.path.join(tmp_path, "dataset")
3564
index_uri = os.path.join(tmp_path, "array")

apis/python/test/test_type_erased_module.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from tiledb.vector_search import _tiledbvspy as vspy
77

8-
# ctx = tiledb.Ctx()
98
ctx = vspy.Ctx({})
109

1110

@@ -187,6 +186,45 @@ def test_construct_IndexVamana():
187186
assert a.dimension() == 0
188187

189188

189+
def test_construct_IndexVamana_with_empty_vector(tmp_path):
190+
opt_l = 100
191+
k_nn = 10
192+
index_uri = os.path.join(tmp_path, "array")
193+
dimensions = 128
194+
feature_type = "float32"
195+
id_type = "uint64"
196+
adjacency_row_index_type = "uint64"
197+
198+
# First create an empty index.
199+
a = vspy.IndexVamana(
200+
feature_type=feature_type,
201+
id_type=id_type,
202+
adjacency_row_index_type=adjacency_row_index_type,
203+
dimension=dimensions,
204+
)
205+
empty_vector = vspy.FeatureVectorArray(dimensions, 0, feature_type, id_type)
206+
a.train(empty_vector)
207+
a.write_index(ctx, index_uri)
208+
209+
# Then load it again, retrain, and query.
210+
a = vspy.IndexVamana(ctx, index_uri)
211+
training_set = vspy.FeatureVectorArray(ctx, siftsmall_inputs_uri)
212+
assert training_set.feature_type_string() == "float32"
213+
query_set = vspy.FeatureVectorArray(ctx, siftsmall_query_uri)
214+
assert query_set.feature_type_string() == "float32"
215+
groundtruth_set = vspy.FeatureVectorArray(ctx, siftsmall_groundtruth_uri)
216+
assert groundtruth_set.feature_type_string() == "uint64"
217+
218+
a.train(training_set)
219+
220+
s, t = a.query(query_set, k_nn, opt_l)
221+
222+
intersections = vspy.count_intersections(t, groundtruth_set, k_nn)
223+
nt = np.double(t.num_vectors()) * np.double(k_nn)
224+
recall = intersections / nt
225+
assert recall == 1.0
226+
227+
190228
def test_inplace_build_query_IndexVamana():
191229
opt_l = 100
192230
k_nn = 10

src/include/detail/linalg/tdb_io.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ std::vector<T> read_vector_helper(
9494

9595
// Create a subarray that reads the array up to the specified subset.
9696
std::vector<int32_t> subarray_vals = {
97-
(int32_t)start_pos, (int32_t)end_pos - 1};
97+
(int32_t)start_pos, std::max(0, (int32_t)end_pos - 1)};
9898
tiledb::Subarray subarray(ctx, *array_);
9999
subarray.set_subarray(subarray_vals);
100100

@@ -136,9 +136,9 @@ void create_empty_for_matrix(
136136
tiledb::Domain domain(ctx);
137137
domain
138138
.add_dimension(tiledb::Dimension::create<int>(
139-
ctx, "rows", {{0, (int)rows - 1}}, row_extent))
139+
ctx, "rows", {{0, std::max(0, (int)rows - 1)}}, row_extent))
140140
.add_dimension(tiledb::Dimension::create<int>(
141-
ctx, "cols", {{0, (int)cols - 1}}, col_extent));
141+
ctx, "cols", {{0, std::max(0, (int)cols - 1)}}, col_extent));
142142

143143
tiledb::ArraySchema schema(ctx, TILEDB_DENSE);
144144

@@ -218,10 +218,9 @@ void write_matrix(
218218

219219
std::vector<int32_t> subarray_vals{
220220
0,
221-
(int)A.num_rows() - 1,
222-
(int)start_pos,
223-
(int)start_pos + (int)A.num_cols() - 1};
224-
221+
std::max(0, (int)A.num_rows() - 1),
222+
std::max(0, (int)start_pos),
223+
std::max(0, (int)start_pos + (int)A.num_cols() - 1)};
225224
// Open array for writing
226225
auto array = tiledb_helpers::open_array(
227226
tdb_func__, ctx, uri, TILEDB_WRITE, temporal_policy);
@@ -265,7 +264,7 @@ void create_empty_for_vector(
265264
std::optional<tiledb_filter_type_t> filter = std::nullopt) {
266265
tiledb::Domain domain(ctx);
267266
domain.add_dimension(tiledb::Dimension::create<int>(
268-
ctx, "rows", {{0, (int)rows - 1}}, row_extent));
267+
ctx, "rows", {{0, std::max(0, (int)rows - 1)}}, row_extent));
269268

270269
// The array will be dense.
271270
tiledb::ArraySchema schema(ctx, TILEDB_DENSE);

src/include/index/ivf_flat_group.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
{
4646
{"centroids_array_name", "partition_centroids"},
4747
{"index_array_name", "partition_indexes"},
48-
{"ids_array_name", "shuffled_vector_ids"},
49-
{"parts_array_name", "shuffled_vectors"},
5048
}}};
5149

5250
template <class Index>

src/include/index/vamana_group.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,26 @@
3737
#include "index/vamana_metadata.h"
3838

3939
/**
40-
* The vamana index group needs to store
41-
* * vectors
42-
* * graph (basically CSR)
43-
* * neighbor lists
44-
* * neighbor scores (distances)
45-
* * "row" index
46-
* * centroids (for the case of partitioned vamana)
40+
* The vamana index group stores:
41+
* - feature_vectors: the original set of vectors which we copy.
42+
* - Example: [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
43+
* - feature_vectors_ids: the IDs of the vectors in feature_vectors_array_name.
44+
* - Example: [99, 100, 101]
45+
* - The graph (basically a CSR)
46+
* - adjacency_ids: These are indexes into feature_vectors. Vertices go from 0
47+
* -> n-1 and each of those vertices indexes into feature_vectors. Then those
48+
* IDs correspond to the indexes. You can also think of it as holding the R
49+
* nearest neighbhors in the graph for each vertex.
50+
* - Example: Here we have 100 and 101 connected, 99 and 101 connected, and
51+
* 99 and 10 connected. Logically you can think of it like: [[1 2], [0, 2], [0,
52+
* 1]], but it's stored as [1, 2, 0, 2, 0, 1]
53+
* - adjacency_scores: This holds the neighbor scores (i.e. the distances)
54+
* - Example: [[distance between 0 and 1, distance between 0 and 2], etc.]
55+
* - adjacency_row_index: Each entry in the row index indicates where the
56+
* neighbhors for that index start. 0 because that's where neighbors for vertex
57+
* 0 start, then 2 b/c that's where niehbhors for vertex 1 start, then 4 b/c
58+
* that's whre niehbhors for vertex 2 start, then 6 b/c that's the end.
59+
* - Example: [0, 2, 4, 6]
4760
*/
4861
[[maybe_unused]] static StorageFormat vamana_storage_formats = {
4962
{"0.3",

0 commit comments

Comments
 (0)