Skip to content

Commit 6d908ce

Browse files
authored
Expose Vamana graph building params (#423)
1 parent bd22702 commit 6d908ce

File tree

5 files changed

+64
-3
lines changed

5 files changed

+64
-3
lines changed

apis/python/src/tiledb/vector_search/ingestion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def ingest(
5252
size: int = -1,
5353
partitions: int = -1,
5454
num_subspaces: int = -1,
55+
l_build: int = -1,
56+
r_max_degree: int = -1,
5557
training_sampling_policy: TrainingSamplingPolicy = TrainingSamplingPolicy.FIRST_N,
5658
copy_centroids_uri: str = None,
5759
training_sample_size: int = -1,
@@ -120,6 +122,12 @@ def ingest(
120122
For PQ encoded indexes, the number of subspaces to use in the PQ encoding. We will divide the dimensions into
121123
num_subspaces parts, and PQ encode each part separately. This means dimensions must
122124
be divisible by num_subspaces.
125+
l_build: int
126+
For Vamana indexes, the number of neighbors considered for each node during construction of the graph. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. l_build should be >= r_max_degree unless you need to build indices quickly and can compromise on quality.
127+
Typically between 75 and 200. If not provided, use the default value of 100.
128+
r_max_degree: int
129+
For Vamana indexes, the maximum degree for each node in the final graph. Larger values will result in larger indices and longer indexing times, but better search quality.
130+
Typically between 60 and 150. If not provided, use the default value of 64.
123131
copy_centroids_uri: str
124132
TileDB array URI to copy centroids from, if not provided, centroids are build running `k-means`.
125133
training_sample_size: int
@@ -2671,6 +2679,8 @@ def consolidate_and_vacuum(
26712679
dimensions=dimensions,
26722680
vector_type=vector_type,
26732681
config=config,
2682+
l_build=l_build,
2683+
r_max_degree=r_max_degree,
26742684
storage_version=storage_version,
26752685
)
26762686
elif index_type == "IVF_PQ":

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ void init_type_erased_module(py::module_& m) {
365365
.def("feature_type_string", &IndexVamana::feature_type_string)
366366
.def("id_type_string", &IndexVamana::id_type_string)
367367
.def("dimensions", &IndexVamana::dimensions)
368+
.def("l_build", &IndexVamana::l_build)
369+
.def("r_max_degree", &IndexVamana::r_max_degree)
368370
.def_static(
369371
"clear_history",
370372
[](const tiledb::Context& ctx,

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727

2828
INDEX_TYPE = "VAMANA"
2929

30+
L_BUILD_DEFAULT = 100
31+
R_MAX_DEGREE_DEFAULT = 64
32+
L_SEARCH_DEFAULT = 100
33+
3034

3135
class VamanaIndex(index.Index):
3236
"""
@@ -97,7 +101,7 @@ def query_internal(
97101
self,
98102
queries: np.ndarray,
99103
k: int = 10,
100-
l_search: Optional[int] = 100,
104+
l_search: Optional[int] = L_SEARCH_DEFAULT,
101105
**kwargs,
102106
):
103107
"""
@@ -110,7 +114,8 @@ def query_internal(
110114
k: int
111115
Number of results to return per query vector.
112116
l_search: int
113-
How deep to search. Should be >= k, and if it's not, we will set it to k.
117+
How deep to search. Larger parameters will result in slower latencies, but higher accuracies.
118+
Should be >= k, and if it's not, we will set it to k.
114119
"""
115120
if self.size == 0:
116121
return np.full((queries.shape[0], k), MAX_FLOAT32), np.full(
@@ -137,6 +142,8 @@ def create(
137142
uri: str,
138143
dimensions: int,
139144
vector_type: np.dtype,
145+
l_build: int = L_BUILD_DEFAULT,
146+
r_max_degree: int = R_MAX_DEGREE_DEFAULT,
140147
config: Optional[Mapping[str, Any]] = None,
141148
storage_version: str = STORAGE_VERSION,
142149
**kwargs,
@@ -152,6 +159,12 @@ def create(
152159
vector_type: np.dtype
153160
Datatype of vectors.
154161
Supported values (uint8, int8, float32).
162+
l_build: int
163+
The number of neighbors considered for each node during construction of the graph. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. l_build should be >= r_max_degree unless you need to build indices quickly and can compromise on quality.
164+
Typically between 75 and 200. If not provided, use the default value of 100.
165+
r_max_degree: int
166+
The maximum degree for each node in the final graph. Larger values will result in larger indices and longer indexing times, but better search quality.
167+
Typically between 60 and 150. If not provided, use the default value of 64.
155168
config: Optional[Mapping[str, Any]]
156169
TileDB config dictionary.
157170
storage_version: str
@@ -169,6 +182,8 @@ def create(
169182
feature_type=np.dtype(vector_type).name,
170183
id_type=np.dtype(np.uint64).name,
171184
dimensions=dimensions,
185+
l_build=l_build if l_build > 0 else L_BUILD_DEFAULT,
186+
r_max_degree=r_max_degree if l_build > 0 else R_MAX_DEGREE_DEFAULT,
172187
)
173188
# TODO(paris): Run all of this with a single C++ call.
174189
empty_vector = vspy.FeatureVectorArray(

apis/python/test/test_ingestion.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from common import load_metadata
88

99
from tiledb.cloud.dag import Mode
10+
from tiledb.vector_search import _tiledbvspy as vspy
1011
from tiledb.vector_search.index import Index
1112
from tiledb.vector_search.ingestion import TrainingSamplingPolicy
1213
from tiledb.vector_search.ingestion import ingest
@@ -40,7 +41,11 @@ def test_vamana_ingestion_u8(tmp_path):
4041
index_uri = os.path.join(tmp_path, "array")
4142
if os.path.exists(index_uri):
4243
shutil.rmtree(index_uri)
43-
create_random_dataset_u8(nb=10000, d=100, nq=100, k=10, path=dataset_dir)
44+
45+
l_build = 101
46+
r_max_degree = 65
47+
dimensions = 100
48+
create_random_dataset_u8(nb=10000, d=dimensions, nq=100, k=10, path=dataset_dir)
4449
dtype = np.dtype(np.uint8)
4550
k = 10
4651

@@ -51,7 +56,18 @@ def test_vamana_ingestion_u8(tmp_path):
5156
index_type="VAMANA",
5257
index_uri=index_uri,
5358
source_uri=os.path.join(dataset_dir, "data.u8bin"),
59+
l_build=l_build,
60+
r_max_degree=r_max_degree,
5461
)
62+
63+
# This is not a public API, but we directly load the C++ type-erased index to test it. If you
64+
# are a library user, you should not do this yourself, as the API may change.
65+
ctx = vspy.Ctx({})
66+
type_erased_index = vspy.IndexVamana(ctx, index_uri, None)
67+
assert type_erased_index.dimensions() == dimensions
68+
assert type_erased_index.l_build() == l_build
69+
assert type_erased_index.r_max_degree() == r_max_degree
70+
5571
_, result = index.query(queries, k=k)
5672
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
5773

apis/python/test/test_type_erased_module.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,18 @@ def test_construct_IndexVamana():
283283
assert a.id_type_string() == "int64"
284284
assert a.dimensions() == 0
285285

286+
a = vspy.IndexVamana(feature_type="float32", id_type="int64", l_build=11)
287+
assert a.l_build() == 11
288+
289+
a = vspy.IndexVamana(feature_type="float32", id_type="int64", r_max_degree=22)
290+
assert a.r_max_degree() == 22
291+
292+
a = vspy.IndexVamana(
293+
feature_type="float32", id_type="int64", l_build=11, r_max_degree=22
294+
)
295+
assert a.l_build() == 11
296+
assert a.r_max_degree() == 22
297+
286298

287299
def test_construct_IndexVamana_with_empty_vector(tmp_path):
288300
l_search = 100
@@ -291,12 +303,16 @@ def test_construct_IndexVamana_with_empty_vector(tmp_path):
291303
dimensions = 128
292304
feature_type = "float32"
293305
id_type = "uint64"
306+
l_build = 100
307+
r_max_degree = 101
294308

295309
# First create an empty index.
296310
a = vspy.IndexVamana(
297311
feature_type=feature_type,
298312
id_type=id_type,
299313
dimensions=dimensions,
314+
l_build=l_build,
315+
r_max_degree=r_max_degree,
300316
)
301317
empty_vector = vspy.FeatureVectorArray(dimensions, 0, feature_type, id_type)
302318
a.train(empty_vector)
@@ -310,6 +326,8 @@ def test_construct_IndexVamana_with_empty_vector(tmp_path):
310326
assert query_set.feature_type_string() == "float32"
311327
groundtruth_set = vspy.FeatureVectorArray(ctx, siftsmall_groundtruth_uri)
312328
assert groundtruth_set.feature_type_string() == "uint64"
329+
assert a.l_build() == l_build
330+
assert a.r_max_degree() == r_max_degree
313331

314332
a.train(training_set)
315333

0 commit comments

Comments
 (0)