Skip to content

Commit 6e58ea3

Browse files
committed
improve experimental pow functions
1 parent c1f1350 commit 6e58ea3

File tree

16 files changed

+234
-72
lines changed

16 files changed

+234
-72
lines changed

modular_arithmetic/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ include(FetchContent)
7575
FetchContent_Declare(
7676
hurchalla_util
7777
GIT_REPOSITORY https://github.com/hurchalla/util.git
78-
GIT_TAG 57c73322b067d003eed058642c0f482967ed8e56
78+
GIT_TAG e3a0fd02c86b67dcbf833fdd4ccf0732552f6e3e
7979
)
8080
FetchContent_MakeAvailable(hurchalla_util)
8181

montgomery_arithmetic/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ include(FetchContent)
7878
FetchContent_Declare(
7979
hurchalla_util
8080
GIT_REPOSITORY https://github.com/hurchalla/util.git
81-
GIT_TAG 57c73322b067d003eed058642c0f482967ed8e56
81+
GIT_TAG e3a0fd02c86b67dcbf833fdd4ccf0732552f6e3e
8282
)
8383
FetchContent_MakeAvailable(hurchalla_util)
8484

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/MontgomeryForm.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,12 +511,15 @@ class MontgomeryForm final {
511511
MontgomeryValue pow(MontgomeryValue base, T exponent) const
512512
{
513513
HPBC_CLOCKWORK_API_PRECONDITION(exponent >= 0);
514+
#if 1
514515
std::array<MontgomeryValue, 1> bases = {{ base }};
515516
std::array<MontgomeryValue, 1> result =
516517
detail::montgomery_array_pow<typename MontyType::MontyTag,
517518
MontgomeryForm>::pow(*this, bases, exponent);
518519
return result[0];
519-
//return detail::montgomery_pow<MontgomeryForm>::scalarpow(*this, base, exponent);
520+
#else
521+
return detail::montgomery_pow<MontgomeryForm>::scalarpow(*this, base, exponent);
522+
#endif
520523
}
521524

522525
// Calculates and returns the modular exponentiation of 2 (converted into a

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyFullRange.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "hurchalla/util/traits/ut_numeric_limits.h"
2020
#include "hurchalla/util/unsigned_multiply_to_hilo_product.h"
2121
#include "hurchalla/util/conditional_select.h"
22+
#include "hurchalla/util/cselect_on_bit.h"
2223
#include "hurchalla/util/compiler_macros.h"
2324
#include "hurchalla/modular_arithmetic/detail/clockwork_programming_by_contract.h"
2425
#include <type_traits>
@@ -38,6 +39,19 @@ struct MontyFRValueTypes {
3839
// regular montgomery value type
3940
struct V : public BaseMontgomeryValue<T> {
4041
HURCHALLA_FORCE_INLINE V() = default;
42+
43+
template <int BITNUM> HURCHALLA_FORCE_INLINE
44+
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
45+
{
46+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
47+
return V(sel);
48+
}
49+
template <int BITNUM> HURCHALLA_FORCE_INLINE
50+
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
51+
{
52+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
53+
return V(sel);
54+
}
4155
protected:
4256
template <typename> friend class MontyFullRange;
4357
HURCHALLA_FORCE_INLINE explicit V(T a) : BaseMontgomeryValue<T>(a) {}

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyHalfRange.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "hurchalla/util/traits/extensible_make_signed.h"
2222
#include "hurchalla/util/signed_multiply_to_hilo_product.h"
2323
#include "hurchalla/util/conditional_select.h"
24+
#include "hurchalla/util/cselect_on_bit.h"
2425
#include "hurchalla/util/compiler_macros.h"
2526
#include "hurchalla/modular_arithmetic/detail/clockwork_programming_by_contract.h"
2627
#include <type_traits>
@@ -53,6 +54,19 @@ struct MontyHRValueTypes {
5354
// regular montgomery value type
5455
struct V : public BaseMontgomeryValue<SignedT> {
5556
HURCHALLA_FORCE_INLINE V() = default;
57+
58+
template <int BITNUM> HURCHALLA_FORCE_INLINE
59+
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
60+
{
61+
SignedT sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
62+
return V(sel);
63+
}
64+
template <int BITNUM> HURCHALLA_FORCE_INLINE
65+
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
66+
{
67+
SignedT sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
68+
return V(sel);
69+
}
5670
protected:
5771
friend struct C;
5872
template <typename> friend class MontyHalfRange;

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyQuarterRange.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "hurchalla/util/traits/ut_numeric_limits.h"
2222
#include "hurchalla/util/traits/extensible_make_signed.h"
2323
#include "hurchalla/util/unsigned_multiply_to_hilo_product.h"
24+
#include "hurchalla/util/cselect_on_bit.h"
2425
#include "hurchalla/util/compiler_macros.h"
2526
#include "hurchalla/modular_arithmetic/detail/clockwork_programming_by_contract.h"
2627
#include <type_traits>
@@ -58,6 +59,19 @@ struct MontyQRValueTypes {
5859
// regular montgomery value type
5960
struct V : public BaseMontgomeryValue<T> {
6061
HURCHALLA_FORCE_INLINE V() = default;
62+
63+
template <int BITNUM> HURCHALLA_FORCE_INLINE
64+
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
65+
{
66+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
67+
return V(sel);
68+
}
69+
template <int BITNUM> HURCHALLA_FORCE_INLINE
70+
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
71+
{
72+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
73+
return V(sel);
74+
}
6175
protected:
6276
template <typename> friend class MontyQuarterRange;
6377
HURCHALLA_FORCE_INLINE explicit V(T a) : BaseMontgomeryValue<T>(a) {}

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyWrappedStandardMath.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "hurchalla/modular_arithmetic/absolute_value_difference.h"
2020
#include "hurchalla/util/traits/ut_numeric_limits.h"
2121
#include "hurchalla/modular_arithmetic/detail/clockwork_programming_by_contract.h"
22+
#include "hurchalla/util/cselect_on_bit.h"
2223
#include "hurchalla/util/compiler_macros.h"
2324
#include <type_traits>
2425

@@ -39,6 +40,19 @@ class MontyWrappedStandardMath final {
3940

4041
struct V : public BaseMontgomeryValue<T> { // regular montgomery value type
4142
HURCHALLA_FORCE_INLINE V() = default;
43+
44+
template <int BITNUM> HURCHALLA_FORCE_INLINE
45+
static V cselect_on_bit_ne0(uint64_t num, V v1, V v2)
46+
{
47+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, v1.get(), v2.get());
48+
return V(sel);
49+
}
50+
template <int BITNUM> HURCHALLA_FORCE_INLINE
51+
static V cselect_on_bit_eq0(uint64_t num, V v1, V v2)
52+
{
53+
T sel = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, v1.get(), v2.get());
54+
return V(sel);
55+
}
4256
protected:
4357
friend MontyWrappedStandardMath;
4458
HURCHALLA_FORCE_INLINE explicit V(T a) : BaseMontgomeryValue<T>(a) {}

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/MontyFullRangeMasked.h

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
#include "hurchalla/util/traits/ut_numeric_limits.h"
1919
#include "hurchalla/util/traits/safely_promote_unsigned.h"
2020
#include "hurchalla/util/conditional_select.h"
21+
#include "hurchalla/util/cselect_on_bit.h"
2122
#include "hurchalla/util/unsigned_multiply_to_hilo_product.h"
2223
#include "hurchalla/util/compiler_macros.h"
2324
#include "hurchalla/modular_arithmetic/detail/clockwork_programming_by_contract.h"
2425
#include <type_traits>
26+
#include <cstdint>
27+
#include <array>
2528

2629
namespace hurchalla { namespace detail {
2730

@@ -60,6 +63,68 @@ struct MfrmValueTypes {
6063
signmask = ::hurchalla::conditional_select<T, PerfTag>(
6164
cond, v.signmask, signmask);
6265
}
66+
// cselect_on_bit_ne0() for up to 64 bit T
67+
template <int BITNUM, typename Tp = T>
68+
HURCHALLA_FORCE_INLINE static
69+
typename std::enable_if<ut_numeric_limits<Tp>::digits <= 64, V>::type
70+
cselect_on_bit_ne0(uint64_t num, V v1, V v2)
71+
{
72+
std::array<uint64_t, 2> arg1 =
73+
{ static_cast<uint64_t>(v1.lowbits), static_cast<uint64_t>(v1.signmask) };
74+
std::array<uint64_t, 2> arg2 =
75+
{ static_cast<uint64_t>(v2.lowbits), static_cast<uint64_t>(v2.signmask) };
76+
std::array<uint64_t, 2> tmp = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, arg1, arg2);
77+
return V(static_cast<T>(tmp[0]), static_cast<T>(tmp[1]));
78+
}
79+
// cselect_on_bit_ne0() for 128 bit T
80+
template <int BITNUM, typename Tp = T>
81+
HURCHALLA_FORCE_INLINE static
82+
typename std::enable_if<(ut_numeric_limits<Tp>::digits > 64) &&
83+
(ut_numeric_limits<Tp>::digits <= 128), V>::type
84+
cselect_on_bit_ne0(uint64_t num, V v1, V v2)
85+
{
86+
std::array<uint64_t, 4> arg1 =
87+
{ static_cast<uint64_t>(v1.lowbits), static_cast<uint64_t>(v1.lowbits >> 64),
88+
static_cast<uint64_t>(v1.signmask), static_cast<uint64_t>(v1.signmask >> 64) };
89+
std::array<uint64_t, 4> arg2 =
90+
{ static_cast<uint64_t>(v2.lowbits), static_cast<uint64_t>(v2.lowbits >> 64),
91+
static_cast<uint64_t>(v2.signmask), static_cast<uint64_t>(v2.signmask >> 64) };
92+
std::array<uint64_t, 4> tmp = ::hurchalla::cselect_on_bit<BITNUM>::ne_0(num, arg1, arg2);
93+
T bits = (static_cast<T>(tmp[1]) << 64) | static_cast<T>(tmp[0]);
94+
T smask = (static_cast<T>(tmp[3]) << 64) | static_cast<T>(tmp[2]);
95+
return V(bits, smask);
96+
}
97+
// cselect_on_bit_eq0() for up to 64 bit T
98+
template <int BITNUM, typename Tp = T>
99+
HURCHALLA_FORCE_INLINE static
100+
typename std::enable_if<ut_numeric_limits<Tp>::digits <= 64, V>::type
101+
cselect_on_bit_eq0(uint64_t num, V v1, V v2)
102+
{
103+
std::array<uint64_t, 2> arg1 =
104+
{ static_cast<uint64_t>(v1.lowbits), static_cast<uint64_t>(v1.signmask) };
105+
std::array<uint64_t, 2> arg2 =
106+
{ static_cast<uint64_t>(v2.lowbits), static_cast<uint64_t>(v2.signmask) };
107+
std::array<uint64_t, 2> tmp = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, arg1, arg2);
108+
return V(static_cast<T>(tmp[0]), static_cast<T>(tmp[1]));
109+
}
110+
// cselect_on_bit_eq0() for 128 bit T
111+
template <int BITNUM, typename Tp = T>
112+
HURCHALLA_FORCE_INLINE static
113+
typename std::enable_if<(ut_numeric_limits<Tp>::digits > 64) &&
114+
(ut_numeric_limits<Tp>::digits <= 128), V>::type
115+
cselect_on_bit_eq0(uint64_t num, V v1, V v2)
116+
{
117+
std::array<uint64_t, 4> arg1 =
118+
{ static_cast<uint64_t>(v1.lowbits), static_cast<uint64_t>(v1.lowbits >> 64),
119+
static_cast<uint64_t>(v1.signmask), static_cast<uint64_t>(v1.signmask >> 64) };
120+
std::array<uint64_t, 4> arg2 =
121+
{ static_cast<uint64_t>(v2.lowbits), static_cast<uint64_t>(v2.lowbits >> 64),
122+
static_cast<uint64_t>(v2.signmask), static_cast<uint64_t>(v2.signmask >> 64) };
123+
std::array<uint64_t, 4> tmp = ::hurchalla::cselect_on_bit<BITNUM>::eq_0(num, arg1, arg2);
124+
T bits = (static_cast<T>(tmp[1]) << 64) | static_cast<T>(tmp[0]);
125+
T smask = (static_cast<T>(tmp[3]) << 64) | static_cast<T>(tmp[2]);
126+
return V(bits, smask);
127+
}
63128
protected:
64129
friend struct C;
65130
friend struct FV;
@@ -300,13 +365,13 @@ class MontyFullRangeMasked final :
300365
return x;
301366
}
302367
template <class PTAG> HURCHALLA_FORCE_INLINE
303-
SV squareSV(SV sv) const
368+
SV squareSV(SV sv, PTAG) const
304369
{
305370
static_assert(std::is_same<V, SV>::value, "");
306371
return BC::square(sv, PTAG());
307372
}
308373
template <class PTAG> HURCHALLA_FORCE_INLINE
309-
V squareToMontgomeryValue(SV sv) const
374+
V squareToMontgomeryValue(SV sv, PTAG) const
310375
{
311376
static_assert(std::is_same<V, SV>::value, "");
312377
return BC::square(sv, PTAG());

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_pow_2kary/experimental_montgomery_pow_2kary.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
573573
static_assert(TABLESIZE >= 2 && TABLESIZE % 2 == 0, "");
574574
constexpr size_t MASK = TABLESIZE - 1;
575575

576-
constexpr int NUM_TABLES = CODE_SECTION - 7;
576+
constexpr size_t NUM_TABLES = CODE_SECTION - 7;
577577
static_assert(NUM_TABLES > 0, "");
578578
constexpr int NUMBITS_MASKBIG = NUM_TABLES * TABLE_BITS;
579579
static_assert(std::numeric_limits<size_t>::digits > NUMBITS_MASKBIG, "");
@@ -610,7 +610,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
610610
V result = table[0][tmp & MASK];
611611

612612

613-
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=1; k < NUM_TABLES; ++k) {
613+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t k=1; k < NUM_TABLES; ++k) {
614614
table[k][0] = mf.getUnityValue();
615615
table[k][1] = mf.square(table[k - 1][TABLESIZE / 2]);
616616
if HURCHALLA_CPP17_CONSTEXPR (TABLESIZE >= 4) {
@@ -654,7 +654,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
654654
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i<TABLE_BITS - 1; ++i)
655655
sv = MFE::squareSV(mf, sv);
656656

657-
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=1; k<NUM_TABLES; ++k) {
657+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t k=1; k<NUM_TABLES; ++k) {
658658
size_t index = (tmp >> (k * TABLE_BITS)) & MASK;
659659
val1 = mf.template multiply<LowuopsTag>(val1, table[k][index]);
660660

@@ -681,7 +681,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
681681
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t i=0; i<TABLE_BITS; ++i)
682682
result = mf.square(result);
683683

684-
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=1; k<NUM_TABLES; ++k) {
684+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t k=1; k<NUM_TABLES; ++k) {
685685
size_t index = (tmp >> (k * TABLE_BITS)) & MASK;
686686
val1 = mf.template multiply<LowuopsTag>(val1, table[k][index]);
687687

@@ -703,7 +703,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
703703
V val1 = table[0][tmp & MASK];
704704

705705
if HURCHALLA_CPP17_CONSTEXPR (USE_SQUARING_VALUE_OPTIMIZATION) {
706-
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=1; k<NUM_TABLES; ++k) {
706+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t k=1; k<NUM_TABLES; ++k) {
707707
size_t index = (tmp >> (k * TABLE_BITS)) & MASK;
708708
val1 = mf.template multiply<LowuopsTag>(val1, table[k][index]);
709709
}
@@ -726,7 +726,7 @@ if HURCHALLA_CPP17_CONSTEXPR (CODE_SECTION == 0) {
726726

727727
result = mf.square(result);
728728

729-
HURCHALLA_REQUEST_UNROLL_LOOP for (int k=2; k<NUM_TABLES; ++k) {
729+
HURCHALLA_REQUEST_UNROLL_LOOP for (size_t k=2; k<NUM_TABLES; ++k) {
730730
index = (tmp >> (k * TABLE_BITS)) & MASK;
731731
val1 = mf.template multiply<LowuopsTag>(val1, table[k][index]);
732732
}

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/experimental/montgomery_pow_2kary/testbench_2kary.sh

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ define_mont_type=-DDEF_MONT_TYPE=$3
2727
define_uint_type=-DDEF_UINT_TYPE=$4
2828

2929

30-
# argument $8 (if present), should be -DHURCHALLA_ALLOW_INLINE_ASM_ALL
31-
define_use_asm=$8
32-
33-
3430
cpp_standard=c++17
3531

3632

@@ -54,19 +50,18 @@ fi
5450

5551

5652

57-
# argument $8 (if present), should be -DHURCHALLA_ALLOW_INLINE_ASM_ALL
58-
59-
60-
# to debug you can compile with the below also
53+
# You can use arguments $8 and $9 and ${10} etc to define macros such as
54+
# -DHURCHALLA_ALLOW_INLINE_ASM_ALL
55+
# for debugging, defining the following macros may be useful
6156
# -DHURCHALLA_CLOCKWORK_ENABLE_ASSERTS -DHURCHALLA_UTIL_ENABLE_ASSERTS
6257

63-
# we could also use -g to get debug symbols (for lldb/gdb, and objdump)
6458

59+
# we could also use -g to get debug symbols (for lldb/gdb, and objdump)
6560

6661
$cppcompiler \
6762
$error_limit -$optimization_level \
68-
$define_mont_type $define_uint_type $define_use_asm \
69-
-Wall -Wextra -Wpedantic $warn_nrvo \
63+
$define_mont_type $define_uint_type $8 $9 ${10} ${11} ${12} ${13} ${14} \
64+
-Wall -Wextra -Wpedantic -Wconversion -Wsign-conversion $warn_nrvo \
7065
-std=$cpp_standard \
7166
-I${repo_directory}/modular_arithmetic/modular_arithmetic/include \
7267
-I${repo_directory}/modular_arithmetic/montgomery_arithmetic/include \

0 commit comments

Comments
 (0)