Skip to content

Commit c6cdcdf

Browse files
Jeffrey HurchallaJeffrey Hurchalla
authored andcommitted
improve experimental montgomery_two_pow benchmarking
1 parent 59a05f2 commit c6cdcdf

File tree

1 file changed

+84
-24
lines changed

1 file changed

+84
-24
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_two_pow/example_montgomery_two_pow.cpp

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "hurchalla/util/compiler_macros.h"
1212
#include "hurchalla/util/programming_by_contract.h"
1313
#include "montgomery_two_pow.h"
14+
#include "../montgomery_pow_kary.h"
1415
#include <iostream>
1516
#include <stdexcept>
1617
#include <chrono>
@@ -131,7 +132,7 @@ void bench_range(U min, U range)
131132

132133
for (U x = max; x > min; x = x-2) {
133134
MontType mf(x);
134-
auto val = hurchalla::montgomery_two_pow(mf, x-1);
135+
auto val = hurchalla::montgomery_two_pow(mf, static_cast<U>(x-1));
135136
// the sole purpose of the next line is to prevent the optimizer from
136137
// being able to eliminate the function call above.
137138
if (mf.getCanonicalValue(val) == mf.getZeroValue())
@@ -164,9 +165,9 @@ void bench_range(U min, U range)
164165
{
165166
U total_zeros = 0;
166167
auto t0 = steady_clock::now();
167-
for (U x = max; x > min; x = x-4) {
168+
for (U x = max; x > min && x >= 4; x = x-4) {
168169
std::array<MontType, 2> mf_arr { MontType(x), MontType(x - 2) };
169-
std::array<U, 2> exponent_arr { mf_arr[0].getModulus() - 1, mf_arr[1].getModulus() - 1 };
170+
std::array<U, 2> exponent_arr { static_cast<U>(mf_arr[0].getModulus() - 1), static_cast<U>(mf_arr[1].getModulus() - 1) };
170171
auto mont_result_arr = hurchalla::array_montgomery_two_pow(mf_arr, exponent_arr);
171172
if (mf_arr[0].getCanonicalValue(mont_result_arr[0]) == mf_arr[0].getZeroValue())
172173
total_zeros++;
@@ -183,7 +184,7 @@ void bench_range(U min, U range)
183184
U total_zeros = 0;
184185
auto t0 = steady_clock::now();
185186
constexpr std::size_t ARRAY_SIZE = 3;
186-
for (U x = max; x > min; x = x - (2*ARRAY_SIZE)) {
187+
for (U x = max; x > min && x >= (2*ARRAY_SIZE); x = x - (2*ARRAY_SIZE)) {
187188
std::array<MontType, ARRAY_SIZE> mf_arr {
188189
MontType(x), MontType(x - 2), MontType(x - 4) };
189190
std::array<U, ARRAY_SIZE> exponent_arr;
@@ -205,7 +206,7 @@ void bench_range(U min, U range)
205206
U total_zeros = 0;
206207
auto t0 = steady_clock::now();
207208
constexpr std::size_t ARRAY_SIZE = 4;
208-
for (U x = max; x > min; x = x - (2*ARRAY_SIZE)) {
209+
for (U x = max; x > min && x >= (2*ARRAY_SIZE); x = x - (2*ARRAY_SIZE)) {
209210
std::array<MontType, ARRAY_SIZE> mf_arr {
210211
MontType(x), MontType(x - 2), MontType(x - 4), MontType(x - 6) };
211212
std::array<U, ARRAY_SIZE> exponent_arr;
@@ -227,7 +228,7 @@ void bench_range(U min, U range)
227228
U total_zeros = 0;
228229
auto t0 = steady_clock::now();
229230
constexpr std::size_t ARRAY_SIZE = 5;
230-
for (U x = max; x > min; x = x - (2*ARRAY_SIZE)) {
231+
for (U x = max; x > min && x >= (2*ARRAY_SIZE); x = x - (2*ARRAY_SIZE)) {
231232
std::array<MontType, ARRAY_SIZE> mf_arr {
232233
MontType(x), MontType(x - 2), MontType(x - 4), MontType(x - 6), MontType(x - 8) };
233234
std::array<U, ARRAY_SIZE> exponent_arr;
@@ -249,7 +250,7 @@ void bench_range(U min, U range)
249250
U total_zeros = 0;
250251
auto t0 = steady_clock::now();
251252
constexpr std::size_t ARRAY_SIZE = 6;
252-
for (U x = max; x > min; x = x - (2*ARRAY_SIZE)) {
253+
for (U x = max; x > min && x >= (2*ARRAY_SIZE); x = x - (2*ARRAY_SIZE)) {
253254
std::array<MontType, ARRAY_SIZE> mf_arr {
254255
MontType(x), MontType(x - 2), MontType(x - 4), MontType(x - 6), MontType(x - 8), MontType(x - 10) };
255256
std::array<U, ARRAY_SIZE> exponent_arr;
@@ -271,7 +272,7 @@ void bench_range(U min, U range)
271272
U total_zeros = 0;
272273
auto t0 = steady_clock::now();
273274
constexpr std::size_t ARRAY_SIZE = 8;
274-
for (U x = max; x > min; x = x - (2*ARRAY_SIZE)) {
275+
for (U x = max; x > min && x >= (2*ARRAY_SIZE); x = x - (2*ARRAY_SIZE)) {
275276
std::array<MontType, ARRAY_SIZE> mf_arr {
276277
MontType(x), MontType(x - 2), MontType(x - 4), MontType(x - 6), MontType(x - 8), MontType(x - 10), MontType(x - 12), MontType(x - 14) };
277278
std::array<U, ARRAY_SIZE> exponent_arr;
@@ -288,6 +289,45 @@ void bench_range(U min, U range)
288289
std::cout << uint_to_string(total_zeros) << " ";
289290
}
290291

292+
dsec::rep mpkary_time = 0;
293+
dsec::rep mfpow_time = 0;
294+
{
295+
U total_zeros = 0;
296+
auto t0 = steady_clock::now();
297+
constexpr std::size_t ARRAY_SIZE = 4;
298+
for (U x = max; x > min && x >= (2*ARRAY_SIZE); x = x - (2*ARRAY_SIZE)) {
299+
MontType mf(x);
300+
U exponent = mf.getModulus() - 1;
301+
std::array<typename MontType::MontgomeryValue, ARRAY_SIZE> bases;
302+
HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0; j<ARRAY_SIZE; ++j)
303+
bases[j] = mf.convertIn(j + 5);
304+
auto mont_result_arr = hurchalla::array_montgomery_pow_kary(mf, bases, exponent);
305+
HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0; j<ARRAY_SIZE; ++j) {
306+
if (mf.getCanonicalValue(mont_result_arr[j]) == mf.getZeroValue())
307+
total_zeros++;
308+
}
309+
}
310+
auto t1 = steady_clock::now();
311+
mpkary_time = dsec(t1-t0).count();
312+
313+
t0 = steady_clock::now();
314+
for (U x = max; x > min && x >= (2*ARRAY_SIZE); x = x - (2*ARRAY_SIZE)) {
315+
MontType mf(x);
316+
U exponent = mf.getModulus() - 1;
317+
std::array<typename MontType::MontgomeryValue, ARRAY_SIZE> bases;
318+
HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0; j<ARRAY_SIZE; ++j)
319+
bases[j] = mf.convertIn(j + 5);
320+
auto mont_result_arr = mf.pow(bases, exponent);
321+
HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0; j<ARRAY_SIZE; ++j) {
322+
if (mf.getCanonicalValue(mont_result_arr[j]) == mf.getZeroValue())
323+
total_zeros++;
324+
}
325+
}
326+
t1 = steady_clock::now();
327+
mfpow_time = dsec(t1-t0).count();
328+
std::cout << uint_to_string(total_zeros) << " ";
329+
}
330+
291331
std::cout << "\n\n";
292332

293333
std::cout << "montgomery_two_pow() time: " << mtp_time << "\n";
@@ -305,6 +345,7 @@ void bench_range(U min, U range)
305345
std::cout << "array5 performance ratio = " << mtp_time / mtp_time_5 << "\n";
306346
std::cout << "array6 performance ratio = " << mtp_time / mtp_time_6 << "\n";
307347
std::cout << "array8 performance ratio = " << mtp_time / mtp_time_8 << "\n";
348+
std::cout << "\narraykary performance ratio = " << mfpow_time / mpkary_time << "\n";
308349

309350
std::cout << '\n';
310351
}
@@ -315,18 +356,32 @@ void bench_range(U min, U range)
315356

316357
int main()
317358
{
359+
namespace hc = hurchalla;
318360
std::cout << "---Running Example Program---\n\n";
319361

320362
// These are types and values that you may wish to change:
321363
using U = __uint128_t;
364+
// using U = uint64_t;
365+
366+
constexpr int UDIGITS = hc::ut_numeric_limits<U>::digits;
322367
// Note you're not required to use string_to_uint(). I just used it as a way to set values greater than 2^64 without getting a compile error.
323368
U exponent = string_to_uint<U>("8");
324-
U modulus = string_to_uint<U>("1234567890123456789012345678901");
369+
U modulus;
370+
if constexpr (UDIGITS >= 128)
371+
modulus = string_to_uint<U>("1234567890123456789012345678901");
372+
else if constexpr (UDIGITS >= 64)
373+
modulus = string_to_uint<U>("1234567890123456789");
374+
else if constexpr (UDIGITS >= 32)
375+
modulus = string_to_uint<U>("123456789");
376+
else if constexpr (UDIGITS >= 16)
377+
modulus = string_to_uint<U>("12345");
378+
else
379+
modulus = string_to_uint<U>("63");
325380
if (modulus % 2 == 0) {
326381
std::cout << "Error: modulus must be odd to use Montgomery arithmetic\n";
327382
return 1;
328383
}
329-
namespace hc = hurchalla;
384+
330385
#if 1
331386
// If you can guarantee your modulus will always be less than one quarter the
332387
// maximum value of type U, then use MontgomeryQuarter for speed.
@@ -350,7 +405,7 @@ int main()
350405
// as fastest per exponentiation, at roughly 1.9x the speed of the plain
351406
// (non-array) function montgomery_two_pow.
352407
std::array<MontType, 2> mf_arr { MontType(modulus), MontType(modulus + 2) }; // modulus + 2 is just an arbitrary second modulus value
353-
std::array<U, 2> exponent_arr { exponent, exponent + 3 }; // exponent + 3 is just an arbitrary second exponent value
408+
std::array<U, 2> exponent_arr { exponent, static_cast<U>(exponent + 3) }; // exponent + 3 is just an arbitrary second exponent value
354409
auto mont_result_arr = hc::array_montgomery_two_pow(mf_arr, exponent_arr);
355410
std::array<U, 2> result_arr { mf_arr[0].convertOut(mont_result_arr[0]),
356411
mf_arr[1].convertOut(mont_result_arr[1]) };
@@ -375,7 +430,7 @@ int main()
375430
// ------ Tests for correctneess ------
376431

377432
// test for correctness with a range of exponents
378-
U range = 100000;
433+
U range = static_cast<U>(100000);
379434
constexpr U maxU = hc::ut_numeric_limits<U>::max();
380435
auto mont_two = mf.add(mf.getUnityValue(), mf.getUnityValue());
381436
for (exponent = maxU; exponent > maxU-range; exponent = exponent-2) {
@@ -385,7 +440,7 @@ int main()
385440
if (result != standard_result) {
386441
std::cout << "bug in montgomery_two_pow found: got wrong result for ";
387442
std::cout << "2^" << uint_to_string(exponent) << " (mod " <<
388-
uint_to_string(modulus) << '\n';
443+
uint_to_string(modulus) << ")\n";
389444
return 1;
390445
}
391446
}
@@ -406,9 +461,9 @@ int main()
406461
result = mf.convertOut(mont_result_arr[j]);
407462
U standard_result = mf.convertOut(mf.pow(mont_two, exponent_arr[j]));
408463
if (result != standard_result) {
409-
std::cout << "bug in array_montgomery_two_pow found: got wrong result for ";
464+
std::cout << "bug2 in array_montgomery_two_pow found: got wrong result for ";
410465
std::cout << "2^" << uint_to_string(exponent_arr[j]) << " (mod " <<
411-
uint_to_string(mf.getModulus()) << '\n';
466+
uint_to_string(mf.getModulus()) << ")\n";
412467
return 1;
413468
}
414469
}
@@ -417,20 +472,25 @@ int main()
417472
// test for correctness with a range of moduli.
418473
// simulates fermat primality tests
419474
constexpr auto maxMF = MontType::max_modulus();
420-
for (auto mod = maxMF; mod > maxMF-range; mod = mod-2) {
475+
auto mod_range = static_cast<decltype(maxMF)>(range);
476+
if (mod_range >= maxMF)
477+
mod_range = maxMF - 1;
478+
for (auto mod = maxMF; mod > maxMF-mod_range; mod = mod-2) {
421479
MontType mt(mod);
422480
auto mont_two = mt.add(mt.getUnityValue(), mt.getUnityValue());
423-
mont_result = hc::montgomery_two_pow(mt, mod-1);
481+
mont_result = hc::montgomery_two_pow(mt, static_cast<decltype(mod)>(mod-1));
424482
result = mt.convertOut(mont_result);
425483
U standard_result = mt.convertOut(mt.pow(mont_two, mod-1));
426484
if (result != standard_result) {
427-
std::cout << "bug2 in montgomery_two_pow found: got wrong result for ";
428-
std::cout << "2^" << uint_to_string(mod-1) << " (mod " <<
429-
uint_to_string(mod) << '\n';
485+
std::cout << "bug3 in montgomery_two_pow found: got wrong result for ";
486+
std::cout << "2^" << uint_to_string(static_cast<decltype(mod)>(mod-1)) << " (mod " <<
487+
uint_to_string(mod) << ")\n";
430488
return 1;
431489
}
432490
}
433-
for (auto mod = maxMF; mod > maxMF-range; mod = mod-2) {
491+
492+
mod_range -= 16;
493+
for (auto mod = maxMF; mod > maxMF-mod_range; mod = mod-2) {
434494
constexpr size_t ARRAY_SIZE = 3;
435495
// We use std::vector to indirectly make a MontType array, since
436496
// MontType has no default constructor.
@@ -448,9 +508,9 @@ int main()
448508
auto mont_two = mf_arr[j].add(mf_arr[j].getUnityValue(), mf_arr[j].getUnityValue());
449509
U standard_result = mf_arr[j].convertOut(mf_arr[j].pow(mont_two, exponent_arr[j]));
450510
if (result != standard_result) {
451-
std::cout << "bug2 in array_montgomery_two_pow found: got wrong result for ";
511+
std::cout << "bug4 in array_montgomery_two_pow found: got wrong result for ";
452512
std::cout << "2^" << uint_to_string(exponent_arr[j]) << " (mod " <<
453-
uint_to_string(mf_arr[j].getModulus()) << '\n';
513+
uint_to_string(mf_arr[j].getModulus()) << ")\n";
454514
return 1;
455515
}
456516
}
@@ -463,7 +523,7 @@ int main()
463523

464524
// ------- Benchmarking --------
465525

466-
bench_range<MontType>(maxU - range, range);
526+
bench_range<MontType>(static_cast<U>(maxU - range), range);
467527

468528
std::cout << "---Example Program Finished---\n";
469529
return 0;

0 commit comments

Comments
 (0)