Skip to content

Commit 9e619f7

Browse files
committed
improve montgomery two pow
1 parent 654f736 commit 9e619f7

File tree

4 files changed

+350
-513
lines changed

4 files changed

+350
-513
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
./testbench.sh clang++ O3 MontgomeryQuarter uint64_t 191 8 22 -DTEST_SCALAR -DHURCHALLA_MONTGOMERY_TWO_POW_USE_CSELECT_ON_BIT -DHURCHALLA_ALLOW_INLINE_ASM_ALL
3+
4+
./testbench.sh clang++ O3 MontgomeryHalf uint64_t 191 8 22 -DTEST_SCALAR -DHURCHALLA_MONTGOMERY_TWO_POW_USE_CSELECT_ON_BIT -DHURCHALLA_ALLOW_INLINE_ASM_ALL
5+
6+
./testbench.sh clang++ O3 MontgomeryFull uint64_t 191 8 22 -DTEST_SCALAR -DHURCHALLA_MONTGOMERY_TWO_POW_USE_CSELECT_ON_BIT -DHURCHALLA_ALLOW_INLINE_ASM_ALL
7+
8+
9+
./testbench.sh g++ O3 MontgomeryQuarter uint64_t 191 8 22 -DTEST_SCALAR -DHURCHALLA_MONTGOMERY_TWO_POW_USE_CSELECT_ON_BIT -DHURCHALLA_ALLOW_INLINE_ASM_ALL
10+
11+
./testbench.sh g++ O3 MontgomeryHalf uint64_t 191 8 22 -DTEST_SCALAR -DHURCHALLA_MONTGOMERY_TWO_POW_USE_CSELECT_ON_BIT -DHURCHALLA_ALLOW_INLINE_ASM_ALL
12+
13+
./testbench.sh g++ O3 MontgomeryFull uint64_t 191 8 22 -DTEST_SCALAR -DHURCHALLA_MONTGOMERY_TWO_POW_USE_CSELECT_ON_BIT -DHURCHALLA_ALLOW_INLINE_ASM_ALL
14+

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_two_pow/experimental_montgomery_two_pow.h

Lines changed: 161 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
#include "hurchalla/montgomery_arithmetic/detail/MontgomeryFormExtensions.h"
13+
#include "hurchalla/montgomery_arithmetic/detail/platform_specific/montgomery_two_pow.h"
1314
#include "hurchalla/modular_arithmetic/detail/optimization_tag_structs.h"
1415
#include "hurchalla/util/traits/ut_numeric_limits.h"
1516
#include "hurchalla/util/count_leading_zeros.h"
@@ -1634,18 +1635,55 @@ goto break_0_18;
16341635
result = MFE_LU::twoPowLimited_times_x(mf, loindex, table_mid[midindex]);
16351636

16361637
V next = r4; // R^4
1637-
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
1638-
tables_extra[i][0] = mf.getUnityValue(); // R^0
1639-
tables_extra[i][1] = next;
1640-
V nextSq = mf.square(next);
1641-
V nexttmp = mf.square(nextSq);
1642-
tables_extra[i][2] = nextSq;
1643-
tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next);
1644-
next = nexttmp;
1645-
1646-
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
1647-
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
1648-
result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result);
1638+
1639+
// Check whether we have 128bit MontgomeryForm or 64bit (or less).
1640+
// We use this to choose whether to unroll the loop.
1641+
if HURCHALLA_CPP17_CONSTEXPR (digitsRU > HURCHALLA_TARGET_BIT_WIDTH) {
1642+
for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
1643+
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
1644+
#if 1
1645+
// This early exit is optional for us to include or not include.
1646+
// (On M2 benches, this check didn't hurt 128bit perf, but 64bit
1647+
// perf was slightly slowed. Thus it's enabled only for 128bit)
1648+
if (n < (static_cast<size_t>(1) << P_extra))
1649+
return result;
1650+
#endif
1651+
tables_extra[i][0] = mf.getUnityValue(); // R^0
1652+
tables_extra[i][1] = next;
1653+
V nextSq = mf.square(next);
1654+
V nexttmp = mf.square(nextSq);
1655+
tables_extra[i][2] = nextSq;
1656+
tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next);
1657+
next = nexttmp;
1658+
1659+
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
1660+
result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result);
1661+
}
1662+
}
1663+
else {
1664+
#if defined(__GNUC__) && !defined(__clang__)
1665+
HURCHALLA_REQUEST_UNROLL_LOOP
1666+
#endif
1667+
for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
1668+
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
1669+
#if 0
1670+
// This early exit is optional for us to include or not include.
1671+
// (On M2 benches, this check didn't hurt 128bit perf, but 64bit
1672+
// perf was slightly slowed. Thus it's enabled only for 128bit)
1673+
if (n < (static_cast<size_t>(1) << P_extra))
1674+
return result;
1675+
#endif
1676+
tables_extra[i][0] = mf.getUnityValue(); // R^0
1677+
tables_extra[i][1] = next;
1678+
V nextSq = mf.square(next);
1679+
V nexttmp = mf.square(nextSq);
1680+
tables_extra[i][2] = nextSq;
1681+
tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next);
1682+
next = nexttmp;
1683+
1684+
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
1685+
result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result);
1686+
}
16491687
}
16501688
}
16511689

@@ -1726,34 +1764,40 @@ goto break_0_18;
17261764
V val1 = MFE::twoPowLimited_times_x(mf, loindex, table_mid[midindex]);
17271765

17281766
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
1729-
// could use:
1730-
//for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < shift); ++i)
1731-
1732-
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
1733-
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
1767+
SV sv = MFE::getSquaringValue(mf, result);
1768+
int i=0;
1769+
for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < shift; ++i) {
1770+
int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3;
17341771
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
1735-
val1 = mf.multiply(val1, tables_extra[i][index_extra]);
1772+
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k)
1773+
sv = MFE::squareSV(mf, sv);
1774+
val1 = mf.template multiply<LowuopsTag>(
1775+
val1, tables_extra[static_cast<size_t>(i)][index_extra]);
17361776
}
1777+
//make 'i' the count of how many squarings of sv (i.e. result) we just did
1778+
i = i * NUMBITS_TABLE_HIGH_SIZE;
17371779

1738-
SV sv = MFE::getSquaringValue(mf, result);
17391780
HPBC_CLOCKWORK_ASSERT2(shift >= 1);
1740-
for (int i=0; i<shift-1; ++i)
1781+
for (; i<shift-1; ++i)
17411782
sv = MFE::squareSV(mf, sv);
1783+
HPBC_CLOCKWORK_ASSERT2(i == shift-1);
17421784
result = MFE::squareToMontgomeryValue(mf, sv);
17431785
}
17441786
else {
1745-
result = mf.square(result);
1746-
1747-
// could use:
1748-
//for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < shift); ++i)
1749-
1750-
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
1751-
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
1787+
int i=0;
1788+
for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < shift; ++i) {
1789+
int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3;
17521790
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
1753-
val1 = mf.multiply(val1, tables_extra[i][index_extra]);
1791+
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k)
1792+
result = mf.square(result);
1793+
val1 = mf.template multiply<LowuopsTag>(
1794+
val1, tables_extra[static_cast<size_t>(i)][index_extra]);
17541795
}
1796+
//make 'i' the count of how many squarings of result we just did
1797+
i = i * NUMBITS_TABLE_HIGH_SIZE;
1798+
HPBC_CLOCKWORK_ASSERT2(i <= shift);
17551799

1756-
for (int i=1; i<shift; ++i)
1800+
for (; i<shift; ++i)
17571801
result = mf.square(result);
17581802
}
17591803
result = mf.multiply(result, val1);
@@ -2268,6 +2312,9 @@ goto break_0_18;
22682312
std::array<C, TABLE_HIGH_SIZE> table_mid;
22692313
std::array<std::array<V, TABLE_HIGH_SIZE>, NUM_EXTRA_TABLES> tables_extra;
22702314

2315+
auto n_orig = n;
2316+
(void)n_orig; // silence potential unused var warnings
2317+
22712318
int shift = 0;
22722319
size_t tmp;
22732320
if (n > MASKBIG) {
@@ -2306,18 +2353,55 @@ goto break_0_18;
23062353
result = MFE_LU::twoPowLimited_times_x(mf, loindex, table_mid[midindex]);
23072354

23082355
V next = r4; // R^4
2309-
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
2310-
tables_extra[i][0] = mf.getUnityValue(); // R^0
2311-
tables_extra[i][1] = next;
2312-
V nextSq = mf.square(next);
2313-
V nexttmp = mf.square(nextSq);
2314-
tables_extra[i][2] = nextSq;
2315-
tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next);
2316-
next = nexttmp;
2317-
2318-
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
2319-
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
2320-
result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result);
2356+
2357+
// Check whether we have 128bit MontgomeryForm or 64bit (or less).
2358+
// We use this to choose whether to unroll the loop.
2359+
if HURCHALLA_CPP17_CONSTEXPR (digitsRU > HURCHALLA_TARGET_BIT_WIDTH) {
2360+
for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
2361+
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
2362+
#if 1
2363+
// This early exit is optional for us to include or not include.
2364+
// (On M2 benches, this check didn't hurt 128bit perf, but 64bit
2365+
// perf was slightly slowed. Thus it's enabled only for 128bit)
2366+
if (n_orig < (static_cast<size_t>(1) << P_extra))
2367+
return result;
2368+
#endif
2369+
tables_extra[i][0] = mf.getUnityValue(); // R^0
2370+
tables_extra[i][1] = next;
2371+
V nextSq = mf.square(next);
2372+
V nexttmp = mf.square(nextSq);
2373+
tables_extra[i][2] = nextSq;
2374+
tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next);
2375+
next = nexttmp;
2376+
2377+
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
2378+
result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result);
2379+
}
2380+
}
2381+
else {
2382+
#if defined(__GNUC__) && !defined(__clang__)
2383+
HURCHALLA_REQUEST_UNROLL_LOOP
2384+
#endif
2385+
for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
2386+
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
2387+
#if 0
2388+
// This early exit is optional for us to include or not include.
2389+
// (On M2 benches, this check didn't hurt 128bit perf, but 64bit
2390+
// perf was slightly slowed. Thus it's enabled only for 128bit)
2391+
if (n_orig < (static_cast<size_t>(1) << P_extra))
2392+
return result;
2393+
#endif
2394+
tables_extra[i][0] = mf.getUnityValue(); // R^0
2395+
tables_extra[i][1] = next;
2396+
V nextSq = mf.square(next);
2397+
V nexttmp = mf.square(nextSq);
2398+
tables_extra[i][2] = nextSq;
2399+
tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next);
2400+
next = nexttmp;
2401+
2402+
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
2403+
result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result);
2404+
}
23212405
}
23222406
}
23232407

@@ -2431,34 +2515,40 @@ goto break_0_18;
24312515
V val1 = MFE::twoPowLimited_times_x(mf, loindex, table_mid[midindex]);
24322516

24332517
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
2434-
// could use:
2435-
//for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < bits_remaining); ++i)
2436-
2437-
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
2438-
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
2518+
SV sv = MFE::getSquaringValue(mf, result);
2519+
int i=0;
2520+
for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < bits_remaining; ++i) {
2521+
int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3;
24392522
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
2440-
val1 = mf.multiply(val1, tables_extra[i][index_extra]);
2523+
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k)
2524+
sv = MFE::squareSV(mf, sv);
2525+
val1 = mf.template multiply<LowuopsTag>(
2526+
val1, tables_extra[static_cast<size_t>(i)][index_extra]);
24412527
}
2528+
//make 'i' the count of how many squarings of sv (i.e. result) we just did
2529+
i = i * NUMBITS_TABLE_HIGH_SIZE;
24422530

2443-
SV sv = MFE::getSquaringValue(mf, result);
24442531
HPBC_CLOCKWORK_ASSERT2(bits_remaining >= 1);
2445-
for (int i=0; i<bits_remaining-1; ++i)
2532+
for (; i<bits_remaining-1; ++i)
24462533
sv = MFE::squareSV(mf, sv);
2534+
HPBC_CLOCKWORK_ASSERT2(i == bits_remaining-1);
24472535
result = MFE::squareToMontgomeryValue(mf, sv);
24482536
}
24492537
else {
2450-
result = mf.square(result);
2451-
2452-
// could use:
2453-
//for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < bits_remaining); ++i)
2454-
2455-
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) {
2456-
int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE);
2538+
int i=0;
2539+
for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < bits_remaining; ++i) {
2540+
int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3;
24572541
size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1);
2458-
val1 = mf.multiply(val1, tables_extra[i][index_extra]);
2542+
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k)
2543+
result = mf.square(result);
2544+
val1 = mf.template multiply<LowuopsTag>(
2545+
val1, tables_extra[static_cast<size_t>(i)][index_extra]);
24592546
}
2547+
//make 'i' the count of how many squarings of result we just did
2548+
i = i * NUMBITS_TABLE_HIGH_SIZE;
2549+
HPBC_CLOCKWORK_ASSERT2(i <= bits_remaining);
24602550

2461-
for (int i=1; i<bits_remaining; ++i)
2551+
for (; i<bits_remaining; ++i)
24622552
result = mf.square(result);
24632553
}
24642554
result = mf.multiply(result, val1);
@@ -2891,7 +2981,13 @@ goto break_0_39;
28912981
}
28922982
result = mf.multiply(result, tableVal);
28932983
return result;
2984+
2985+
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 42) {
2986+
// call our non-experimental (presumed best) implementation
2987+
2988+
return hurchalla::detail::montgomery_two_pow::call(mf, n);
28942989
}
2990+
28952991
}
28962992
else if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE == 2) {
28972993
table[0] = mf.getUnityValue(); // montgomery one
@@ -4719,9 +4815,8 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
47194815
tableVal[j]);
47204816
}
47214817
return result;
4722-
} else { // CODE_SECTION 31
4818+
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 31) {
47234819
// this is an optimization of CODE_SECTION 29
4724-
static_assert(CODE_SECTION == 31, "");
47254820

47264821
std::array<V, ARRAY_SIZE> result;
47274822
if (n_max <= MASK) {
@@ -4822,6 +4917,12 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
48224917
tableVal[j]);
48234918
}
48244919
return result;
4920+
4921+
} else {
4922+
// call our non-experimental (presumed best) implementation
4923+
static_assert(CODE_SECTION == 32, "");
4924+
4925+
return hurchalla::detail::montgomery_two_pow::call(mf, n);
48254926
}
48264927

48274928
}

0 commit comments

Comments
 (0)