Skip to content

Commit 1bab417

Browse files
authored
Fix bug with adjacency_row_index_uri() array type being incorrect, remove adjacency_row_index_type from API (#405)
1 parent 434b33b commit 1bab417

File tree

10 files changed

+69
-184
lines changed

10 files changed

+69
-184
lines changed

apis/python/src/tiledb/vector_search/type_erased_module.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,6 @@ void init_type_erased_module(py::module_& m) {
364364
py::arg("storage_version") = "")
365365
.def("feature_type_string", &IndexVamana::feature_type_string)
366366
.def("id_type_string", &IndexVamana::id_type_string)
367-
.def(
368-
"adjacency_row_index_type_string",
369-
&IndexVamana::adjacency_row_index_type_string)
370367
.def("dimensions", &IndexVamana::dimensions)
371368
.def_static(
372369
"clear_history",

apis/python/src/tiledb/vector_search/vamana_index.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def create(
163163
index = vspy.IndexVamana(
164164
feature_type=np.dtype(vector_type).name,
165165
id_type=np.dtype(np.uint64).name,
166-
adjacency_row_index_type=np.dtype(np.uint64).name,
167166
dimensions=dimensions,
168167
)
169168
# TODO(paris): Run all of this with a single C++ call.

apis/python/test/test_type_erased_module.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -266,29 +266,21 @@ def test_construct_IndexVamana():
266266
a = vspy.IndexVamana()
267267
assert a.feature_type_string() == "any"
268268
assert a.id_type_string() == "uint32"
269-
assert a.adjacency_row_index_type_string() == "uint32"
270269
assert a.dimensions() == 0
271270

272271
a = vspy.IndexVamana(feature_type="float32")
273272
assert a.feature_type_string() == "float32"
274273
assert a.id_type_string() == "uint32"
275-
assert a.adjacency_row_index_type_string() == "uint32"
276274
assert a.dimensions() == 0
277275

278-
a = vspy.IndexVamana(
279-
feature_type="uint8", id_type="uint64", adjacency_row_index_type="int64"
280-
)
276+
a = vspy.IndexVamana(feature_type="uint8", id_type="uint64")
281277
assert a.feature_type_string() == "uint8"
282278
assert a.id_type_string() == "uint64"
283-
assert a.adjacency_row_index_type_string() == "int64"
284279
assert a.dimensions() == 0
285280

286-
a = vspy.IndexVamana(
287-
feature_type="float32", id_type="int64", adjacency_row_index_type="uint64"
288-
)
281+
a = vspy.IndexVamana(feature_type="float32", id_type="int64")
289282
assert a.feature_type_string() == "float32"
290283
assert a.id_type_string() == "int64"
291-
assert a.adjacency_row_index_type_string() == "uint64"
292284
assert a.dimensions() == 0
293285

294286

@@ -299,13 +291,11 @@ def test_construct_IndexVamana_with_empty_vector(tmp_path):
299291
dimensions = 128
300292
feature_type = "float32"
301293
id_type = "uint64"
302-
adjacency_row_index_type = "uint64"
303294

304295
# First create an empty index.
305296
a = vspy.IndexVamana(
306297
feature_type=feature_type,
307298
id_type=id_type,
308-
adjacency_row_index_type=adjacency_row_index_type,
309299
dimensions=dimensions,
310300
)
311301
empty_vector = vspy.FeatureVectorArray(dimensions, 0, feature_type, id_type)
@@ -335,9 +325,7 @@ def test_inplace_build_query_IndexVamana():
335325
opt_l = 100
336326
k_nn = 10
337327

338-
a = vspy.IndexVamana(
339-
id_type="uint32", adjacency_row_index_type="uint32", feature_type="float32"
340-
)
328+
a = vspy.IndexVamana(id_type="uint32", feature_type="float32")
341329

342330
training_set = vspy.FeatureVectorArray(ctx, siftsmall_inputs_uri)
343331
assert training_set.feature_type_string() == "float32"

src/include/api/vamana_index.h

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class IndexVamana {
9090
const std::optional<IndexOptions>& config = std::nullopt) {
9191
feature_datatype_ = TILEDB_ANY;
9292
id_datatype_ = TILEDB_UINT32;
93-
adjacency_row_index_datatype_ = TILEDB_UINT32;
9493

9594
if (config) {
9695
for (auto&& c : *config) {
@@ -108,8 +107,6 @@ class IndexVamana {
108107
feature_datatype_ = string_to_datatype(value);
109108
} else if (key == "id_type") {
110109
id_datatype_ = string_to_datatype(value);
111-
} else if (key == "adjacency_row_index_type") {
112-
adjacency_row_index_datatype_ = string_to_datatype(value);
113110
} else {
114111
throw std::runtime_error("Invalid index config key: " + key);
115112
}
@@ -135,12 +132,7 @@ class IndexVamana {
135132
const tiledb::Context& ctx,
136133
const URI& group_uri,
137134
std::optional<TemporalPolicy> temporal_policy = std::nullopt) {
138-
read_types(
139-
ctx,
140-
group_uri,
141-
&feature_datatype_,
142-
&id_datatype_,
143-
&adjacency_row_index_datatype_);
135+
read_types(ctx, group_uri, &feature_datatype_, &id_datatype_);
144136

145137
auto type = std::tuple{
146138
feature_datatype_, id_datatype_, adjacency_row_index_datatype_};
@@ -250,16 +242,10 @@ class IndexVamana {
250242
uint64_t timestamp) {
251243
tiledb_datatype_t feature_datatype{TILEDB_ANY};
252244
tiledb_datatype_t id_datatype{TILEDB_ANY};
253-
tiledb_datatype_t adjacency_row_index_datatype{TILEDB_ANY};
254-
read_types(
255-
ctx,
256-
group_uri,
257-
&feature_datatype,
258-
&id_datatype,
259-
&adjacency_row_index_datatype);
260-
261-
auto type =
262-
std::tuple{feature_datatype, id_datatype, adjacency_row_index_datatype};
245+
read_types(ctx, group_uri, &feature_datatype, &id_datatype);
246+
247+
auto type = std::tuple{
248+
feature_datatype, id_datatype, adjacency_row_index_datatype_};
263249
if (clear_history_dispatch_table.find(type) ==
264250
clear_history_dispatch_table.end()) {
265251
throw std::runtime_error("Unsupported datatype combination");
@@ -307,29 +293,17 @@ class IndexVamana {
307293
return datatype_to_string(id_datatype_);
308294
}
309295

310-
constexpr auto adjacency_row_index_type() const {
311-
return adjacency_row_index_datatype_;
312-
}
313-
314-
inline auto adjacency_row_index_type_string() const {
315-
return datatype_to_string(adjacency_row_index_datatype_);
316-
}
317-
318296
private:
319297
static void read_types(
320298
const tiledb::Context& ctx,
321299
const std::string& group_uri,
322300
tiledb_datatype_t* feature_datatype,
323-
tiledb_datatype_t* id_datatype,
324-
tiledb_datatype_t* adjacency_row_index_datatype) {
301+
tiledb_datatype_t* id_datatype) {
325302
using metadata_element =
326303
std::tuple<std::string, tiledb_datatype_t*, tiledb_datatype_t>;
327304
std::vector<metadata_element> metadata{
328305
{"feature_datatype", feature_datatype, TILEDB_UINT32},
329-
{"id_datatype", id_datatype, TILEDB_UINT32},
330-
{"adjacency_row_index_datatype",
331-
adjacency_row_index_datatype,
332-
TILEDB_UINT32}};
306+
{"id_datatype", id_datatype, TILEDB_UINT32}};
333307

334308
tiledb::Group read_group(ctx, group_uri, TILEDB_READ, ctx.config());
335309

@@ -536,7 +510,8 @@ class IndexVamana {
536510
size_t b_backtrack_ = 0;
537511
tiledb_datatype_t feature_datatype_{TILEDB_ANY};
538512
tiledb_datatype_t id_datatype_{TILEDB_ANY};
539-
tiledb_datatype_t adjacency_row_index_datatype_{TILEDB_ANY};
513+
static constexpr tiledb_datatype_t adjacency_row_index_datatype_{
514+
TILEDB_UINT64};
540515
std::unique_ptr<index_base> index_;
541516
};
542517

src/include/index/index_metadata.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
* "base_sizes", // (json) list
3434
* "dataset_type", // "vector_search"
3535
* "dtype", // "float32", etc (Python dtype names)
36-
* "index_type", // "FLAT", "IVF_FLAT", "VAMANA"
36+
* "index_type", // "FLAT", "IVF_FLAT", "VAMANA", "IVF_PQ"
3737
* "ingestion_timestamps", // (json) list
3838
* "storage_version", // "0.3"
3939
* "temp_size", // TILEDB_INT64 or TILEDB_FLOAT64

src/include/index/vamana_group.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,12 @@ class vamana_index_group : public base_index_group<index_type> {
237237
* @todo Make this table-driven
238238
*************************************************************************/
239239
metadata_.adjacency_scores_datatype_ =
240-
type_to_tiledb_v<typename index_type::score_type>;
240+
type_to_tiledb_v<typename index_type::adjacency_scores_type>;
241241
metadata_.adjacency_row_index_datatype_ =
242242
type_to_tiledb_v<typename index_type::adjacency_row_index_type>;
243243

244244
metadata_.adjacency_scores_type_str_ =
245-
type_to_string_v<typename index_type::score_type>;
245+
type_to_string_v<typename index_type::adjacency_scores_type>;
246246
metadata_.adjacency_row_index_type_str_ =
247247
type_to_string_v<typename index_type::adjacency_row_index_type>;
248248

@@ -281,7 +281,7 @@ class vamana_index_group : public base_index_group<index_type> {
281281
tiledb_helpers::add_to_group(
282282
write_group, this->ids_uri(), this->ids_array_name());
283283

284-
create_empty_for_vector<typename index_type::score_type>(
284+
create_empty_for_vector<typename index_type::adjacency_scores_type>(
285285
cached_ctx_,
286286
adjacency_scores_uri(),
287287
default_domain,
@@ -299,7 +299,7 @@ class vamana_index_group : public base_index_group<index_type> {
299299
tiledb_helpers::add_to_group(
300300
write_group, adjacency_ids_uri(), adjacency_ids_array_name());
301301

302-
create_empty_for_vector<typename index_type::id_type>(
302+
create_empty_for_vector<typename index_type::adjacency_row_index_type>(
303303
cached_ctx_,
304304
adjacency_row_index_uri(),
305305
default_domain,

src/include/index/vamana_index.h

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,22 @@ auto medoid(auto&& P, Distance distance = Distance{}) {
9898
}
9999

100100
/**
101-
* @brief Index class for vamana search
102-
* @tparam feature_type Type of the elements in the feature vectors
103-
* @tparam id_type Type of the ids of the feature vectors
101+
* @brief The Vamana index.
102+
*
103+
* @tparam FeatureType Type of the elements in the feature vectors.
104+
* @tparam IdType Type of the ids of the feature vectors.
105+
* @tparam AdjacencyRowIndexType Types of the indexes used in the graph.
104106
*/
105-
template <class FeatureType, class IdType, class IndexType = uint64_t>
107+
template <
108+
class FeatureType,
109+
class IdType,
110+
class AdjacencyRowIndexType = uint64_t>
106111
class vamana_index {
107112
public:
108113
using feature_type = FeatureType;
109114
using id_type = IdType;
110-
using adjacency_row_index_type = IndexType;
111-
using score_type = float;
115+
using adjacency_row_index_type = AdjacencyRowIndexType;
116+
using adjacency_scores_type = float;
112117

113118
using group_type = vamana_index_group<vamana_index>;
114119
using metadata_type = vamana_index_metadata;
@@ -136,7 +141,7 @@ class vamana_index {
136141
uint64_t num_edges_{0};
137142

138143
/** The graph representing the index over `feature_vectors_` */
139-
::detail::graph::adj_list<score_type, id_type> graph_;
144+
::detail::graph::adj_list<adjacency_scores_type, id_type> graph_;
140145

141146
/*
142147
* The medoid of the feature vectors -- the vector in the set that is closest
@@ -250,7 +255,7 @@ class vamana_index {
250255
****************************************************************************/
251256
graph_ = ::detail::graph::adj_list<feature_type, id_type>(num_vectors_);
252257

253-
auto adj_scores = read_vector<score_type>(
258+
auto adj_scores = read_vector<adjacency_scores_type>(
254259
group_->cached_ctx(),
255260
group_->adjacency_scores_uri(),
256261
0,
@@ -467,7 +472,7 @@ class vamana_index {
467472

468473
auto top_k = ColMajorMatrix<id_type>(k_nn, ::num_vectors(queries));
469474
auto top_k_scores =
470-
ColMajorMatrix<score_type>(k_nn, ::num_vectors(queries));
475+
ColMajorMatrix<adjacency_scores_type>(k_nn, ::num_vectors(queries));
471476

472477
for (size_t i = 0; i < num_vectors(queries); ++i) {
473478
auto&& [tk_scores, tk, V] = ::best_first_O2(
@@ -494,7 +499,7 @@ class vamana_index {
494499

495500
auto top_k = ColMajorMatrix<id_type>(k_nn, ::num_vectors(queries));
496501
auto top_k_scores =
497-
ColMajorMatrix<score_type>(k_nn, ::num_vectors(queries));
502+
ColMajorMatrix<adjacency_scores_type>(k_nn, ::num_vectors(queries));
498503

499504
for (size_t i = 0; i < num_vectors(queries); ++i) {
500505
auto&& [tk_scores, tk, V] = ::best_first_O3(
@@ -521,7 +526,7 @@ class vamana_index {
521526

522527
auto top_k = ColMajorMatrix<id_type>(k_nn, ::num_vectors(queries));
523528
auto top_k_scores =
524-
ColMajorMatrix<score_type>(k_nn, ::num_vectors(queries));
529+
ColMajorMatrix<adjacency_scores_type>(k_nn, ::num_vectors(queries));
525530

526531
for (size_t i = 0; i < num_vectors(queries); ++i) {
527532
auto&& [tk_scores, tk, V] = ::best_first_O4(
@@ -548,7 +553,7 @@ class vamana_index {
548553

549554
auto top_k = ColMajorMatrix<id_type>(k_nn, ::num_vectors(queries));
550555
auto top_k_scores =
551-
ColMajorMatrix<score_type>(k_nn, ::num_vectors(queries));
556+
ColMajorMatrix<adjacency_scores_type>(k_nn, ::num_vectors(queries));
552557

553558
for (size_t i = 0; i < num_vectors(queries); ++i) {
554559
auto&& [tk_scores, tk, V] = ::best_first_O5(
@@ -588,7 +593,8 @@ class vamana_index {
588593
// L = std::min<size_t>(L, l_build_);
589594

590595
auto top_k = ColMajorMatrix<id_type>(k, ::num_vectors(query_set));
591-
auto top_k_scores = ColMajorMatrix<score_type>(k, ::num_vectors(query_set));
596+
auto top_k_scores =
597+
ColMajorMatrix<adjacency_scores_type>(k, ::num_vectors(query_set));
592598

593599
#if 0
594600
// Parallelized implementation -- we stay single-threaded for now
@@ -768,7 +774,7 @@ class vamana_index {
768774
false,
769775
temporal_policy_);
770776

771-
auto adj_scores = Vector<score_type>(graph_.num_edges());
777+
auto adj_scores = Vector<adjacency_scores_type>(graph_.num_edges());
772778
auto adj_ids = Vector<id_type>(graph_.num_edges());
773779
auto adj_index =
774780
Vector<adjacency_row_index_type>(graph_.num_vertices() + 1);

src/include/index/vamana_metadata.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,6 @@ class vamana_index_metadata
117117
{"alpha_min", &alpha_min_, TILEDB_FLOAT32, false},
118118
{"alpha_max", &alpha_max_, TILEDB_FLOAT32, false},
119119
{"medoid", &medoid_, TILEDB_UINT64, false},
120-
{"adjacency_row_index_datatype",
121-
&adjacency_row_index_datatype_,
122-
TILEDB_UINT32,
123-
false},
124120
};
125121

126122
void clear_history_impl(uint64_t timestamp) {

0 commit comments

Comments
 (0)