@@ -72,9 +72,27 @@ struct experimental_montgomery_pow_2kary {
7272
7373
7474if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0 ) {
75- // For comparison purposes, this is the current MontgomeryForm pow.
76- // the static cast may lose bits, so this might not be an exact benchmark
77- return mf.pow (x, static_cast <typename MF::IntegerType>(n));
75+ // this is a masked version of pow from montgomery_pow.h
76+ V base = x;
77+ U exponent = n;
78+
79+ V result;
80+ if (static_cast <size_t >(exponent) & 1u )
81+ result = base;
82+ else
83+ result = mf.getUnityValue ();
84+
85+ while (exponent > 1u ) {
86+ exponent = static_cast <U>(exponent >> 1 );
87+
88+ base = mf.square (base);
89+ V tmp = mf.getUnityValue ();
90+ // note: since we are doing masked selections, we definitely don't
91+ // want to use cselect_on_bit here
92+ tmp.template cmov <CSelectMaskedTag>(static_cast <size_t >(exponent) & 1u , base);
93+ result = mf.multiply (result, tmp);
94+ }
95+ return result;
7896
7997} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 1 ) {
8098 // this is a branch version of pow from montgomery_pow.h
@@ -1201,7 +1219,8 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12011219 bool USE_SLIDING_WINDOW_OPTIMIZATION = false ,
12021220 size_t TABLE_BITS = 4 ,
12031221 size_t CODE_SECTION = 0 ,
1204- bool USE_SQUARING_VALUE_OPTIMIZATION = false >
1222+ bool USE_SQUARING_VALUE_OPTIMIZATION = false ,
1223+ class PTAG = LowuopsTag>
12051224 static std::array<typename MF::MontgomeryValue, ARRAY_SIZE>
12061225 call (const MF& mf,
12071226 const std::array<typename MF::MontgomeryValue, ARRAY_SIZE>& x,
@@ -1215,7 +1234,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12151234 " the beginning of this function to calculate 1024+ table entries!)" );
12161235
12171236 using V = typename MF::MontgomeryValue;
1218- using MFE_LU = hurchalla::detail::MontgomeryFormExtensions<MF, LowuopsTag >;
1237+ using MFE_LU = hurchalla::detail::MontgomeryFormExtensions<MF, PTAG >;
12191238 using SV = typename MFE_LU::SquaringValue;
12201239 using std::size_t ;
12211240 U n = nexp;
@@ -1230,9 +1249,31 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12301249
12311250
12321251if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0 ) {
1233- // For comparison purposes, this is the current MontgomeryForm pow.
1234- // the static cast may lose bits, so this might not be an exact benchmark
1235- return mf.pow (x, static_cast <typename MF::IntegerType>(nexp));
1252+ // this is adapted from arraypow_cond_branch_unrolled() in montgomery_pow.h
1253+
1254+ std::array<V, ARRAY_SIZE> bases = x;
1255+ U exponent = n;
1256+
1257+ std::array<V, ARRAY_SIZE> result;
1258+ if (static_cast <size_t >(exponent) & 1u ) {
1259+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1260+ result[j] = bases[j];
1261+ } else {
1262+ V mont_one = mf.getUnityValue ();
1263+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1264+ result[j] = mont_one;
1265+ }
1266+
1267+ while (exponent > 1u ) {
1268+ exponent = static_cast <U>(exponent >> 1 );
1269+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1270+ bases[j] = mf.template square <PTAG>(bases[j]);
1271+ if (static_cast <size_t >(exponent) & 1u ) {
1272+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1273+ result[j] = mf.template multiply <PTAG>(result[j], bases[j]);
1274+ }
1275+ }
1276+ return result;
12361277
12371278} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 1 ) {
12381279
@@ -1245,9 +1286,9 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12451286 for (std::size_t i=2 ; i<TABLESIZE; i+=2 ) {
12461287 std::size_t halfi = i/2 ;
12471288 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1248- table[i][j] = mf.template square <LowuopsTag >(table[halfi][j]);
1289+ table[i][j] = mf.template square <PTAG >(table[halfi][j]);
12491290 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1250- table[i+1 ][j] = mf.template multiply <LowuopsTag >(table[halfi+1 ][j], table[halfi][j]);
1291+ table[i+1 ][j] = mf.template multiply <PTAG >(table[halfi+1 ][j], table[halfi][j]);
12511292 }
12521293 }
12531294
@@ -1302,21 +1343,21 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
13021343 if (USE_SLIDING_WINDOW_OPTIMIZATION) {
13031344 while (shift > P && (static_cast <size_t >(branchless_shift_right (n, shift-1 )) & 1u ) == 0 ) {
13041345 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1305- result[j] = mf.template square <LowuopsTag >(result[j]);
1346+ result[j] = mf.template square <PTAG >(result[j]);
13061347 --shift;
13071348 }
13081349 }
13091350
13101351 for (int i=0 ; i<P; ++i) {
13111352 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1312- result[j] = mf.template square <LowuopsTag >(result[j]);
1353+ result[j] = mf.template square <PTAG >(result[j]);
13131354 }
13141355 }
13151356
13161357 shift -= P;
13171358 index = static_cast <size_t >(branchless_shift_right (n, shift)) & MASK;
13181359 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
1319- result[j] = mf.template multiply <LowuopsTag >(result[j], table[index][j]);
1360+ result[j] = mf.template multiply <PTAG >(result[j], table[index][j]);
13201361 }
13211362 }
13221363
@@ -1326,15 +1367,15 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
13261367
13271368 for (int i=0 ; i<shift; ++i) {
13281369 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1329- result[j] = mf.template square <LowuopsTag >(result[j]);
1370+ result[j] = mf.template square <PTAG >(result[j]);
13301371 }
13311372 size_t tmpmask = (1u << shift) - 1 ;
13321373 index = static_cast <size_t >(n) & tmpmask;
13331374 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
1334- result[j] = mf.template multiply <LowuopsTag >(result[j], table[index][j]);
1375+ result[j] = mf.template multiply <PTAG >(result[j], table[index][j]);
13351376 }
13361377 return result;
1337- } else {
1378+ } else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 2 ) {
13381379
13391380 // This CODE_SECTION optimizes table initialization to skip the high even
13401381 // elements of the table. The while loop does clever cmovs to avoid ever
@@ -1364,17 +1405,17 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
13641405 for (std::size_t i=2 ; i<HALFSIZE; i+=2 ) {
13651406 std::size_t halfi = i/2 ;
13661407 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
1367- table[i][j]= mf.template square <LowuopsTag >(table[halfi][j]);
1408+ table[i][j]= mf.template square <PTAG >(table[halfi][j]);
13681409 }
13691410 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
1370- table[i+1 ][j] = mf.template multiply <LowuopsTag >(
1411+ table[i+1 ][j] = mf.template multiply <PTAG >(
13711412 table[halfi+1 ][j], table[halfi][j]);
13721413 }
13731414 }
13741415 constexpr size_t QUARTERSIZE = TABLESIZE/4 ;
13751416 for (std::size_t i=1 ; i<HALFSIZE; i+=2 ) {
13761417 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
1377- table[HALFSIZE + i][j] = mf.template multiply <LowuopsTag >(
1418+ table[HALFSIZE + i][j] = mf.template multiply <PTAG >(
13781419 table[QUARTERSIZE + i][j], table[QUARTERSIZE][j]);
13791420 }
13801421 }
@@ -1412,7 +1453,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14121453 HPBC_CLOCKWORK_ASSERT (index1 % 2 == 1 );
14131454 HPBC_CLOCKWORK_ASSERT (index2 < TABLESIZE/2 );
14141455 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1415- result[j] = mf.template multiply <LowuopsTag >(table[index1][j], table[index2][j]);
1456+ result[j] = mf.template multiply <PTAG >(table[index1][j], table[index2][j]);
14161457
14171458
14181459 while (shift >= P) {
@@ -1442,7 +1483,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14421483 if (USE_SLIDING_WINDOW_OPTIMIZATION) {
14431484 while (shift > P && (static_cast <size_t >(branchless_shift_right (n, shift-1 )) & 1u ) == 0 ) {
14441485 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1445- result[j] = mf.template square <LowuopsTag >(result[j]);
1486+ result[j] = mf.template square <PTAG >(result[j]);
14461487 --shift;
14471488 }
14481489 }
@@ -1451,7 +1492,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14511492 static_assert (P > 0 , " " );
14521493 for (int i=0 ; i<P - 1 ; ++i) {
14531494 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1454- result[j] = mf.template square <LowuopsTag >(result[j]);
1495+ result[j] = mf.template square <PTAG >(result[j]);
14551496 }
14561497 }
14571498
@@ -1465,7 +1506,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14651506#else
14661507 V tmp = V::template cselect_on_bit_eq0<0 >(static_cast <uint64_t >(index), table[index/2 ][j], result[j]);
14671508#endif
1468- result[j] = mf.template multiply <LowuopsTag >(tmp, result[j]);
1509+ result[j] = mf.template multiply <PTAG >(tmp, result[j]);
14691510 }
14701511
14711512 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
@@ -1475,7 +1516,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14751516#else
14761517 V tmp = V::template cselect_on_bit_eq0<0 >(static_cast <uint64_t >(index), result[j], table[index][j]);
14771518#endif
1478- result[j] = mf.template multiply <LowuopsTag >(tmp, result[j]);
1519+ result[j] = mf.template multiply <PTAG >(tmp, result[j]);
14791520 }
14801521 }
14811522
@@ -1485,16 +1526,110 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14851526
14861527 for (int i=0 ; i<shift; ++i) {
14871528 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1488- result[j] = mf.template square <LowuopsTag >(result[j]);
1529+ result[j] = mf.template square <PTAG >(result[j]);
14891530 }
14901531
14911532 size_t tmpmask = (1u << shift) - 1 ;
14921533 HPBC_CLOCKWORK_ASSERT (tmpmask <= MASK_SMALL);
14931534 index = static_cast <size_t >(n) & tmpmask;
14941535 HPBC_CLOCKWORK_ASSERT (index < TABLESIZE/2 );
14951536 HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1496- result[j] = mf.template multiply <LowuopsTag>(result[j],table[index][j]);
1537+ result[j] = mf.template multiply <PTAG>(result[j],table[index][j]);
1538+
1539+ return result;
1540+
1541+ } else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 3 ) {
1542+ // this is adapted from arraypow_cmov() in montgomery_pow.h
1543+
1544+ std::array<V, ARRAY_SIZE> bases = x;
1545+ U exponent = n;
1546+
1547+ std::array<V, ARRAY_SIZE> result;
1548+ if (static_cast <size_t >(exponent) & 1u ) {
1549+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1550+ result[j] = bases[j];
1551+ } else {
1552+ V mont_one = mf.getUnityValue ();
1553+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1554+ result[j] = mont_one;
1555+ }
14971556
1557+ while (exponent > 1u ) {
1558+ exponent = static_cast <U>(exponent >> 1 );
1559+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1560+ bases[j] = mf.template square <PTAG>(bases[j]);
1561+
1562+ V mont_one = mf.getUnityValue ();
1563+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
1564+ # ifndef HURCHALLA_MONTGOMERY_POW_2KARY_USE_CSELECT_ON_BIT
1565+ V tmp = mont_one;
1566+ tmp.cmov (static_cast <size_t >(exponent) & 1u , bases[j]);
1567+ # else
1568+ V tmp = V::template cselect_on_bit_ne0<0 >(
1569+ static_cast <uint64_t >(exponent), bases[j], mont_one);
1570+ # endif
1571+ result[j] = mf.template multiply <PTAG>(result[j], tmp);
1572+ }
1573+ }
1574+ return result;
1575+
1576+ } else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 4 ) {
1577+ // this is adapted from arraypow_masked() in montgomery_pow.h
1578+
1579+ std::array<V, ARRAY_SIZE> bases = x;
1580+ U exponent = n;
1581+
1582+ std::array<V, ARRAY_SIZE> result;
1583+ if (static_cast <size_t >(exponent) & 1u ) {
1584+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1585+ result[j] = bases[j];
1586+ } else {
1587+ V mont_one = mf.getUnityValue ();
1588+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1589+ result[j] = mont_one;
1590+ }
1591+
1592+ while (exponent > 1u ) {
1593+ exponent = static_cast <U>(exponent >> 1 );
1594+
1595+ V mont_one = mf.getUnityValue ();
1596+ HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0 ; j<ARRAY_SIZE; ++j) {
1597+ bases[j] = mf.template square <PTAG>(bases[j]);
1598+ V tmp = mont_one;
1599+ // note: since we are doing masked selections, we definitely don't
1600+ // want to use cselect_on_bit here
1601+ tmp.template cmov <CSelectMaskedTag>(static_cast <size_t >(exponent) & 1u , bases[j]);
1602+ result[j] = mf.template multiply <PTAG>(result[j], tmp);
1603+ }
1604+ }
1605+ return result;
1606+
1607+ } else {
1608+ static_assert (CODE_SECTION == 5 , " " );
1609+ // this is adapted from arraypow_cond_branch() in montgomery_pow.h
1610+
1611+ std::array<V, ARRAY_SIZE> bases = x;
1612+ U exponent = n;
1613+
1614+ std::array<V, ARRAY_SIZE> result;
1615+ if (static_cast <size_t >(exponent) & 1u ) {
1616+ for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1617+ result[j] = bases[j];
1618+ } else {
1619+ V mont_one = mf.getUnityValue ();
1620+ for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1621+ result[j] = mont_one;
1622+ }
1623+
1624+ while (exponent > 1u ) {
1625+ exponent = static_cast <U>(exponent >> 1 );
1626+ for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1627+ bases[j] = mf.template square <PTAG>(bases[j]);
1628+ if (static_cast <size_t >(exponent) & 1u ) {
1629+ for (size_t j=0 ; j<ARRAY_SIZE; ++j)
1630+ result[j] = mf.template multiply <PTAG>(result[j], bases[j]);
1631+ }
1632+ }
14981633 return result;
14991634}
15001635
0 commit comments