Skip to content

Commit c3d0d1c

Browse files
committed
improve MontgomeryForm pow(), add 128 bit lowuop REDC (not enabled yet), add and improve functions to experimental montgomery two_pow
1 parent 08d26db commit c3d0d1c

File tree

19 files changed

+3382
-846
lines changed

19 files changed

+3382
-846
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/MontgomeryForm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class MontgomeryForm final {
4343
static_assert(ut_numeric_limits<T>::is_integer, "");
4444
static_assert(ut_numeric_limits<T>::digits <=
4545
ut_numeric_limits<typename MontyType::uint_type>::digits, "");
46+
using SV = typename MontyType::squaringvalue_type;
4647
using RU = typename MontyType::uint_type;
4748
public:
4849
using IntegerType = T;
@@ -477,6 +478,7 @@ class MontgomeryForm final {
477478

478479
// Calculates and returns the modular exponentiation of the montgomery value
479480
// 'base' to the power of (the type T variable) 'exponent'.
481+
// Performance note: if your base is 2, two_pow() is much more efficient.
480482
HURCHALLA_FORCE_INLINE
481483
MontgomeryValue pow(MontgomeryValue base, T exponent) const
482484
{

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/ImplMontgomeryForm.contents

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public:
3838
using MontgomeryValue = typename MontyType::montvalue_type;
3939
using CanonicalValue = typename MontyType::canonvalue_type;
4040
using FusingValue = typename MontyType::fusingvalue_type;
41+
using SquaringValue = typename MontyType::squaringvalue_type;
4142

4243
explicit ImplMontgomeryForm(T modulus) : impl(static_cast<U>(modulus)) {}
4344

@@ -247,6 +248,17 @@ public:
247248
return impl.convertIn(a, PTAG());
248249
}
249250

251+
HURCHALLA_IMF_MAYBE_FORCE_INLINE CanonicalValue getMontvalueR() const
252+
{
253+
return impl.getMontvalueR();
254+
}
255+
256+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
257+
MontgomeryValue twoPowLimited_times_x(size_t exponent, CanonicalValue x) const
258+
{
259+
return impl.twoPowLimited_times_x(exponent, x, PTAG());
260+
}
261+
250262
HURCHALLA_IMF_MAYBE_FORCE_INLINE U getMagicValue() const
251263
{
252264
return impl.getMagicValue();
@@ -268,3 +280,28 @@ public:
268280
{
269281
return impl.RTimesTwoPowLimited(exponent, magicValue, PTAG());
270282
}
283+
284+
285+
HURCHALLA_IMF_MAYBE_FORCE_INLINE
286+
SquaringValue getSquaringValue(MontgomeryValue x) const
287+
{
288+
return impl.getSquaringValue(x);
289+
}
290+
291+
HURCHALLA_IMF_MAYBE_FORCE_INLINE
292+
SquaringValue squareSV(SquaringValue sv) const
293+
{
294+
return impl.squareSV(sv);
295+
}
296+
297+
HURCHALLA_IMF_MAYBE_FORCE_INLINE
298+
MontgomeryValue squareToMontgomeryValue(SquaringValue sv) const
299+
{
300+
return impl.squareToMontgomeryValue(sv);
301+
}
302+
303+
HURCHALLA_IMF_MAYBE_FORCE_INLINE
304+
MontgomeryValue getMontgomeryValue(SquaringValue sv) const
305+
{
306+
return impl.getMontgomeryValue(sv);
307+
}

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontgomeryFormExtensions.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,44 @@ struct MontgomeryFormExtensions final {
3030
static_assert(ut_numeric_limits<RU>::is_integer, "");
3131
static_assert(!(ut_numeric_limits<RU>::is_signed), "");
3232

33+
using CanonicalValue = typename MF::CanonicalValue;
3334
using MontgomeryValue = typename MF::MontgomeryValue;
35+
using SquaringValue = typename MF::SV;
3436

3537
HURCHALLA_FORCE_INLINE
3638
static MontgomeryValue convertInExtended(const MF& mf, RU a)
3739
{
3840
return mf.impl.template convertInExtended<PTAG>(a);
3941
}
4042

43+
// note: montvalueR is the Montgomery representation of R.
44+
// In normal integer form it is literally R squared mod N.
45+
HURCHALLA_FORCE_INLINE
46+
static CanonicalValue getMontvalueR(const MF& mf)
47+
{
48+
return mf.impl.getMontvalueR();
49+
}
50+
51+
// this first shifts x by exponent, which is equivalent to
52+
// multiplying x by 2^exponent, and then it completes the
53+
// mont mul as usual by calling REDC.
54+
// -- IMPORTANT NOTE -- because (2^exponent) is an integer domain
55+
// value rather than a montgomery domain value, the returned
56+
// result viewed as an integer value is
57+
// REDC((x_int * R) * (2^exponent)) == (x_int * (2^exponent) * R) * R^(-1)
58+
// To counteract the inverse R factor, so that you get what most likely
59+
// you wanted, being just plain (x_int * (2^exponent) * R),
60+
// you need to ensure that x has an extra factor of R built into it it,
61+
// rather than just the normal single factor of x_int * R. To build an
62+
// extra factor of R into x, you first get montR = getMontvalueR(mf),
63+
// and then you do a normal montgomery multiply of x and montR.
64+
HURCHALLA_FORCE_INLINE
65+
static MontgomeryValue twoPowLimited_times_x(const MF& mf, size_t exponent, CanonicalValue x)
66+
{
67+
HPBC_CLOCKWORK_PRECONDITION(exponent < ut_numeric_limits<RU>::digits);
68+
return mf.impl.template twoPowLimited_times_x<PTAG>(exponent, x);
69+
}
70+
4171
// note: magicValue is R cubed mod N (in normal integer form)
4272
HURCHALLA_FORCE_INLINE
4373
static RU getMagicValue(const MF& mf)
@@ -76,6 +106,32 @@ struct MontgomeryFormExtensions final {
76106
HPBC_CLOCKWORK_PRECONDITION(exponent < ut_numeric_limits<RU>::digits);
77107
return mf.impl.template RTimesTwoPowLimited<PTAG>(exponent, magicValue);
78108
}
109+
110+
111+
HURCHALLA_FORCE_INLINE
112+
static SquaringValue getSquaringValue(const MF& mf, MontgomeryValue x)
113+
{
114+
return mf.impl.getSquaringValue(x);
115+
}
116+
117+
HURCHALLA_FORCE_INLINE
118+
static SquaringValue squareSV(const MF& mf, SquaringValue sv)
119+
{
120+
return mf.impl.squareSV(sv);
121+
}
122+
123+
HURCHALLA_FORCE_INLINE
124+
static MontgomeryValue
125+
squareToMontgomeryValue(const MF& mf, SquaringValue sv)
126+
{
127+
return mf.impl.squareToMontgomeryValue(sv);
128+
}
129+
130+
HURCHALLA_FORCE_INLINE
131+
static MontgomeryValue getMontgomeryValue(const MF& mf, SquaringValue sv)
132+
{
133+
return mf.impl.getMontgomeryValue(sv);
134+
}
79135
};
80136

81137

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyCommonBase.h

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,8 @@ class MontyCommonBase {
9999
namespace hc = ::hurchalla;
100100
bool isNegative;
101101
T result = hc::REDC_incomplete(isNegative, u_hi, u_lo, n_, inv_n_);
102-
HPBC_CLOCKWORK_ASSERT2(result == 0 ?
103-
isNegative == false : isNegative == true);
104-
// We would expect that result == 0 is usually very rare, so going by
105-
// the assert above, the next 'if' should be well predicted.
102+
// Because u_hi == 0, we should expect the following 'if' to be very
103+
// well predicted.
106104
if (isNegative)
107105
result = static_cast<T>(result + n_);
108106
HPBC_CLOCKWORK_ASSERT2(result==hc::REDC_standard(u_hi,u_lo,n_,inv_n_));
@@ -444,6 +442,32 @@ class MontyCommonBase {
444442

445443

446444

445+
// returns (R*R) mod N
446+
HURCHALLA_FORCE_INLINE C getMontvalueR() const
447+
{
448+
HPBC_CLOCKWORK_INVARIANT2(r_squared_mod_n_ < n_);
449+
return C(r_squared_mod_n_);
450+
}
451+
template <class PTAG> HURCHALLA_FORCE_INLINE
452+
V twoPowLimited_times_x(size_t exponent, C cx, PTAG) const
453+
{
454+
static constexpr int digitsT = ut_numeric_limits<T>::digits;
455+
int power = static_cast<int>(exponent);
456+
HPBC_CLOCKWORK_PRECONDITION2(0 <= power && power < digitsT);
457+
458+
T tmp = cx.get();
459+
HPBC_CLOCKWORK_INVARIANT2(tmp < n_);
460+
T u_lo = static_cast<T>(tmp << power);
461+
int rshift = digitsT - power;
462+
HPBC_CLOCKWORK_ASSERT2(rshift > 0);
463+
T u_hi = (tmp >> 1) >> (rshift - 1);
464+
465+
HPBC_CLOCKWORK_ASSERT2(u_hi < n_);
466+
const D* child = static_cast<const D*>(this);
467+
V result = child->montyREDC(u_hi, u_lo, PTAG());
468+
HPBC_CLOCKWORK_POSTCONDITION2(child->isValid(result));
469+
return result;
470+
}
447471
// returns (R*R*R) mod N
448472
HURCHALLA_FORCE_INLINE T getMagicValue() const
449473
{

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyFullRange.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ struct MontyFRValueTypes {
6161
template <typename> friend class MontyFullRange;
6262
HURCHALLA_FORCE_INLINE explicit FV(T a) : V(a) {}
6363
};
64+
// squaring value type - used for square() optimizations (fyi, those
65+
// optimizations wouldn't help much or at all for monty types other than
66+
// MontyFullRange).
67+
struct SV {
68+
HURCHALLA_FORCE_INLINE SV() = default;
69+
protected:
70+
template <typename> friend class MontyFullRange;
71+
HURCHALLA_FORCE_INLINE T getbits() const { return bits; }
72+
HURCHALLA_FORCE_INLINE T get_subtrahend() const { return subtrahend; }
73+
HURCHALLA_FORCE_INLINE SV(T bbits, T subt) :
74+
bits(bbits), subtrahend(subt) {}
75+
private:
76+
T bits;
77+
T subtrahend;
78+
};
6479
};
6580

6681

@@ -77,12 +92,14 @@ class MontyFullRange final :
7792
using typename BC::V;
7893
using typename BC::C;
7994
using FV = typename MontyFRValueTypes<T>::FV;
95+
using SV = typename MontyFRValueTypes<T>::SV;
8096
public:
8197
using MontyTag = TagMontyFullrange;
8298
using uint_type = T;
8399
using montvalue_type = V;
84100
using canonvalue_type = C;
85101
using fusingvalue_type = FV;
102+
using squaringvalue_type = SV;
86103

87104
explicit MontyFullRange(T modulus) : BC(modulus) {}
88105

@@ -193,6 +210,59 @@ class MontyFullRange final :
193210
return add(x, x);
194211
}
195212

213+
214+
HURCHALLA_FORCE_INLINE SV getSquaringValue(V x) const
215+
{
216+
return SV(x.get(), 0);
217+
}
218+
219+
HURCHALLA_FORCE_INLINE SV squareSV(SV sv) const
220+
{
221+
// see squareToHiLo in MontyFullRangeMasked.h for basic ideas of
222+
// proof for why u_hi and u_lo are correct
223+
namespace hc = ::hurchalla;
224+
T a = sv.getbits();
225+
T sqlo;
226+
T sqhi = hc::unsigned_multiply_to_hilo_product(sqlo, a, a);
227+
T a_or_zero = sv.get_subtrahend();
228+
T u_hi = static_cast<T>(sqhi - a_or_zero - a_or_zero);
229+
T u_lo = sqlo;
230+
HPBC_CLOCKWORK_ASSERT2(u_hi < n_);
231+
232+
bool isNegative;
233+
T res = hc::REDC_incomplete(isNegative, u_hi, u_lo, n_, BC::inv_n_);
234+
T subtrahend = isNegative ? res : static_cast<T>(0);
235+
SV result(res, subtrahend);
236+
return result;
237+
}
238+
239+
HURCHALLA_FORCE_INLINE V squareToMontgomeryValue(SV sv) const
240+
{
241+
// see squareToHiLo in MontyFullRangeMasked.h for basic ideas of
242+
// proof for why u_hi and u_lo are correct
243+
namespace hc = ::hurchalla;
244+
T a = sv.getbits();
245+
T sqlo;
246+
T sqhi = hc::unsigned_multiply_to_hilo_product(sqlo, a, a);
247+
T a_or_zero = sv.get_subtrahend();
248+
T u_hi = static_cast<T>(sqhi - a_or_zero - a_or_zero);
249+
T u_lo = sqlo;
250+
HPBC_CLOCKWORK_ASSERT2(u_hi < n_);
251+
252+
T res = hc::REDC_standard(
253+
u_hi, u_lo, n_, BC::inv_n_, hc::LowlatencyTag());
254+
V result(res);
255+
return result;
256+
}
257+
258+
// probably I would not want to use this, instead preferring to get a SV
259+
// via squareToMontgomeryValue
260+
HURCHALLA_FORCE_INLINE V getMontgomeryValue(SV sv) const
261+
{
262+
T nonneg_value = sv.get_subtrahend() != 0 ? sv.get() + n_ : sv.get();
263+
return V(nonneg_value);
264+
}
265+
196266
private:
197267
// functions called by the 'curiously recurring template pattern' base (BC).
198268
friend BC;

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyHalfRange.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,14 @@ class MontyHalfRange final :
113113
using typename BC::V;
114114
using typename BC::C;
115115
using FV = typename MontyHRValueTypes<T>::FV;
116+
using SV = V;
116117
public:
117118
using MontyTag = TagMontyHalfrange;
118119
using uint_type = T;
119120
using montvalue_type = V;
120121
using canonvalue_type = C;
121122
using fusingvalue_type = FV;
123+
using squaringvalue_type = SV;
122124

123125
explicit MontyHalfRange(T modulus) : BC(modulus)
124126
{
@@ -490,6 +492,28 @@ class MontyHalfRange final :
490492
return C(result);
491493
}
492494

495+
496+
HURCHALLA_FORCE_INLINE SV getSquaringValue(V x) const
497+
{
498+
static_assert(std::is_same<V, SV>::value, "");
499+
return x;
500+
}
501+
HURCHALLA_FORCE_INLINE SV squareSV(SV sv) const
502+
{
503+
static_assert(std::is_same<V, SV>::value, "");
504+
return BC::square(sv, LowlatencyTag());
505+
}
506+
HURCHALLA_FORCE_INLINE V squareToMontgomeryValue(SV sv) const
507+
{
508+
static_assert(std::is_same<V, SV>::value, "");
509+
return BC::square(sv, LowlatencyTag());
510+
}
511+
HURCHALLA_FORCE_INLINE V getMontgomeryValue(SV sv) const
512+
{
513+
static_assert(std::is_same<V, SV>::value, "");
514+
return sv;
515+
}
516+
493517
private:
494518
// functions called by the 'curiously recurring template pattern' base (BC).
495519
friend BC;
@@ -500,8 +524,7 @@ class MontyHalfRange final :
500524
{
501525
HPBC_CLOCKWORK_PRECONDITION2(u_hi < n_); // verifies that (u_hi*R + u_lo) < n*R
502526
namespace hc = ::hurchalla;
503-
bool isNegative; // ignored
504-
T result = hc::REDC_incomplete(isNegative, u_hi, u_lo, n_, BC::inv_n_);
527+
T result = hc::REDC_incomplete(u_hi, u_lo, n_, BC::inv_n_);
505528
resultIsZero = (result == 0);
506529
V v = V(static_cast<S>(result));
507530
HPBC_CLOCKWORK_POSTCONDITION2(isValid(v));

0 commit comments

Comments
 (0)