@@ -18,14 +18,45 @@ namespace intel {
1818namespace hexl {
1919
2020// Checks AVX512 and native implementations match
21- #ifdef HEXL_HAS_AVX512DQ
21+ #ifdef HEXL_HAS_AVX512IFMA
22+ TEST (EltwiseCmpSubMod, AVX512_52) {
23+ if (!has_avx512dq) {
24+ GTEST_SKIP ();
25+ }
26+ uint64_t length = 9 ;
27+ uint64_t modulus = 1125896819525633 ;
28+
29+ for (size_t trial = 0 ; trial < 200 ; ++trial) {
30+ auto op1 = std::vector<uint64_t >(length, 1106601337915084531 );
31+ uint64_t bound = 576460751967876096 ;
32+ uint64_t diff = 3160741504001 ;
33+
34+ auto op1_native = op1;
35+ auto op1_avx512 = op1;
36+ std::vector<uint64_t > op1_out (op1.size (), 0 );
37+ std::vector<uint64_t > op1_native_out (op1.size (), 0 );
38+ std::vector<uint64_t > op1_avx512_out (op1.size (), 0 );
39+
40+ EltwiseCmpSubMod (op1_out.data (), op1.data (), op1.size (), modulus,
41+ intel::hexl::CMPINT::NLE, bound, diff);
42+ EltwiseCmpSubModNative (op1_native_out.data (), op1.data (), op1.size (),
43+ modulus, intel::hexl::CMPINT::NLE, bound, diff);
44+ EltwiseCmpSubModAVX512<52 >(op1_avx512_out.data (), op1.data (), op1.size (),
45+ modulus, intel::hexl::CMPINT::NLE, bound, diff);
46+
47+ ASSERT_EQ (op1_out, op1_native_out);
48+ ASSERT_EQ (op1_native_out, op1_avx512_out);
49+ }
50+ }
51+ #endif
52+
53+ #ifdef HEXL_HAS_AVX512IFMA
2254TEST (EltwiseCmpSubMod, AVX512) {
2355 if (!has_avx512dq) {
2456 GTEST_SKIP ();
2557 }
2658
2759 uint64_t length = 172 ;
28-
2960 for (size_t cmp = 0 ; cmp < 8 ; ++cmp) {
3061 for (size_t bits = 48 ; bits <= 51 ; ++bits) {
3162 uint64_t modulus = GeneratePrimes (1 , bits, true , 1024 )[0 ];
@@ -48,15 +79,47 @@ TEST(EltwiseCmpSubMod, AVX512) {
4879 static_cast <CMPINT>(cmp), bound, diff);
4980 EltwiseCmpSubModNative (op1a_out.data (), op1a.data (), op1a.size (),
5081 modulus, static_cast <CMPINT>(cmp), bound, diff);
51- EltwiseCmpSubModAVX512 (op1b_out.data (), op1b.data (), op1b.size (),
52- modulus, static_cast <CMPINT>(cmp), bound, diff);
82+ EltwiseCmpSubModAVX512<52 >(op1b_out.data (), op1b.data (), op1b.size (),
83+ modulus, static_cast <CMPINT>(cmp), bound,
84+ diff);
5385
5486 ASSERT_EQ (op1_out, op1a_out);
5587 ASSERT_EQ (op1_out, op1b_out);
5688 }
5789 }
5890 }
5991}
92+
93+ TEST (EltwiseCmpSubMod, AVX512_64) {
94+ if (!has_avx512dq) {
95+ GTEST_SKIP ();
96+ }
97+ uint64_t length = 9 ;
98+ uint64_t modulus = 1152921504606748673 ;
99+
100+ for (size_t trial = 0 ; trial < 200 ; ++trial) {
101+ auto op1 = std::vector<uint64_t >(length, 64961 );
102+ uint64_t bound = 576460752303415296 ;
103+ uint64_t diff = 81920 ;
104+
105+ auto op1_native = op1;
106+ auto op1_avx512 = op1;
107+ std::vector<uint64_t > op1_out (op1.size (), 0 );
108+ std::vector<uint64_t > op1_native_out (op1.size (), 0 );
109+ std::vector<uint64_t > op1_avx512_out (op1.size (), 0 );
110+
111+ EltwiseCmpSubMod (op1_out.data (), op1.data (), op1.size (), modulus,
112+ intel::hexl::CMPINT::NLE, bound, diff);
113+ EltwiseCmpSubModNative (op1_native_out.data (), op1.data (), op1.size (),
114+ modulus, intel::hexl::CMPINT::NLE, bound, diff);
115+ EltwiseCmpSubModAVX512<64 >(op1_avx512_out.data (), op1.data (), op1.size (),
116+ modulus, intel::hexl::CMPINT::NLE, bound, diff);
117+
118+ ASSERT_EQ (op1_out, op1_native_out);
119+ ASSERT_EQ (op1_native_out, op1_avx512_out);
120+ }
121+ }
122+
60123#endif
61124} // namespace hexl
62125} // namespace intel
0 commit comments