1212#include "modinv64_impl.h"
1313#include "util.h"
1414
15+ #ifdef X86
16+ # include <immintrin.h>
17+ #endif
18+
1519/* Limbs of the secp256k1 order. */
1620#define SECP256K1_N_0 ((uint64_t)0xBFD25E8CD0364141ULL)
1721#define SECP256K1_N_1 ((uint64_t)0xBAAEDCE6AF48A03BULL)
@@ -143,6 +147,7 @@ static void secp256k1_scalar_cadd_bit(secp256k1_scalar *r, unsigned int bit, int
143147
144148static void secp256k1_scalar_set_b32 (secp256k1_scalar * r , const unsigned char * b32 , int * overflow ) {
145149 int over ;
150+
146151 r -> d [0 ] = secp256k1_read_be64 (& b32 [24 ]);
147152 r -> d [1 ] = secp256k1_read_be64 (& b32 [16 ]);
148153 r -> d [2 ] = secp256k1_read_be64 (& b32 [8 ]);
@@ -866,14 +871,27 @@ static void secp256k1_scalar_mul(secp256k1_scalar *r, const secp256k1_scalar *a,
866871static void secp256k1_scalar_split_128 (secp256k1_scalar * r1 , secp256k1_scalar * r2 , const secp256k1_scalar * k ) {
867872 SECP256K1_SCALAR_VERIFY (k );
868873
874+ #ifdef __AVX2__
875+ {
876+ __m128i k_01 = _mm_loadu_si128 ((__m128i * )k -> d );
877+ __m128i k_23 = _mm_loadu_si128 ((__m128i * )(k -> d + 2 ));
878+ const __m128i zeros = _mm_setzero_si128 (); /* TODO: precompute */
879+ _mm_storeu_si128 ((__m128i * )(r1 -> d + 2 ), zeros );
880+ _mm_storeu_si128 ((__m128i * )(r2 -> d + 2 ), zeros );
881+ _mm_storeu_si128 ((__m128i * )r1 -> d , k_01 );
882+ _mm_storeu_si128 ((__m128i * )r2 -> d , k_23 );
883+ }
884+ #else
869885 r1 -> d [0 ] = k -> d [0 ];
870886 r1 -> d [1 ] = k -> d [1 ];
871887 r1 -> d [2 ] = 0 ;
872888 r1 -> d [3 ] = 0 ;
889+
873890 r2 -> d [0 ] = k -> d [2 ];
874891 r2 -> d [1 ] = k -> d [3 ];
875892 r2 -> d [2 ] = 0 ;
876893 r2 -> d [3 ] = 0 ;
894+ #endif
877895
878896 SECP256K1_SCALAR_VERIFY (r1 );
879897 SECP256K1_SCALAR_VERIFY (r2 );
@@ -883,7 +901,19 @@ SECP256K1_INLINE static int secp256k1_scalar_eq(const secp256k1_scalar *a, const
883901 SECP256K1_SCALAR_VERIFY (a );
884902 SECP256K1_SCALAR_VERIFY (b );
885903
886- return ((a -> d [0 ] ^ b -> d [0 ]) | (a -> d [1 ] ^ b -> d [1 ]) | (a -> d [2 ] ^ b -> d [2 ]) | (a -> d [3 ] ^ b -> d [3 ])) == 0 ;
904+ #ifdef __AVX2__
905+ {
906+ __m256i vec_a = _mm256_loadu_si256 ((__m256i * )a -> d );
907+ __m256i vec_b = _mm256_loadu_si256 ((__m256i * )b -> d );
908+ __m256i vec_xor = _mm256_xor_si256 (vec_a , vec_b );
909+ return _mm256_testz_si256 (vec_xor , vec_xor );
910+ }
911+ #else
912+ return ( (a -> d [0 ] ^ b -> d [0 ]) |
913+ (a -> d [1 ] ^ b -> d [1 ]) |
914+ (a -> d [2 ] ^ b -> d [2 ]) |
915+ (a -> d [3 ] ^ b -> d [3 ]) ) == 0 ;
916+ #endif
887917}
888918
889919SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var (secp256k1_scalar * r , const secp256k1_scalar * a , const secp256k1_scalar * b , unsigned int shift ) {
@@ -899,6 +929,9 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
899929 shiftlimbs = shift >> 6 ;
900930 shiftlow = shift & 0x3F ;
901931 shifthigh = 64 - shiftlow ;
932+
933+ /* TODO: parallel? */
934+
902935 r -> d [0 ] = shift < 512 ? (l [0 + shiftlimbs ] >> shiftlow | (shift < 448 && shiftlow ? (l [1 + shiftlimbs ] << shifthigh ) : 0 )) : 0 ;
903936 r -> d [1 ] = shift < 448 ? (l [1 + shiftlimbs ] >> shiftlow | (shift < 384 && shiftlow ? (l [2 + shiftlimbs ] << shifthigh ) : 0 )) : 0 ;
904937 r -> d [2 ] = shift < 384 ? (l [2 + shiftlimbs ] >> shiftlow | (shift < 320 && shiftlow ? (l [3 + shiftlimbs ] << shifthigh ) : 0 )) : 0 ;
@@ -909,17 +942,34 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
909942}
910943
911944static SECP256K1_INLINE void secp256k1_scalar_cmov (secp256k1_scalar * r , const secp256k1_scalar * a , int flag ) {
945+ #ifdef __AVX2__
946+ /* load here to mitigate load latency */
947+ __m256i vec_r = _mm256_loadu_si256 ((__m256i * )(r -> d ));
948+ __m256i vec_a = _mm256_loadu_si256 ((__m256i * )(a -> d ));
949+ #endif
912950 uint64_t mask0 , mask1 ;
913951 volatile int vflag = flag ;
914952 SECP256K1_SCALAR_VERIFY (a );
915953 SECP256K1_CHECKMEM_CHECK_VERIFY (r -> d , sizeof (r -> d ));
916954
917955 mask0 = vflag + ~((uint64_t )0 );
918956 mask1 = ~mask0 ;
957+
958+ #ifdef __AVX2__
959+ {
960+ const __m256i vec_mask0 = _mm256_set1_epi64x (mask0 ); /* TODO: precompute*/
961+ const __m256i vec_mask1 = _mm256_set1_epi64x (mask1 ); /* TODO: precompute*/
962+ vec_r = _mm256_and_si256 (vec_r , vec_mask0 );
963+ vec_a = _mm256_and_si256 (vec_a , vec_mask1 );
964+ vec_r = _mm256_or_si256 (vec_r , vec_a );
965+ _mm256_storeu_si256 ((__m256i * )(r -> d ), vec_r );
966+ }
967+ #else
919968 r -> d [0 ] = (r -> d [0 ] & mask0 ) | (a -> d [0 ] & mask1 );
920969 r -> d [1 ] = (r -> d [1 ] & mask0 ) | (a -> d [1 ] & mask1 );
921970 r -> d [2 ] = (r -> d [2 ] & mask0 ) | (a -> d [2 ] & mask1 );
922971 r -> d [3 ] = (r -> d [3 ] & mask0 ) | (a -> d [3 ] & mask1 );
972+ #endif
923973
924974 SECP256K1_SCALAR_VERIFY (r );
925975}
@@ -936,10 +986,23 @@ static void secp256k1_scalar_from_signed62(secp256k1_scalar *r, const secp256k1_
936986 VERIFY_CHECK (a3 >> 62 == 0 );
937987 VERIFY_CHECK (a4 >> 8 == 0 );
938988
989+ #ifdef __AVX2__
990+ {
991+ __m256i limbs_0123 = _mm256_setr_epi64x (a0 , a1 , a2 , a3 );
992+ __m256i limbs_1234 = _mm256_setr_epi64x (a1 , a2 , a3 , a4 );
993+ const __m256i shift_lhs = _mm256_setr_epi64x (0 , 2 , 4 , 6 );
994+ const __m256i shift_rhs = _mm256_setr_epi64x (62 , 60 , 58 , 56 );
995+ __m256i lhs = _mm256_srlv_epi64 (limbs_0123 , shift_lhs );
996+ __m256i rhs = _mm256_sllv_epi64 (limbs_1234 , shift_rhs );
997+ __m256i out = _mm256_or_si256 (lhs , rhs );
998+ _mm256_storeu_si256 ((__m256i * )(r -> d ), out );
999+ }
1000+ #else
9391001 r -> d [0 ] = a0 | a1 << 62 ;
9401002 r -> d [1 ] = a1 >> 2 | a2 << 60 ;
9411003 r -> d [2 ] = a2 >> 4 | a3 << 58 ;
9421004 r -> d [3 ] = a3 >> 6 | a4 << 56 ;
1005+ #endif
9431006
9441007 SECP256K1_SCALAR_VERIFY (r );
9451008}
@@ -949,10 +1012,25 @@ static void secp256k1_scalar_to_signed62(secp256k1_modinv64_signed62 *r, const s
9491012 const uint64_t a0 = a -> d [0 ], a1 = a -> d [1 ], a2 = a -> d [2 ], a3 = a -> d [3 ];
9501013 SECP256K1_SCALAR_VERIFY (a );
9511014
1015+ #ifdef __AVX2__
1016+ {
1017+ __m256i limbs_0012 = _mm256_setr_epi64x (a0 , a0 , a1 , a2 );
1018+ __m256i limbs_0123 = _mm256_setr_epi64x (a0 , a1 , a2 , a3 );
1019+ const __m256i shift_lhs = _mm256_setr_epi64x (0 , 62 , 60 , 58 ); /*TODO: precompute */
1020+ const __m256i shift_rhs = _mm256_setr_epi64x (64 , 2 , 4 , 6 ); /*TODO: precompute */
1021+ const __m256i mask62 = _mm256_set1_epi64x (M62 ); /*TODO: precompute */
1022+ __m256i lhs = _mm256_srlv_epi64 (limbs_0012 , shift_lhs );
1023+ __m256i rhs = _mm256_sllv_epi64 (limbs_0123 , shift_rhs );
1024+ __m256i out = _mm256_or_si256 (lhs , rhs );
1025+ out = _mm256_and_si256 (out , mask62 );
1026+ _mm256_storeu_si256 ((__m256i * )r -> v , out );
1027+ }
1028+ #else
9521029 r -> v [0 ] = a0 & M62 ;
9531030 r -> v [1 ] = (a0 >> 62 | a1 << 2 ) & M62 ;
9541031 r -> v [2 ] = (a1 >> 60 | a2 << 4 ) & M62 ;
9551032 r -> v [3 ] = (a2 >> 58 | a3 << 6 ) & M62 ;
1033+ #endif
9561034 r -> v [4 ] = a3 >> 56 ;
9571035}
9581036
0 commit comments