Skip to content

Commit ed60178

Browse files
authored
Adding IFMA big mod fix issue #121 (#123)
* Fix pre-built CpuFeatures Error with Cmake 3.16 * Adding big mod tests for modular reduction * Cleaning * Adding comment * Adding comment
1 parent 01a6da8 commit ed60178

File tree

3 files changed

+141
-7
lines changed

3 files changed

+141
-7
lines changed

hexl/eltwise/eltwise-reduce-mod.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n,
9999
}
100100

101101
#ifdef HEXL_HAS_AVX512IFMA
102-
if (has_avx512ifma && modulus < (1ULL << 52)) {
102+
// Modulus can be 52 bits only if input mod factors <= 4
103+
// otherwise modulus should be 51 bits max to give correct results
104+
if ((has_avx512ifma && modulus < (1ULL << 51)) ||
105+
(modulus < (1ULL << 52) && input_mod_factor <= 4)) {
103106
EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor,
104107
output_mod_factor);
105108
return;

hexl/util/avx512-util.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
472472

473473
// alpha - beta == 52, so we only need high 52 bits
474474
__m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64);
475-
476475
// Z = prod_lo - (p * q_hat)_lo
477476
x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod);
478477
} else {

test/test-eltwise-reduce-mod-avx512.cpp

Lines changed: 137 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ TEST(EltwiseReduceModMontInOut, avx512_64_mod_1) {
6565
}
6666

6767
#ifdef HEXL_HAS_AVX512IFMA
68+
6869
TEST(EltwiseReduceMod, avx512_52_mod_1) {
69-
if (!has_avx512dq) {
70+
if (!has_avx512ifma) {
7071
GTEST_SKIP();
7172
}
7273

@@ -82,8 +83,8 @@ TEST(EltwiseReduceMod, avx512_52_mod_1) {
8283
CheckEqual(result, exp_out);
8384
}
8485

85-
TEST(EltwiseReduceMod, avx512Big_mod_1) {
86-
if (!has_avx512dq) {
86+
TEST(EltwiseReduceMod, avx512_52_Big_mod_1) {
87+
if (!has_avx512ifma) {
8788
GTEST_SKIP();
8889
}
8990

@@ -101,6 +102,7 @@ TEST(EltwiseReduceMod, avx512Big_mod_1) {
101102

102103
EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus,
103104
input_mod_factor, output_mod_factor);
105+
104106
CheckEqual(result, exp_out);
105107
}
106108

@@ -204,7 +206,7 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) {
204206
size_t num_trials = 100;
205207
#endif
206208
for (size_t trial = 0; trial < num_trials; ++trial) {
207-
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, modulus);
209+
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 1ULL << 63);
208210
auto op2 = op1;
209211

210212
std::vector<uint64_t> result1(length, 0);
@@ -306,10 +308,138 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {
306308
std::vector<uint64_t> result1(length, 0);
307309
std::vector<uint64_t> result2(length, 0);
308310

311+
EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 2,
312+
1);
313+
EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, 2,
314+
1);
315+
316+
ASSERT_EQ(result1, result2);
317+
ASSERT_EQ(result1, result2);
318+
}
319+
}
320+
}
321+
322+
#ifdef HEXL_HAS_AVX512IFMA
323+
// Checks AVX512 and native EltwiseReduceMod implementations match with randomly
324+
// generated inputs
325+
TEST(EltwiseReduceMod, AVX512_52_Big_0_1) {
326+
if (!has_avx512ifma) {
327+
GTEST_SKIP();
328+
}
329+
330+
size_t length = 8;
331+
332+
for (size_t bits = 45; bits <= 51; ++bits) {
333+
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
334+
#ifdef HEXL_DEBUG
335+
size_t num_trials = 10;
336+
#else
337+
size_t num_trials = 1;
338+
#endif
339+
for (size_t trial = 0; trial < num_trials; ++trial) {
340+
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 1ULL << 63);
341+
auto op2 = op1;
342+
343+
std::vector<uint64_t> result1(length, 0);
344+
std::vector<uint64_t> result2(length, 0);
345+
346+
EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus,
347+
modulus, 1);
348+
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
349+
modulus, modulus, 1);
350+
351+
ASSERT_EQ(result1, result2);
352+
ASSERT_EQ(result1, result2);
353+
}
354+
}
355+
}
356+
357+
TEST(EltwiseReduceMod, AVX512_52_Big_4_1) {
358+
if (!has_avx512ifma) {
359+
GTEST_SKIP();
360+
}
361+
362+
size_t length = 8;
363+
364+
for (size_t bits = 45; bits <= 52; ++bits) {
365+
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
366+
#ifdef HEXL_DEBUG
367+
size_t num_trials = 10;
368+
#else
369+
size_t num_trials = 1;
370+
#endif
371+
for (size_t trial = 0; trial < num_trials; ++trial) {
372+
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 4 * modulus);
373+
auto op2 = op1;
374+
std::vector<uint64_t> result1(length, 0);
375+
std::vector<uint64_t> result2(length, 0);
376+
309377
EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 4,
310378
1);
311-
EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, 4,
379+
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
380+
modulus, 4, 1);
381+
382+
ASSERT_EQ(result1, result2);
383+
ASSERT_EQ(result1, result2);
384+
}
385+
}
386+
}
387+
388+
TEST(EltwiseReduceMod, AVX512_52_Big_4_2) {
389+
if (!has_avx512ifma) {
390+
GTEST_SKIP();
391+
}
392+
393+
size_t length = 8;
394+
395+
for (size_t bits = 45; bits <= 52; ++bits) {
396+
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
397+
#ifdef HEXL_DEBUG
398+
size_t num_trials = 10;
399+
#else
400+
size_t num_trials = 1;
401+
#endif
402+
for (size_t trial = 0; trial < num_trials; ++trial) {
403+
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 4 * modulus);
404+
auto op2 = op1;
405+
std::vector<uint64_t> result1(length, 0);
406+
std::vector<uint64_t> result2(length, 0);
407+
408+
EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 4,
409+
2);
410+
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
411+
modulus, 4, 2);
412+
413+
ASSERT_EQ(result1, result2);
414+
ASSERT_EQ(result1, result2);
415+
}
416+
}
417+
}
418+
419+
TEST(EltwiseReduceMod, AVX512_52_Big_2_1) {
420+
if (!has_avx512ifma) {
421+
GTEST_SKIP();
422+
}
423+
424+
size_t length = 8;
425+
426+
for (size_t bits = 45; bits <= 52; ++bits) {
427+
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
428+
#ifdef HEXL_DEBUG
429+
size_t num_trials = 10;
430+
#else
431+
size_t num_trials = 1;
432+
#endif
433+
for (size_t trial = 0; trial < num_trials; ++trial) {
434+
auto op1 = GenerateInsecureUniformIntRandomValues(length, 0, 2 * modulus);
435+
auto op2 = op1;
436+
std::vector<uint64_t> result1(length, 0);
437+
std::vector<uint64_t> result2(length, 0);
438+
439+
EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 2,
312440
1);
441+
EltwiseReduceModAVX512<52>(result2.data(), op2.data(), op1.size(),
442+
modulus, 2, 1);
313443

314444
ASSERT_EQ(result1, result2);
315445
ASSERT_EQ(result1, result2);
@@ -319,5 +449,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {
319449

320450
#endif
321451

452+
#endif
453+
322454
} // namespace hexl
323455
} // namespace intel

0 commit comments

Comments
 (0)