Skip to content

Commit cb16f5c

Browse files
committed
Add intel simd
1 parent 7a05d18 commit cb16f5c

File tree

3 files changed

+197
-14
lines changed

3 files changed

+197
-14
lines changed

src/field_5x52_impl.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
#include "field_5x52_int128_impl.h"
1616

17+
#ifdef X86
18+
# include <immintrin.h>
19+
#endif
20+
1721
#ifdef VERIFY
1822
static void secp256k1_fe_impl_verify(const secp256k1_fe *a) {
1923
const uint64_t *d = a->n;
@@ -37,10 +41,15 @@ static void secp256k1_fe_impl_get_bounds(secp256k1_fe *r, int m) {
3741
const uint64_t bound1 = 0xFFFFFFFFFFFFFULL * two_m;
3842
const uint64_t bound2 = 0x0FFFFFFFFFFFFULL * two_m;
3943

44+
#ifdef __AVX__
45+
__m256i vec = _mm256_set1_epi64x(bound1);
46+
_mm256_storeu_si256((__m256i *)r->n, vec);
47+
#else
4048
r->n[0] = bound1;
4149
r->n[1] = bound1;
4250
r->n[2] = bound1;
4351
r->n[3] = bound1;
52+
#endif
4453
r->n[4] = bound2;
4554
}
4655

@@ -239,6 +248,8 @@ static void secp256k1_fe_impl_set_b32_mod(secp256k1_fe *r, const unsigned char *
239248
limbs[3] = BYTESWAP_64(limbs[3]);
240249
#endif
241250

251+
/* TODO: parallelize avx2 */
252+
242253
r->n[0] = (limbs[3] & 0xFFFFFFFFFFFFFULL);
243254
r->n[1] = (limbs[3] >> 52) | ((limbs[2] & 0xFFFFFFFFFFULL) << 12);
244255
r->n[2] = (limbs[2] >> 40) | ((limbs[1] & 0xFFFFFFFULL) << 24);
@@ -291,6 +302,10 @@ static void secp256k1_fe_impl_get_b32(unsigned char *r, const secp256k1_fe *a) {
291302
}
292303

293304
SECP256K1_INLINE static void secp256k1_fe_impl_negate_unchecked(secp256k1_fe *r, const secp256k1_fe *a, int m) {
305+
#if defined(__AVX__) && defined(__AVX2__)
306+
/* load here to mitigate load latency */
307+
__m256i vec_a = _mm256_loadu_si256((__m256i *)a->n);
308+
#endif
294309
const uint32_t two_m1 = 2 * (m + 1);
295310
const uint64_t bound1 = 0xFFFFEFFFFFC2FULL * two_m1;
296311
const uint64_t bound2 = 0xFFFFFFFFFFFFFULL * two_m1;
@@ -303,10 +318,18 @@ SECP256K1_INLINE static void secp256k1_fe_impl_negate_unchecked(secp256k1_fe *r,
303318

304319
/* Due to the properties above, the left hand in the subtractions below is never less than
305320
* the right hand. */
321+
#if defined(__AVX__) && defined(__AVX2__)
322+
{
323+
__m256i vec_bounds = _mm256_setr_epi64x(bound1, bound2, bound2, bound2);
324+
__m256i out = _mm256_sub_epi64(vec_bounds, vec_a);
325+
_mm256_storeu_si256((__m256i *)r->n, out);
326+
}
327+
#else
306328
r->n[0] = bound1 - a->n[0];
307329
r->n[1] = bound2 - a->n[1];
308330
r->n[2] = bound2 - a->n[2];
309331
r->n[3] = bound2 - a->n[3];
332+
#endif
310333
r->n[4] = bound3 - a->n[4];
311334
}
312335

@@ -339,15 +362,32 @@ SECP256K1_INLINE static void secp256k1_fe_impl_sqr(secp256k1_fe *r, const secp25
339362
}
340363

341364
SECP256K1_INLINE static void secp256k1_fe_impl_cmov(secp256k1_fe *r, const secp256k1_fe *a, int flag) {
365+
#if defined(__AVX__) && defined(__AVX2__)
366+
/* load here to mitigate load latency */
367+
__m256i vec_r = _mm256_loadu_si256((__m256i *)(r->n));
368+
__m256i vec_a = _mm256_loadu_si256((__m256i *)(a->n));
369+
#endif
370+
342371
uint64_t mask0, mask1;
343372
volatile int vflag = flag;
344373
SECP256K1_CHECKMEM_CHECK_VERIFY(r->n, sizeof(r->n));
345374
mask0 = vflag + ~((uint64_t)0);
346375
mask1 = ~mask0;
376+
377+
#if defined(__AVX__) && defined(__AVX2__)
378+
{
379+
__m256i vec_mask0 = _mm256_set1_epi64x(mask0);
380+
__m256i vec_mask1 = _mm256_set1_epi64x(mask1);
381+
vec_r = _mm256_and_si256(vec_r, vec_mask0);
382+
vec_a = _mm256_and_si256(vec_a, vec_mask1);
383+
_mm256_storeu_si256((__m256i *)r->n, _mm256_or_si256(vec_r, vec_a));
384+
}
385+
#else
347386
r->n[0] = (r->n[0] & mask0) | (a->n[0] & mask1);
348387
r->n[1] = (r->n[1] & mask0) | (a->n[1] & mask1);
349388
r->n[2] = (r->n[2] & mask0) | (a->n[2] & mask1);
350389
r->n[3] = (r->n[3] & mask0) | (a->n[3] & mask1);
390+
#endif
351391
r->n[4] = (r->n[4] & mask0) | (a->n[4] & mask1);
352392
}
353393

@@ -418,19 +458,42 @@ static SECP256K1_INLINE void secp256k1_fe_storage_cmov(secp256k1_fe_storage *r,
418458
}
419459

420460
static void secp256k1_fe_impl_to_storage(secp256k1_fe_storage *r, const secp256k1_fe *a) {
461+
#if defined(__AVX__) && defined(__AVX2__)
462+
__m256i limbs_0123 = _mm256_loadu_si256((__m256i *)a->n);
463+
__m256i limbs_1234 = _mm256_loadu_si256((__m256i *)(a->n + 1));
464+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 12, 24, 36); /* TODO: precompute */
465+
const __m256i shift_rhs = _mm256_setr_epi64x(52, 40, 28, 16); /* TODO: precompute */
466+
__m256i rhs = _mm256_sllv_epi64(limbs_1234, shift_rhs);
467+
__m256i lhs = _mm256_srlv_epi64(limbs_0123, shift_lhs);
468+
_mm256_storeu_si256((__m256i *)r->n, _mm256_or_si256(lhs, rhs));
469+
#else
421470
r->n[0] = a->n[0] | a->n[1] << 52;
422471
r->n[1] = a->n[1] >> 12 | a->n[2] << 40;
423472
r->n[2] = a->n[2] >> 24 | a->n[3] << 28;
424473
r->n[3] = a->n[3] >> 36 | a->n[4] << 16;
474+
#endif
425475
}
426476

427477
static SECP256K1_INLINE void secp256k1_fe_impl_from_storage(secp256k1_fe *r, const secp256k1_fe_storage *a) {
428478
const uint64_t a0 = a->n[0], a1 = a->n[1], a2 = a->n[2], a3 = a->n[3];
429479

480+
#if defined(__AVX__) && defined(__AVX2__)
481+
{
482+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
483+
__m256i limbs_0012 = _mm256_setr_epi64x(a0, a0, a1, a2);
484+
const __m256i shift_lhs = _mm256_setr_epi64x(64, 52, 40, 28); /* TODO: precompute */
485+
const __m256i shift_rhs = _mm256_setr_epi64x(0, 12, 24, 36); /* TODO: precompute */
486+
const __m256i mask52 = _mm256_set1_epi64x(0xFFFFFFFFFFFFFULL); /* TODO: precompute */
487+
__m256i rhs = _mm256_and_si256(_mm256_sllv_epi64(limbs_0123, shift_rhs), mask52);
488+
__m256i lhs = _mm256_srlv_epi64(limbs_0012, shift_lhs);
489+
_mm256_storeu_si256((__m256i*)r->n, _mm256_or_si256(lhs, rhs));
490+
}
491+
#else
430492
r->n[0] = a0 & 0xFFFFFFFFFFFFFULL;
431493
r->n[1] = a0 >> 52 | ((a1 << 12) & 0xFFFFFFFFFFFFFULL);
432494
r->n[2] = a1 >> 40 | ((a2 << 24) & 0xFFFFFFFFFFFFFULL);
433495
r->n[3] = a2 >> 28 | ((a3 << 36) & 0xFFFFFFFFFFFFFULL);
496+
#endif
434497
r->n[4] = a3 >> 16;
435498
}
436499

@@ -447,21 +510,49 @@ static void secp256k1_fe_from_signed62(secp256k1_fe *r, const secp256k1_modinv64
447510
VERIFY_CHECK(a3 >> 62 == 0);
448511
VERIFY_CHECK(a4 >> 8 == 0);
449512

513+
#if defined(__AVX__) && defined(__AVX2__)
514+
{
515+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
516+
__m256i limbs_0012 = _mm256_setr_epi64x(a0, a0, a1, a2);
517+
const __m256i shift_lhs = _mm256_setr_epi64x(64, 52, 42, 32); /*TODO: precompute */
518+
const __m256i shift_rhs = _mm256_setr_epi64x(0, 10, 20, 30); /*TODO: precompute */
519+
const __m256i mask52 = _mm256_set1_epi64x(M52); /*TODO: precompute */
520+
__m256i rhs = _mm256_sllv_epi64(limbs_0123, shift_rhs);
521+
__m256i lhs = _mm256_srlv_epi64(limbs_0012, shift_lhs);
522+
__m256i out = _mm256_or_si256(lhs, rhs);
523+
_mm256_storeu_si256((__m256i*)r->n, _mm256_and_si256(out, mask52));
524+
}
525+
#else
450526
r->n[0] = a0 & M52;
451527
r->n[1] = (a0 >> 52 | a1 << 10) & M52;
452528
r->n[2] = (a1 >> 42 | a2 << 20) & M52;
453529
r->n[3] = (a2 >> 32 | a3 << 30) & M52;
530+
#endif
454531
r->n[4] = (a3 >> 22 | a4 << 40);
455532
}
456533

457534
static void secp256k1_fe_to_signed62(secp256k1_modinv64_signed62 *r, const secp256k1_fe *a) {
458535
const uint64_t M62 = UINT64_MAX >> 2;
459536
const uint64_t a0 = a->n[0], a1 = a->n[1], a2 = a->n[2], a3 = a->n[3], a4 = a->n[4];
460537

538+
#if defined(__AVX__) && defined(__AVX2__)
539+
{
540+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
541+
__m256i limbs_1234 = _mm256_setr_epi64x(a1, a2, a3, a4);
542+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 10, 20, 30); /*TODO: precompute */
543+
const __m256i shift_rhs = _mm256_setr_epi64x(52, 42, 32, 22); /*TODO: precompute */
544+
const __m256i mask62 = _mm256_set1_epi64x(M62); /*TODO: precompute */
545+
__m256i lhs = _mm256_srlv_epi64(limbs_0123, shift_lhs);
546+
__m256i rhs = _mm256_sllv_epi64(limbs_1234, shift_rhs);
547+
__m256i out = _mm256_or_si256(lhs, rhs);
548+
_mm256_storeu_si256((__m256i *)r->v, _mm256_and_si256(out, mask62));
549+
}
550+
#else
461551
r->v[0] = (a0 | a1 << 52) & M62;
462552
r->v[1] = (a1 >> 10 | a2 << 42) & M62;
463553
r->v[2] = (a2 >> 20 | a3 << 32) & M62;
464554
r->v[3] = (a3 >> 30 | a4 << 22) & M62;
555+
#endif
465556
r->v[4] = a4 >> 40;
466557
}
467558

src/scalar_4x64_impl.h

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include "modinv64_impl.h"
1313
#include "util.h"
1414

15+
#ifdef X86
16+
# include <immintrin.h>
17+
#endif
18+
1519
/* Limbs of the secp256k1 order. */
1620
#define SECP256K1_N_0 ((uint64_t)0xBFD25E8CD0364141ULL)
1721
#define SECP256K1_N_1 ((uint64_t)0xBAAEDCE6AF48A03BULL)
@@ -143,10 +147,25 @@ static void secp256k1_scalar_cadd_bit(secp256k1_scalar *r, unsigned int bit, int
143147

144148
static void secp256k1_scalar_set_b32(secp256k1_scalar *r, const unsigned char *b32, int *overflow) {
145149
int over;
150+
151+
#if defined(__AVX__) && defined(__AVX2__)
152+
{
153+
__m256i vec_b32 = _mm256_loadu_si256((__m256i*)b32);
154+
vec_b32 = _mm256_permute4x64_epi64(vec_b32, _MM_SHUFFLE(0,1,2,3));
155+
const __m256i bswap_mask = _mm256_setr_epi8( /* TODO: precompute */
156+
7,6,5,4,3,2,1,0,
157+
15,14,13,12,11,10,9,8,
158+
23,22,21,20,19,18,17,16,
159+
31,30,29,28,27,26,25,24);
160+
__m256i output = _mm256_shuffle_epi8(vec_b32, bswap_mask);
161+
_mm256_storeu_si256((__m256i*)r->d, output);
162+
}
163+
#else
146164
r->d[0] = secp256k1_read_be64(&b32[24]);
147165
r->d[1] = secp256k1_read_be64(&b32[16]);
148166
r->d[2] = secp256k1_read_be64(&b32[8]);
149167
r->d[3] = secp256k1_read_be64(&b32[0]);
168+
#endif
150169
over = secp256k1_scalar_reduce(r, secp256k1_scalar_check_overflow(r));
151170
if (overflow) {
152171
*overflow = over;
@@ -157,16 +176,28 @@ static void secp256k1_scalar_set_b32(secp256k1_scalar *r, const unsigned char *b
157176

158177
static void secp256k1_scalar_get_b32(unsigned char *bin, const secp256k1_scalar* a) {
159178
SECP256K1_SCALAR_VERIFY(a);
160-
179+
#if defined(__AVX__) && defined(__AVX2__)
180+
{
181+
__m256i vec_a = _mm256_loadu_si256((__m256i*)a->d);
182+
vec_a = _mm256_permute4x64_epi64(vec_a, _MM_SHUFFLE(0,1,2,3));
183+
const __m256i bswap_mask = _mm256_setr_epi8( /* TODO: precompute */
184+
7,6,5,4,3,2,1,0,
185+
15,14,13,12,11,10,9,8,
186+
23,22,21,20,19,18,17,16,
187+
31,30,29,28,27,26,25,24);
188+
__m256i output = _mm256_shuffle_epi8(vec_a, bswap_mask);
189+
_mm256_storeu_si256((__m256i*)bin, output);
190+
}
191+
#else
161192
secp256k1_write_be64(&bin[0], a->d[3]);
162193
secp256k1_write_be64(&bin[8], a->d[2]);
163194
secp256k1_write_be64(&bin[16], a->d[1]);
164195
secp256k1_write_be64(&bin[24], a->d[0]);
196+
#endif
165197
}
166198

167199
SECP256K1_INLINE static int secp256k1_scalar_is_zero(const secp256k1_scalar *a) {
168200
SECP256K1_SCALAR_VERIFY(a);
169-
170201
return (a->d[0] | a->d[1] | a->d[2] | a->d[3]) == 0;
171202
}
172203

@@ -882,8 +913,16 @@ static void secp256k1_scalar_split_128(secp256k1_scalar *r1, secp256k1_scalar *r
882913
SECP256K1_INLINE static int secp256k1_scalar_eq(const secp256k1_scalar *a, const secp256k1_scalar *b) {
883914
SECP256K1_SCALAR_VERIFY(a);
884915
SECP256K1_SCALAR_VERIFY(b);
885-
916+
#if defined(__AVX__) && defined(__AVX2__)
917+
{
918+
__m256i vec_a = _mm256_loadu_si256((__m256i *)a->d);
919+
__m256i vec_b = _mm256_loadu_si256((__m256i *)b->d);
920+
__m256i vec_xor = _mm256_xor_si256(vec_a, vec_b);
921+
return _mm256_testz_si256(vec_xor, vec_xor);
922+
}
923+
#else
886924
return ((a->d[0] ^ b->d[0]) | (a->d[1] ^ b->d[1]) | (a->d[2] ^ b->d[2]) | (a->d[3] ^ b->d[3])) == 0;
925+
#endif
887926
}
888927

889928
SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r, const secp256k1_scalar *a, const secp256k1_scalar *b, unsigned int shift) {
@@ -899,6 +938,9 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
899938
shiftlimbs = shift >> 6;
900939
shiftlow = shift & 0x3F;
901940
shifthigh = 64 - shiftlow;
941+
942+
/* TODO: parallelize */
943+
902944
r->d[0] = shift < 512 ? (l[0 + shiftlimbs] >> shiftlow | (shift < 448 && shiftlow ? (l[1 + shiftlimbs] << shifthigh) : 0)) : 0;
903945
r->d[1] = shift < 448 ? (l[1 + shiftlimbs] >> shiftlow | (shift < 384 && shiftlow ? (l[2 + shiftlimbs] << shifthigh) : 0)) : 0;
904946
r->d[2] = shift < 384 ? (l[2 + shiftlimbs] >> shiftlow | (shift < 320 && shiftlow ? (l[3 + shiftlimbs] << shifthigh) : 0)) : 0;
@@ -909,37 +951,68 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
909951
}
910952

911953
static SECP256K1_INLINE void secp256k1_scalar_cmov(secp256k1_scalar *r, const secp256k1_scalar *a, int flag) {
954+
#if defined(__AVX__) && defined(__AVX2__)
955+
/* load here to mitigate load latency */
956+
__m256i vec_r = _mm256_loadu_si256((__m256i *)(r->d));
957+
__m256i vec_a = _mm256_loadu_si256((__m256i *)(a->d));
958+
#endif
959+
912960
uint64_t mask0, mask1;
913961
volatile int vflag = flag;
914962
SECP256K1_SCALAR_VERIFY(a);
915963
SECP256K1_CHECKMEM_CHECK_VERIFY(r->d, sizeof(r->d));
916964

917965
mask0 = vflag + ~((uint64_t)0);
918966
mask1 = ~mask0;
967+
968+
#if defined(__AVX__) && defined(__AVX2__)
969+
{
970+
__m256i vec_mask0 = _mm256_set1_epi64x(mask0);
971+
__m256i vec_mask1 = _mm256_set1_epi64x(mask1);
972+
vec_r = _mm256_and_si256(vec_r, vec_mask0);
973+
vec_a = _mm256_and_si256(vec_a, vec_mask1);
974+
_mm256_storeu_si256((__m256i *)(r->d), _mm256_or_si256(vec_r, vec_a));
975+
}
976+
#else
919977
r->d[0] = (r->d[0] & mask0) | (a->d[0] & mask1);
920978
r->d[1] = (r->d[1] & mask0) | (a->d[1] & mask1);
921979
r->d[2] = (r->d[2] & mask0) | (a->d[2] & mask1);
922980
r->d[3] = (r->d[3] & mask0) | (a->d[3] & mask1);
981+
#endif
923982

924983
SECP256K1_SCALAR_VERIFY(r);
925984
}
926985

927986
static void secp256k1_scalar_from_signed62(secp256k1_scalar *r, const secp256k1_modinv64_signed62 *a) {
928-
const uint64_t a0 = a->v[0], a1 = a->v[1], a2 = a->v[2], a3 = a->v[3], a4 = a->v[4];
929-
930987
/* The output from secp256k1_modinv64{_var} should be normalized to range [0,modulus), and
931988
* have limbs in [0,2^62). The modulus is < 2^256, so the top limb must be below 2^(256-62*4).
932989
*/
933-
VERIFY_CHECK(a0 >> 62 == 0);
934-
VERIFY_CHECK(a1 >> 62 == 0);
935-
VERIFY_CHECK(a2 >> 62 == 0);
936-
VERIFY_CHECK(a3 >> 62 == 0);
937-
VERIFY_CHECK(a4 >> 8 == 0);
990+
VERIFY_CHECK(a->v[0] >> 62 == 0);
991+
VERIFY_CHECK(a->v[1] >> 62 == 0);
992+
VERIFY_CHECK(a->v[2] >> 62 == 0);
993+
VERIFY_CHECK(a->v[3] >> 62 == 0);
994+
VERIFY_CHECK(a->v[4] >> 8 == 0);
995+
996+
#if defined(__AVX__) && defined(__AVX2__)
997+
{
998+
__m256i limbs_0123 = _mm256_loadu_si256((__m256i *)a->v);
999+
__m256i limbs_1234 = _mm256_loadu_si256((__m256i *)(a->v + 1));
1000+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 2, 4, 6); /* TODO: precompute */
1001+
const __m256i shift_rhs = _mm256_setr_epi64x(62, 60, 58, 56); /* TODO: precompute */
1002+
__m256i lhs = _mm256_srlv_epi64(limbs_0123, shift_lhs);
1003+
__m256i rhs = _mm256_sllv_epi64(limbs_1234, shift_rhs);
1004+
_mm256_storeu_si256((__m256i *)(r->d), _mm256_or_si256(lhs, rhs));
1005+
}
1006+
#else
1007+
{
1008+
const uint64_t a0 = a->v[0], a1 = a->v[1], a2 = a->v[2], a3 = a->v[3], a4 = a->v[4];
9381009

939-
r->d[0] = a0 | a1 << 62;
940-
r->d[1] = a1 >> 2 | a2 << 60;
941-
r->d[2] = a2 >> 4 | a3 << 58;
942-
r->d[3] = a3 >> 6 | a4 << 56;
1010+
r->d[0] = a0 | a1 << 62;
1011+
r->d[1] = a1 >> 2 | a2 << 60;
1012+
r->d[2] = a2 >> 4 | a3 << 58;
1013+
r->d[3] = a3 >> 6 | a4 << 56;
1014+
}
1015+
#endif
9431016

9441017
SECP256K1_SCALAR_VERIFY(r);
9451018
}
@@ -949,10 +1022,24 @@ static void secp256k1_scalar_to_signed62(secp256k1_modinv64_signed62 *r, const s
9491022
const uint64_t a0 = a->d[0], a1 = a->d[1], a2 = a->d[2], a3 = a->d[3];
9501023
SECP256K1_SCALAR_VERIFY(a);
9511024

1025+
#if defined(__AVX__) && defined(__AVX2__)
1026+
{
1027+
__m256i limbs_0012 = _mm256_setr_epi64x(a0, a0, a1, a2);
1028+
__m256i limbs_0123 = _mm256_setr_epi64x(a0, a1, a2, a3);
1029+
const __m256i shift_lhs = _mm256_setr_epi64x(0, 62, 60, 58); /*TODO: precompute */
1030+
const __m256i shift_rhs = _mm256_setr_epi64x(64, 2, 4, 6); /*TODO: precompute */
1031+
const __m256i mask62 = _mm256_set1_epi64x(M62); /*TODO: precompute */
1032+
__m256i lhs = _mm256_srlv_epi64(limbs_0012, shift_lhs);
1033+
__m256i rhs = _mm256_sllv_epi64(limbs_0123, shift_rhs);
1034+
__m256i out = _mm256_or_si256(lhs, rhs);
1035+
_mm256_storeu_si256((__m256i *)r->v, _mm256_and_si256(out, mask62));
1036+
}
1037+
#else
9521038
r->v[0] = a0 & M62;
9531039
r->v[1] = (a0 >> 62 | a1 << 2) & M62;
9541040
r->v[2] = (a1 >> 60 | a2 << 4) & M62;
9551041
r->v[3] = (a2 >> 58 | a3 << 6) & M62;
1042+
#endif
9561043
r->v[4] = a3 >> 56;
9571044
}
9581045

0 commit comments

Comments
 (0)