Skip to content

Commit cc73cd1

Browse files
committed
add ARM64 asm, and improve montgomery two pow
1 parent 84f8feb commit cc73cd1

File tree

13 files changed

+846
-299
lines changed

13 files changed

+846
-299
lines changed

modular_arithmetic/include/hurchalla/modular_arithmetic/detail/platform_specific/impl_absolute_value_difference.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,12 @@ struct impl_absolute_value_difference_unsigned<std::uint32_t> {
147147

148148

149149

150-
#if 0 // dont enable arm64 assembly yet
151-
/*
150+
152151
// ARM64
153152
// MSVC doesn't support inline asm so we skip it.
154153
#if (defined(HURCHALLA_ALLOW_INLINE_ASM_ALL) || \
155154
defined(HURCHALLA_ALLOW_INLINE_ASM_ABSDIFF)) && \
156155
defined(HURCHALLA_TARGET_ISA_ARM_64) && !defined(_MSC_VER)
157-
*/
158156
# if (HURCHALLA_COMPILER_HAS_UINT128_T())
159157
template <>
160158
struct impl_absolute_value_difference_unsigned<__uint128_t> {
@@ -170,14 +168,16 @@ struct impl_absolute_value_difference_unsigned<__uint128_t> {
170168
uint64_t diffhi = static_cast<uint64_t>(diff >> 64);
171169
uint64_t blo = static_cast<uint64_t>(b);
172170
uint64_t bhi = static_cast<uint64_t>(b >> 64);
173-
__asm__ ("subs %[alo], %[alo], %[blo] \n\t" /* tmp = a - b */
174-
"sbcs %[ahi], %[ahi], %[bhi] \n\t"
175-
"csel %[alo], %[difflo], %[alo], lo \n\t" /* tmp = (a < b) ? diff : tmp */
176-
"csel %[ahi], %[diffhi], %[ahi], lo \n\t"
177-
: [alo]"+&r"(alo), [ahi]"+&r"(ahi)
178-
: [blo]"r"(blo), [bhi]"r"(bhi), [difflo]"r"(difflo), [diffhi]"r"(diffhi)
171+
uint64_t reslo;
172+
uint64_t reshi;
173+
__asm__ ("subs %[reslo], %[alo], %[blo] \n\t" /* res = a - b */
174+
"sbcs %[reshi], %[ahi], %[bhi] \n\t"
175+
"csel %[reslo], %[difflo], %[reslo], lo \n\t" /* res = (a < b) ? diff : res */
176+
"csel %[reshi], %[diffhi], %[reshi], lo \n\t"
177+
: [reslo]"=&r"(reslo), [reshi]"=&r"(reshi)
178+
: [alo]"r"(alo), [ahi]"r"(ahi), [blo]"r"(blo), [bhi]"r"(bhi), [difflo]"r"(difflo), [diffhi]"r"(diffhi)
179179
: "cc");
180-
__uint128_t result = (static_cast<__uint128_t>(ahi) << 64) | alo;
180+
__uint128_t result = (static_cast<__uint128_t>(reshi) << 64) | reslo;
181181

182182
HPBC_POSTCONDITION2(result<=a || result<=b);
183183
HPBC_POSTCONDITION2(result == default_impl_absdiff_unsigned::call(a, b));
@@ -197,7 +197,7 @@ struct impl_absolute_value_difference_unsigned<std::uint64_t> {
197197
__asm__ ("subs %[tmp], %[a], %[b] \n\t" /* tmp = a - b */
198198
"csel %[tmp], %[diff], %[tmp], lo \n\t" /* tmp = (a < b) ? diff : tmp */
199199
: [tmp]"=&r"(tmp)
200-
: [b]"r"(b), [diff]"r"(diff)
200+
: [a]"r"(a), [b]"r"(b), [diff]"r"(diff)
201201
: "cc");
202202
uint64_t result = tmp;
203203

modular_arithmetic/include/hurchalla/modular_arithmetic/detail/platform_specific/impl_modular_addition.h

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ struct impl_modular_addition_unsigned {
125125
defined(HURCHALLA_ALLOW_INLINE_ASM_MODADD)) && \
126126
defined(HURCHALLA_TARGET_ISA_X86_64) && !defined(_MSC_VER)
127127

128-
129128
// Note: these functions contain the calculation "b - modulus". If neither 'b'
130129
// nor 'modulus' was recently set/modified, then "b - modulus" will usually be
131130
// calculated at the same time as earlier work by the CPU, or in a loop it could
@@ -244,6 +243,95 @@ struct impl_modular_addition_unsigned<std::uint32_t> {
244243
#endif
245244

246245

246+
247+
248+
// ARM64
249+
// MSVC doesn't support inline asm so we skip it.
250+
#if (defined(HURCHALLA_ALLOW_INLINE_ASM_ALL) || \
251+
defined(HURCHALLA_ALLOW_INLINE_ASM_MODADD)) && \
252+
defined(HURCHALLA_TARGET_ISA_ARM_64) && !defined(_MSC_VER)
253+
# if (HURCHALLA_COMPILER_HAS_UINT128_T())
254+
template <>
255+
struct impl_modular_addition_unsigned<__uint128_t> {
256+
HURCHALLA_FORCE_INLINE static
257+
__uint128_t call(__uint128_t a, __uint128_t b, __uint128_t modulus)
258+
{
259+
using std::uint64_t;
260+
HPBC_PRECONDITION2(modulus>0);
261+
HPBC_PRECONDITION2(a<modulus); // __uint128_t guarantees a>=0.
262+
HPBC_PRECONDITION2(b<modulus); // __uint128_t guarantees b>=0.
263+
264+
__uint128_t tmp = static_cast<__uint128_t>(b - modulus);
265+
__uint128_t sum = static_cast<__uint128_t>(a + b);
266+
uint64_t alo = static_cast<uint64_t>(a);
267+
uint64_t ahi = static_cast<uint64_t>(a >> 64);
268+
uint64_t tmplo = static_cast<uint64_t>(tmp);
269+
uint64_t tmphi = static_cast<uint64_t>(tmp >> 64);
270+
uint64_t sumlo = static_cast<uint64_t>(sum);
271+
uint64_t sumhi = static_cast<uint64_t>(sum >> 64);
272+
uint64_t reslo;
273+
uint64_t reshi;
274+
__asm__ ("adds %[reslo], %[alo], %[tmplo] \n\t" /* res = a + tmp */
275+
"adcs %[reshi], %[ahi], %[tmphi] \n\t"
276+
"csel %[reslo], %[sumlo], %[reslo], lo \n\t" /* res = (res>=a) ? sum : res */
277+
"csel %[reshi], %[sumhi], %[reshi], lo \n\t"
278+
: [reslo]"=&r"(reslo), [reshi]"=&r"(reshi)
279+
: [alo]"r"(alo), [ahi]"r"(ahi), [tmplo]"r"(tmplo), [tmphi]"r"(tmphi), [sumlo]"r"(sumlo), [sumhi]"r"(sumhi)
280+
: "cc");
281+
__uint128_t result = (static_cast<__uint128_t>(reshi) << 64) | reslo;
282+
283+
HPBC_POSTCONDITION2(result < modulus); // __uint128_t guarantees result>=0.
284+
HPBC_POSTCONDITION2(result ==
285+
default_impl_modadd_unsigned::call(a, b, modulus));
286+
return result;
287+
}
288+
};
289+
# endif
290+
291+
template <>
292+
struct impl_modular_addition_unsigned<std::uint64_t> {
293+
HURCHALLA_FORCE_INLINE static
294+
std::uint64_t call(std::uint64_t a, std::uint64_t b, std::uint64_t modulus)
295+
{
296+
using std::uint64_t;
297+
HPBC_PRECONDITION2(modulus>0);
298+
HPBC_PRECONDITION2(a<modulus); // uint64_t guarantees a>=0.
299+
HPBC_PRECONDITION2(b<modulus); // uint64_t guarantees b>=0.
300+
301+
uint64_t sum = static_cast<uint64_t>(a + b);
302+
uint64_t tmp = static_cast<uint64_t>(b - modulus);
303+
uint64_t res;
304+
__asm__ ("adds %[res], %[a], %[tmp] \n\t" /* res = a + tmp */
305+
"csel %[res], %[sum], %[res], lo \n\t" /* res = (res>=a) ? sum : res */
306+
: [res]"=&r"(res)
307+
: [a]"r"(a), [tmp]"r"(tmp), [sum]"r"(sum)
308+
: "cc");
309+
uint64_t result = res;
310+
311+
HPBC_POSTCONDITION2(result < modulus); // uint64_t guarantees result>=0.
312+
HPBC_POSTCONDITION2(result ==
313+
default_impl_modadd_unsigned::call(a, b, modulus));
314+
return result;
315+
}
316+
};
317+
318+
template <>
319+
struct impl_modular_addition_unsigned<std::uint32_t> {
320+
using U = std::uint32_t;
321+
HURCHALLA_FORCE_INLINE static U call(U a, U b, U modulus)
322+
{
323+
std::uint64_t result = impl_modular_addition_unsigned
324+
<std::uint64_t>::call(a, b, modulus);
325+
return static_cast<U>(result);
326+
}
327+
};
328+
329+
// end of inline asm functions for ARM_64
330+
#endif
331+
332+
333+
334+
247335
template <>
248336
struct impl_modular_addition_unsigned<std::uint16_t> {
249337
using U = std::uint16_t;

modular_arithmetic/include/hurchalla/modular_arithmetic/detail/platform_specific/impl_modular_subtraction.h

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -425,14 +425,12 @@ struct impl_modular_subtraction_unsigned<std::uint32_t, LowlatencyTag> {
425425

426426

427427

428-
#if 0 // dont enable arm64 assembly yet
429-
/*
428+
430429
// ARM64
431430
// MSVC doesn't support inline asm so we skip it.
432431
#if (defined(HURCHALLA_ALLOW_INLINE_ASM_ALL) || \
433432
defined(HURCHALLA_ALLOW_INLINE_ASM_MODSUB)) && \
434433
defined(HURCHALLA_TARGET_ISA_ARM_64) && !defined(_MSC_VER)
435-
*/
436434

437435
# if (HURCHALLA_COMPILER_HAS_UINT128_T())
438436
template <>
@@ -445,6 +443,8 @@ struct impl_modular_subtraction_unsigned<__uint128_t, LowuopsTag> {
445443
HPBC_PRECONDITION2(a<modulus); // __uint128_t guarantees a>=0.
446444
HPBC_PRECONDITION2(b<modulus); // __uint128_t guarantees b>=0.
447445

446+
uint64_t difflo;
447+
uint64_t diffhi;
448448
uint64_t mozlo;
449449
uint64_t mozhi;
450450
uint64_t alo = static_cast<uint64_t>(a);
@@ -453,15 +453,13 @@ struct impl_modular_subtraction_unsigned<__uint128_t, LowuopsTag> {
453453
uint64_t bhi = static_cast<uint64_t>(b >> 64);
454454
uint64_t mlo = static_cast<uint64_t>(modulus);
455455
uint64_t mhi = static_cast<uint64_t>(modulus >> 64);
456-
__asm__ ("subs %[alo], %[alo], %[blo] \n\t" /* diff = a - b */
457-
"sbcs %[ahi], %[ahi], %[bhi] \n\t"
456+
__asm__ ("subs %[difflo], %[alo], %[blo] \n\t" /* diff = a - b */
457+
"sbcs %[diffhi], %[ahi], %[bhi] \n\t"
458458
"csel %[mozlo], %[mlo], xzr, lo \n\t" /* mozlo = (a<b) ? mlo : 0 */
459459
"csel %[mozhi], %[mhi], xzr, lo \n\t" /* mozhi = (a<b) ? mhi : 0 */
460-
: [alo]"+&r"(alo), [ahi]"+&r"(ahi), [mozlo]"=&r"(mozlo), [mozhi]"=r"(mozhi)
461-
: [blo]"r"(blo), [bhi]"r"(bhi), [mlo]"r"(mlo), [mhi]"r"(mhi)
460+
: [difflo]"=&r"(difflo), [diffhi]"=&r"(diffhi), [mozlo]"=&r"(mozlo), [mozhi]"=r"(mozhi)
461+
: [alo]"r"(alo), [ahi]"r"(ahi), [blo]"r"(blo), [bhi]"r"(bhi), [mlo]"r"(mlo), [mhi]"r"(mhi)
462462
: "cc");
463-
uint64_t difflo = alo;
464-
uint64_t diffhi = ahi;
465463
__uint128_t diff = (static_cast<__uint128_t>(diffhi) << 64) | difflo;
466464
__uint128_t moz = (static_cast<__uint128_t>(mozhi) << 64) | mozlo;
467465

@@ -485,14 +483,15 @@ struct impl_modular_subtraction_unsigned<std::uint64_t, LowuopsTag> {
485483
HPBC_PRECONDITION2(a<modulus); // uint64_t guarantees a>=0.
486484
HPBC_PRECONDITION2(b<modulus); // uint64_t guarantees b>=0.
487485

488-
uint64_t tmp;
489-
uint64_t result;
490-
__asm__ ("subs %[tmp], %[a], %[b] \n\t" /* tmp = a - b */
491-
"add %[res], %[tmp], %[m] \n\t" /* res = tmp + modulus */
492-
"csel %[res], %[res], %[tmp], lo \n\t" /* res = (a<b) ? res : tmp */
493-
: [tmp]"=&r"(tmp), [res]"=r"(result)
494-
: [m]"r"(modulus), [b]"r"(b)
486+
uint64_t diff;
487+
uint64_t res;
488+
__asm__ ("subs %[diff], %[a], %[b] \n\t" /* diff = a - b */
489+
"add %[res], %[diff], %[m] \n\t" /* res = diff + modulus */
490+
"csel %[res], %[res], %[diff], lo \n\t" /* res = (a<b) ? res : diff */
491+
: [diff]"=&r"(diff), [res]"=r"(res)
492+
: [a]"r"(a), [b]"r"(b), [m]"r"(modulus)
495493
: "cc");
494+
uint64_t result = res;
496495

497496
HPBC_POSTCONDITION2(result < modulus); // uint64_t guarantees result>=0.
498497
HPBC_POSTCONDITION2(result ==
@@ -553,12 +552,14 @@ struct impl_modular_subtraction_unsigned<std::uint64_t, LowlatencyTag> {
553552

554553
uint64_t diff = b - modulus;
555554
uint64_t tmp = a - diff;
556-
uint64_t result;
555+
556+
uint64_t res;
557557
__asm__ ("subs %[res], %[a], %[b] \n\t" /* res = a - b */
558558
"csel %[res], %[tmp], %[res], lo \n\t" /* res = (a<b) ? tmp : res */
559-
: [res]"=&r"(result)
559+
: [res]"=&r"(res)
560560
: [a]"r"(a), [b]"r"(b), [tmp]"r"(tmp)
561561
: "cc");
562+
uint64_t result = res;
562563

563564
HPBC_POSTCONDITION2(result < modulus); // uint64_t guarantees result>=0.
564565
HPBC_POSTCONDITION2(result ==

montgomery_arithmetic/include/hurchalla/montgomery_arithmetic/detail/MontyHalfRange.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ class MontyHalfRange final :
139139
HPBC_PRECONDITION2(isValid(x));
140140
S sx = x.get();
141141
S result = static_cast<S>(-sx);
142+
if (result == static_cast<S>(n_))
143+
result = 0;
142144
HPBC_POSTCONDITION2(isValid(V(result)));
143145
HPBC_POSTCONDITION2(getCanonicalValue(V(result)) ==
144146
getCanonicalValue(subtract(C(0), x, LowuopsTag())));

0 commit comments

Comments
 (0)