Skip to content

Commit 3afdbb3

Browse files
committed
add divideBySmallPowerOf2 function to MontgomeryForm
1 parent 7215f7b commit 3afdbb3

File tree

10 files changed

+198
-35
lines changed

10 files changed

+198
-35
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/MontgomeryForm.h

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,28 @@
2626
namespace hurchalla {
2727

2828

29-
// T must be a signed or unsigned integral type.
29+
// T must be a signed or unsigned integral type. You should normally set T to
30+
// the same type as the (integer) modulus that you will use for this class's
31+
// constructor.
3032
//
3133
// For InlineAllFunctions, you should usually accept the default rather than
32-
// specify an argument. However if you wish to reduce compilation times you can
34+
// specify an argument. However if you wish to reduce compilation times you can
3335
// set it to false, which may help.
3436
//
3537
// For MontyType, you should just accept the default (this parameter exists to
3638
// provide you the alias classes in montgomery_form_aliases.h.)
3739
template <class T,
38-
bool InlineAll = (ut_numeric_limits<T>::digits <= HURCHALLA_TARGET_BIT_WIDTH),
40+
bool InlineAllFunctions = (ut_numeric_limits<T>::digits <= HURCHALLA_TARGET_BIT_WIDTH),
3941
class MontyType = typename detail::MontgomeryDefault<T>::type>
4042
class MontgomeryForm final {
41-
const detail::ImplMontgomeryForm<T, InlineAll, MontyType> impl;
43+
const detail::ImplMontgomeryForm<T, InlineAllFunctions, MontyType> impl;
4244
template <class,class> friend struct detail::MontgomeryFormExtensions;
4345
static_assert(ut_numeric_limits<T>::is_integer, "");
4446
static_assert(ut_numeric_limits<T>::digits <=
4547
ut_numeric_limits<typename MontyType::uint_type>::digits, "");
46-
using SV = typename MontyType::squaringvalue_type;
47-
using RU = typename MontyType::uint_type;
4848
public:
4949
using IntegerType = T;
50-
using MontyTag = typename MontyType::MontyTag;
50+
using MontType = MontyType;
5151

5252
// If you need to compare MontgomeryValues for equality or inequality, call
5353
// getCanonicalValue() and compare the resulting CanonicalValues.
@@ -485,7 +485,7 @@ class MontgomeryForm final {
485485
HPBC_CLOCKWORK_API_PRECONDITION(exponent >= 0);
486486
std::array<MontgomeryValue, 1> bases = {{ base }};
487487
std::array<MontgomeryValue, 1> result =
488-
detail::montgomery_array_pow<MontyTag,
488+
detail::montgomery_array_pow<typename MontyType::MontyTag,
489489
MontgomeryForm>::pow(*this, bases, exponent);
490490
return result[0];
491491
//return detail::montgomery_pow<MontgomeryForm>::scalarpow(*this, base, exponent);
@@ -531,23 +531,19 @@ class MontgomeryForm final {
531531
pow(const std::array<MontgomeryValue, NUM_BASES>& bases, T exponent) const
532532
{
533533
HPBC_CLOCKWORK_API_PRECONDITION(exponent >= 0);
534-
return detail::montgomery_array_pow<MontyTag,
534+
return detail::montgomery_array_pow<typename MontyType::MontyTag,
535535
MontgomeryForm>::pow(*this, bases, exponent);
536536
}
537537

538538

539-
// Returns the multiplicative inverse of 'x' in the Montgomery domain if
540-
// the inverse exists. If the inverse does not exist, it returns zero (or
541-
// more precisely, it returns the value equal to getZeroValue()).
542-
// This is a convenience function to stay in the Montgomery domain when you
543-
// want to find the multiplicative inverse of a MontgomeryValue.
544-
//
545-
// Performance note: this function has no performance advantage over
546-
// hurchalla::modular_multiplicative_inverse if you need the inverse of a
547-
// number in standard integer domain - i.e. don't convert into Montgomery
548-
// domain just to call this function. However, when you intend to stay in
549-
// the Montgomery domain, this function is the fastest way to get the
550-
// multiplicative inverse.
539+
// Calculates and returns the multiplicative inverse of 'x' as a canonical
540+
// Montgomery value, if the inverse exists. If the inverse does not exist,
541+
// this function returns zero (more precisely it returns the value equal to
542+
// getZeroValue()).
543+
// Performance note: there is no performance advantage to converting into
544+
// Montgomery form if all you want is the inverse of a number in standard
545+
// integer domain - prefer hurchalla::modular_multiplicative_inverse() for
546+
// that case.
551547
template <class PTAG = LowlatencyTag> HURCHALLA_FORCE_INLINE
552548
CanonicalValue inverse(MontgomeryValue x) const
553549
{
@@ -558,6 +554,48 @@ class MontgomeryForm final {
558554
}
559555

560556

557+
// Returns the Montgomery division of x by a small power of two (requires
558+
// 0 <= power <= 7, which translates to Montgomery division by 1,2,4,8,16,
559+
// 32,64, or 128). This function always produces an exact correct result.
560+
// Note that Montgomery division is modular division, which is different
561+
// from normal and non modular division - modular division performs modular
562+
// multiplication by the modular multiplicative inverse of the divisor. So,
563+
// this function calculates and returns the product of x times the modular
564+
// multiplicative inverse of 2^power (^ denotes exponentiation), in
565+
// Montgomery form. Due to the requirement that every Montgomery modulus
566+
// must be odd, the inverse of all powers of two exist in Montgomery form,
567+
// and so multiplication by the inverse of 2^power is always valid. And
568+
// since all calculation is modular, the division result is exactly correct
569+
// - i.e. x is congruent to the return value times 2^power.
570+
//
571+
// In the common case that you wish to divide a MontgomeryValue x instead of
572+
// a CanonicalValue, use getCanonicalValue(x) as your first argument to this
573+
// function.
574+
//
575+
// If you wish to divide by some large power of 2, you can use the following
576+
// sequence of calls:
577+
// // Assume "mf" is a MontgomeryForm instance that you have constructed
578+
// // ...for example via... auto mf = MontgomeryForm(modulus);
579+
// // and assume that the large power is some_large_exponent.
580+
// auto inv_two = mf.divideBySmallPowerOf2(mf.getUnityValue(), 1);
581+
// auto full_inv = mf.pow(inv_two, some_large_exponent);
582+
// auto desired_result = mf.multiply(full_inv, x);
583+
//
584+
// Performance note: this function is very efficient. It should ordinarily
585+
// be faster than even a single call of multiply().
586+
template <class PTAG = LowlatencyTag> HURCHALLA_FORCE_INLINE
587+
MontgomeryValue divideBySmallPowerOf2(CanonicalValue x, int power) const
588+
{
589+
HPBC_CLOCKWORK_API_PRECONDITION(0 <= power && power < 8);
590+
591+
MontgomeryValue ret= impl.template divideBySmallPowerOf2<PTAG>(x,power);
592+
593+
HPBC_CLOCKWORK_POSTCONDITION(x ==
594+
getCanonicalValue(multiply(ret, two_pow(static_cast<T>(power)))));
595+
return ret;
596+
}
597+
598+
561599
// Returns the "greatest common divisor" of the standard representations
562600
// (non-montgomery) of both x and the modulus, using the gcd functor that
563601
// you supply. The functor must take two integral arguments of the same type

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/ImplMontgomeryForm.contents

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ public:
248248
}
249249

250250

251+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
252+
MontgomeryValue divideBySmallPowerOf2(CanonicalValue x, int power) const
253+
{
254+
return impl.divideBySmallPowerOf2(x, power, PTAG());
255+
}
256+
257+
251258
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
252259
MontgomeryValue convertInExtended(U a) const
253260
{

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontgomeryFormExtensions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ namespace hurchalla { namespace detail {
2525
template <class MF, class PTAG>
2626
struct MontgomeryFormExtensions final {
2727

28-
using RU = typename MF::RU;
28+
using RU = typename MF::MontType::uint_type;
2929
// conceptually, R = 1 << (ut_numeric_limits<RU>::digits)
3030
static_assert(ut_numeric_limits<RU>::is_integer, "");
3131
static_assert(!(ut_numeric_limits<RU>::is_signed), "");
3232

3333
using CanonicalValue = typename MF::CanonicalValue;
3434
using MontgomeryValue = typename MF::MontgomeryValue;
35-
using SquaringValue = typename MF::SV;
35+
using SquaringValue = typename MF::MontType::squaringvalue_type;
3636

3737
HURCHALLA_FORCE_INLINE
3838
static MontgomeryValue convertInExtended(const MF& mf, RU a)

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyCommonBase.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,30 @@ class MontyCommonBase {
475475
}
476476

477477

478+
// this is close to being a copy/paste of twoPowLimited_times_x, but it's
479+
// adjusted for the different meaning of exponent.
480+
template <class PTAG> HURCHALLA_FORCE_INLINE
481+
V divideBySmallPowerOf2(C cx, int exponent, PTAG) const
482+
{
483+
static constexpr int digitsT = ut_numeric_limits<T>::digits;
484+
HPBC_CLOCKWORK_PRECONDITION2(0 <= exponent && exponent < digitsT);
485+
int power = digitsT - exponent;
486+
HPBC_CLOCKWORK_ASSERT2(0 < power && power <= digitsT);
487+
488+
T tmp = cx.get();
489+
HPBC_CLOCKWORK_INVARIANT2(tmp < n_);
490+
T u_lo = static_cast<T>((tmp << 1) << (power - 1));
491+
int rshift = digitsT - power;
492+
HPBC_CLOCKWORK_ASSERT2(0 <= rshift && rshift < digitsT);
493+
T u_hi = static_cast<T>(tmp >> rshift);
494+
495+
HPBC_CLOCKWORK_ASSERT2(u_hi < n_);
496+
const D* child = static_cast<const D*>(this);
497+
V result = child->montyREDC(u_hi, u_lo, PTAG());
498+
HPBC_CLOCKWORK_POSTCONDITION2(child->isValid(result));
499+
return result;
500+
}
501+
478502

479503
// returns (R*R) mod N
480504
HURCHALLA_FORCE_INLINE C getMontvalueR() const

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyWrappedStandardMath.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,20 @@ class MontyWrappedStandardMath final {
296296
return C(inv);
297297
}
298298

299+
template <class PTAG> HURCHALLA_FORCE_INLINE
300+
V divideBySmallPowerOf2(C cx, int exponent, PTAG) const
301+
{
302+
V pow_of_two = twoPowLimited(static_cast<size_t>(exponent), PTAG());
303+
C inv_pow_of_two = inverse(pow_of_two, PTAG());
304+
C zero = getZeroValue();
305+
HPBC_CLOCKWORK_ASSERT2(inv_pow_of_two != zero);
306+
bool isZero;
307+
V product = multiply(inv_pow_of_two, cx, isZero, PTAG());
308+
HPBC_CLOCKWORK_ASSERT2((cx == zero) == isZero);
309+
C result = getCanonicalValue(product);
310+
return result;
311+
}
312+
299313
// Returns the greatest common divisor of the standard representations
300314
// (non-montgomery) of both x and the modulus, using the supplied functor.
301315
// The functor must take two integral arguments of the same type and return

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/unit_testing_helpers/AbstractMontgomeryForm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ class AbstractMontgomeryForm {
219219
virtual CanonicalValue inverse(MontgomeryValue x,
220220
bool useLowlatencyTag) const = 0;
221221

222+
virtual MontgomeryValue divideBySmallPowerOf2(CanonicalValue cx,
223+
int exponent, bool useLowlatencyTag) const = 0;
224+
222225
virtual std::vector<MontgomeryValue> vectorPow(
223226
const std::vector<MontgomeryValue>& bases, IntegerType exponent) const = 0;
224227

@@ -310,6 +313,13 @@ class AbstractMontgomeryForm {
310313
return inverse(x, std::is_same<PTAG, LowlatencyTag>::value);
311314
}
312315

316+
template <class PTAG = LowlatencyTag>
317+
MontgomeryValue divideBySmallPowerOf2(CanonicalValue cx, int exponent) const
318+
{
319+
return divideBySmallPowerOf2(cx, exponent,
320+
std::is_same<PTAG, LowlatencyTag>::value);
321+
}
322+
313323
template <std::size_t NUM_BASES>
314324
std::array<MontgomeryValue, NUM_BASES>
315325
pow(const std::array<MontgomeryValue, NUM_BASES>& bases, IntegerType exponent) const

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/unit_testing_helpers/AbstractMontgomeryWrapper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ class AbstractMontgomeryWrapper final {
122122
CanonicalValue inverse(MontgomeryValue x) const
123123
{ return pimpl->template inverse<PTAG>(x); }
124124

125+
template <class PTAG = LowlatencyTag>
126+
MontgomeryValue divideBySmallPowerOf2(CanonicalValue cx, int exponent) const
127+
{ return pimpl->template divideBySmallPowerOf2<PTAG>(cx, exponent); }
128+
125129
MontgomeryValue pow(MontgomeryValue base, IntegerType exponent) const
126130
{ return pimpl->pow(base, exponent); }
127131

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/unit_testing_helpers/ConcreteMontgomeryForm.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
467467
mfv = mfv2;
468468
} else {
469469
OpenMFV mfv2(mf.template multiply<LowuopsTag>(OpenMFV(OpenV(x)), OpenMFV(OpenV(y))));
470-
mfv = mfv2;
470+
mfv = mfv2;
471471
}
472472
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
473473
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -484,7 +484,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
484484
} else {
485485
OpenMFV mfv2(mf.template multiply<LowuopsTag>(OpenMFV(OpenV(x)),
486486
OpenMFV(OpenV(y)), resultIsZero));
487-
mfv = mfv2;
487+
mfv = mfv2;
488488
}
489489
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
490490
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -500,7 +500,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
500500
} else {
501501
OpenMFV mfv2(mf.template fmsub<LowuopsTag>(OpenMFV(OpenV(x)),
502502
OpenMFV(OpenV(y)), OpenMFC(OpenC(z))));
503-
mfv = mfv2;
503+
mfv = mfv2;
504504
}
505505
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
506506
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -516,7 +516,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
516516
} else {
517517
OpenMFV mfv2(mf.template fmsub<LowuopsTag>(OpenMFV(OpenV(x)),
518518
OpenMFV(OpenV(y)), OpenMFFV(OpenFV(z))));
519-
mfv = mfv2;
519+
mfv = mfv2;
520520
}
521521
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
522522
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -532,7 +532,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
532532
} else {
533533
OpenMFV mfv2(mf.template fmadd<LowuopsTag>(OpenMFV(OpenV(x)),
534534
OpenMFV(OpenV(y)), OpenMFC(OpenC(z))));
535-
mfv = mfv2;
535+
mfv = mfv2;
536536
}
537537
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
538538
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -548,7 +548,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
548548
} else {
549549
OpenMFV mfv2(mf.template fmadd<LowuopsTag>(OpenMFV(OpenV(x)),
550550
OpenMFV(OpenV(y)), OpenMFFV(OpenFV(z))));
551-
mfv = mfv2;
551+
mfv = mfv2;
552552
}
553553
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
554554
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -562,7 +562,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
562562
mfv = mfv2;
563563
} else {
564564
OpenMFV mfv2(mf.template square<LowuopsTag>(OpenMFV(OpenV(x))));
565-
mfv = mfv2;
565+
mfv = mfv2;
566566
}
567567
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
568568
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -578,7 +578,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
578578
} else {
579579
OpenMFV mfv2(mf.template fusedSquareSub<LowuopsTag>(OpenMFV(OpenV(x)),
580580
OpenMFC(OpenC(cv))));
581-
mfv = mfv2;
581+
mfv = mfv2;
582582
}
583583
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
584584
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -594,7 +594,7 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
594594
} else {
595595
OpenMFV mfv2(mf.template fusedSquareAdd<LowuopsTag>(OpenMFV(OpenV(x)),
596596
OpenMFC(OpenC(cv))));
597-
mfv = mfv2;
597+
mfv = mfv2;
598598
}
599599
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
600600
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
@@ -614,6 +614,22 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
614614
return OpenC(static_cast<typename OpenC::OT>(mfc.get()));
615615
}
616616

617+
virtual V divideBySmallPowerOf2(C cx, int exponent, bool useLowlatencyTag)
618+
const override
619+
{
620+
OpenMFV mfv;
621+
if (useLowlatencyTag) {
622+
OpenMFV mfv2(mf.template divideBySmallPowerOf2<LowlatencyTag>(
623+
OpenMFC(OpenC(cx)), exponent));
624+
mfv = mfv2;
625+
} else {
626+
OpenMFV mfv2(mf.template divideBySmallPowerOf2<LowuopsTag>(
627+
OpenMFC(OpenC(cx)), exponent));
628+
mfv = mfv2;
629+
}
630+
// note: mfv.get() might be signed or unsigned; OpenV::OT is unsigned
631+
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
632+
}
617633

618634
// This class (ConcreteMontgomeryForm) only supports calling vectorPow()
619635
// with a std::vector that has size equal to one of the sizes given by the

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/platform_specific/montgomery_two_pow.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ struct montgomery_two_pow {
151151
using U = typename extensible_make_unsigned<T>::type;
152152
U n = static_cast<U>(nt);
153153

154-
using MontyTag = typename MF::MontyTag;
154+
using MontyTag = typename MF::MontType::MontyTag;
155155
using RU = typename MontgomeryFormExtensions<MF, LowlatencyTag>::RU;
156156
constexpr bool isBigPow = ut_numeric_limits<RU>::digits >
157157
HURCHALLA_TARGET_BIT_WIDTH;
@@ -176,7 +176,7 @@ struct montgomery_two_pow {
176176
static_assert(hurchalla::ut_numeric_limits<U>::is_integer, "");
177177
static_assert(!hurchalla::ut_numeric_limits<U>::is_signed, "");
178178

179-
using MontyTag = typename MF::MontyTag;
179+
using MontyTag = typename MF::MontType::MontyTag;
180180
using RU = typename MontgomeryFormExtensions<MF, LowlatencyTag>::RU;
181181
constexpr bool isBigPow = ut_numeric_limits<RU>::digits >
182182
HURCHALLA_TARGET_BIT_WIDTH;

0 commit comments

Comments
 (0)