@@ -75,14 +75,29 @@ void debug_flat_ivf_centroids(auto& index) {
7575 std::cout << std::endl;
7676}
7777
78+ TEST_CASE (" construct different types" , " [ivf_pq_index]" ) {
79+ ivf_pq_index<int8_t , uint32_t , uint32_t > index1{};
80+ ivf_pq_index<uint8_t , uint32_t , uint32_t > index2{};
81+ ivf_pq_index<float , uint32_t , uint32_t > index3{};
82+ ivf_pq_index<int8_t , uint32_t , uint64_t > index4{};
83+ ivf_pq_index<uint8_t , uint32_t , uint64_t > index5{};
84+ ivf_pq_index<float , uint32_t , uint64_t > index6{};
85+ ivf_pq_index<int8_t , uint64_t , uint32_t > index7{};
86+ ivf_pq_index<uint8_t , uint64_t , uint32_t > index8{};
87+ ivf_pq_index<float , uint64_t , uint32_t > index9{};
88+ ivf_pq_index<int8_t , uint64_t , uint64_t > index10{};
89+ ivf_pq_index<uint8_t , uint64_t , uint64_t > index11{};
90+ ivf_pq_index<float , uint64_t , uint64_t > index12{};
91+ }
92+
7893TEST_CASE (" default construct two" , " [ivf_pq_index]" ) {
7994 ivf_pq_index<float , uint32_t , uint32_t > x;
8095 ivf_pq_index<float , uint32_t , uint32_t > y;
8196 CHECK (x.compare_cached_metadata (y));
8297 CHECK (y.compare_cached_metadata (x));
8398}
8499
85- TEST_CASE (" test kmeans initializations" , " [ivf_index ][init]" ) {
100+ TEST_CASE (" test kmeans initializations" , " [ivf_pq_index ][init]" ) {
86101 const bool debug = false ;
87102
88103 std::vector<float > data = {8 , 6 , 7 , 5 , 3 , 3 , 7 , 2 , 1 , 4 , 1 , 3 , 0 , 5 , 1 , 2 ,
@@ -139,7 +154,7 @@ TEST_CASE("test kmeans initializations", "[ivf_index][init]") {
139154 CHECK (outer_counts == index.get_flat_ivf_centroids ().num_cols ());
140155}
141156
142- TEST_CASE (" test kmeans" , " [ivf_index ][kmeans]" ) {
157+ TEST_CASE (" test kmeans" , " [ivf_pq_index ][kmeans]" ) {
143158 const bool debug = false ;
144159
145160 std::vector<float > data = {8 , 6 , 7 , 5 , 3 , 3 , 7 , 2 , 1 , 4 , 1 , 3 , 0 , 5 , 1 , 2 ,
@@ -166,8 +181,8 @@ TEST_CASE("test kmeans", "[ivf_index][kmeans]") {
166181 }
167182}
168183
169- TEST_CASE (" debug w/ sk" , " [ivf_index ]" ) {
170- const bool debug = true ;
184+ TEST_CASE (" debug w/ sk" , " [ivf_pq_index ]" ) {
185+ const bool debug = false ;
171186
172187 ColMajorMatrix<float > training_data{
173188 {1.0573647 , 5.082087 },
@@ -278,7 +293,7 @@ TEST_CASE("debug w/ sk", "[ivf_index]") {
278293 }
279294}
280295
281- TEST_CASE (" ivf_index write and read" , " [ivf_index ]" ) {
296+ TEST_CASE (" ivf_index write and read" , " [ivf_pq_index ]" ) {
282297 size_t dimension = 128 ;
283298 size_t nlist = 100 ;
284299 size_t num_subspaces = 16 ;
@@ -294,25 +309,27 @@ TEST_CASE("ivf_index write and read", "[ivf_index]") {
294309 if (vfs.is_dir (ivf_index_uri)) {
295310 vfs.remove_dir (ivf_index_uri);
296311 }
312+
297313 auto training_set = tdbColMajorMatrix<float >(ctx, siftsmall_inputs_uri, 0 );
298314 load (training_set);
299-
315+ std::vector<siftsmall_ids_type> ids (num_vectors (training_set));
316+ std::iota (begin (ids), end (ids), 0 );
300317 auto idx = ivf_pq_index<float , uint32_t , uint32_t >(
301318 /* dimension,*/ nlist, num_subspaces, max_iters, nthreads);
302-
303319 idx.train_ivf (training_set, kmeans_init::kmeanspp);
304- idx.add (training_set);
305-
320+ idx.add (training_set, ids);
306321 ivf_index_uri =
307322 (std::filesystem::temp_directory_path () / " second_tmp_ivf_index" )
308323 .string ();
309324 if (vfs.is_dir (ivf_index_uri)) {
310325 vfs.remove_dir (ivf_index_uri);
311326 }
327+
312328 idx.write_index (ctx, ivf_index_uri);
313329 auto idx2 = ivf_pq_index<float , uint32_t , uint32_t >(ctx, ivf_index_uri);
314330 idx2.read_index_infinite ();
315331
332+ CHECK (idx.compare_cached_metadata (idx2));
316333 CHECK (idx.compare_cached_metadata (idx2));
317334 CHECK (idx.compare_cluster_centroids (idx2));
318335 CHECK (idx.compare_flat_ivf_centroids (idx2));
@@ -324,19 +341,20 @@ TEST_CASE("ivf_index write and read", "[ivf_index]") {
324341}
325342
326343TEST_CASE (
327- " flat_pq_index: verify pq_encoding and pq_distances with siftsmall" ,
328- " [flat_pq_index]" ) {
344+ " verify pq_encoding and pq_distances with siftsmall" , " [ivf_pq_index]" ) {
329345 tiledb::Context ctx;
330346 auto training_set = tdbColMajorMatrix<siftsmall_feature_type>(
331347 ctx, siftsmall_inputs_uri, 2500 );
332348 training_set.load ();
349+ std::vector<siftsmall_ids_type> ids (num_vectors (training_set));
350+ std::iota (begin (ids), end (ids), 0 );
333351
334352 auto pq_idx = ivf_pq_index<
335353 siftsmall_feature_type,
336354 siftsmall_ids_type,
337355 siftsmall_indices_type>(20 , 16 , 50 );
338356 pq_idx.train_ivf (training_set);
339- pq_idx.add (training_set);
357+ pq_idx.add (training_set, ids );
340358
341359 SECTION (" pq_encoding" ) {
342360 auto avg_error = pq_idx.verify_pq_encoding (training_set);
@@ -377,6 +395,9 @@ TEMPLATE_TEST_CASE(
377395 auto hypercube2 = ColMajorMatrix<TestType>(6, num_vectors(hypercube0));
378396 auto hypercube4 = ColMajorMatrix<TestType>(12, num_vectors(hypercube0));
379397
398+ std::vector<uint32_t> ids(num_vectors(hypercube0));
399+ std::iota(begin(ids), end(ids), 0);
400+
380401 for (size_t j = 0; j < 3; ++j) {
381402 for (size_t i = 0; i < num_vectors(hypercube4); ++i) {
382403 hypercube2(j, i) = hypercube0(j, i);
@@ -395,11 +416,11 @@ TEMPLATE_TEST_CASE(
395416 auto ivf_idx2 = ivf_pq_index<TestType, uint32_t, uint32_t>(
396417 /*128,*/ nlist, 2, 4, 1.e-4); // dim nlist maxiter eps nthreads
397418 ivf_idx2.train_ivf(hypercube2);
398- ivf_idx2.add(hypercube2);
419+ ivf_idx2.add(hypercube2, ids );
399420 auto ivf_idx4 = ivf_pq_index<TestType, uint32_t, uint32_t>(
400421 /*128,*/ nlist, 2, 4, 1.e-4);
401422 ivf_idx4.train_ivf(hypercube4);
402- ivf_idx4.add(hypercube4);
423+ ivf_idx4.add(hypercube4, ids );
403424
404425 auto top_k_ivf_scores = ColMajorMatrix<float>();
405426 auto top_k_ivf = ColMajorMatrix<unsigned>();
0 commit comments