Skip to content

Commit 0927f0b

Browse files
committed
improve mont 2^k-ary pow benchmarking
1 parent 716e0fd commit 0927f0b

File tree

2 files changed

+296
-443
lines changed

2 files changed

+296
-443
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_pow_2kary/experimental_montgomery_pow_2kary.h

Lines changed: 161 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,27 @@ struct experimental_montgomery_pow_2kary {
7272

7373

7474
if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
75-
// For comparison purposes, this is the current MontgomeryForm pow.
76-
// the static cast may lose bits, so this might not be an exact benchmark
77-
return mf.pow(x, static_cast<typename MF::IntegerType>(n));
75+
// this is a masked version of pow from montgomery_pow.h
76+
V base = x;
77+
U exponent = n;
78+
79+
V result;
80+
if (static_cast<size_t>(exponent) & 1u)
81+
result = base;
82+
else
83+
result = mf.getUnityValue();
84+
85+
while (exponent > 1u) {
86+
exponent = static_cast<U>(exponent >> 1);
87+
88+
base = mf.square(base);
89+
V tmp = mf.getUnityValue();
90+
// note: since we are doing masked selections, we definitely don't
91+
// want to use cselect_on_bit here
92+
tmp.template cmov<CSelectMaskedTag>(static_cast<size_t>(exponent) & 1u, base);
93+
result = mf.multiply(result, tmp);
94+
}
95+
return result;
7896

7997
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 1) {
8098
// this is a branch version of pow from montgomery_pow.h
@@ -1201,7 +1219,8 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12011219
bool USE_SLIDING_WINDOW_OPTIMIZATION = false,
12021220
size_t TABLE_BITS = 4,
12031221
size_t CODE_SECTION = 0,
1204-
bool USE_SQUARING_VALUE_OPTIMIZATION = false>
1222+
bool USE_SQUARING_VALUE_OPTIMIZATION = false,
1223+
class PTAG = LowuopsTag>
12051224
static std::array<typename MF::MontgomeryValue, ARRAY_SIZE>
12061225
call(const MF& mf,
12071226
const std::array<typename MF::MontgomeryValue, ARRAY_SIZE>& x,
@@ -1215,7 +1234,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12151234
"the beginning of this function to calculate 1024+ table entries!)");
12161235

12171236
using V = typename MF::MontgomeryValue;
1218-
using MFE_LU = hurchalla::detail::MontgomeryFormExtensions<MF, LowuopsTag>;
1237+
using MFE_LU = hurchalla::detail::MontgomeryFormExtensions<MF, PTAG>;
12191238
using SV = typename MFE_LU::SquaringValue;
12201239
using std::size_t;
12211240
U n = nexp;
@@ -1230,9 +1249,31 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12301249

12311250

12321251
if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
1233-
// For comparison purposes, this is the current MontgomeryForm pow.
1234-
// the static cast may lose bits, so this might not be an exact benchmark
1235-
return mf.pow(x, static_cast<typename MF::IntegerType>(nexp));
1252+
// this is adapted from arraypow_cond_branch_unrolled() in montgomery_pow.h
1253+
1254+
std::array<V, ARRAY_SIZE> bases = x;
1255+
U exponent = n;
1256+
1257+
std::array<V, ARRAY_SIZE> result;
1258+
if (static_cast<size_t>(exponent) & 1u) {
1259+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1260+
result[j] = bases[j];
1261+
} else {
1262+
V mont_one = mf.getUnityValue();
1263+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1264+
result[j] = mont_one;
1265+
}
1266+
1267+
while (exponent > 1u) {
1268+
exponent = static_cast<U>(exponent >> 1);
1269+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1270+
bases[j] = mf.template square<PTAG>(bases[j]);
1271+
if (static_cast<size_t>(exponent) & 1u) {
1272+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1273+
result[j] = mf.template multiply<PTAG>(result[j], bases[j]);
1274+
}
1275+
}
1276+
return result;
12361277

12371278
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 1) {
12381279

@@ -1245,9 +1286,9 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
12451286
for (std::size_t i=2; i<TABLESIZE; i+=2) {
12461287
std::size_t halfi = i/2;
12471288
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1248-
table[i][j] = mf.template square<LowuopsTag>(table[halfi][j]);
1289+
table[i][j] = mf.template square<PTAG>(table[halfi][j]);
12491290
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1250-
table[i+1][j] = mf.template multiply<LowuopsTag>(table[halfi+1][j], table[halfi][j]);
1291+
table[i+1][j] = mf.template multiply<PTAG>(table[halfi+1][j], table[halfi][j]);
12511292
}
12521293
}
12531294

@@ -1302,21 +1343,21 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
13021343
if (USE_SLIDING_WINDOW_OPTIMIZATION) {
13031344
while (shift > P && (static_cast<size_t>(branchless_shift_right(n, shift-1)) & 1u) == 0) {
13041345
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1305-
result[j] = mf.template square<LowuopsTag>(result[j]);
1346+
result[j] = mf.template square<PTAG>(result[j]);
13061347
--shift;
13071348
}
13081349
}
13091350

13101351
for (int i=0; i<P; ++i) {
13111352
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1312-
result[j] = mf.template square<LowuopsTag>(result[j]);
1353+
result[j] = mf.template square<PTAG>(result[j]);
13131354
}
13141355
}
13151356

13161357
shift -= P;
13171358
index = static_cast<size_t>(branchless_shift_right(n, shift)) & MASK;
13181359
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1319-
result[j] = mf.template multiply<LowuopsTag>(result[j], table[index][j]);
1360+
result[j] = mf.template multiply<PTAG>(result[j], table[index][j]);
13201361
}
13211362
}
13221363

@@ -1326,15 +1367,15 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
13261367

13271368
for (int i=0; i<shift; ++i) {
13281369
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1329-
result[j] = mf.template square<LowuopsTag>(result[j]);
1370+
result[j] = mf.template square<PTAG>(result[j]);
13301371
}
13311372
size_t tmpmask = (1u << shift) - 1;
13321373
index = static_cast<size_t>(n) & tmpmask;
13331374
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1334-
result[j] = mf.template multiply<LowuopsTag>(result[j], table[index][j]);
1375+
result[j] = mf.template multiply<PTAG>(result[j], table[index][j]);
13351376
}
13361377
return result;
1337-
} else {
1378+
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 2) {
13381379

13391380
// This CODE_SECTION optimizes table initialization to skip the high even
13401381
// elements of the table. The while loop does clever cmovs to avoid ever
@@ -1364,17 +1405,17 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
13641405
for (std::size_t i=2; i<HALFSIZE; i+=2) {
13651406
std::size_t halfi = i/2;
13661407
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1367-
table[i][j]= mf.template square<LowuopsTag>(table[halfi][j]);
1408+
table[i][j]= mf.template square<PTAG>(table[halfi][j]);
13681409
}
13691410
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1370-
table[i+1][j] = mf.template multiply<LowuopsTag>(
1411+
table[i+1][j] = mf.template multiply<PTAG>(
13711412
table[halfi+1][j], table[halfi][j]);
13721413
}
13731414
}
13741415
constexpr size_t QUARTERSIZE = TABLESIZE/4;
13751416
for (std::size_t i=1; i<HALFSIZE; i+=2) {
13761417
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1377-
table[HALFSIZE + i][j] = mf.template multiply<LowuopsTag>(
1418+
table[HALFSIZE + i][j] = mf.template multiply<PTAG>(
13781419
table[QUARTERSIZE + i][j], table[QUARTERSIZE][j]);
13791420
}
13801421
}
@@ -1412,7 +1453,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14121453
HPBC_CLOCKWORK_ASSERT(index1 % 2 == 1);
14131454
HPBC_CLOCKWORK_ASSERT(index2 < TABLESIZE/2);
14141455
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1415-
result[j] = mf.template multiply<LowuopsTag>(table[index1][j], table[index2][j]);
1456+
result[j] = mf.template multiply<PTAG>(table[index1][j], table[index2][j]);
14161457

14171458

14181459
while (shift >= P) {
@@ -1442,7 +1483,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14421483
if (USE_SLIDING_WINDOW_OPTIMIZATION) {
14431484
while (shift > P && (static_cast<size_t>(branchless_shift_right(n, shift-1)) & 1u) == 0) {
14441485
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1445-
result[j] = mf.template square<LowuopsTag>(result[j]);
1486+
result[j] = mf.template square<PTAG>(result[j]);
14461487
--shift;
14471488
}
14481489
}
@@ -1451,7 +1492,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14511492
static_assert(P > 0, "");
14521493
for (int i=0; i<P - 1; ++i) {
14531494
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1454-
result[j] = mf.template square<LowuopsTag>(result[j]);
1495+
result[j] = mf.template square<PTAG>(result[j]);
14551496
}
14561497
}
14571498

@@ -1465,7 +1506,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14651506
#else
14661507
V tmp = V::template cselect_on_bit_eq0<0>(static_cast<uint64_t>(index), table[index/2][j], result[j]);
14671508
#endif
1468-
result[j] = mf.template multiply<LowuopsTag>(tmp, result[j]);
1509+
result[j] = mf.template multiply<PTAG>(tmp, result[j]);
14691510
}
14701511

14711512
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
@@ -1475,7 +1516,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14751516
#else
14761517
V tmp = V::template cselect_on_bit_eq0<0>(static_cast<uint64_t>(index), result[j], table[index][j]);
14771518
#endif
1478-
result[j] = mf.template multiply<LowuopsTag>(tmp, result[j]);
1519+
result[j] = mf.template multiply<PTAG>(tmp, result[j]);
14791520
}
14801521
}
14811522

@@ -1485,16 +1526,110 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
14851526

14861527
for (int i=0; i<shift; ++i) {
14871528
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1488-
result[j] = mf.template square<LowuopsTag>(result[j]);
1529+
result[j] = mf.template square<PTAG>(result[j]);
14891530
}
14901531

14911532
size_t tmpmask = (1u << shift) - 1;
14921533
HPBC_CLOCKWORK_ASSERT(tmpmask <= MASK_SMALL);
14931534
index = static_cast<size_t>(n) & tmpmask;
14941535
HPBC_CLOCKWORK_ASSERT(index < TABLESIZE/2);
14951536
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1496-
result[j] = mf.template multiply<LowuopsTag>(result[j],table[index][j]);
1537+
result[j] = mf.template multiply<PTAG>(result[j],table[index][j]);
1538+
1539+
return result;
1540+
1541+
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 3) {
1542+
// this is adapted from arraypow_cmov() in montgomery_pow.h
1543+
1544+
std::array<V, ARRAY_SIZE> bases = x;
1545+
U exponent = n;
1546+
1547+
std::array<V, ARRAY_SIZE> result;
1548+
if (static_cast<size_t>(exponent) & 1u) {
1549+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1550+
result[j] = bases[j];
1551+
} else {
1552+
V mont_one = mf.getUnityValue();
1553+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1554+
result[j] = mont_one;
1555+
}
14971556

1557+
while (exponent > 1u) {
1558+
exponent = static_cast<U>(exponent >> 1);
1559+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1560+
bases[j] = mf.template square<PTAG>(bases[j]);
1561+
1562+
V mont_one = mf.getUnityValue();
1563+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1564+
# ifndef HURCHALLA_MONTGOMERY_POW_2KARY_USE_CSELECT_ON_BIT
1565+
V tmp = mont_one;
1566+
tmp.cmov(static_cast<size_t>(exponent) & 1u, bases[j]);
1567+
# else
1568+
V tmp = V::template cselect_on_bit_ne0<0>(
1569+
static_cast<uint64_t>(exponent), bases[j], mont_one);
1570+
# endif
1571+
result[j] = mf.template multiply<PTAG>(result[j], tmp);
1572+
}
1573+
}
1574+
return result;
1575+
1576+
} else if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 4) {
1577+
// this is adapted from arraypow_masked() in montgomery_pow.h
1578+
1579+
std::array<V, ARRAY_SIZE> bases = x;
1580+
U exponent = n;
1581+
1582+
std::array<V, ARRAY_SIZE> result;
1583+
if (static_cast<size_t>(exponent) & 1u) {
1584+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1585+
result[j] = bases[j];
1586+
} else {
1587+
V mont_one = mf.getUnityValue();
1588+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j)
1589+
result[j] = mont_one;
1590+
}
1591+
1592+
while (exponent > 1u) {
1593+
exponent = static_cast<U>(exponent >> 1);
1594+
1595+
V mont_one = mf.getUnityValue();
1596+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1597+
bases[j] = mf.template square<PTAG>(bases[j]);
1598+
V tmp = mont_one;
1599+
// note: since we are doing masked selections, we definitely don't
1600+
// want to use cselect_on_bit here
1601+
tmp.template cmov<CSelectMaskedTag>(static_cast<size_t>(exponent) & 1u, bases[j]);
1602+
result[j] = mf.template multiply<PTAG>(result[j], tmp);
1603+
}
1604+
}
1605+
return result;
1606+
1607+
} else {
1608+
static_assert(CODE_SECTION == 5, "");
1609+
// this is adapted from arraypow_cond_branch() in montgomery_pow.h
1610+
1611+
std::array<V, ARRAY_SIZE> bases = x;
1612+
U exponent = n;
1613+
1614+
std::array<V, ARRAY_SIZE> result;
1615+
if (static_cast<size_t>(exponent) & 1u) {
1616+
for (size_t j=0; j<ARRAY_SIZE; ++j)
1617+
result[j] = bases[j];
1618+
} else {
1619+
V mont_one = mf.getUnityValue();
1620+
for (size_t j=0; j<ARRAY_SIZE; ++j)
1621+
result[j] = mont_one;
1622+
}
1623+
1624+
while (exponent > 1u) {
1625+
exponent = static_cast<U>(exponent >> 1);
1626+
for (size_t j=0; j<ARRAY_SIZE; ++j)
1627+
bases[j] = mf.template square<PTAG>(bases[j]);
1628+
if (static_cast<size_t>(exponent) & 1u) {
1629+
for (size_t j=0; j<ARRAY_SIZE; ++j)
1630+
result[j] = mf.template multiply<PTAG>(result[j], bases[j]);
1631+
}
1632+
}
14981633
return result;
14991634
}
15001635

0 commit comments

Comments
 (0)