Skip to content

Commit 6bb588e

Browse files
authored
Check finite and infinite IVF_PQ queries return the same ids and distances, fix count_intersections() to not modify inputs, update read_index_finite() to return data (#509)
1 parent fab865f commit 6bb588e

File tree

7 files changed

+258
-46
lines changed

7 files changed

+258
-46
lines changed

src/include/api/feature_vector_array.h

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,4 +543,203 @@ auto count_intersections(
543543
}
544544
}
545545

546+
bool are_equal(
547+
const FeatureVectorArray& a,
548+
const FeatureVectorArray& b,
549+
double epsilon = 0.0) {
550+
if (a.feature_type() != b.feature_type()) {
551+
std::cout << "[feature_vector_array@are_equal] Feature types do not match: "
552+
<< a.feature_type_string() << " vs " << b.feature_type_string()
553+
<< std::endl;
554+
return false;
555+
}
556+
if (a.feature_size() != b.feature_size()) {
557+
std::cout << "[feature_vector_array@are_equal] Feature sizes do not match: "
558+
<< a.feature_size() << " vs " << b.feature_size() << std::endl;
559+
return false;
560+
}
561+
if (a.num_vectors() != b.num_vectors()) {
562+
std::cout
563+
<< "[feature_vector_array@are_equal] Number of vectors do not match: "
564+
<< a.num_vectors() << " vs " << b.num_vectors() << std::endl;
565+
return false;
566+
}
567+
if (a.dimensions() != b.dimensions()) {
568+
std::cout << "[feature_vector_array@are_equal] Number of dimensions do not "
569+
"match: "
570+
<< a.dimensions() << " vs " << b.dimensions() << std::endl;
571+
return false;
572+
}
573+
574+
if (a.ids_type() != b.ids_type()) {
575+
std::cout << "[feature_vector_array@are_equal] IDs types do not match: "
576+
<< a.ids_type_string() << " vs " << b.ids_type_string()
577+
<< std::endl;
578+
return false;
579+
}
580+
if (a.ids_size() != b.ids_size()) {
581+
std::cout << "[feature_vector_array@are_equal] IDs sizes do not match: "
582+
<< a.ids_size() << " vs " << b.ids_size() << std::endl;
583+
return false;
584+
}
585+
if (a.num_ids() != b.num_ids()) {
586+
std::cout << "[feature_vector_array@are_equal] Number of IDs do not match: "
587+
<< a.num_ids() << " vs " << b.num_ids() << std::endl;
588+
return false;
589+
}
590+
591+
if (a.extents() != b.extents()) {
592+
std::cout << "[feature_vector_array@are_equal] Extents do not match: "
593+
<< "A: {" << a.extents()[0] << ", " << a.dimensions() << "}, "
594+
<< "B: {" << b.extents()[0] << ", " << b.dimensions() << "}"
595+
<< std::endl;
596+
return false;
597+
}
598+
599+
// Compare the data of the feature vectors with optional epsilon
600+
auto compare_data = [&epsilon](
601+
const auto* data_a, const auto* data_b, size_t size) {
602+
if (epsilon > 0.0) {
603+
for (size_t i = 0; i < size; ++i) {
604+
if (std::abs(
605+
static_cast<double>(data_a[i]) -
606+
static_cast<double>(data_b[i])) > epsilon) {
607+
std::cout
608+
<< "[feature_vector_array@are_equal] Data mismatch at index " << i
609+
<< ": " << data_a[i] << " vs " << data_b[i]
610+
<< " (epsilon: " << epsilon << ")" << std::endl;
611+
return false;
612+
}
613+
}
614+
} else {
615+
for (size_t i = 0; i < size; ++i) {
616+
if (data_a[i] != data_b[i]) {
617+
std::cout
618+
<< "[feature_vector_array@are_equal] Data mismatch at index " << i
619+
<< ": " << data_a[i] << " vs " << data_b[i] << std::endl;
620+
return false;
621+
}
622+
}
623+
}
624+
return true;
625+
};
626+
627+
switch (a.feature_type()) {
628+
case TILEDB_FLOAT32: {
629+
const auto* data_a = static_cast<const float*>(a.data());
630+
const auto* data_b = static_cast<const float*>(b.data());
631+
if (!compare_data(data_a, data_b, a.num_vectors() * a.dimensions())) {
632+
std::cout << "[feature_vector_array@are_equal] Feature vector data "
633+
"comparison failed for type FLOAT32"
634+
<< std::endl;
635+
return false;
636+
}
637+
break;
638+
}
639+
case TILEDB_INT8: {
640+
const auto* data_a = static_cast<const int8_t*>(a.data());
641+
const auto* data_b = static_cast<const int8_t*>(b.data());
642+
if (!compare_data(data_a, data_b, a.num_vectors() * a.dimensions())) {
643+
std::cout << "[feature_vector_array@are_equal] Feature vector data "
644+
"comparison failed for type INT8"
645+
<< std::endl;
646+
return false;
647+
}
648+
break;
649+
}
650+
case TILEDB_UINT8: {
651+
const auto* data_a = static_cast<const uint8_t*>(a.data());
652+
const auto* data_b = static_cast<const uint8_t*>(b.data());
653+
if (!compare_data(data_a, data_b, a.num_vectors() * a.dimensions())) {
654+
std::cout << "[feature_vector_array@are_equal] Feature vector data "
655+
"comparison failed for type UINT8"
656+
<< std::endl;
657+
return false;
658+
}
659+
break;
660+
}
661+
case TILEDB_INT32: {
662+
const auto* data_a = static_cast<const int32_t*>(a.data());
663+
const auto* data_b = static_cast<const int32_t*>(b.data());
664+
if (!compare_data(data_a, data_b, a.num_vectors() * a.dimensions())) {
665+
std::cout << "[feature_vector_array@are_equal] Feature vector data "
666+
"comparison failed for type INT32"
667+
<< std::endl;
668+
return false;
669+
}
670+
break;
671+
}
672+
case TILEDB_UINT32: {
673+
const auto* data_a = static_cast<const uint32_t*>(a.data());
674+
const auto* data_b = static_cast<const uint32_t*>(b.data());
675+
if (!compare_data(data_a, data_b, a.num_vectors() * a.dimensions())) {
676+
std::cout << "[feature_vector_array@are_equal] Feature vector data "
677+
"comparison failed for type UINT32"
678+
<< std::endl;
679+
return false;
680+
}
681+
break;
682+
}
683+
case TILEDB_INT64: {
684+
const auto* data_a = static_cast<const int64_t*>(a.data());
685+
const auto* data_b = static_cast<const int64_t*>(b.data());
686+
if (!compare_data(data_a, data_b, a.num_vectors() * a.dimensions())) {
687+
std::cout << "[feature_vector_array@are_equal] Feature vector data "
688+
"comparison failed for type INT64"
689+
<< std::endl;
690+
return false;
691+
}
692+
break;
693+
}
694+
case TILEDB_UINT64: {
695+
const auto* data_a = static_cast<const uint64_t*>(a.data());
696+
const auto* data_b = static_cast<const uint64_t*>(b.data());
697+
if (!compare_data(data_a, data_b, a.num_vectors() * a.dimensions())) {
698+
std::cout << "[feature_vector_array@are_equal] Feature vector data "
699+
"comparison failed for type UINT64"
700+
<< std::endl;
701+
return false;
702+
}
703+
break;
704+
}
705+
default:
706+
std::cout << "[feature_vector_array@are_equal] Unsupported feature "
707+
"vector attribute type"
708+
<< std::endl;
709+
throw std::runtime_error("Unsupported attribute type");
710+
}
711+
712+
// If the arrays have IDs, compare the IDs as well
713+
if (a.ids_type() != TILEDB_ANY && b.ids_type() != TILEDB_ANY) {
714+
switch (a.ids_type()) {
715+
case TILEDB_UINT32: {
716+
const auto* ids_a = static_cast<const uint32_t*>(a.ids());
717+
const auto* ids_b = static_cast<const uint32_t*>(b.ids());
718+
if (!compare_data(ids_a, ids_b, a.num_ids())) {
719+
std::cout << "[feature_vector_array@are_equal] ID comparison failed "
720+
"for type UINT32"
721+
<< std::endl;
722+
return false;
723+
}
724+
break;
725+
}
726+
case TILEDB_UINT64: {
727+
const auto* ids_a = static_cast<const uint64_t*>(a.ids());
728+
const auto* ids_b = static_cast<const uint64_t*>(b.ids());
729+
if (!compare_data(ids_a, ids_b, a.num_ids())) {
730+
std::cout << "[feature_vector_array@are_equal] ID comparison failed "
731+
"for type UINT64"
732+
<< std::endl;
733+
return false;
734+
}
735+
break;
736+
}
737+
default:
738+
throw std::runtime_error("Unsupported ID type");
739+
}
740+
}
741+
742+
return true;
743+
}
744+
546745
#endif // TILEDB_API_FEATURE_VECTOR_ARRAY_H

src/include/detail/linalg/matrix.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,10 @@ constexpr auto SubMatrix(
442442
template <class Matrix>
443443
void debug_matrix(
444444
const Matrix& matrix, const std::string& msg = "", size_t max_size = 10) {
445-
auto rowsEnd = std::min(dimensions(matrix), static_cast<size_t>(max_size));
446-
auto colsEnd = std::min(num_vectors(matrix), static_cast<size_t>(max_size));
445+
auto rowsEnd = std::min(
446+
dimensions(matrix), static_cast<typename Matrix::size_type>(max_size));
447+
auto colsEnd = std::min(
448+
num_vectors(matrix), static_cast<typename Matrix::size_type>(max_size));
447449

448450
std::cout << "# " << msg << " (" << dimensions(matrix) << " rows x "
449451
<< num_vectors(matrix) << " cols) ("

src/include/detail/linalg/matrix_with_ids.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,12 @@ void debug_matrix_with_ids(
218218
const MatrixWithIds& matrix,
219219
const std::string& msg = "",
220220
size_t max_size = 10) {
221-
auto rowsEnd = std::min(dimensions(matrix), static_cast<size_t>(max_size));
222-
auto colsEnd = std::min(num_vectors(matrix), static_cast<size_t>(max_size));
221+
auto rowsEnd = std::min(
222+
dimensions(matrix),
223+
static_cast<typename MatrixWithIds::size_type>(max_size));
224+
auto colsEnd = std::min(
225+
num_vectors(matrix),
226+
static_cast<typename MatrixWithIds::size_type>(max_size));
223227

224228
debug_matrix(matrix, msg, max_size);
225229

src/include/detail/linalg/partitioned_matrix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ void debug_partitioned_matrix(
242242
debug_matrix(matrix, msg, max_size);
243243

244244
std::cout << "# ids: [";
245-
auto end = std::min(matrix.ids().size(), static_cast<size_t>(max_size));
245+
auto end = std::min(matrix.num_vectors(), static_cast<size_t>(max_size));
246246
for (size_t i = 0; i < end; ++i) {
247247
std::cout << matrix.ids()[i];
248248
if (i != matrix.ids().size() - 1) {

src/include/index/ivf_pq_index.h

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,6 @@ class ivf_pq_index {
452452
auto sub_begin = subspace * dimensions_ / num_subspaces_;
453453
auto sub_end = (subspace + 1) * dimensions_ / num_subspaces_;
454454

455-
auto local_sub_distance = SubDistance{sub_begin, sub_end};
456-
457455
// @todo Make choice of kmeans init configurable
458456
sub_kmeans_random_init(
459457
training_set, cluster_centroids_, sub_begin, sub_end);
@@ -909,8 +907,7 @@ class ivf_pq_index {
909907
auto&& [active_partitions, active_queries] =
910908
detail::ivf::partition_ivf_flat_index<indices_type>(
911909
flat_ivf_centroids_, query_vectors, nprobe, num_threads_);
912-
913-
partitioned_pq_vectors_ = std::make_unique<tdb_pq_storage_type>(
910+
auto partitioned_pq_vectors = std::make_unique<tdb_pq_storage_type>(
914911
group_->cached_ctx(),
915912
group_->pq_ivf_vectors_uri(),
916913
group_->pq_ivf_indices_uri(),
@@ -923,7 +920,9 @@ class ivf_pq_index {
923920
// NB: We don't load the partitioned_pq_vectors here. We will load them
924921
// when we do the query.
925922
return std::make_tuple(
926-
std::move(active_partitions), std::move(active_queries));
923+
std::move(active_partitions),
924+
std::move(active_queries),
925+
std::move(partitioned_pq_vectors));
927926
}
928927

929928
/**
@@ -1236,21 +1235,17 @@ class ivf_pq_index {
12361235
"run if you're loading the index by URI. Please open it by URI and "
12371236
"try again. If you just wrote the index, open it up again by URI.");
12381237
}
1239-
if (partitioned_pq_vectors_) {
1240-
// We did an infinite query before this. Reset so we can load again.
1241-
partitioned_pq_vectors_.reset();
1242-
}
12431238
if (::num_vectors(flat_ivf_centroids_) < nprobe) {
12441239
nprobe = ::num_vectors(flat_ivf_centroids_);
12451240
}
1246-
auto&& [active_partitions, active_queries] =
1241+
auto&& [active_partitions, active_queries, partitioned_pq_vectors] =
12471242
read_index_finite(query_vectors, nprobe, upper_bound);
12481243
auto query_to_pq_centroid_distance_tables =
12491244
std::move(*generate_query_to_pq_centroid_distance_tables<
12501245
Q,
12511246
ColMajorMatrix<float>>(query_vectors));
12521247
return detail::ivf::query_finite_ram(
1253-
*partitioned_pq_vectors_,
1248+
*partitioned_pq_vectors,
12541249
query_to_pq_centroid_distance_tables,
12551250
active_queries,
12561251
k_nn,

src/include/scoring.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -680,36 +680,39 @@ bool validate_top_k(TK& top_k, const G& g) {
680680

681681
template <feature_vector_array U, feature_vector_array V>
682682
auto count_intersections(const U& I, const V& groundtruth, size_t k_nn) {
683-
// print_types(I, groundtruth);
684-
685683
size_t total_intersected = 0;
686684

687685
if constexpr (feature_vector_array<std::remove_cvref_t<decltype(I)>>) {
688686
for (size_t i = 0; i < I.num_cols(); ++i) {
689-
std::sort(begin(I[i]), end(I[i]));
690-
std::sort(begin(groundtruth[i]), begin(groundtruth[i]) + k_nn);
687+
std::vector<std::decay_t<decltype(I[i][0])>> sorted_I(
688+
begin(I[i]), end(I[i]));
689+
std::vector<std::decay_t<decltype(groundtruth[i][0])>> sorted_groundtruth(
690+
begin(groundtruth[i]), begin(groundtruth[i]) + k_nn);
691691

692-
// @todo remove -- for debugging only
693-
std::vector<size_t> x(begin(I[i]), end(I[i]));
694-
std::vector<size_t> y(begin(groundtruth[i]), end(groundtruth[i]));
692+
std::sort(begin(sorted_I), end(sorted_I));
693+
std::sort(begin(sorted_groundtruth), end(sorted_groundtruth));
695694

696695
total_intersected += std::set_intersection(
697-
begin(I[i]),
698-
end(I[i]),
699-
begin(groundtruth[i]),
700-
/*end(groundtruth[i]*/ begin(groundtruth[i]) + k_nn,
696+
begin(sorted_I),
697+
end(sorted_I),
698+
begin(sorted_groundtruth),
699+
end(sorted_groundtruth),
701700
assignment_counter{});
702701
}
703702
} else {
704703
if constexpr (feature_vector<std::remove_cvref_t<decltype(I)>>) {
705-
std::sort(begin(I), end(I));
706-
std::sort(begin(groundtruth), begin(groundtruth) + k_nn);
704+
std::vector<std::decay_t<decltype(I[0])>> sorted_I(begin(I), end(I));
705+
std::vector<std::decay_t<decltype(groundtruth[0])>> sorted_groundtruth(
706+
begin(groundtruth), begin(groundtruth) + k_nn);
707+
708+
std::sort(begin(sorted_I), end(sorted_I));
709+
std::sort(begin(sorted_groundtruth), end(sorted_groundtruth));
707710

708711
total_intersected += std::set_intersection(
709-
begin(I),
710-
end(I),
711-
begin(groundtruth),
712-
/*end(groundtruth)*/ begin(groundtruth) + k_nn,
712+
begin(sorted_I),
713+
end(sorted_I),
714+
begin(sorted_groundtruth),
715+
end(sorted_groundtruth),
713716
assignment_counter{});
714717
} else {
715718
static_assert(
@@ -718,7 +721,7 @@ auto count_intersections(const U& I, const V& groundtruth, size_t k_nn) {
718721
}
719722
}
720723
return total_intersected;
721-
};
724+
}
722725

723726
#if defined(TILEDB_VS_ENABLE_BLAS) && 0
724727

0 commit comments

Comments
 (0)