Skip to content

Commit 2c42c1b

Browse files
committed
improve experimental pow functions
1 parent a2c2b08 commit 2c42c1b

File tree

12 files changed

+961
-204
lines changed

12 files changed

+961
-204
lines changed

montgomery_arithmetic/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ target_sources(hurchalla_montgomery_arithmetic INTERFACE
4545
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/hurchalla/montgomery_arithmetic/detail/experimental/unit_testing_helpers/ConcreteMontgomeryForm.h>
4646
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/hurchalla/montgomery_arithmetic/detail/platform_specific/montgomery_pow.h>
4747
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/hurchalla/montgomery_arithmetic/detail/platform_specific/montgomery_two_pow.h>
48+
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/hurchalla/montgomery_arithmetic/detail/platform_specific/subtract_returning_difference_or_zero.h>
4849
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/hurchalla/montgomery_arithmetic/detail/platform_specific/two_times_restricted.h>
4950
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/hurchalla/montgomery_arithmetic/detail/platform_specific/quarterrange_get_canonical.h>
5051
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/hurchalla/montgomery_arithmetic/detail/platform_specific/halfrange_get_canonical.h>

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyFullRange.h

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "hurchalla/montgomery_arithmetic/low_level_api/REDC.h"
1414
#include "hurchalla/montgomery_arithmetic/detail/MontyCommonBase.h"
1515
#include "hurchalla/montgomery_arithmetic/detail/MontyTags.h"
16+
#include "hurchalla/montgomery_arithmetic/detail/platform_specific/subtract_returning_difference_or_zero.h"
1617
#include "hurchalla/modular_arithmetic/modular_addition.h"
1718
#include "hurchalla/modular_arithmetic/modular_subtraction.h"
1819
#include "hurchalla/modular_arithmetic/absolute_value_difference.h"
@@ -40,14 +41,14 @@ struct MontyFRValueTypes {
4041
struct V : public BaseMontgomeryValue<T> {
4142
HURCHALLA_FORCE_INLINE V() = default;
4243

43-
template <int BITNUM> HURCHALLA_FORCE_INLINE
44-
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
44+
template <int BITNUM>
45+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
4546
{
4647
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
4748
return V(sel);
4849
}
49-
template <int BITNUM> HURCHALLA_FORCE_INLINE
50-
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
50+
template <int BITNUM>
51+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
5152
{
5253
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
5354
return V(sel);
@@ -63,6 +64,19 @@ struct MontyFRValueTypes {
6364
{ return x.get() == y.get(); }
6465
HURCHALLA_FORCE_INLINE friend bool operator!=(const C& x, const C& y)
6566
{ return !(x == y); }
67+
68+
template <int BITNUM>
69+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_ne0(uint64_t num, C c1, C c2)
70+
{
71+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, c1.get(), c2.get());
72+
return C(sel);
73+
}
74+
template <int BITNUM>
75+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_eq0(uint64_t num, C c1, C c2)
76+
{
77+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, c1.get(), c2.get());
78+
return C(sel);
79+
}
6680
protected:
6781
template <template<class> class, template<class> class, typename>
6882
friend class MontyCommonBase;
@@ -271,28 +285,15 @@ class MontyFullRange final :
271285

272286
T minu, subt;
273287
hc::REDC_incomplete(minu, subt, u_hi, u_lo, n_, BC::inv_n_, PTAG());
274-
#if 1
275-
T res = static_cast<T>(minu - subt);
276-
// T subtrahend = (minu < subt) ? res : static_cast<T>(0);
277-
T subtrahend = hc::conditional_select((minu < subt), res, static_cast<T>(0));
288+
#if 0
289+
T diff = static_cast<T>(minu - subt);
290+
// T subtrahend = (minu < subt) ? diff : static_cast<T>(0);
291+
T subtrahend = hc::conditional_select((minu < subt), diff, static_cast<T>(0));
278292
#else
279-
uint64_t minu_lo = static_cast<uint64_t>(minu);
280-
uint64_t minu_hi = static_cast<uint64_t>(minu >> 64);
281-
uint64_t subt_lo = static_cast<uint64_t>(subt);
282-
uint64_t subt_hi = static_cast<uint64_t>(subt >> 64);
283-
uint64_t reslo, reshi;
284-
uint64_t subtrahend_lo, subtrahend_hi;
285-
__asm__ ("subs %[reslo], %[minu_lo], %[subt_lo] \n\t" /* res = minu - subt */
286-
"sbcs %[reshi], %[minu_hi], %[subt_hi] \n\t"
287-
"csel %[subtrahend_lo], %[reslo], xzr, lo \n\t" /* res = (minu < subt) ? res : 0 */
288-
"csel %[subtrahend_hi], %[reshi], xzr, lo \n\t"
289-
: [reslo]"=&r"(reslo), [reshi]"=r"(reshi), [subtrahend_lo]"=r"(subtrahend_lo), [subtrahend_hi]"=r"(subtrahend_hi)
290-
: [minu_lo]"r"(minu_lo), [minu_hi]"r"(minu_hi), [subt_lo]"r"(subt_lo), [subt_hi]"r"(subt_hi)
291-
: "cc");
292-
__uint128_t res = (static_cast<__uint128_t>(reshi) << 64) | reslo;
293-
__uint128_t subtrahend = (static_cast<__uint128_t>(subtrahend_hi) << 64) | subtrahend_lo;
293+
T diff;
294+
T subtrahend = subtract_returning_difference_or_zero(diff, minu, subt);
294295
#endif
295-
SV result(res, subtrahend);
296+
SV result(diff, subtrahend);
296297
return result;
297298
}
298299

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyHalfRange.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ struct MontyHRValueTypes {
5555
struct V : public BaseMontgomeryValue<SignedT> {
5656
HURCHALLA_FORCE_INLINE V() = default;
5757

58-
template <int BITNUM> HURCHALLA_FORCE_INLINE
59-
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
58+
template <int BITNUM>
59+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
6060
{
6161
SignedT sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
6262
return V(sel);
6363
}
64-
template <int BITNUM> HURCHALLA_FORCE_INLINE
65-
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
64+
template <int BITNUM>
65+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
6666
{
6767
SignedT sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
6868
return V(sel);
@@ -88,6 +88,19 @@ struct MontyHRValueTypes {
8888
{ return x.get() == y.get(); }
8989
HURCHALLA_FORCE_INLINE friend bool operator!=(const C& x, const C& y)
9090
{ return !(x == y); }
91+
92+
template <int BITNUM>
93+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_ne0(uint64_t num, C c1, C c2)
94+
{
95+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, c1.get(), c2.get());
96+
return C(sel);
97+
}
98+
template <int BITNUM>
99+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_eq0(uint64_t num, C c1, C c2)
100+
{
101+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, c1.get(), c2.get());
102+
return C(sel);
103+
}
91104
protected:
92105
template <typename> friend class MontyHalfRange;
93106
template <template<class> class, template<class> class, typename>

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyQuarterRange.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ struct MontyQRValueTypes {
6060
struct V : public BaseMontgomeryValue<T> {
6161
HURCHALLA_FORCE_INLINE V() = default;
6262

63-
template <int BITNUM> HURCHALLA_FORCE_INLINE
64-
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
63+
template <int BITNUM>
64+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
6565
{
6666
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
6767
return V(sel);
6868
}
69-
template <int BITNUM> HURCHALLA_FORCE_INLINE
70-
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
69+
template <int BITNUM>
70+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
7171
{
7272
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
7373
return V(sel);
@@ -83,6 +83,19 @@ struct MontyQRValueTypes {
8383
{ return x.get() == y.get(); }
8484
HURCHALLA_FORCE_INLINE friend bool operator!=(const C& x, const C& y)
8585
{ return !(x == y); }
86+
87+
template <int BITNUM>
88+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_ne0(uint64_t num, C c1, C c2)
89+
{
90+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, c1.get(), c2.get());
91+
return C(sel);
92+
}
93+
template <int BITNUM>
94+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_eq0(uint64_t num, C c1, C c2)
95+
{
96+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, c1.get(), c2.get());
97+
return C(sel);
98+
}
8699
protected:
87100
template <template<class> class, template<class> class, typename>
88101
friend class MontyCommonBase;

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyWrappedStandardMath.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ class MontyWrappedStandardMath final {
4141
struct V : public BaseMontgomeryValue<T> { // regular montgomery value type
4242
HURCHALLA_FORCE_INLINE V() = default;
4343

44-
template <int BITNUM> HURCHALLA_FORCE_INLINE
45-
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
44+
template <int BITNUM>
45+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
4646
{
4747
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
4848
return V(sel);
4949
}
50-
template <int BITNUM> HURCHALLA_FORCE_INLINE
51-
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
50+
template <int BITNUM>
51+
HURCHALLA_FORCE_INLINE static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
5252
{
5353
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
5454
return V(sel);
@@ -63,6 +63,19 @@ class MontyWrappedStandardMath final {
6363
{ return x.get() == y.get(); }
6464
HURCHALLA_FORCE_INLINE friend bool operator!=(const C& x, const C& y)
6565
{ return !(x == y); }
66+
67+
template <int BITNUM>
68+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_ne0(uint64_t num, C c1, C c2)
69+
{
70+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, c1.get(), c2.get());
71+
return C(sel);
72+
}
73+
template <int BITNUM>
74+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_eq0(uint64_t num, C c1, C c2)
75+
{
76+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, c1.get(), c2.get());
77+
return C(sel);
78+
}
6679
protected:
6780
friend MontyWrappedStandardMath;
6881
HURCHALLA_FORCE_INLINE explicit C(T a) : V(a) {}

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/MontyFullRangeMasked.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,19 @@ struct MfrmValueTypes {
146146
{ return x.get() == y.get(); }
147147
HURCHALLA_FORCE_INLINE friend bool operator!=(const C& x, const C& y)
148148
{ return !(x == y); }
149+
150+
template <int BITNUM>
151+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_ne0(uint64_t num, C c1, C c2)
152+
{
153+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, c1.get(), c2.get());
154+
return C(sel);
155+
}
156+
template <int BITNUM>
157+
HURCHALLA_FORCE_INLINE static C cselect_on_bit_eq0(uint64_t num, C c1, C c2)
158+
{
159+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, c1.get(), c2.get());
160+
return C(sel);
161+
}
149162
protected:
150163
template <typename> friend class MontyFullRangeMasked;
151164
template <template<class> class, template<class> class, typename>
@@ -417,7 +430,7 @@ class MontyFullRangeMasked final :
417430
T umlo;
418431
T umhi = ::hurchalla::unsigned_multiply_to_hilo_product(umlo, a, a);
419432
T masked_a = static_cast<T>(x.getmask() & a);
420-
T result_hi = static_cast<T>(umhi - static_cast<T>(2) * masked_a);
433+
T result_hi = static_cast<T>(umhi - masked_a - masked_a);
421434
u_lo = umlo;
422435
// Complete details are in the proof below, but roughly what we do here
423436
// is get a*a as a two-word product (umhi, umlo). We let s == 1 if x is

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_pow_2kary/experimental_montgomery_pow_2kary.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,23 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
9898
U exponent = n;
9999

100100
V mont_one = mf.getUnityValue();
101+
#ifndef HURCHALLA_MONTGOMERY_POW_2KARY_USE_CSELECT_ON_BIT
101102
V result = mont_one;
102103
result.cmov(static_cast<size_t>(exponent) & 1u, base);
104+
#else
105+
V result = V::template cselect_on_bit_ne0<0>(static_cast<uint64_t>(exponent), base, mont_one);
106+
#endif
103107
exponent = static_cast<U>(exponent >> 1);
104108

105109
while (exponent > 0u) {
106110
base = mf.square(base);
107111

112+
#ifndef HURCHALLA_MONTGOMERY_POW_2KARY_USE_CSELECT_ON_BIT
108113
V tmp = mont_one;
109114
tmp.cmov(static_cast<size_t>(exponent) & 1u, base);
115+
#else
116+
V tmp = V::template cselect_on_bit_ne0<0>(static_cast<uint64_t>(exponent), base, mont_one);
117+
#endif
110118
result = mf.multiply(result, tmp);
111119

112120
exponent = static_cast<U>(exponent >> 1);
@@ -1148,14 +1156,22 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
11481156
index = static_cast<size_t>(branchless_shift_right(n, shift)) & MASK;
11491157

11501158
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1159+
#ifndef HURCHALLA_MONTGOMERY_POW_2KARY_USE_CSELECT_ON_BIT
11511160
V tmp = result[j];
11521161
tmp.cmov((index % 2 == 0), table[index/2][j]);
1162+
#else
1163+
V tmp = V::template cselect_on_bit_eq0<0>(static_cast<uint64_t>(index), table[index/2][j], result[j]);
1164+
#endif
11531165
result[j] = mf.template multiply<LowuopsTag>(tmp, result[j]);
11541166
}
11551167

11561168
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t j=0; j<ARRAY_SIZE; ++j) {
1169+
#ifndef HURCHALLA_MONTGOMERY_POW_2KARY_USE_CSELECT_ON_BIT
11571170
V tmp = table[index][j];
11581171
tmp.cmov((index % 2 == 0), result[j]);
1172+
#else
1173+
V tmp = V::template cselect_on_bit_eq0<0>(static_cast<uint64_t>(index), result[j], table[index][j]);
1174+
#endif
11591175
result[j] = mf.template multiply<LowuopsTag>(tmp, result[j]);
11601176
}
11611177
}

0 commit comments

Comments
 (0)