|
10 | 10 |
|
11 | 11 |
|
12 | 12 | #include "hurchalla/montgomery_arithmetic/detail/MontgomeryFormExtensions.h" |
| 13 | +#include "hurchalla/montgomery_arithmetic/detail/platform_specific/montgomery_two_pow.h" |
13 | 14 | #include "hurchalla/modular_arithmetic/detail/optimization_tag_structs.h" |
14 | 15 | #include "hurchalla/util/traits/ut_numeric_limits.h" |
15 | 16 | #include "hurchalla/util/count_leading_zeros.h" |
@@ -1634,18 +1635,55 @@ goto break_0_18; |
1634 | 1635 | result = MFE_LU::twoPowLimited_times_x(mf, loindex, table_mid[midindex]); |
1635 | 1636 |
|
1636 | 1637 | V next = r4; // R^4 |
1637 | | - HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
1638 | | - tables_extra[i][0] = mf.getUnityValue(); // R^0 |
1639 | | - tables_extra[i][1] = next; |
1640 | | - V nextSq = mf.square(next); |
1641 | | - V nexttmp = mf.square(nextSq); |
1642 | | - tables_extra[i][2] = nextSq; |
1643 | | - tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next); |
1644 | | - next = nexttmp; |
1645 | | - |
1646 | | - int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
1647 | | - size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
1648 | | - result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result); |
| 1638 | + |
| 1639 | + // Check whether we have 128bit MontgomeryForm or 64bit (or less). |
| 1640 | + // We use this to choose whether to unroll the loop. |
| 1641 | + if HURCHALLA_CPP17_CONSTEXPR (digitsRU > HURCHALLA_TARGET_BIT_WIDTH) { |
| 1642 | + for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
| 1643 | + int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 1644 | +#if 1 |
| 1645 | + // This early exit is optional for us to include or not include. |
| 1646 | + // (On M2 benches, this check didn't hurt 128bit perf, but 64bit |
| 1647 | + // perf was slightly slowed. Thus it's enabled only for 128bit) |
| 1648 | + if (n < (static_cast<size_t>(1) << P_extra)) |
| 1649 | + return result; |
| 1650 | +#endif |
| 1651 | + tables_extra[i][0] = mf.getUnityValue(); // R^0 |
| 1652 | + tables_extra[i][1] = next; |
| 1653 | + V nextSq = mf.square(next); |
| 1654 | + V nexttmp = mf.square(nextSq); |
| 1655 | + tables_extra[i][2] = nextSq; |
| 1656 | + tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next); |
| 1657 | + next = nexttmp; |
| 1658 | + |
| 1659 | + size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
| 1660 | + result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result); |
| 1661 | + } |
| 1662 | + } |
| 1663 | + else { |
| 1664 | +#if defined(__GNUC__) && !defined(__clang__) |
| 1665 | + HURCHALLA_REQUEST_UNROLL_LOOP |
| 1666 | +#endif |
| 1667 | + for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
| 1668 | + int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 1669 | +#if 0 |
| 1670 | + // This early exit is optional for us to include or not include. |
| 1671 | + // (On M2 benches, this check didn't hurt 128bit perf, but 64bit |
| 1672 | + // perf was slightly slowed. Thus it's enabled only for 128bit) |
| 1673 | + if (n < (static_cast<size_t>(1) << P_extra)) |
| 1674 | + return result; |
| 1675 | +#endif |
| 1676 | + tables_extra[i][0] = mf.getUnityValue(); // R^0 |
| 1677 | + tables_extra[i][1] = next; |
| 1678 | + V nextSq = mf.square(next); |
| 1679 | + V nexttmp = mf.square(nextSq); |
| 1680 | + tables_extra[i][2] = nextSq; |
| 1681 | + tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next); |
| 1682 | + next = nexttmp; |
| 1683 | + |
| 1684 | + size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
| 1685 | + result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result); |
| 1686 | + } |
1649 | 1687 | } |
1650 | 1688 | } |
1651 | 1689 |
|
@@ -1726,34 +1764,40 @@ goto break_0_18; |
1726 | 1764 | V val1 = MFE::twoPowLimited_times_x(mf, loindex, table_mid[midindex]); |
1727 | 1765 |
|
1728 | 1766 | if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) { |
1729 | | - // could use: |
1730 | | - //for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < shift); ++i) |
1731 | | - |
1732 | | - HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
1733 | | - int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 1767 | + SV sv = MFE::getSquaringValue(mf, result); |
| 1768 | + int i=0; |
| 1769 | + for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < shift; ++i) { |
| 1770 | + int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3; |
1734 | 1771 | size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
1735 | | - val1 = mf.multiply(val1, tables_extra[i][index_extra]); |
| 1772 | + HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k) |
| 1773 | + sv = MFE::squareSV(mf, sv); |
| 1774 | + val1 = mf.template multiply<LowuopsTag>( |
| 1775 | + val1, tables_extra[static_cast<size_t>(i)][index_extra]); |
1736 | 1776 | } |
| 1777 | + //make 'i' the count of how many squarings of sv (i.e. result) we just did |
| 1778 | + i = i * NUMBITS_TABLE_HIGH_SIZE; |
1737 | 1779 |
|
1738 | | - SV sv = MFE::getSquaringValue(mf, result); |
1739 | 1780 | HPBC_CLOCKWORK_ASSERT2(shift >= 1); |
1740 | | - for (int i=0; i<shift-1; ++i) |
| 1781 | + for (; i<shift-1; ++i) |
1741 | 1782 | sv = MFE::squareSV(mf, sv); |
| 1783 | + HPBC_CLOCKWORK_ASSERT2(i == shift-1); |
1742 | 1784 | result = MFE::squareToMontgomeryValue(mf, sv); |
1743 | 1785 | } |
1744 | 1786 | else { |
1745 | | - result = mf.square(result); |
1746 | | - |
1747 | | - // could use: |
1748 | | - //for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < shift); ++i) |
1749 | | - |
1750 | | - HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
1751 | | - int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 1787 | + int i=0; |
| 1788 | + for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < shift; ++i) { |
| 1789 | + int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3; |
1752 | 1790 | size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
1753 | | - val1 = mf.multiply(val1, tables_extra[i][index_extra]); |
| 1791 | + HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k) |
| 1792 | + result = mf.square(result); |
| 1793 | + val1 = mf.template multiply<LowuopsTag>( |
| 1794 | + val1, tables_extra[static_cast<size_t>(i)][index_extra]); |
1754 | 1795 | } |
| 1796 | + //make 'i' the count of how many squarings of result we just did |
| 1797 | + i = i * NUMBITS_TABLE_HIGH_SIZE; |
| 1798 | + HPBC_CLOCKWORK_ASSERT2(i <= shift); |
1755 | 1799 |
|
1756 | | - for (int i=1; i<shift; ++i) |
| 1800 | + for (; i<shift; ++i) |
1757 | 1801 | result = mf.square(result); |
1758 | 1802 | } |
1759 | 1803 | result = mf.multiply(result, val1); |
@@ -2268,6 +2312,9 @@ goto break_0_18; |
2268 | 2312 | std::array<C, TABLE_HIGH_SIZE> table_mid; |
2269 | 2313 | std::array<std::array<V, TABLE_HIGH_SIZE>, NUM_EXTRA_TABLES> tables_extra; |
2270 | 2314 |
|
| 2315 | + auto n_orig = n; |
| 2316 | + (void)n_orig; // silence potential unused var warnings |
| 2317 | + |
2271 | 2318 | int shift = 0; |
2272 | 2319 | size_t tmp; |
2273 | 2320 | if (n > MASKBIG) { |
@@ -2306,18 +2353,55 @@ goto break_0_18; |
2306 | 2353 | result = MFE_LU::twoPowLimited_times_x(mf, loindex, table_mid[midindex]); |
2307 | 2354 |
|
2308 | 2355 | V next = r4; // R^4 |
2309 | | - HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
2310 | | - tables_extra[i][0] = mf.getUnityValue(); // R^0 |
2311 | | - tables_extra[i][1] = next; |
2312 | | - V nextSq = mf.square(next); |
2313 | | - V nexttmp = mf.square(nextSq); |
2314 | | - tables_extra[i][2] = nextSq; |
2315 | | - tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next); |
2316 | | - next = nexttmp; |
2317 | | - |
2318 | | - int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
2319 | | - size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
2320 | | - result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result); |
| 2356 | + |
| 2357 | + // Check whether we have 128bit MontgomeryForm or 64bit (or less). |
| 2358 | + // We use this to choose whether to unroll the loop. |
| 2359 | + if HURCHALLA_CPP17_CONSTEXPR (digitsRU > HURCHALLA_TARGET_BIT_WIDTH) { |
| 2360 | + for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
| 2361 | + int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 2362 | +#if 1 |
| 2363 | + // This early exit is optional for us to include or not include. |
| 2364 | + // (On M2 benches, this check didn't hurt 128bit perf, but 64bit |
| 2365 | + // perf was slightly slowed. Thus it's enabled only for 128bit) |
| 2366 | + if (n_orig < (static_cast<size_t>(1) << P_extra)) |
| 2367 | + return result; |
| 2368 | +#endif |
| 2369 | + tables_extra[i][0] = mf.getUnityValue(); // R^0 |
| 2370 | + tables_extra[i][1] = next; |
| 2371 | + V nextSq = mf.square(next); |
| 2372 | + V nexttmp = mf.square(nextSq); |
| 2373 | + tables_extra[i][2] = nextSq; |
| 2374 | + tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next); |
| 2375 | + next = nexttmp; |
| 2376 | + |
| 2377 | + size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
| 2378 | + result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result); |
| 2379 | + } |
| 2380 | + } |
| 2381 | + else { |
| 2382 | +#if defined(__GNUC__) && !defined(__clang__) |
| 2383 | + HURCHALLA_REQUEST_UNROLL_LOOP |
| 2384 | +#endif |
| 2385 | + for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
| 2386 | + int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 2387 | +#if 0 |
| 2388 | + // This early exit is optional for us to include or not include. |
| 2389 | + // (On M2 benches, this check didn't hurt 128bit perf, but 64bit |
| 2390 | + // perf was slightly slowed. Thus it's enabled only for 128bit) |
| 2391 | + if (n_orig < (static_cast<size_t>(1) << P_extra)) |
| 2392 | + return result; |
| 2393 | +#endif |
| 2394 | + tables_extra[i][0] = mf.getUnityValue(); // R^0 |
| 2395 | + tables_extra[i][1] = next; |
| 2396 | + V nextSq = mf.square(next); |
| 2397 | + V nexttmp = mf.square(nextSq); |
| 2398 | + tables_extra[i][2] = nextSq; |
| 2399 | + tables_extra[i][3] = mf.template multiply<LowuopsTag>(nextSq, next); |
| 2400 | + next = nexttmp; |
| 2401 | + |
| 2402 | + size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
| 2403 | + result = mf.template multiply<LowuopsTag>(tables_extra[i][index_extra], result); |
| 2404 | + } |
2321 | 2405 | } |
2322 | 2406 | } |
2323 | 2407 |
|
@@ -2431,34 +2515,40 @@ goto break_0_18; |
2431 | 2515 | V val1 = MFE::twoPowLimited_times_x(mf, loindex, table_mid[midindex]); |
2432 | 2516 |
|
2433 | 2517 | if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) { |
2434 | | - // could use: |
2435 | | - //for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < bits_remaining); ++i) |
2436 | | - |
2437 | | - HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
2438 | | - int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 2518 | + SV sv = MFE::getSquaringValue(mf, result); |
| 2519 | + int i=0; |
| 2520 | + for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < bits_remaining; ++i) { |
| 2521 | + int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3; |
2439 | 2522 | size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
2440 | | - val1 = mf.multiply(val1, tables_extra[i][index_extra]); |
| 2523 | + HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k) |
| 2524 | + sv = MFE::squareSV(mf, sv); |
| 2525 | + val1 = mf.template multiply<LowuopsTag>( |
| 2526 | + val1, tables_extra[static_cast<size_t>(i)][index_extra]); |
2441 | 2527 | } |
| 2528 | + //make 'i' the count of how many squarings of sv (i.e. result) we just did |
| 2529 | + i = i * NUMBITS_TABLE_HIGH_SIZE; |
2442 | 2530 |
|
2443 | | - SV sv = MFE::getSquaringValue(mf, result); |
2444 | 2531 | HPBC_CLOCKWORK_ASSERT2(bits_remaining >= 1); |
2445 | | - for (int i=0; i<bits_remaining-1; ++i) |
| 2532 | + for (; i<bits_remaining-1; ++i) |
2446 | 2533 | sv = MFE::squareSV(mf, sv); |
| 2534 | + HPBC_CLOCKWORK_ASSERT2(i == bits_remaining-1); |
2447 | 2535 | result = MFE::squareToMontgomeryValue(mf, sv); |
2448 | 2536 | } |
2449 | 2537 | else { |
2450 | | - result = mf.square(result); |
2451 | | - |
2452 | | - // could use: |
2453 | | - //for (int i=0; i < NUM_EXTRA_TABLES && (2*i + P3 < bits_remaining); ++i) |
2454 | | - |
2455 | | - HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i < NUM_EXTRA_TABLES; ++i) { |
2456 | | - int P_extra = P3 + static_cast<int>(i * NUMBITS_TABLE_HIGH_SIZE); |
| 2538 | + int i=0; |
| 2539 | + for (; i * NUMBITS_TABLE_HIGH_SIZE + P3 < bits_remaining; ++i) { |
| 2540 | + int P_extra = i * NUMBITS_TABLE_HIGH_SIZE + P3; |
2457 | 2541 | size_t index_extra = (tmp >> P_extra) & (TABLE_HIGH_SIZE - 1); |
2458 | | - val1 = mf.multiply(val1, tables_extra[i][index_extra]); |
| 2542 | + HURCHALLA_REQUEST_UNROLL_LOOP for (int k=0; k < NUMBITS_TABLE_HIGH_SIZE; ++k) |
| 2543 | + result = mf.square(result); |
| 2544 | + val1 = mf.template multiply<LowuopsTag>( |
| 2545 | + val1, tables_extra[static_cast<size_t>(i)][index_extra]); |
2459 | 2546 | } |
| 2547 | + //make 'i' the count of how many squarings of result we just did |
| 2548 | + i = i * NUMBITS_TABLE_HIGH_SIZE; |
| 2549 | + HPBC_CLOCKWORK_ASSERT2(i <= bits_remaining); |
2460 | 2550 |
|
2461 | | - for (int i=1; i<bits_remaining; ++i) |
| 2551 | + for (; i<bits_remaining; ++i) |
2462 | 2552 | result = mf.square(result); |
2463 | 2553 | } |
2464 | 2554 | result = mf.multiply(result, val1); |
@@ -2891,7 +2981,13 @@ goto break_0_39; |
2891 | 2981 | } |
2892 | 2982 | result = mf.multiply(result, tableVal); |
2893 | 2983 | return result; |
| 2984 | + |
| 2985 | +} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 42) { |
| 2986 | + // call our non-experimental (presumed best) implementation |
| 2987 | + |
| 2988 | + return hurchalla::detail::montgomery_two_pow::call(mf, n); |
2894 | 2989 | } |
| 2990 | + |
2895 | 2991 | } |
2896 | 2992 | else if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE == 2) { |
2897 | 2993 | table[0] = mf.getUnityValue(); // montgomery one |
@@ -4719,9 +4815,8 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) { |
4719 | 4815 | tableVal[j]); |
4720 | 4816 | } |
4721 | 4817 | return result; |
4722 | | - } else { // CODE_SECTION 31 |
| 4818 | + } else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 31) { |
4723 | 4819 | // this is an optimization of CODE_SECTION 29 |
4724 | | - static_assert(CODE_SECTION == 31, ""); |
4725 | 4820 |
|
4726 | 4821 | std::array<V, ARRAY_SIZE> result; |
4727 | 4822 | if (n_max <= MASK) { |
@@ -4822,6 +4917,12 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) { |
4822 | 4917 | tableVal[j]); |
4823 | 4918 | } |
4824 | 4919 | return result; |
| 4920 | + |
| 4921 | + } else { |
| 4922 | + // call our non-experimental (presumed best) implementation |
| 4923 | + static_assert(CODE_SECTION == 32, ""); |
| 4924 | + |
| 4925 | + return hurchalla::detail::montgomery_two_pow::call(mf, n); |
4825 | 4926 | } |
4826 | 4927 |
|
4827 | 4928 | } |
|
0 commit comments