Skip to content

Commit 7215f7b

Browse files
committed
add inverse function to MontgomeryForm, and add tests for it.
1 parent c3d0d1c commit 7215f7b

File tree

9 files changed

+192
-10
lines changed

9 files changed

+192
-10
lines changed

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/MontgomeryForm.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ class MontgomeryForm final {
488488
detail::montgomery_array_pow<MontyTag,
489489
MontgomeryForm>::pow(*this, bases, exponent);
490490
return result[0];
491+
//return detail::montgomery_pow<MontgomeryForm>::scalarpow(*this, base, exponent);
491492
}
492493

493494
// Calculates and returns the modular exponentiation of 2 (converted into a
@@ -498,11 +499,10 @@ class MontgomeryForm final {
498499
MontgomeryValue two_pow(T exponent) const
499500
{
500501
HPBC_CLOCKWORK_API_PRECONDITION(exponent >= 0);
501-
MontgomeryValue result =
502-
detail::montgomery_two_pow::call(*this, exponent);
503-
HPBC_CLOCKWORK_POSTCONDITION(getCanonicalValue(result) ==
502+
MontgomeryValue ret = detail::montgomery_two_pow::call(*this, exponent);
503+
HPBC_CLOCKWORK_POSTCONDITION(getCanonicalValue(ret) ==
504504
getCanonicalValue(pow(convertIn(2), exponent)));
505-
return result;
505+
return ret;
506506
}
507507

508508
// This is a specially optimized version of the pow() function above.
@@ -536,6 +536,28 @@ class MontgomeryForm final {
536536
}
537537

538538

539+
// Returns the multiplicative inverse of 'x' in the Montgomery domain if
540+
// the inverse exists. If the inverse does not exist, it returns zero (or
541+
// more precisely, it returns the value equal to getZeroValue()).
542+
// This is a convenience function to stay in the Montgomery domain when you
543+
// want to find the multiplicative inverse of a MontgomeryValue.
544+
//
545+
// Performance note: this function has no performance advantage over
546+
// hurchalla::modular_multiplicative_inverse if you need the inverse of a
547+
// number in standard integer domain - i.e. don't convert into Montgomery
548+
// domain just to call this function. However, when you intend to stay in
549+
// the Montgomery domain, this function is the fastest way to get the
550+
// multiplicative inverse.
551+
template <class PTAG = LowlatencyTag> HURCHALLA_FORCE_INLINE
552+
CanonicalValue inverse(MontgomeryValue x) const
553+
{
554+
CanonicalValue ret = impl.template inverse<PTAG>(x);
555+
HPBC_CLOCKWORK_POSTCONDITION(ret == getZeroValue() ||
556+
getCanonicalValue(multiply(x, ret)) == getUnityValue());
557+
return ret;
558+
}
559+
560+
539561
// Returns the "greatest common divisor" of the standard representations
540562
// (non-montgomery) of both x and the modulus, using the gcd functor that
541563
// you supply. The functor must take two integral arguments of the same type

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/ImplMontgomeryForm.contents

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ public:
228228
return impl.fusedSquareAdd(x, cv, PTAG());
229229
}
230230

231+
template <class PTAG> HURCHALLA_IMF_MAYBE_FORCE_INLINE
232+
CanonicalValue inverse(MontgomeryValue x) const
233+
{
234+
return impl.inverse(x, PTAG());
235+
}
236+
231237
template <class F> HURCHALLA_IMF_MAYBE_FORCE_INLINE
232238
T gcd_with_modulus(MontgomeryValue x, const F& gcd_functor) const
233239
{

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyCommonBase.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#include "hurchalla/montgomery_arithmetic/low_level_api/REDC.h"
1313
#include "hurchalla/modular_arithmetic/detail/optimization_tag_structs.h"
14+
#include "hurchalla/modular_arithmetic/modular_multiplication.h"
15+
#include "hurchalla/modular_arithmetic/modular_multiplicative_inverse.h"
1416
#include "hurchalla/montgomery_arithmetic/low_level_api/get_Rsquared_mod_n.h"
1517
#include "hurchalla/montgomery_arithmetic/low_level_api/get_R_mod_n.h"
1618
#include "hurchalla/montgomery_arithmetic/low_level_api/inverse_mod_R.h"
@@ -383,6 +385,38 @@ class MontyCommonBase {
383385
}
384386

385387

388+
template <class PTAG> // Performance TAG (see optimization_tag_structs.h)
389+
HURCHALLA_FORCE_INLINE C inverse(V x, PTAG) const
390+
{
391+
// Given x == a*R, we do 2 REDCs to get a*R^(-1), then we call
392+
// the standard integer domain inverse() to get a^(-1)*R.
393+
394+
namespace hc = ::hurchalla;
395+
const D* child = static_cast<const D*>(this);
396+
HPBC_CLOCKWORK_PRECONDITION2(child->isValid(x));
397+
T u_hi = 0;
398+
// get a Natural number (i.e. number >= 0) congruent to x (mod n)
399+
T u_lo = static_cast<const D*>(this)->getNaturalEquivalence(x);
400+
V result = child->montyREDC(u_hi, u_lo, PTAG());
401+
402+
u_hi = 0;
403+
u_lo = static_cast<const D*>(this)->getNaturalEquivalence(result);
404+
V result2 = child->montyREDC(u_hi, u_lo, PTAG());
405+
406+
T result3 = static_cast<const D*>(this)->getNaturalEquivalence(result2);
407+
T gcd; // ignored
408+
T inv = hc::modular_multiplicative_inverse(result3, n_, gcd);
409+
410+
HPBC_CLOCKWORK_POSTCONDITION2(inv < n_);
411+
//POSTCONDITION: Return 0 if the inverse does not exist. Otherwise
412+
// return the value of the inverse (which would never be 0, given that
413+
// the modulus n_ > 1).
414+
HPBC_CLOCKWORK_POSTCONDITION2(inv == 0 ||
415+
hc::modular_multiplication_prereduced_inputs(result3, inv, n_) == 1);
416+
return C(inv);
417+
}
418+
419+
386420
// Returns the greatest common divisor of the standard representations
387421
// (non-montgomery) of both x and the modulus, using the supplied functor.
388422
// The functor must take two integral arguments of the same type and return

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyWrappedStandardMath.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "hurchalla/montgomery_arithmetic/low_level_api/get_R_mod_n.h"
1414
#include "hurchalla/montgomery_arithmetic/detail/MontyTags.h"
1515
#include "hurchalla/modular_arithmetic/modular_multiplication.h"
16+
#include "hurchalla/modular_arithmetic/modular_multiplicative_inverse.h"
1617
#include "hurchalla/modular_arithmetic/modular_addition.h"
1718
#include "hurchalla/modular_arithmetic/modular_subtraction.h"
1819
#include "hurchalla/modular_arithmetic/absolute_value_difference.h"
@@ -278,6 +279,22 @@ class MontyWrappedStandardMath final {
278279
return sv;
279280
}
280281

282+
template <class PTAG> // Performance TAG (see optimization_tag_structs.h)
283+
HURCHALLA_FORCE_INLINE C inverse(V x, PTAG) const
284+
{
285+
namespace hc = ::hurchalla;
286+
HPBC_CLOCKWORK_PRECONDITION2(isCanonical(x));
287+
T gcd; // ignored
288+
T inv = hc::modular_multiplicative_inverse(x.get(), modulus_, gcd);
289+
290+
HPBC_CLOCKWORK_POSTCONDITION2(inv < modulus_);
291+
//POSTCONDITION: Return 0 if the inverse does not exist. Otherwise
292+
// return the value of the inverse (which would never be 0, given that
293+
// modulus_ > 1).
294+
HPBC_CLOCKWORK_POSTCONDITION2(inv == 0 || 1 ==
295+
hc::modular_multiplication_prereduced_inputs(inv,x.get(),modulus_));
296+
return C(inv);
297+
}
281298

282299
// Returns the greatest common divisor of the standard representations
283300
// (non-montgomery) of both x and the modulus, using the supplied functor.

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/unit_testing_helpers/AbstractMontgomeryForm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ class AbstractMontgomeryForm {
216216
virtual MontgomeryValue fusedSquareAdd(MontgomeryValue x, CanonicalValue cv,
217217
bool useLowlatencyTag) const = 0;
218218

219+
virtual CanonicalValue inverse(MontgomeryValue x,
220+
bool useLowlatencyTag) const = 0;
221+
219222
virtual std::vector<MontgomeryValue> vectorPow(
220223
const std::vector<MontgomeryValue>& bases, IntegerType exponent) const = 0;
221224

@@ -301,6 +304,12 @@ class AbstractMontgomeryForm {
301304
return fusedSquareAdd(x, cv, std::is_same<PTAG, LowlatencyTag>::value);
302305
}
303306

307+
template <class PTAG = LowlatencyTag>
308+
CanonicalValue inverse(MontgomeryValue x) const
309+
{
310+
return inverse(x, std::is_same<PTAG, LowlatencyTag>::value);
311+
}
312+
304313
template <std::size_t NUM_BASES>
305314
std::array<MontgomeryValue, NUM_BASES>
306315
pow(const std::array<MontgomeryValue, NUM_BASES>& bases, IntegerType exponent) const

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/unit_testing_helpers/AbstractMontgomeryWrapper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ class AbstractMontgomeryWrapper final {
118118
MontgomeryValue fusedSquareAdd(MontgomeryValue x, CanonicalValue cv) const
119119
{ return pimpl->template fusedSquareAdd<PTAG>(x, cv); }
120120

121+
template <class PTAG = LowlatencyTag>
122+
CanonicalValue inverse(MontgomeryValue x) const
123+
{ return pimpl->template inverse<PTAG>(x); }
124+
121125
MontgomeryValue pow(MontgomeryValue base, IntegerType exponent) const
122126
{ return pimpl->pow(base, exponent); }
123127

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/unit_testing_helpers/ConcreteMontgomeryForm.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,20 @@ class ConcreteMontgomeryForm final : public AbstractMontgomeryForm<ut_numeric_li
600600
return OpenV(static_cast<typename OpenV::OT>(mfv.get()));
601601
}
602602

603+
virtual C inverse(V x, bool useLowlatencyTag) const override
604+
{
605+
OpenMFC mfc;
606+
if (useLowlatencyTag) {
607+
OpenMFC mfc2(mf.template inverse<LowlatencyTag>(OpenMFV(OpenV(x))));
608+
mfc = mfc2;
609+
} else {
610+
OpenMFC mfc2(mf.template inverse<LowuopsTag>(OpenMFV(OpenV(x))));
611+
mfc = mfc2;
612+
}
613+
// note: mfc.get() might be signed or unsigned; OpenC::OT is unsigned
614+
return OpenC(static_cast<typename OpenC::OT>(mfc.get()));
615+
}
616+
603617

604618
// This class (ConcreteMontgomeryForm) only supports calling vectorPow()
605619
// with a std::vector that has size equal to one of the sizes given by the

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/platform_specific/montgomery_pow.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,32 @@ struct montgomery_pow {
168168
}
169169
while (exponent > static_cast<T>(1)) {
170170
exponent = static_cast<T>(exponent >> static_cast<T>(1));
171+
#if 0
171172
std::array<V, NUM_BASES> tmp;
173+
172174
Unroll<NUM_BASES>::call([&](std::size_t i) HURCHALLA_INLINE_LAMBDA {
173175
bases[i] = mf.template square<LowuopsTag>(bases[i]);
174176
tmp[i] = mf.template multiply<LowuopsTag>(result[i], bases[i]);
175177
});
176178
Unroll<NUM_BASES>::call([&](std::size_t i) HURCHALLA_INLINE_LAMBDA {
177179
result[i].cmov(exponent & static_cast<T>(1), tmp[i]);
178180
});
181+
#else
182+
// see scalarpow)() comments for why this #else section might be
183+
// preferable to the #if alternative above. There's probably little
184+
// difference at larger NUM_BASES though, where total uops is the
185+
// bottleneck rather than dependency chain length.
186+
Unroll<NUM_BASES>::call([&](std::size_t i) HURCHALLA_INLINE_LAMBDA {
187+
bases[i] = mf.template square<LowuopsTag>(bases[i]);
188+
});
189+
190+
V mont_one = mf.getUnityValue();
191+
Unroll<NUM_BASES>::call([&](std::size_t i) HURCHALLA_INLINE_LAMBDA {
192+
V tmp = mont_one;
193+
tmp.cmov(exponent & static_cast<T>(1), bases[i]);
194+
result[i] = mf.template multiply<LowlatencyTag>(result[i], tmp);
195+
});
196+
#endif
179197
}
180198
return result;
181199
}
@@ -198,12 +216,24 @@ struct montgomery_pow {
198216
}
199217
while (exponent > static_cast<T>(1)) {
200218
exponent = static_cast<T>(exponent >> 1u);
219+
#if 0
201220
Unroll<NUM_BASES>::call([&](std::size_t i) HURCHALLA_INLINE_LAMBDA {
202221
bases[i] = mf.template square<LowuopsTag>(bases[i]);
203222
V tmp = mf.template multiply<LowuopsTag>(result[i], bases[i]);
204223
result[i].template
205224
cmov<CSelectMaskedTag>(exponent & static_cast<T>(1), tmp);
206225
});
226+
#else
227+
// the comments in arraypow_cmov() apply equally in this section
228+
V mont_one = mf.getUnityValue();
229+
Unroll<NUM_BASES>::call([&](std::size_t i) HURCHALLA_INLINE_LAMBDA {
230+
bases[i] = mf.template square<LowuopsTag>(bases[i]);
231+
V tmp = mont_one;
232+
tmp.template
233+
cmov<CSelectMaskedTag>(exponent & static_cast<T>(1), bases[i]);
234+
result[i] = mf.template multiply<LowlatencyTag>(result[i], tmp);
235+
});
236+
#endif
207237
}
208238
return result;
209239
}

test/montgomery_arithmetic/test_MontgomeryForm.h

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,40 @@ void test_remainder(const M& mf)
293293
EXPECT_TRUE(mf.remainder(static_cast<T>(mid+1)) == ((mid+1) % modulus));
294294
}
295295

296+
template <typename M>
297+
void test_single_inverse(const M& mf, typename M::IntegerType a)
298+
{
299+
namespace hc = ::hurchalla;
300+
using T = typename M::IntegerType;
301+
using U = typename hc::extensible_make_unsigned<T>::type;
302+
303+
U n = static_cast<U>(mf.getModulus());
304+
U gcd; // ignored
305+
auto answer = hc::modular_multiplicative_inverse(static_cast<U>(a), n, gcd);
306+
U val = static_cast<U>(mf.convertOut(mf.inverse(mf.convertIn(a))));
307+
EXPECT_TRUE(val == answer);
308+
}
309+
310+
template <typename M>
311+
void test_inverse(const M& mf)
312+
{
313+
using T = typename M::IntegerType;
314+
T max = ::hurchalla::ut_numeric_limits<T>::max();
315+
T mid = static_cast<T>(max/2);
316+
T modulus = mf.getModulus();
317+
test_single_inverse(mf, static_cast<T>(0));
318+
test_single_inverse(mf, static_cast<T>(1));
319+
test_single_inverse(mf, static_cast<T>(2));
320+
test_single_inverse(mf, static_cast<T>(max-0));
321+
test_single_inverse(mf, static_cast<T>(max-1));
322+
test_single_inverse(mf, static_cast<T>(mid-0));
323+
test_single_inverse(mf, static_cast<T>(mid-1));
324+
test_single_inverse(mf, static_cast<T>(modulus-1));
325+
test_single_inverse(mf, static_cast<T>(modulus-2));
326+
test_single_inverse(mf, static_cast<T>(modulus/2));
327+
test_single_inverse(mf, static_cast<T>((modulus/2) - 1));
328+
}
329+
296330
template <typename M>
297331
void test_mf_general_checks(const M& mf, typename M::IntegerType a,
298332
typename M::IntegerType b, typename M::IntegerType c)
@@ -739,16 +773,28 @@ void test_MontgomeryForm()
739773
EXPECT_TRUE(mf.gcd_with_modulus(mf.convertIn(12), GcdFunctor()) == 3);
740774
}
741775

742-
// test remainder()
776+
// test remainder() and inverse()
743777
{
744778
T max = max_modulus;
745779
T mid = static_cast<T>(max/2);
746780
mid = (mid % 2 == 0) ? static_cast<T>(mid + 1) : mid;
747-
test_remainder(MFactory::construct(3)); // smallest possible modulus
748-
test_remainder(MFactory::construct(max)); // largest possible modulus
749-
if (121 <= max)
750-
test_remainder(MFactory::construct(121));
751-
test_remainder(MFactory::construct(mid));
781+
auto mf_3 = MFactory::construct(3);
782+
test_remainder(mf_3); // smallest possible modulus
783+
test_inverse(mf_3);
784+
785+
auto mf_max = MFactory::construct(max);
786+
test_remainder(mf_max); // largest possible modulus
787+
test_inverse(mf_max);
788+
789+
if (121 <= max) {
790+
auto mf_121 = MFactory::construct(121);
791+
test_remainder(mf_121);
792+
test_inverse(mf_121);
793+
}
794+
795+
auto mf_mid = MFactory::construct(mid);
796+
test_remainder(mf_mid);
797+
test_inverse(mf_mid);
752798
}
753799
}
754800

0 commit comments

Comments
 (0)