@@ -65,8 +65,9 @@ TEST(EltwiseReduceModMontInOut, avx512_64_mod_1) {
6565}
6666
6767#ifdef HEXL_HAS_AVX512IFMA
68+
6869TEST (EltwiseReduceMod, avx512_52_mod_1) {
69- if (!has_avx512dq ) {
70+ if (!has_avx512ifma ) {
7071 GTEST_SKIP ();
7172 }
7273
@@ -82,8 +83,8 @@ TEST(EltwiseReduceMod, avx512_52_mod_1) {
8283 CheckEqual (result, exp_out);
8384}
8485
85- TEST (EltwiseReduceMod, avx512Big_mod_1 ) {
86- if (!has_avx512dq ) {
86+ TEST (EltwiseReduceMod, avx512_52_Big_mod_1 ) {
87+ if (!has_avx512ifma ) {
8788 GTEST_SKIP ();
8889 }
8990
@@ -101,6 +102,7 @@ TEST(EltwiseReduceMod, avx512Big_mod_1) {
101102
102103 EltwiseReduceModAVX512<52 >(result.data (), op.data (), op.size (), modulus,
103104 input_mod_factor, output_mod_factor);
105+
104106 CheckEqual (result, exp_out);
105107}
106108
@@ -204,7 +206,7 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) {
204206 size_t num_trials = 100 ;
205207#endif
206208 for (size_t trial = 0 ; trial < num_trials; ++trial) {
207- auto op1 = GenerateInsecureUniformIntRandomValues (length, 0 , modulus );
209+ auto op1 = GenerateInsecureUniformIntRandomValues (length, 0 , 1ULL << 63 );
208210 auto op2 = op1;
209211
210212 std::vector<uint64_t > result1 (length, 0 );
@@ -306,10 +308,138 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {
306308 std::vector<uint64_t > result1 (length, 0 );
307309 std::vector<uint64_t > result2 (length, 0 );
308310
311+ EltwiseReduceModNative (result1.data (), op1.data (), op1.size (), modulus, 2 ,
312+ 1 );
313+ EltwiseReduceModAVX512 (result2.data (), op2.data (), op1.size (), modulus, 2 ,
314+ 1 );
315+
316+ ASSERT_EQ (result1, result2);
317+ ASSERT_EQ (result1, result2);
318+ }
319+ }
320+ }
321+
322+ #ifdef HEXL_HAS_AVX512IFMA
323+ // Checks AVX512 and native EltwiseReduceMod implementations match with randomly
324+ // generated inputs
325+ TEST (EltwiseReduceMod, AVX512_52_Big_0_1) {
326+ if (!has_avx512ifma) {
327+ GTEST_SKIP ();
328+ }
329+
330+ size_t length = 8 ;
331+
332+ for (size_t bits = 45 ; bits <= 51 ; ++bits) {
333+ uint64_t modulus = GeneratePrimes (1 , bits, true , length)[0 ];
334+ #ifdef HEXL_DEBUG
335+ size_t num_trials = 10 ;
336+ #else
337+ size_t num_trials = 1 ;
338+ #endif
339+ for (size_t trial = 0 ; trial < num_trials; ++trial) {
340+ auto op1 = GenerateInsecureUniformIntRandomValues (length, 0 , 1ULL << 63 );
341+ auto op2 = op1;
342+
343+ std::vector<uint64_t > result1 (length, 0 );
344+ std::vector<uint64_t > result2 (length, 0 );
345+
346+ EltwiseReduceModNative (result1.data (), op1.data (), op1.size (), modulus,
347+ modulus, 1 );
348+ EltwiseReduceModAVX512<52 >(result2.data (), op2.data (), op1.size (),
349+ modulus, modulus, 1 );
350+
351+ ASSERT_EQ (result1, result2);
352+ ASSERT_EQ (result1, result2);
353+ }
354+ }
355+ }
356+
357+ TEST (EltwiseReduceMod, AVX512_52_Big_4_1) {
358+ if (!has_avx512ifma) {
359+ GTEST_SKIP ();
360+ }
361+
362+ size_t length = 8 ;
363+
364+ for (size_t bits = 45 ; bits <= 52 ; ++bits) {
365+ uint64_t modulus = GeneratePrimes (1 , bits, true , length)[0 ];
366+ #ifdef HEXL_DEBUG
367+ size_t num_trials = 10 ;
368+ #else
369+ size_t num_trials = 1 ;
370+ #endif
371+ for (size_t trial = 0 ; trial < num_trials; ++trial) {
372+ auto op1 = GenerateInsecureUniformIntRandomValues (length, 0 , 4 * modulus);
373+ auto op2 = op1;
374+ std::vector<uint64_t > result1 (length, 0 );
375+ std::vector<uint64_t > result2 (length, 0 );
376+
309377 EltwiseReduceModNative (result1.data (), op1.data (), op1.size (), modulus, 4 ,
310378 1 );
311- EltwiseReduceModAVX512 (result2.data (), op2.data (), op1.size (), modulus, 4 ,
379+ EltwiseReduceModAVX512<52 >(result2.data (), op2.data (), op1.size (),
380+ modulus, 4 , 1 );
381+
382+ ASSERT_EQ (result1, result2);
383+ ASSERT_EQ (result1, result2);
384+ }
385+ }
386+ }
387+
388+ TEST (EltwiseReduceMod, AVX512_52_Big_4_2) {
389+ if (!has_avx512ifma) {
390+ GTEST_SKIP ();
391+ }
392+
393+ size_t length = 8 ;
394+
395+ for (size_t bits = 45 ; bits <= 52 ; ++bits) {
396+ uint64_t modulus = GeneratePrimes (1 , bits, true , length)[0 ];
397+ #ifdef HEXL_DEBUG
398+ size_t num_trials = 10 ;
399+ #else
400+ size_t num_trials = 1 ;
401+ #endif
402+ for (size_t trial = 0 ; trial < num_trials; ++trial) {
403+ auto op1 = GenerateInsecureUniformIntRandomValues (length, 0 , 4 * modulus);
404+ auto op2 = op1;
405+ std::vector<uint64_t > result1 (length, 0 );
406+ std::vector<uint64_t > result2 (length, 0 );
407+
408+ EltwiseReduceModNative (result1.data (), op1.data (), op1.size (), modulus, 4 ,
409+ 2 );
410+ EltwiseReduceModAVX512<52 >(result2.data (), op2.data (), op1.size (),
411+ modulus, 4 , 2 );
412+
413+ ASSERT_EQ (result1, result2);
414+ ASSERT_EQ (result1, result2);
415+ }
416+ }
417+ }
418+
419+ TEST (EltwiseReduceMod, AVX512_52_Big_2_1) {
420+ if (!has_avx512ifma) {
421+ GTEST_SKIP ();
422+ }
423+
424+ size_t length = 8 ;
425+
426+ for (size_t bits = 45 ; bits <= 52 ; ++bits) {
427+ uint64_t modulus = GeneratePrimes (1 , bits, true , length)[0 ];
428+ #ifdef HEXL_DEBUG
429+ size_t num_trials = 10 ;
430+ #else
431+ size_t num_trials = 1 ;
432+ #endif
433+ for (size_t trial = 0 ; trial < num_trials; ++trial) {
434+ auto op1 = GenerateInsecureUniformIntRandomValues (length, 0 , 2 * modulus);
435+ auto op2 = op1;
436+ std::vector<uint64_t > result1 (length, 0 );
437+ std::vector<uint64_t > result2 (length, 0 );
438+
439+ EltwiseReduceModNative (result1.data (), op1.data (), op1.size (), modulus, 2 ,
312440 1 );
441+ EltwiseReduceModAVX512<52 >(result2.data (), op2.data (), op1.size (),
442+ modulus, 2 , 1 );
313443
314444 ASSERT_EQ (result1, result2);
315445 ASSERT_EQ (result1, result2);
@@ -319,5 +449,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {
319449
320450#endif
321451
452+ #endif
453+
322454} // namespace hexl
323455} // namespace intel
0 commit comments