Skip to content

Commit e493894

Browse files
authored
[SVS] Add unit test to validate the compute distance fix (#858)
* Add test with direct call to SVS distance computations * Use system page protection to catch the unmasked vector loading issue * Do not use pre-compiled SVS binaries but compile from sources * Update SVS submodule to the latest main with AVX2 fixes * Address code review comment
1 parent 1a89238 commit e493894

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

tests/unit/test_svs.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3158,6 +3158,75 @@ TEST(SVSTest, scalar_quantization_query) {
31583158
}
31593159
}
31603160

3161+
#if defined(__linux__) && defined(__x86_64__)
3162+
TEST(SVSTest, compute_distance) {
3163+
// Test svs::distance computation for custom data allocations and alignments
3164+
constexpr size_t dim = 4;
3165+
3166+
// get system pagesize
3167+
size_t page_size = sysconf(_SC_PAGESIZE);
3168+
ASSERT_GT(page_size, 16);
3169+
3170+
// Allocate two consecutive pages: one for data, one as a guard (inaccessible)
3171+
uint8_t *raw_a = (uint8_t *)mmap(nullptr, 2 * page_size, PROT_READ | PROT_WRITE,
3172+
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
3173+
ASSERT_NE(raw_a, MAP_FAILED);
3174+
// Protect the second page to prevent access
3175+
ASSERT_EQ(mprotect(raw_a + page_size, page_size, PROT_NONE), 0);
3176+
3177+
// Allocate the second buffer
3178+
uint8_t *raw_b = (uint8_t *)mmap(nullptr, 2 * page_size, PROT_READ | PROT_WRITE,
3179+
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
3180+
ASSERT_NE(raw_b, MAP_FAILED);
3181+
// Protect the second page to prevent access
3182+
ASSERT_EQ(mprotect(raw_b + page_size, page_size, PROT_NONE), 0);
3183+
3184+
// use last bytes of page for data
3185+
// Note: Accessing above 'dim' should trigger Memory Access Error.
3186+
constexpr size_t data_size = dim * sizeof(float);
3187+
float *a = reinterpret_cast<float *>(raw_a + page_size - data_size);
3188+
float *b = reinterpret_cast<float *>(raw_b + page_size - data_size);
3189+
3190+
std::iota(a, a + dim, 1.f);
3191+
std::iota(b, b + dim, 2.f);
3192+
3193+
// Verify default implementation
3194+
auto dist_l2 = svs::distance::compute(svs::DistanceL2{}, std::span(a, dim), std::span(b, dim));
3195+
EXPECT_GT(dist_l2, 0.0);
3196+
auto dist_ip = svs::distance::compute(svs::DistanceIP{}, std::span(a, dim), std::span(b, dim));
3197+
EXPECT_GT(dist_ip, 0.0);
3198+
3199+
// Verify AVX2 and AVX512 implementations
3200+
if (svs::detail::avx_runtime_flags.is_avx2_supported()) {
3201+
// AVX2 implementations
3202+
auto dist_l2_avx2 = svs::distance::
3203+
L2Impl<svs::Dynamic, float, float, svs::distance::AVX_AVAILABILITY::AVX2>::compute(
3204+
a, b, svs::lib::MaybeStatic(dim));
3205+
auto dist_ip_avx2 = svs::distance::
3206+
IPImpl<svs::Dynamic, float, float, svs::distance::AVX_AVAILABILITY::AVX2>::compute(
3207+
a, b, svs::lib::MaybeStatic(dim));
3208+
EXPECT_DOUBLE_EQ(dist_l2, dist_l2_avx2);
3209+
EXPECT_DOUBLE_EQ(dist_ip, dist_ip_avx2);
3210+
}
3211+
3212+
if (svs::detail::avx_runtime_flags.is_avx512f_supported()) {
3213+
// AVX512 implementations
3214+
auto dist_l2_avx512 = svs::distance::
3215+
L2Impl<svs::Dynamic, float, float, svs::distance::AVX_AVAILABILITY::AVX512>::compute(
3216+
a, b, svs::lib::MaybeStatic(dim));
3217+
auto dist_ip_avx512 = svs::distance::
3218+
IPImpl<svs::Dynamic, float, float, svs::distance::AVX_AVAILABILITY::AVX512>::compute(
3219+
a, b, svs::lib::MaybeStatic(dim));
3220+
EXPECT_DOUBLE_EQ(dist_l2, dist_l2_avx512);
3221+
EXPECT_DOUBLE_EQ(dist_ip, dist_ip_avx512);
3222+
}
3223+
3224+
// unmap pages
3225+
munmap(raw_a, 2 * page_size);
3226+
munmap(raw_b, 2 * page_size);
3227+
}
3228+
#endif // defined(__linux__) && defined(__x86_64__)
3229+
31613230
#else // HAVE_SVS
31623231

31633232
TEST(SVSTest, svs_not_supported) {

0 commit comments

Comments
 (0)