Skip to content

Commit cbe5577

Browse files
authored
Update add() for IVF Flat and IVF PQ to take in IDs (#402)
1 parent ca71d6f commit cbe5577

File tree

5 files changed

+59
-21
lines changed

5 files changed

+59
-21
lines changed

src/include/index/ivf_flat_index.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,13 @@ class ivf_flat_index {
331331
#endif
332332
}
333333

334+
// @todo Use these IDs when we start using this index instead of the pure
335+
// Python one.
336+
template <feature_vector_array Array, feature_vector Vector>
337+
void add(const Array& training_set, const Vector& training_set_ids) {
338+
add(training_set);
339+
}
340+
334341
/**
335342
* @brief Build the index from a training set, given the centroids. This
336343
* will partition the training set into a contiguous array, with one

src/include/index/ivf_pq_index.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,14 @@ class ivf_pq_index {
679679
*
680680
* @todo Create and write index that is larger than RAM
681681
*/
682-
template <feature_vector_array V, class Distance = sum_of_squares_distance>
683-
void add(const V& training_set, Distance distance = Distance{}) {
682+
template <
683+
feature_vector_array V,
684+
feature_vector Vector,
685+
class Distance = sum_of_squares_distance>
686+
void add(
687+
const V& training_set,
688+
const Vector& training_set_ids,
689+
Distance distance = Distance{}) {
684690
auto num_unique_labels = ::num_vectors(flat_ivf_centroids_);
685691

686692
train_pq(training_set); // cluster_centroids_, distance_tables_
@@ -731,7 +737,7 @@ class ivf_pq_index {
731737
template <
732738
feature_vector_array U,
733739
class Distance = uncached_sub_sum_of_squares_distance>
734-
auto pq_encode(const U& training_set, Distance distance = Distance{}) {
740+
auto pq_encode(const U& training_set, Distance distance = Distance{}) const {
735741
auto pq_vectors = std::make_unique<ColMajorMatrix<pq_vector_feature_type>>(
736742
num_subspaces_, num_vectors(training_set));
737743
auto& pqv = *pq_vectors;

src/include/test/unit_api_ivf_flat_index.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ TEST_CASE("read index and query infinite and finite", "[api_ivf_flat_index]") {
257257
CHECK(nt == nv);
258258
auto recall = ((double)intersections_a) / ((double)nt * k_nn);
259259
if (nprobe == 32) {
260-
CHECK(recall >= 0.999);
260+
CHECK(recall >= 0.998);
261261
} else if (nprobe == 8) {
262262
CHECK(recall > 0.925);
263263
}
@@ -280,7 +280,7 @@ TEST_CASE("read index and query infinite and finite", "[api_ivf_flat_index]") {
280280
CHECK(nt == nv);
281281
auto recall = ((double)intersections_a) / ((double)nt * k_nn);
282282
if (nprobe == 32) {
283-
CHECK(recall >= 0.999);
283+
CHECK(recall >= 0.998);
284284
} else if (nprobe == 8) {
285285
CHECK(recall > 0.925);
286286
}
@@ -303,7 +303,7 @@ TEST_CASE("read index and query infinite and finite", "[api_ivf_flat_index]") {
303303
CHECK(nt == nv);
304304
auto recall = ((double)intersections_a) / ((double)nt * k_nn);
305305
if (nprobe == 32) {
306-
CHECK(recall >= 0.999);
306+
CHECK(recall >= 0.998);
307307
} else if (nprobe == 8) {
308308
CHECK(recall > 0.925);
309309
}

src/include/test/unit_ivf_pq_index.cc

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,29 @@ void debug_flat_ivf_centroids(auto& index) {
7575
std::cout << std::endl;
7676
}
7777

78+
TEST_CASE("construct different types", "[ivf_pq_index]") {
79+
ivf_pq_index<int8_t, uint32_t, uint32_t> index1{};
80+
ivf_pq_index<uint8_t, uint32_t, uint32_t> index2{};
81+
ivf_pq_index<float, uint32_t, uint32_t> index3{};
82+
ivf_pq_index<int8_t, uint32_t, uint64_t> index4{};
83+
ivf_pq_index<uint8_t, uint32_t, uint64_t> index5{};
84+
ivf_pq_index<float, uint32_t, uint64_t> index6{};
85+
ivf_pq_index<int8_t, uint64_t, uint32_t> index7{};
86+
ivf_pq_index<uint8_t, uint64_t, uint32_t> index8{};
87+
ivf_pq_index<float, uint64_t, uint32_t> index9{};
88+
ivf_pq_index<int8_t, uint64_t, uint64_t> index10{};
89+
ivf_pq_index<uint8_t, uint64_t, uint64_t> index11{};
90+
ivf_pq_index<float, uint64_t, uint64_t> index12{};
91+
}
92+
7893
TEST_CASE("default construct two", "[ivf_pq_index]") {
7994
ivf_pq_index<float, uint32_t, uint32_t> x;
8095
ivf_pq_index<float, uint32_t, uint32_t> y;
8196
CHECK(x.compare_cached_metadata(y));
8297
CHECK(y.compare_cached_metadata(x));
8398
}
8499

85-
TEST_CASE("test kmeans initializations", "[ivf_index][init]") {
100+
TEST_CASE("test kmeans initializations", "[ivf_pq_index][init]") {
86101
const bool debug = false;
87102

88103
std::vector<float> data = {8, 6, 7, 5, 3, 3, 7, 2, 1, 4, 1, 3, 0, 5, 1, 2,
@@ -139,7 +154,7 @@ TEST_CASE("test kmeans initializations", "[ivf_index][init]") {
139154
CHECK(outer_counts == index.get_flat_ivf_centroids().num_cols());
140155
}
141156

142-
TEST_CASE("test kmeans", "[ivf_index][kmeans]") {
157+
TEST_CASE("test kmeans", "[ivf_pq_index][kmeans]") {
143158
const bool debug = false;
144159

145160
std::vector<float> data = {8, 6, 7, 5, 3, 3, 7, 2, 1, 4, 1, 3, 0, 5, 1, 2,
@@ -166,8 +181,8 @@ TEST_CASE("test kmeans", "[ivf_index][kmeans]") {
166181
}
167182
}
168183

169-
TEST_CASE("debug w/ sk", "[ivf_index]") {
170-
const bool debug = true;
184+
TEST_CASE("debug w/ sk", "[ivf_pq_index]") {
185+
const bool debug = false;
171186

172187
ColMajorMatrix<float> training_data{
173188
{1.0573647, 5.082087},
@@ -278,7 +293,7 @@ TEST_CASE("debug w/ sk", "[ivf_index]") {
278293
}
279294
}
280295

281-
TEST_CASE("ivf_index write and read", "[ivf_index]") {
296+
TEST_CASE("ivf_index write and read", "[ivf_pq_index]") {
282297
size_t dimension = 128;
283298
size_t nlist = 100;
284299
size_t num_subspaces = 16;
@@ -294,25 +309,27 @@ TEST_CASE("ivf_index write and read", "[ivf_index]") {
294309
if (vfs.is_dir(ivf_index_uri)) {
295310
vfs.remove_dir(ivf_index_uri);
296311
}
312+
297313
auto training_set = tdbColMajorMatrix<float>(ctx, siftsmall_inputs_uri, 0);
298314
load(training_set);
299-
315+
std::vector<siftsmall_ids_type> ids(num_vectors(training_set));
316+
std::iota(begin(ids), end(ids), 0);
300317
auto idx = ivf_pq_index<float, uint32_t, uint32_t>(
301318
/*dimension,*/ nlist, num_subspaces, max_iters, nthreads);
302-
303319
idx.train_ivf(training_set, kmeans_init::kmeanspp);
304-
idx.add(training_set);
305-
320+
idx.add(training_set, ids);
306321
ivf_index_uri =
307322
(std::filesystem::temp_directory_path() / "second_tmp_ivf_index")
308323
.string();
309324
if (vfs.is_dir(ivf_index_uri)) {
310325
vfs.remove_dir(ivf_index_uri);
311326
}
327+
312328
idx.write_index(ctx, ivf_index_uri);
313329
auto idx2 = ivf_pq_index<float, uint32_t, uint32_t>(ctx, ivf_index_uri);
314330
idx2.read_index_infinite();
315331

332+
CHECK(idx.compare_cached_metadata(idx2));
316333
CHECK(idx.compare_cached_metadata(idx2));
317334
CHECK(idx.compare_cluster_centroids(idx2));
318335
CHECK(idx.compare_flat_ivf_centroids(idx2));
@@ -324,19 +341,20 @@ TEST_CASE("ivf_index write and read", "[ivf_index]") {
324341
}
325342

326343
TEST_CASE(
327-
"flat_pq_index: verify pq_encoding and pq_distances with siftsmall",
328-
"[flat_pq_index]") {
344+
"verify pq_encoding and pq_distances with siftsmall", "[ivf_pq_index]") {
329345
tiledb::Context ctx;
330346
auto training_set = tdbColMajorMatrix<siftsmall_feature_type>(
331347
ctx, siftsmall_inputs_uri, 2500);
332348
training_set.load();
349+
std::vector<siftsmall_ids_type> ids(num_vectors(training_set));
350+
std::iota(begin(ids), end(ids), 0);
333351

334352
auto pq_idx = ivf_pq_index<
335353
siftsmall_feature_type,
336354
siftsmall_ids_type,
337355
siftsmall_indices_type>(20, 16, 50);
338356
pq_idx.train_ivf(training_set);
339-
pq_idx.add(training_set);
357+
pq_idx.add(training_set, ids);
340358

341359
SECTION("pq_encoding") {
342360
auto avg_error = pq_idx.verify_pq_encoding(training_set);
@@ -377,6 +395,9 @@ TEMPLATE_TEST_CASE(
377395
auto hypercube2 = ColMajorMatrix<TestType>(6, num_vectors(hypercube0));
378396
auto hypercube4 = ColMajorMatrix<TestType>(12, num_vectors(hypercube0));
379397

398+
std::vector<uint32_t> ids(num_vectors(hypercube0));
399+
std::iota(begin(ids), end(ids), 0);
400+
380401
for (size_t j = 0; j < 3; ++j) {
381402
for (size_t i = 0; i < num_vectors(hypercube4); ++i) {
382403
hypercube2(j, i) = hypercube0(j, i);
@@ -395,11 +416,11 @@ TEMPLATE_TEST_CASE(
395416
auto ivf_idx2 = ivf_pq_index<TestType, uint32_t, uint32_t>(
396417
/*128,*/ nlist, 2, 4, 1.e-4); // dim nlist maxiter eps nthreads
397418
ivf_idx2.train_ivf(hypercube2);
398-
ivf_idx2.add(hypercube2);
419+
ivf_idx2.add(hypercube2, ids);
399420
auto ivf_idx4 = ivf_pq_index<TestType, uint32_t, uint32_t>(
400421
/*128,*/ nlist, 2, 4, 1.e-4);
401422
ivf_idx4.train_ivf(hypercube4);
402-
ivf_idx4.add(hypercube4);
423+
ivf_idx4.add(hypercube4, ids);
403424

404425
auto top_k_ivf_scores = ColMajorMatrix<float>();
405426
auto top_k_ivf = ColMajorMatrix<unsigned>();

src/include/test/utils/query_common.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ struct siftsmall_test_init : public siftsmall_test_init_defaults {
184184
training_set.load();
185185
query_set.load();
186186
groundtruth_set.load();
187+
188+
std::vector<id_type> ids(_cpo::num_vectors(training_set));
189+
std::iota(begin(ids), end(ids), 0);
190+
187191
std::tie(top_k_scores, top_k) = detail::flat::qv_query_heap(
188192
training_set, query_set, k_nn, 1, sum_of_squares_distance{});
189193

@@ -198,7 +202,7 @@ struct siftsmall_test_init : public siftsmall_test_init_defaults {
198202
} else {
199203
std::cout << "Unsupported index type" << std::endl;
200204
}
201-
idx.add(training_set);
205+
idx.add(training_set, ids);
202206
}
203207

204208
auto get_write_read_idx() {

0 commit comments

Comments
 (0)