1111#include " hurchalla/util/compiler_macros.h"
1212#include " hurchalla/util/programming_by_contract.h"
1313#include " montgomery_two_pow.h"
14+ #include " ../montgomery_pow_kary.h"
1415#include < iostream>
1516#include < stdexcept>
1617#include < chrono>
@@ -131,7 +132,7 @@ void bench_range(U min, U range)
131132
132133 for (U x = max; x > min; x = x-2 ) {
133134 MontType mf (x);
134- auto val = hurchalla::montgomery_two_pow (mf, x-1 );
135+ auto val = hurchalla::montgomery_two_pow (mf, static_cast <U>( x-1 ) );
135136 // the sole purpose of the next line is to prevent the optimizer from
136137 // being able to eliminate the function call above.
137138 if (mf.getCanonicalValue (val) == mf.getZeroValue ())
@@ -164,9 +165,9 @@ void bench_range(U min, U range)
164165 {
165166 U total_zeros = 0 ;
166167 auto t0 = steady_clock::now ();
167- for (U x = max; x > min; x = x-4 ) {
168+ for (U x = max; x > min && x >= 4 ; x = x-4 ) {
168169 std::array<MontType, 2 > mf_arr { MontType (x), MontType (x - 2 ) };
169- std::array<U, 2 > exponent_arr { mf_arr[0 ].getModulus () - 1 , mf_arr[1 ].getModulus () - 1 };
170+ std::array<U, 2 > exponent_arr { static_cast <U>( mf_arr[0 ].getModulus () - 1 ) , static_cast <U>( mf_arr[1 ].getModulus () - 1 ) };
170171 auto mont_result_arr = hurchalla::array_montgomery_two_pow (mf_arr, exponent_arr);
171172 if (mf_arr[0 ].getCanonicalValue (mont_result_arr[0 ]) == mf_arr[0 ].getZeroValue ())
172173 total_zeros++;
@@ -183,7 +184,7 @@ void bench_range(U min, U range)
183184 U total_zeros = 0 ;
184185 auto t0 = steady_clock::now ();
185186 constexpr std::size_t ARRAY_SIZE = 3 ;
186- for (U x = max; x > min; x = x - (2 *ARRAY_SIZE)) {
187+ for (U x = max; x > min && x >= ( 2 *ARRAY_SIZE) ; x = x - (2 *ARRAY_SIZE)) {
187188 std::array<MontType, ARRAY_SIZE> mf_arr {
188189 MontType (x), MontType (x - 2 ), MontType (x - 4 ) };
189190 std::array<U, ARRAY_SIZE> exponent_arr;
@@ -205,7 +206,7 @@ void bench_range(U min, U range)
205206 U total_zeros = 0 ;
206207 auto t0 = steady_clock::now ();
207208 constexpr std::size_t ARRAY_SIZE = 4 ;
208- for (U x = max; x > min; x = x - (2 *ARRAY_SIZE)) {
209+ for (U x = max; x > min && x >= ( 2 *ARRAY_SIZE) ; x = x - (2 *ARRAY_SIZE)) {
209210 std::array<MontType, ARRAY_SIZE> mf_arr {
210211 MontType (x), MontType (x - 2 ), MontType (x - 4 ), MontType (x - 6 ) };
211212 std::array<U, ARRAY_SIZE> exponent_arr;
@@ -227,7 +228,7 @@ void bench_range(U min, U range)
227228 U total_zeros = 0 ;
228229 auto t0 = steady_clock::now ();
229230 constexpr std::size_t ARRAY_SIZE = 5 ;
230- for (U x = max; x > min; x = x - (2 *ARRAY_SIZE)) {
231+ for (U x = max; x > min && x >= ( 2 *ARRAY_SIZE) ; x = x - (2 *ARRAY_SIZE)) {
231232 std::array<MontType, ARRAY_SIZE> mf_arr {
232233 MontType (x), MontType (x - 2 ), MontType (x - 4 ), MontType (x - 6 ), MontType (x - 8 ) };
233234 std::array<U, ARRAY_SIZE> exponent_arr;
@@ -249,7 +250,7 @@ void bench_range(U min, U range)
249250 U total_zeros = 0 ;
250251 auto t0 = steady_clock::now ();
251252 constexpr std::size_t ARRAY_SIZE = 6 ;
252- for (U x = max; x > min; x = x - (2 *ARRAY_SIZE)) {
253+ for (U x = max; x > min && x >= ( 2 *ARRAY_SIZE) ; x = x - (2 *ARRAY_SIZE)) {
253254 std::array<MontType, ARRAY_SIZE> mf_arr {
254255 MontType (x), MontType (x - 2 ), MontType (x - 4 ), MontType (x - 6 ), MontType (x - 8 ), MontType (x - 10 ) };
255256 std::array<U, ARRAY_SIZE> exponent_arr;
@@ -271,7 +272,7 @@ void bench_range(U min, U range)
271272 U total_zeros = 0 ;
272273 auto t0 = steady_clock::now ();
273274 constexpr std::size_t ARRAY_SIZE = 8 ;
274- for (U x = max; x > min; x = x - (2 *ARRAY_SIZE)) {
275+ for (U x = max; x > min && x >= ( 2 *ARRAY_SIZE) ; x = x - (2 *ARRAY_SIZE)) {
275276 std::array<MontType, ARRAY_SIZE> mf_arr {
276277 MontType (x), MontType (x - 2 ), MontType (x - 4 ), MontType (x - 6 ), MontType (x - 8 ), MontType (x - 10 ), MontType (x - 12 ), MontType (x - 14 ) };
277278 std::array<U, ARRAY_SIZE> exponent_arr;
@@ -288,6 +289,45 @@ void bench_range(U min, U range)
288289 std::cout << uint_to_string (total_zeros) << " " ;
289290 }
290291
292+ dsec::rep mpkary_time = 0 ;
293+ dsec::rep mfpow_time = 0 ;
294+ {
295+ U total_zeros = 0 ;
296+ auto t0 = steady_clock::now ();
297+ constexpr std::size_t ARRAY_SIZE = 4 ;
298+ for (U x = max; x > min && x >= (2 *ARRAY_SIZE); x = x - (2 *ARRAY_SIZE)) {
299+ MontType mf (x);
300+ U exponent = mf.getModulus () - 1 ;
301+ std::array<typename MontType::MontgomeryValue, ARRAY_SIZE> bases;
302+ HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0 ; j<ARRAY_SIZE; ++j)
303+ bases[j] = mf.convertIn (j + 5 );
304+ auto mont_result_arr = hurchalla::array_montgomery_pow_kary (mf, bases, exponent);
305+ HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0 ; j<ARRAY_SIZE; ++j) {
306+ if (mf.getCanonicalValue (mont_result_arr[j]) == mf.getZeroValue ())
307+ total_zeros++;
308+ }
309+ }
310+ auto t1 = steady_clock::now ();
311+ mpkary_time = dsec (t1-t0).count ();
312+
313+ t0 = steady_clock::now ();
314+ for (U x = max; x > min && x >= (2 *ARRAY_SIZE); x = x - (2 *ARRAY_SIZE)) {
315+ MontType mf (x);
316+ U exponent = mf.getModulus () - 1 ;
317+ std::array<typename MontType::MontgomeryValue, ARRAY_SIZE> bases;
318+ HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0 ; j<ARRAY_SIZE; ++j)
319+ bases[j] = mf.convertIn (j + 5 );
320+ auto mont_result_arr = mf.pow (bases, exponent);
321+ HURCHALLA_REQUEST_UNROLL_LOOP for (int j=0 ; j<ARRAY_SIZE; ++j) {
322+ if (mf.getCanonicalValue (mont_result_arr[j]) == mf.getZeroValue ())
323+ total_zeros++;
324+ }
325+ }
326+ t1 = steady_clock::now ();
327+ mfpow_time = dsec (t1-t0).count ();
328+ std::cout << uint_to_string (total_zeros) << " " ;
329+ }
330+
291331 std::cout << " \n\n " ;
292332
293333 std::cout << " montgomery_two_pow() time: " << mtp_time << " \n " ;
@@ -305,6 +345,7 @@ void bench_range(U min, U range)
305345 std::cout << " array5 performance ratio = " << mtp_time / mtp_time_5 << " \n " ;
306346 std::cout << " array6 performance ratio = " << mtp_time / mtp_time_6 << " \n " ;
307347 std::cout << " array8 performance ratio = " << mtp_time / mtp_time_8 << " \n " ;
348+ std::cout << " \n arraykary performance ratio = " << mfpow_time / mpkary_time << " \n " ;
308349
309350 std::cout << ' \n ' ;
310351}
@@ -315,18 +356,32 @@ void bench_range(U min, U range)
315356
316357int main ()
317358{
359+ namespace hc = hurchalla;
318360 std::cout << " ---Running Example Program---\n\n " ;
319361
320362// These are types and values that you may wish to change:
321363 using U = __uint128_t ;
364+ // using U = uint64_t;
365+
366+ constexpr int UDIGITS = hc::ut_numeric_limits<U>::digits;
322367 // Note you're not required to use string_to_uint(). I just used it as a way to set values greater than 2^64 without getting a compile error.
323368 U exponent = string_to_uint<U>(" 8" );
324- U modulus = string_to_uint<U>(" 1234567890123456789012345678901" );
369+ U modulus;
370+ if constexpr (UDIGITS >= 128 )
371+ modulus = string_to_uint<U>(" 1234567890123456789012345678901" );
372+ else if constexpr (UDIGITS >= 64 )
373+ modulus = string_to_uint<U>(" 1234567890123456789" );
374+ else if constexpr (UDIGITS >= 32 )
375+ modulus = string_to_uint<U>(" 123456789" );
376+ else if constexpr (UDIGITS >= 16 )
377+ modulus = string_to_uint<U>(" 12345" );
378+ else
379+ modulus = string_to_uint<U>(" 63" );
325380 if (modulus % 2 == 0 ) {
326381 std::cout << " Error: modulus must be odd to use Montgomery arithmetic\n " ;
327382 return 1 ;
328383 }
329- namespace hc = hurchalla;
384+
330385#if 1
331386 // If you can guarantee your modulus will always be less than one quarter the
332387 // maximum value of type U, then use MontgomeryQuarter for speed.
@@ -350,7 +405,7 @@ int main()
350405 // as fastest per exponentiation, at roughly 1.9x the speed of the plain
351406 // (non-array) function montgomery_two_pow.
352407 std::array<MontType, 2 > mf_arr { MontType (modulus), MontType (modulus + 2 ) }; // modulus + 2 is just an arbitrary second modulus value
353- std::array<U, 2 > exponent_arr { exponent, exponent + 3 }; // exponent + 3 is just an arbitrary second exponent value
408+ std::array<U, 2 > exponent_arr { exponent, static_cast <U>( exponent + 3 ) }; // exponent + 3 is just an arbitrary second exponent value
354409 auto mont_result_arr = hc::array_montgomery_two_pow (mf_arr, exponent_arr);
355410 std::array<U, 2 > result_arr { mf_arr[0 ].convertOut (mont_result_arr[0 ]),
356411 mf_arr[1 ].convertOut (mont_result_arr[1 ]) };
@@ -375,7 +430,7 @@ int main()
375430// ------ Tests for correctneess ------
376431
377432 // test for correctness with a range of exponents
378- U range = 100000 ;
433+ U range = static_cast <U>( 100000 ) ;
379434 constexpr U maxU = hc::ut_numeric_limits<U>::max ();
380435 auto mont_two = mf.add (mf.getUnityValue (), mf.getUnityValue ());
381436 for (exponent = maxU; exponent > maxU-range; exponent = exponent-2 ) {
@@ -385,7 +440,7 @@ int main()
385440 if (result != standard_result) {
386441 std::cout << " bug in montgomery_two_pow found: got wrong result for " ;
387442 std::cout << " 2^" << uint_to_string (exponent) << " (mod " <<
388- uint_to_string (modulus) << ' \n ' ;
443+ uint_to_string (modulus) << " ) \n " ;
389444 return 1 ;
390445 }
391446 }
@@ -406,9 +461,9 @@ int main()
406461 result = mf.convertOut (mont_result_arr[j]);
407462 U standard_result = mf.convertOut (mf.pow (mont_two, exponent_arr[j]));
408463 if (result != standard_result) {
409- std::cout << " bug in array_montgomery_two_pow found: got wrong result for " ;
464+ std::cout << " bug2 in array_montgomery_two_pow found: got wrong result for " ;
410465 std::cout << " 2^" << uint_to_string (exponent_arr[j]) << " (mod " <<
411- uint_to_string (mf.getModulus ()) << ' \n ' ;
466+ uint_to_string (mf.getModulus ()) << " ) \n " ;
412467 return 1 ;
413468 }
414469 }
@@ -417,20 +472,25 @@ int main()
417472 // test for correctness with a range of moduli.
418473 // simulates fermat primality tests
419474 constexpr auto maxMF = MontType::max_modulus ();
420- for (auto mod = maxMF; mod > maxMF-range; mod = mod-2 ) {
475+ auto mod_range = static_cast <decltype (maxMF)>(range);
476+ if (mod_range >= maxMF)
477+ mod_range = maxMF - 1 ;
478+ for (auto mod = maxMF; mod > maxMF-mod_range; mod = mod-2 ) {
421479 MontType mt (mod);
422480 auto mont_two = mt.add (mt.getUnityValue (), mt.getUnityValue ());
423- mont_result = hc::montgomery_two_pow (mt, mod- 1 );
481+ mont_result = hc::montgomery_two_pow (mt, static_cast < decltype ( mod)>(mod- 1 ) );
424482 result = mt.convertOut (mont_result);
425483 U standard_result = mt.convertOut (mt.pow (mont_two, mod-1 ));
426484 if (result != standard_result) {
427- std::cout << " bug2 in montgomery_two_pow found: got wrong result for " ;
428- std::cout << " 2^" << uint_to_string (mod- 1 ) << " (mod " <<
429- uint_to_string (mod) << ' \n ' ;
485+ std::cout << " bug3 in montgomery_two_pow found: got wrong result for " ;
486+ std::cout << " 2^" << uint_to_string (static_cast < decltype ( mod)>(mod- 1 ) ) << " (mod " <<
487+ uint_to_string (mod) << " ) \n " ;
430488 return 1 ;
431489 }
432490 }
433- for (auto mod = maxMF; mod > maxMF-range; mod = mod-2 ) {
491+
492+ mod_range -= 16 ;
493+ for (auto mod = maxMF; mod > maxMF-mod_range; mod = mod-2 ) {
434494 constexpr size_t ARRAY_SIZE = 3 ;
435495 // We use std::vector to indirectly make a MontType array, since
436496 // MontType has no default constructor.
@@ -448,9 +508,9 @@ int main()
448508 auto mont_two = mf_arr[j].add (mf_arr[j].getUnityValue (), mf_arr[j].getUnityValue ());
449509 U standard_result = mf_arr[j].convertOut (mf_arr[j].pow (mont_two, exponent_arr[j]));
450510 if (result != standard_result) {
451- std::cout << " bug2 in array_montgomery_two_pow found: got wrong result for " ;
511+ std::cout << " bug4 in array_montgomery_two_pow found: got wrong result for " ;
452512 std::cout << " 2^" << uint_to_string (exponent_arr[j]) << " (mod " <<
453- uint_to_string (mf_arr[j].getModulus ()) << ' \n ' ;
513+ uint_to_string (mf_arr[j].getModulus ()) << " ) \n " ;
454514 return 1 ;
455515 }
456516 }
@@ -463,7 +523,7 @@ int main()
463523
464524// ------- Benchmarking --------
465525
466- bench_range<MontType>(maxU - range, range);
526+ bench_range<MontType>(static_cast <U>( maxU - range) , range);
467527
468528 std::cout << " ---Example Program Finished---\n " ;
469529 return 0 ;
0 commit comments