Skip to content

Commit 343acab

Browse files
authored
Fix EltwiseCmpSubMod (#84)
* Fix EltwiseCmpSubMod * Fix Windows build by fixing cpu-features commit
1 parent ccb6063 commit 343acab

File tree

8 files changed

+95
-34
lines changed

8 files changed

+95
-34
lines changed

benchmark/bench-eltwise-cmp-sub-mod.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,20 @@ BENCHMARK(BM_EltwiseCmpSubModNative)
4242

4343
#ifdef HEXL_HAS_AVX512DQ
4444
// state[0] is the degree
45-
static void BM_EltwiseCmpSubModAVX512(benchmark::State& state) { // NOLINT
45+
static void BM_EltwiseCmpSubModAVX512_64(benchmark::State& state) { // NOLINT
4646
size_t input_size = state.range(0);
4747
uint64_t modulus = 100;
4848
uint64_t bound = GenerateInsecureUniformRandomValue(0, modulus);
4949
uint64_t diff = GenerateInsecureUniformRandomValue(1, modulus);
5050
auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, modulus);
5151

5252
for (auto _ : state) {
53-
EltwiseCmpSubModAVX512(input1.data(), input1.data(), input_size, modulus,
54-
CMPINT::NLT, bound, diff);
53+
EltwiseCmpSubModAVX512<64>(input1.data(), input1.data(), input_size,
54+
modulus, CMPINT::NLT, bound, diff);
5555
}
5656
}
5757

58-
BENCHMARK(BM_EltwiseCmpSubModAVX512)
58+
BENCHMARK(BM_EltwiseCmpSubModAVX512_64)
5959
->Unit(benchmark::kMicrosecond)
6060
->Args({1024})
6161
->Args({4096})

cmake/third-party/cpu-features/CMakeLists.txt.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ project(cpu-features-download NONE)
88
include(ExternalProject)
99
ExternalProject_Add(cpu_features
1010
GIT_REPOSITORY https://github.com/google/cpu_features.git
11-
GIT_TAG master
11+
GIT_TAG 32b49eb5e7809052a28422cfde2f2745fbb0eb76 # master branch on Oct 20, 2021
1212
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu-features-src"
1313
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/cpu-features-build"
1414
CONFIGURE_COMMAND ""

hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace intel {
1515
namespace hexl {
1616

1717
#ifdef HEXL_HAS_AVX512DQ
18-
template <int BitShift = 64>
18+
template <int BitShift>
1919
void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
2020
uint64_t n, uint64_t modulus, CMPINT cmp,
2121
uint64_t bound, uint64_t diff) {
@@ -51,12 +51,25 @@ void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
5151
uint64_t prod_right_shift = ceil_log_mod + beta;
5252
__m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
5353

54+
uint64_t alpha = BitShift - 2;
55+
uint64_t mu_64 =
56+
MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift,
57+
modulus)
58+
.BarrettFactor();
59+
60+
if (BitShift == 64) {
61+
// Single-worded Barrett reduction.
62+
mu_64 = MultiplyFactor(1, 64, modulus).BarrettFactor();
63+
}
64+
65+
__m512i v_mu_64 = _mm512_set1_epi64(static_cast<int64_t>(mu_64));
66+
5467
for (size_t i = n / 8; i > 0; --i) {
5568
__m512i v_op = _mm512_loadu_si512(v_op_ptr);
5669
__mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));
5770

5871
v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
59-
v_op, v_modulus, v_mu, v_mu, prod_right_shift, v_neg_mod);
72+
v_op, v_modulus, v_mu_64, v_mu, prod_right_shift, v_neg_mod);
6073

6174
__m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
6275
v_to_add = _mm512_sub_epi64(v_to_add, v_diff);

hexl/eltwise/eltwise-cmp-sub-mod.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1,
5252

5353
for (size_t i = 0; i < n; ++i) {
5454
uint64_t op = operand1[i];
55-
5655
bool op_cmp = Compare(cmp, op, bound);
5756
op %= modulus;
58-
5957
if (op_cmp) {
6058
op = SubUIntMod(op, diff, modulus);
6159
}

hexl/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace hexl {
1919
/// @param[in] bound Scalar to compare against
2020
/// @param[in] diff Scalar to subtract by
2121
/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p
22-
/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1
22+
/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0,
23+
/// ..., n-1
2324
void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n,
2425
uint64_t modulus, CMPINT cmp, uint64_t bound,
2526
uint64_t diff);

hexl/util/avx512-util.hpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <vector>
99

10+
#include "hexl/logging/logging.hpp"
1011
#include "hexl/number-theory/number-theory.hpp"
1112
#include "hexl/util/check.hpp"
1213
#include "hexl/util/defines.hpp"
@@ -389,6 +390,7 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
389390
__mmask8 mask =
390391
_mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT);
391392
if (mask != 0) {
393+
// values above 2^52
392394
__m512i x_hi = _mm512_srli_epi64(x, static_cast<unsigned int>(52ULL));
393395
__m512i x_intr = _mm512_slli_epi64(x, static_cast<unsigned int>(12ULL));
394396
__m512i x_lo =
@@ -408,33 +410,22 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
408410
x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod);
409411
} else {
410412
__m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52);
411-
// Barrett subtraction
412-
// tmp[0] = input - tmp[1] * q;
413413
__m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q);
414414
x = _mm512_sub_epi64(x, tmp1_times_mod);
415415
}
416416
}
417417
#endif
418418
if (BitShift == 64) {
419419
__m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64);
420-
// Barrett subtraction
421-
// tmp[0] = input - tmp[1] * q;
422420
__m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q);
423421
x = _mm512_sub_epi64(x, tmp1_times_mod);
424422
}
425423

426424
// Correction
427-
if (OutputModFactor == 2) {
428-
return x;
429-
} else {
430-
if (BitShift == 64) {
431-
x = _mm512_hexl_small_mod_epu64(x, q);
432-
}
433-
if (BitShift == 52) {
434-
x = _mm512_hexl_small_mod_epu64<2>(x, q);
435-
}
436-
return x;
425+
if (OutputModFactor == 1) {
426+
x = _mm512_hexl_small_mod_epu64<2>(x, q);
437427
}
428+
return x;
438429
}
439430

440431
// Concatenate packed 64-bit integers in x and y, producing an intermediate

test/test-eltwise-cmp-sub-mod-avx512.cpp

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,45 @@ namespace intel {
1818
namespace hexl {
1919

2020
// Checks AVX512 and native implementations match
21-
#ifdef HEXL_HAS_AVX512DQ
21+
#ifdef HEXL_HAS_AVX512IFMA
22+
TEST(EltwiseCmpSubMod, AVX512_52) {
23+
if (!has_avx512dq) {
24+
GTEST_SKIP();
25+
}
26+
uint64_t length = 9;
27+
uint64_t modulus = 1125896819525633;
28+
29+
for (size_t trial = 0; trial < 200; ++trial) {
30+
auto op1 = std::vector<uint64_t>(length, 1106601337915084531);
31+
uint64_t bound = 576460751967876096;
32+
uint64_t diff = 3160741504001;
33+
34+
auto op1_native = op1;
35+
auto op1_avx512 = op1;
36+
std::vector<uint64_t> op1_out(op1.size(), 0);
37+
std::vector<uint64_t> op1_native_out(op1.size(), 0);
38+
std::vector<uint64_t> op1_avx512_out(op1.size(), 0);
39+
40+
EltwiseCmpSubMod(op1_out.data(), op1.data(), op1.size(), modulus,
41+
intel::hexl::CMPINT::NLE, bound, diff);
42+
EltwiseCmpSubModNative(op1_native_out.data(), op1.data(), op1.size(),
43+
modulus, intel::hexl::CMPINT::NLE, bound, diff);
44+
EltwiseCmpSubModAVX512<52>(op1_avx512_out.data(), op1.data(), op1.size(),
45+
modulus, intel::hexl::CMPINT::NLE, bound, diff);
46+
47+
ASSERT_EQ(op1_out, op1_native_out);
48+
ASSERT_EQ(op1_native_out, op1_avx512_out);
49+
}
50+
}
51+
#endif
52+
53+
#ifdef HEXL_HAS_AVX512IFMA
2254
TEST(EltwiseCmpSubMod, AVX512) {
2355
if (!has_avx512dq) {
2456
GTEST_SKIP();
2557
}
2658

2759
uint64_t length = 172;
28-
2960
for (size_t cmp = 0; cmp < 8; ++cmp) {
3061
for (size_t bits = 48; bits <= 51; ++bits) {
3162
uint64_t modulus = GeneratePrimes(1, bits, true, 1024)[0];
@@ -48,15 +79,47 @@ TEST(EltwiseCmpSubMod, AVX512) {
4879
static_cast<CMPINT>(cmp), bound, diff);
4980
EltwiseCmpSubModNative(op1a_out.data(), op1a.data(), op1a.size(),
5081
modulus, static_cast<CMPINT>(cmp), bound, diff);
51-
EltwiseCmpSubModAVX512(op1b_out.data(), op1b.data(), op1b.size(),
52-
modulus, static_cast<CMPINT>(cmp), bound, diff);
82+
EltwiseCmpSubModAVX512<52>(op1b_out.data(), op1b.data(), op1b.size(),
83+
modulus, static_cast<CMPINT>(cmp), bound,
84+
diff);
5385

5486
ASSERT_EQ(op1_out, op1a_out);
5587
ASSERT_EQ(op1_out, op1b_out);
5688
}
5789
}
5890
}
5991
}
92+
93+
TEST(EltwiseCmpSubMod, AVX512_64) {
94+
if (!has_avx512dq) {
95+
GTEST_SKIP();
96+
}
97+
uint64_t length = 9;
98+
uint64_t modulus = 1152921504606748673;
99+
100+
for (size_t trial = 0; trial < 200; ++trial) {
101+
auto op1 = std::vector<uint64_t>(length, 64961);
102+
uint64_t bound = 576460752303415296;
103+
uint64_t diff = 81920;
104+
105+
auto op1_native = op1;
106+
auto op1_avx512 = op1;
107+
std::vector<uint64_t> op1_out(op1.size(), 0);
108+
std::vector<uint64_t> op1_native_out(op1.size(), 0);
109+
std::vector<uint64_t> op1_avx512_out(op1.size(), 0);
110+
111+
EltwiseCmpSubMod(op1_out.data(), op1.data(), op1.size(), modulus,
112+
intel::hexl::CMPINT::NLE, bound, diff);
113+
EltwiseCmpSubModNative(op1_native_out.data(), op1.data(), op1.size(),
114+
modulus, intel::hexl::CMPINT::NLE, bound, diff);
115+
EltwiseCmpSubModAVX512<64>(op1_avx512_out.data(), op1.data(), op1.size(),
116+
modulus, intel::hexl::CMPINT::NLE, bound, diff);
117+
118+
ASSERT_EQ(op1_out, op1_native_out);
119+
ASSERT_EQ(op1_native_out, op1_avx512_out);
120+
}
121+
}
122+
60123
#endif
61124
} // namespace hexl
62125
} // namespace intel

test/test-eltwise-reduce-mod-avx512.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,6 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) {
134134
GTEST_SKIP();
135135
}
136136

137-
std::random_device rd;
138-
std::mt19937 gen(rd());
139-
140137
size_t length = 1024;
141138

142139
for (size_t bits = 50; bits <= 62; ++bits) {
@@ -170,9 +167,6 @@ TEST(EltwiseReduceMod, AVX512Big_4_1) {
170167
GTEST_SKIP();
171168
}
172169

173-
std::random_device rd;
174-
std::mt19937 gen(rd());
175-
176170
size_t length = 1024;
177171

178172
for (size_t bits = 50; bits <= 62; ++bits) {
@@ -263,6 +257,7 @@ TEST(EltwiseReduceMod, AVX512Big_2_1) {
263257
}
264258
}
265259
}
260+
266261
#endif
267262

268263
} // namespace hexl

0 commit comments

Comments
 (0)