Skip to content

Commit 01fe30b

Browse files
committed
update montgomery two pow function call choices to get ideal perf on ARM64 (M2). add two code sections to experimental mont two pow to potentially benefit x64
1 parent 6bf5408 commit 01fe30b

File tree

54 files changed

+16710
-40
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+16710
-40
lines changed

modular_arithmetic/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ include(FetchContent)
7575
FetchContent_Declare(
7676
hurchalla_util
7777
GIT_REPOSITORY https://github.com/hurchalla/util.git
78-
GIT_TAG 9fac434b586717052c648339eb0f0f89d23e0298
78+
GIT_TAG ea4d95c8852d8351cbd1529bbb48a9c10e7d61bf
7979
)
8080
FetchContent_MakeAvailable(hurchalla_util)
8181

montgomery_arithmetic/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ include(FetchContent)
7979
FetchContent_Declare(
8080
hurchalla_util
8181
GIT_REPOSITORY https://github.com/hurchalla/util.git
82-
GIT_TAG 9fac434b586717052c648339eb0f0f89d23e0298
82+
GIT_TAG ea4d95c8852d8351cbd1529bbb48a9c10e7d61bf
8383
)
8484
FetchContent_MakeAvailable(hurchalla_util)
8585

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_two_pow/experimental_montgomery_two_pow.h

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2726,6 +2726,171 @@ goto break_0_39;
27262726
}
27272727
result = mf.multiply(result, val1);
27282728
return result;
2729+
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 40) {
2730+
// optimization of code section 28
2731+
// that replaces 'shift' with 'bits_remaining' in order to obtain more
2732+
// efficient shifts. It may or may not make a difference for speed...
2733+
2734+
if (n <= MASK) {
2735+
C cR1 = MFE::getMontvalueR(mf);
2736+
V result = MFE::twoPowLimited_times_x(mf, static_cast<size_t>(n), cR1);
2737+
return result;
2738+
}
2739+
HPBC_CLOCKWORK_ASSERT2(n > MASK);
2740+
2741+
HPBC_CLOCKWORK_ASSERT2(n > 0);
2742+
int leading_zeros = count_leading_zeros(n);
2743+
int bits_remaining = ut_numeric_limits<decltype(n)>::digits - leading_zeros;
2744+
HPBC_CLOCKWORK_ASSERT2(bits_remaining > P2);
2745+
2746+
U n2 = branchless_shift_left(n, leading_zeros);
2747+
2748+
// calculate the constexpr var 'high_word_shift' - when we right shift a
2749+
// type U variable by this amount, we'll get the size_t furthest most
2750+
// left bits of the type U variable. Note that we assume that a right
2751+
// shift by high_word_shift will be zero cost, since the shift is just a
2752+
// way to access the CPU register that has the most significant bits -
2753+
// unless the compiler is really dumb and misses this optimization,
2754+
// which I haven't seen happen and which would surprise me.
2755+
constexpr int size_t_digits = ut_numeric_limits<size_t>::digits;
2756+
constexpr int digits_U = ut_numeric_limits<U>::digits;
2757+
constexpr int digits_bigger = (digits_U > size_t_digits) ? digits_U : size_t_digits;
2758+
constexpr int digits_smaller = (digits_U < size_t_digits) ? digits_U : size_t_digits;
2759+
constexpr int high_word_shift = digits_bigger - size_t_digits;
2760+
2761+
size_t index = static_cast<size_t>(n2 >> high_word_shift) >> (digits_smaller - P2);
2762+
n2 = static_cast<U>(n2 << P2);
2763+
HPBC_CLOCKWORK_ASSERT2(index <= MASK);
2764+
// normally we use (index & MASK), but it's redundant with index <= MASK
2765+
C cR1 = MFE::getMontvalueR(mf);
2766+
V result = MFE::twoPowLimited_times_x_v2(mf, index + 1, cR1);
2767+
2768+
bits_remaining -= P2;
2769+
2770+
while (bits_remaining >= P2) {
2771+
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
2772+
SV sv = MFE::getSquaringValue(mf, result);
2773+
static_assert(P2 > 0, "");
2774+
HURCHALLA_REQUEST_UNROLL_LOOP for (int i=0; i<P2 - 1; ++i)
2775+
sv = MFE::squareSV(mf, sv);
2776+
result = MFE::squareToMontgomeryValue(mf, sv);
2777+
} else {
2778+
HURCHALLA_REQUEST_UNROLL_LOOP for (int i=0; i<P2; ++i)
2779+
result = mf.square(result);
2780+
}
2781+
2782+
bits_remaining -= P2;
2783+
index = static_cast<size_t>(n2 >> high_word_shift) >> (digits_smaller - P2);
2784+
n2 = static_cast<U>(n2 << P2);
2785+
C tmp = mf.getCanonicalValue(result);
2786+
result = MFE::twoPowLimited_times_x_v2(mf, index + 1, tmp);
2787+
}
2788+
result = mf.halve(result);
2789+
2790+
if (bits_remaining == 0)
2791+
return result;
2792+
HPBC_CLOCKWORK_ASSERT2(0 < bits_remaining && bits_remaining < P2);
2793+
2794+
index = static_cast<size_t>(n2 >> high_word_shift) >> (digits_smaller - bits_remaining);
2795+
V tableVal = MFE::twoPowLimited_times_x(mf, index, cR1);
2796+
2797+
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
2798+
SV sv = MFE::getSquaringValue(mf, result);
2799+
HPBC_CLOCKWORK_ASSERT2(bits_remaining >= 1);
2800+
for (int i=0; i<bits_remaining-1; ++i)
2801+
sv = MFE::squareSV(mf, sv);
2802+
result = MFE::squareToMontgomeryValue(mf, sv);
2803+
}
2804+
else {
2805+
for (int i=0; i<bits_remaining; ++i)
2806+
result = mf.square(result);
2807+
}
2808+
result = mf.multiply(result, tableVal);
2809+
return result;
2810+
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 41) {
2811+
// optimization of code section 29
2812+
// that replaces 'shift' with 'bits_remaining' in order to obtain more
2813+
// efficient shifts. It may or may not make a difference for speed...
2814+
2815+
if (n <= MASK) {
2816+
C cR1 = MFE::getMontvalueR(mf);
2817+
V result = MFE::twoPowLimited_times_x(mf, static_cast<size_t>(n), cR1);
2818+
return result;
2819+
}
2820+
HPBC_CLOCKWORK_ASSERT2(n > MASK);
2821+
2822+
HPBC_CLOCKWORK_ASSERT2(n > 0);
2823+
int leading_zeros = count_leading_zeros(n);
2824+
int bits_remaining = ut_numeric_limits<decltype(n)>::digits - leading_zeros;
2825+
HPBC_CLOCKWORK_ASSERT2(bits_remaining > P2);
2826+
2827+
U n2 = branchless_shift_left(n, leading_zeros);
2828+
2829+
// calculate the constexpr var 'high_word_shift' - when we right shift a
2830+
// type U variable by this amount, we'll get the size_t furthest most
2831+
// left bits of the type U variable. Note that we assume that a right
2832+
// shift by high_word_shift will be zero cost, since the shift is just a
2833+
// way to access the CPU register that has the most significant bits -
2834+
// unless the compiler is really dumb and misses this optimization,
2835+
// which I haven't seen happen and which would surprise me.
2836+
constexpr int size_t_digits = ut_numeric_limits<size_t>::digits;
2837+
constexpr int digits_U = ut_numeric_limits<U>::digits;
2838+
constexpr int digits_bigger = (digits_U > size_t_digits) ? digits_U : size_t_digits;
2839+
constexpr int digits_smaller = (digits_U < size_t_digits) ? digits_U : size_t_digits;
2840+
constexpr int high_word_shift = digits_bigger - size_t_digits;
2841+
2842+
C cresult = MFE::getMontvalueR(mf);
2843+
2844+
HPBC_CLOCKWORK_ASSERT2(bits_remaining > P2);
2845+
// we check against P2 + P2 because we always process P2 more bits after
2846+
// the loop ends -- so we need to ensure we'll actually have
2847+
// (bits_remaining >= P2) after the loop ends.
2848+
while (bits_remaining >= P2 + P2) {
2849+
size_t index = static_cast<size_t>(n2 >> high_word_shift) >> (digits_smaller - P2);
2850+
n2 = static_cast<U>(n2 << P2);
2851+
V result = MFE::twoPowLimited_times_x_v2(mf, index + 1, cresult);
2852+
2853+
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
2854+
SV sv = MFE::getSquaringValue(mf, result);
2855+
static_assert(P2 > 0, "");
2856+
HURCHALLA_REQUEST_UNROLL_LOOP for (int i=0; i<P2 - 1; ++i)
2857+
sv = MFE::squareSV(mf, sv);
2858+
result = MFE::squareToMontgomeryValue(mf, sv);
2859+
} else {
2860+
HURCHALLA_REQUEST_UNROLL_LOOP for (int i=0; i<P2; ++i)
2861+
result = mf.square(result);
2862+
}
2863+
cresult = mf.getCanonicalValue(result);
2864+
2865+
bits_remaining -= P2;
2866+
}
2867+
HPBC_CLOCKWORK_ASSERT2(P2 <= bits_remaining && bits_remaining < P2 + P2);
2868+
2869+
size_t index = static_cast<size_t>(n2 >> high_word_shift) >> (digits_smaller - P2);
2870+
n2 = static_cast<U>(n2 << P2);
2871+
V result = MFE::twoPowLimited_times_x(mf, index, cresult);
2872+
bits_remaining -= P2;
2873+
if (bits_remaining == 0)
2874+
return result;
2875+
HPBC_CLOCKWORK_ASSERT2(0 < bits_remaining && bits_remaining < P2);
2876+
2877+
index = static_cast<size_t>(n2 >> high_word_shift) >> (digits_smaller - bits_remaining);
2878+
C cR1 = MFE::getMontvalueR(mf);
2879+
V tableVal = MFE::twoPowLimited_times_x(mf, index, cR1);
2880+
2881+
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
2882+
SV sv = MFE::getSquaringValue(mf, result);
2883+
HPBC_CLOCKWORK_ASSERT2(bits_remaining >= 1);
2884+
for (int i=0; i<bits_remaining-1; ++i)
2885+
sv = MFE::squareSV(mf, sv);
2886+
result = MFE::squareToMontgomeryValue(mf, sv);
2887+
}
2888+
else {
2889+
for (int i=0; i<bits_remaining; ++i)
2890+
result = mf.square(result);
2891+
}
2892+
result = mf.multiply(result, tableVal);
2893+
return result;
27292894
}
27302895
}
27312896
else if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE == 2) {

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_two_pow/testbench_montgomery_two_pow.cpp

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -945,7 +945,7 @@ using namespace hurchalla;
945945
std::cout << "\nbegin benchmarks - array two_pow\n";
946946

947947
// warm up call
948-
bench_array_two_pow<1, 8, 8, MontType, false>(static_cast<U>(maxU - range), range, dummy, max_modulus_bits_reduce, seed, exponent_bits_reduce);
948+
bench_array_two_pow<0, 30, 6, MontType, false>(static_cast<U>(maxU - range), range, dummy, max_modulus_bits_reduce, seed, exponent_bits_reduce);
949949

950950
// format is bench_array_two_pow<TABLE_BITS, CODE_SECTION, ARRAY_SIZE, MontType, USE_SQUARING_VALUE_OPTIMIZATION>(...)
951951

@@ -954,21 +954,7 @@ using namespace hurchalla;
954954
for (size_t i=0; i<4; ++i) {
955955
for (size_t j=0; j<timingA[i].size(); ++j) {
956956

957-
timingA[i][j].push_back(
958-
bench_array_two_pow<1, 8, 3, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
959-
timingA[i][j].push_back(
960-
bench_array_two_pow<1, 8, 4, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
961-
timingA[i][j].push_back(
962-
bench_array_two_pow<1, 8, 5, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
963-
timingA[i][j].push_back(
964-
bench_array_two_pow<1, 8, 6, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
965-
timingA[i][j].push_back(
966-
bench_array_two_pow<1, 8, 7, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
967-
timingA[i][j].push_back(
968-
bench_array_two_pow<1, 8, 8, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
969-
970-
971-
#if 0
957+
#if 1
972958
timingA[i][j].push_back(
973959
bench_array_two_pow<0, 27, 3, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
974960
timingA[i][j].push_back(
@@ -1149,7 +1135,7 @@ using namespace hurchalla;
11491135
}
11501136
#endif
11511137

1152-
#if 0
1138+
#if 1
11531139
timingA[i][j].push_back(
11541140
bench_array_two_pow<0, 0, 10, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
11551141
timingA[i][j].push_back(
@@ -1784,7 +1770,7 @@ std::cout << "Timings By Test Type:\n";
17841770

17851771
// warm up to get cpu boost (or throttle) going
17861772
for (size_t i=0; i<1; ++i)
1787-
bench_range<1, false, 0, MontType, false>(static_cast<U>(maxU - range), range, dummy, max_modulus_bits_reduce, seed, exponent_bits_reduce);
1773+
bench_range<0, false, 34, MontType, true>(static_cast<U>(maxU - range), range, dummy, max_modulus_bits_reduce, seed, exponent_bits_reduce);
17881774

17891775
// std::array<std::vector<Timing>, 4> timings;
17901776

@@ -1795,8 +1781,6 @@ std::cout << "Timings By Test Type:\n";
17951781

17961782
// format is bench_range<TABLE_BITS, USE_SLIDING_WINDOW_OPTIMIZATION, CODE_SECTION,
17971783
// MontType, USE_SQUARING_VALUE_OPTIMIZATION>
1798-
timings[i][j].push_back(
1799-
bench_range<1, false, 0, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
18001784

18011785
#if 0
18021786
// This is a copy/paste of the "best of best" code sections from further below (nothing is new here).
@@ -2016,8 +2000,7 @@ std::cout << "Timings By Test Type:\n";
20162000

20172001

20182002

2019-
2020-
#if 0
2003+
#if 1
20212004
timings[i][j].push_back(
20222005
bench_range<0, true , 17, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
20232006
timings[i][j].push_back(
@@ -2092,6 +2075,11 @@ std::cout << "Timings By Test Type:\n";
20922075
timings[i][j].push_back(
20932076
bench_range<0, false, 39, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
20942077

2078+
timings[i][j].push_back(
2079+
bench_range<0, false, 40, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
2080+
timings[i][j].push_back(
2081+
bench_range<0, false, 41, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
2082+
20952083
timings[i][j].push_back(
20962084
bench_range<0, true , 19, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
20972085
timings[i][j].push_back(
@@ -2242,6 +2230,11 @@ std::cout << "Timings By Test Type:\n";
22422230
timings[i][j].push_back(
22432231
bench_range<0, false, 39, MontType, true>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
22442232

2233+
timings[i][j].push_back(
2234+
bench_range<0, false, 40, MontType, true>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
2235+
timings[i][j].push_back(
2236+
bench_range<0, false, 41, MontType, true>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
2237+
22452238
timings[i][j].push_back(
22462239
bench_range<0, true , 19, MontType, true>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
22472240
timings[i][j].push_back(
@@ -2327,7 +2320,7 @@ std::cout << "Timings By Test Type:\n";
23272320
bench_range<4, true , 1, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
23282321
#endif
23292322

2330-
#if 0
2323+
#if 1
23312324
timings[i][j].push_back(
23322325
bench_range<4, true , 0, MontType, false>(static_cast<U>(maxU - range), range, dummy, mmbr[i], seed, ebr[i]));
23332326
timings[i][j].push_back(

0 commit comments

Comments
 (0)