Skip to content

Commit 0a1c971

Browse files
committed
improve REDC
1 parent e132616 commit 0a1c971

File tree

23 files changed

+825
-391
lines changed

23 files changed

+825
-391
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/MontgomeryForm.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,19 +111,25 @@ class MontgomeryForm final {
111111

112112
// Returns the converted value of the standard number 'a' into monty form.
113113
// Requires a >= 0. (Note there is no restriction on how large 'a' can be.)
114-
HURCHALLA_FORCE_INLINE
114+
// Normally you don't want to specify PTAG (just accept the default).
115+
// For advanced use: PTAG can be either LowlatencyTag or LowuopsTag, which
116+
// will optimize this function for either low latency or a low uop count.
117+
template <class PTAG = LowuopsTag> HURCHALLA_FORCE_INLINE
115118
MontgomeryValue convertIn(T a) const
116119
{
117120
HPBC_CLOCKWORK_API_PRECONDITION(a >= 0);
118-
return impl.convertIn(a);
121+
return impl.template convertIn<PTAG>(a);
119122
}
120123

121124
// Converts (montgomery value) x into a "normal" number; returns the result.
122125
// Guarantees 0 <= result < modulus.
123-
HURCHALLA_FORCE_INLINE
126+
// Normally you don't want to specify PTAG (just accept the default).
127+
// For advanced use: PTAG can be either LowlatencyTag or LowuopsTag, which
128+
// will optimize this function for either low latency or a low uop count.
129+
template <class PTAG = LowuopsTag> HURCHALLA_FORCE_INLINE
124130
T convertOut(MontgomeryValue x) const
125131
{
126-
T a = impl.convertOut(x);
132+
T a = impl.template convertOut<PTAG>(x);
127133
HPBC_CLOCKWORK_POSTCONDITION(0 <= a && a < getModulus());
128134
return a;
129135
}
@@ -556,7 +562,7 @@ class MontgomeryForm final {
556562

557563
// Returns the Montgomery division of x by a small power of two (requires
558564
// 0 <= power <= 7, which translates to Montgomery division by 1,2,4,8,16,
559-
// 32,64, or 128). This function always produces an exact correct result.
565+
// 32,64, or 128). This function always produces an exactly correct result.
560566
// Note that Montgomery division is modular division, which is different
561567
// from normal and non modular division - modular division performs modular
562568
// multiplication by the modular multiplicative inverse of the divisor. So,
@@ -617,10 +623,11 @@ class MontgomeryForm final {
617623
// If you have already instantiated this MontgomeryForm, then calling
618624
// remainder() should be faster than directly computing a % modulus,
619625
// even if your CPU has extremely fast division (like many new CPUs).
620-
HURCHALLA_FORCE_INLINE T remainder(T a) const
626+
template <class PTAG = LowlatencyTag> HURCHALLA_FORCE_INLINE
627+
T remainder(T a) const
621628
{
622629
HPBC_CLOCKWORK_API_PRECONDITION(a >= 0);
623-
return impl.remainder(a);
630+
return impl.template remainder<PTAG>(a);
624631
}
625632

626633
};

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/ImplMontgomeryForm.contents

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@ public:
4545
HURCHALLA_IMF_MAYBE_FORCE_INLINE
4646
T getModulus() const { return static_cast<T>(impl.getModulus()); }
4747

48-
HURCHALLA_IMF_MAYBE_FORCE_INLINE
48+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
4949
MontgomeryValue convertIn(T a) const
5050
{
51-
return impl.convertIn(static_cast<U>(a), LowlatencyTag());
51+
return impl.convertIn(static_cast<U>(a), PTAG());
5252
}
5353

54-
HURCHALLA_IMF_MAYBE_FORCE_INLINE
54+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
5555
T convertOut(MontgomeryValue x) const
5656
{
57-
return static_cast<T>(impl.convertOut(x));
57+
return static_cast<T>(impl.convertOut(x, PTAG()));
5858
}
5959

6060
HURCHALLA_IMF_MAYBE_FORCE_INLINE
@@ -240,11 +240,11 @@ public:
240240
return static_cast<T>(impl.gcd_with_modulus(x, gcd_functor));
241241
}
242242

243-
HURCHALLA_IMF_MAYBE_FORCE_INLINE
243+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
244244
T remainder(T a) const
245245
{
246246
HPBC_CLOCKWORK_PRECONDITION(a >= 0);
247-
return static_cast<T>(impl.remainder(static_cast<U>(a)));
247+
return static_cast<T>(impl.remainder(static_cast<U>(a), PTAG()));
248248
}
249249

250250

@@ -277,9 +277,10 @@ public:
277277
return impl.twoPowLimited_times_x_v2(exponent, x, PTAG());
278278
}
279279

280-
HURCHALLA_IMF_MAYBE_FORCE_INLINE U getMagicValue() const
280+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
281+
U getMagicValue() const
281282
{
282-
return impl.getMagicValue();
283+
return impl.getMagicValue(PTAG());
283284
}
284285

285286
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
@@ -306,16 +307,16 @@ public:
306307
return impl.getSquaringValue(x);
307308
}
308309

309-
HURCHALLA_IMF_MAYBE_FORCE_INLINE
310+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
310311
SquaringValue squareSV(SquaringValue sv) const
311312
{
312-
return impl.squareSV(sv);
313+
return impl.squareSV(sv, PTAG());
313314
}
314315

315-
HURCHALLA_IMF_MAYBE_FORCE_INLINE
316+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
316317
MontgomeryValue squareToMontgomeryValue(SquaringValue sv) const
317318
{
318-
return impl.squareToMontgomeryValue(sv);
319+
return impl.squareToMontgomeryValue(sv, PTAG());
319320
}
320321

321322
HURCHALLA_IMF_MAYBE_FORCE_INLINE

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontgomeryFormExtensions.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct MontgomeryFormExtensions final {
7878
HURCHALLA_FORCE_INLINE
7979
static RU getMagicValue(const MF& mf)
8080
{
81-
return mf.impl.getMagicValue();
81+
return mf.impl.template getMagicValue<PTAG>();
8282
}
8383

8484
HURCHALLA_FORCE_INLINE
@@ -123,14 +123,14 @@ struct MontgomeryFormExtensions final {
123123
HURCHALLA_FORCE_INLINE
124124
static SquaringValue squareSV(const MF& mf, SquaringValue sv)
125125
{
126-
return mf.impl.squareSV(sv);
126+
return mf.impl.template squareSV<PTAG>(sv);
127127
}
128128

129129
HURCHALLA_FORCE_INLINE
130130
static MontgomeryValue
131131
squareToMontgomeryValue(const MF& mf, SquaringValue sv)
132132
{
133-
return mf.impl.squareToMontgomeryValue(sv);
133+
return mf.impl.template squareToMontgomeryValue<PTAG>(sv);
134134
}
135135

136136
HURCHALLA_FORCE_INLINE

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyCommonBase.h

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,25 +93,29 @@ class MontyCommonBase {
9393
return result;
9494
}
9595

96-
HURCHALLA_FORCE_INLINE T convertOut(V x) const
96+
template <class PTAG> HURCHALLA_FORCE_INLINE
97+
T convertOut(V x, PTAG) const
9798
{
9899
T u_hi = 0;
99100
// get a Natural number (i.e. number >= 0) congruent to x (mod n)
100101
T u_lo = static_cast<const D*>(this)->getNaturalEquivalence(x);
101102
namespace hc = ::hurchalla;
102-
bool isNegative;
103-
T result = hc::REDC_incomplete(isNegative, u_hi, u_lo, n_, inv_n_);
103+
104+
T minuend, subtrahend;
105+
hc::REDC_incomplete(minuend, subtrahend, u_hi, u_lo, n_, inv_n_, PTAG());
106+
T result = static_cast<T>(minuend - subtrahend);
104107
// Because u_hi == 0, we should expect the following 'if' to be very
105108
// well predicted.
106-
if (isNegative)
109+
if (minuend < subtrahend)
107110
result = static_cast<T>(result + n_);
108-
HPBC_CLOCKWORK_ASSERT2(result==hc::REDC_standard(u_hi,u_lo,n_,inv_n_));
111+
HPBC_CLOCKWORK_ASSERT2(result == hc::REDC_standard(u_hi, u_lo, n_, inv_n_, PTAG()));
109112

110113
HPBC_CLOCKWORK_POSTCONDITION2(result < n_);
111114
return result;
112115
}
113116

114-
HURCHALLA_FORCE_INLINE T remainder(T a) const
117+
template <class PTAG> HURCHALLA_FORCE_INLINE
118+
T remainder(T a, PTAG) const
115119
{
116120
HPBC_CLOCKWORK_INVARIANT2(r_mod_n_ < n_);
117121
namespace hc = ::hurchalla;
@@ -120,7 +124,7 @@ class MontyCommonBase {
120124
// Since a is type T, 0 <= a < R. And since r_mod_n is type T and
121125
// r_mod_n < n, we know 0 <= r_mod_n < n. Therefore,
122126
// 0 <= u == a * r_mod_n < R*n, which will satisfy REDC's precondition.
123-
T result = hc::REDC_standard(u_hi, u_lo, n_, inv_n_, LowlatencyTag());
127+
T result = hc::REDC_standard(u_hi, u_lo, n_, inv_n_, PTAG());
124128

125129
HPBC_CLOCKWORK_POSTCONDITION2(result < n_);
126130
return result;
@@ -548,15 +552,16 @@ class MontyCommonBase {
548552
}
549553

550554
// returns (R*R*R) mod N
551-
HURCHALLA_FORCE_INLINE T getMagicValue() const
555+
template <class PTAG> HURCHALLA_FORCE_INLINE
556+
T getMagicValue(PTAG) const
552557
{
553558
HPBC_CLOCKWORK_INVARIANT2(r_squared_mod_n_ < n_);
554559
namespace hc = ::hurchalla;
555560
T u_lo;
556561
T u_hi = hc::unsigned_multiply_to_hilo_product(u_lo,
557562
r_squared_mod_n_, r_squared_mod_n_);
558563
HPBC_CLOCKWORK_ASSERT2(u_hi < n_); // verify that (u_hi*R + u_lo) < n*R
559-
T result = hc::REDC_standard(u_hi, u_lo, n_, inv_n_, LowlatencyTag());
564+
T result = hc::REDC_standard(u_hi, u_lo, n_, inv_n_, PTAG());
560565
HPBC_CLOCKWORK_POSTCONDITION2(result < n_);
561566
return result;
562567
}
@@ -566,7 +571,7 @@ class MontyCommonBase {
566571
V convertInExtended_aTimesR(T a, T magicValue, PTAG) const
567572
{
568573
// see convertIn() comments for explanation
569-
HPBC_CLOCKWORK_PRECONDITION2(magicValue == getMagicValue());
574+
HPBC_CLOCKWORK_PRECONDITION2(magicValue == getMagicValue(PTAG()));
570575
HPBC_CLOCKWORK_ASSERT2(magicValue < n_);
571576
namespace hc = ::hurchalla;
572577
T u_lo;
@@ -602,7 +607,7 @@ class MontyCommonBase {
602607
template <class PTAG> HURCHALLA_FORCE_INLINE
603608
V RTimesTwoPowLimited(size_t exponent, T magicValue, PTAG) const
604609
{
605-
HPBC_CLOCKWORK_PRECONDITION2(magicValue == getMagicValue());
610+
HPBC_CLOCKWORK_PRECONDITION2(magicValue == getMagicValue(PTAG()));
606611
HPBC_CLOCKWORK_ASSERT2(magicValue < n_);
607612
static constexpr int digitsT = ut_numeric_limits<T>::digits;
608613
int power = static_cast<int>(exponent);

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyFullRange.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ class MontyFullRange final :
216216
return SV(x.get(), 0);
217217
}
218218

219-
HURCHALLA_FORCE_INLINE SV squareSV(SV sv) const
219+
template <class PTAG> HURCHALLA_FORCE_INLINE
220+
SV squareSV(SV sv, PTAG) const
220221
{
221222
// see squareToHiLo in MontyFullRangeMasked.h for basic ideas of
222223
// proof for why u_hi and u_lo are correct
@@ -229,14 +230,16 @@ class MontyFullRange final :
229230
T u_lo = sqlo;
230231
HPBC_CLOCKWORK_ASSERT2(u_hi < n_);
231232

232-
bool isNegative;
233-
T res = hc::REDC_incomplete(isNegative, u_hi, u_lo, n_, BC::inv_n_);
234-
T subtrahend = isNegative ? res : static_cast<T>(0);
233+
T minu, subt;
234+
hc::REDC_incomplete(minu, subt, u_hi, u_lo, n_, BC::inv_n_, PTAG());
235+
T res = static_cast<T>(minu - subt);
236+
T subtrahend = (minu < subt) ? res : static_cast<T>(0);
235237
SV result(res, subtrahend);
236238
return result;
237239
}
238240

239-
HURCHALLA_FORCE_INLINE V squareToMontgomeryValue(SV sv) const
241+
template <class PTAG> HURCHALLA_FORCE_INLINE
242+
V squareToMontgomeryValue(SV sv, PTAG) const
240243
{
241244
// see squareToHiLo in MontyFullRangeMasked.h for basic ideas of
242245
// proof for why u_hi and u_lo are correct
@@ -249,8 +252,7 @@ class MontyFullRange final :
249252
T u_lo = sqlo;
250253
HPBC_CLOCKWORK_ASSERT2(u_hi < n_);
251254

252-
T res = hc::REDC_standard(
253-
u_hi, u_lo, n_, BC::inv_n_, hc::LowlatencyTag());
255+
T res = hc::REDC_standard(u_hi, u_lo, n_, BC::inv_n_, PTAG());
254256
V result(res);
255257
return result;
256258
}

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyHalfRange.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,15 +498,17 @@ class MontyHalfRange final :
498498
static_assert(std::is_same<V, SV>::value, "");
499499
return x;
500500
}
501-
HURCHALLA_FORCE_INLINE SV squareSV(SV sv) const
501+
template <class PTAG> HURCHALLA_FORCE_INLINE
502+
SV squareSV(SV sv, PTAG) const
502503
{
503504
static_assert(std::is_same<V, SV>::value, "");
504-
return BC::square(sv, LowlatencyTag());
505+
return BC::square(sv, PTAG());
505506
}
506-
HURCHALLA_FORCE_INLINE V squareToMontgomeryValue(SV sv) const
507+
template <class PTAG> HURCHALLA_FORCE_INLINE
508+
V squareToMontgomeryValue(SV sv, PTAG) const
507509
{
508510
static_assert(std::is_same<V, SV>::value, "");
509-
return BC::square(sv, LowlatencyTag());
511+
return BC::square(sv, PTAG());
510512
}
511513
HURCHALLA_FORCE_INLINE V getMontgomeryValue(SV sv) const
512514
{
@@ -524,7 +526,7 @@ class MontyHalfRange final :
524526
{
525527
HPBC_CLOCKWORK_PRECONDITION2(u_hi < n_); // verifies that (u_hi*R + u_lo) < n*R
526528
namespace hc = ::hurchalla;
527-
T result = hc::REDC_incomplete(u_hi, u_lo, n_, BC::inv_n_);
529+
T result = hc::REDC_incomplete(u_hi, u_lo, n_, BC::inv_n_, PTAG());
528530
resultIsZero = (result == 0);
529531
V v = V(static_cast<S>(result));
530532
HPBC_CLOCKWORK_POSTCONDITION2(isValid(v));

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyQuarterRange.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,17 @@ class MontyQuarterRange final : public
317317
static_assert(std::is_same<V, SV>::value, "");
318318
return x;
319319
}
320-
HURCHALLA_FORCE_INLINE SV squareSV(SV sv) const
320+
template <class PTAG> HURCHALLA_FORCE_INLINE
321+
SV squareSV(SV sv, PTAG) const
321322
{
322323
static_assert(std::is_same<V, SV>::value, "");
323-
return BC::square(sv, LowlatencyTag());
324+
return BC::square(sv, PTAG());
324325
}
325-
HURCHALLA_FORCE_INLINE V squareToMontgomeryValue(SV sv) const
326+
template <class PTAG> HURCHALLA_FORCE_INLINE
327+
V squareToMontgomeryValue(SV sv, PTAG) const
326328
{
327329
static_assert(std::is_same<V, SV>::value, "");
328-
return BC::square(sv, LowlatencyTag());
330+
return BC::square(sv, PTAG());
329331
}
330332
HURCHALLA_FORCE_INLINE V getMontgomeryValue(SV sv) const
331333
{
@@ -343,7 +345,7 @@ class MontyQuarterRange final : public
343345
{
344346
HPBC_CLOCKWORK_PRECONDITION2(u_hi < n_); // verifies that (u_hi*R + u_lo) < n*R
345347
namespace hc = ::hurchalla;
346-
T result = hc::REDC_incomplete(u_hi, u_lo, n_, BC::inv_n_);
348+
T result = hc::REDC_incomplete(u_hi, u_lo, n_, BC::inv_n_, PTAG());
347349
resultIsZero = (result == 0);
348350
T sum = static_cast<T>(result + n_);
349351
HPBC_CLOCKWORK_POSTCONDITION2(0 < sum && sum < static_cast<T>(2*n_));

0 commit comments

Comments
 (0)