1212#include "modinv64_impl.h"
1313#include "util.h"
1414
15+ #ifdef X86
16+ # include <immintrin.h>
17+ #endif
18+
1519/* Limbs of the secp256k1 order. */
1620#define SECP256K1_N_0 ((uint64_t)0xBFD25E8CD0364141ULL)
1721#define SECP256K1_N_1 ((uint64_t)0xBAAEDCE6AF48A03BULL)
@@ -143,10 +147,25 @@ static void secp256k1_scalar_cadd_bit(secp256k1_scalar *r, unsigned int bit, int
143147
144148static void secp256k1_scalar_set_b32 (secp256k1_scalar * r , const unsigned char * b32 , int * overflow ) {
145149 int over ;
150+
151+ #if defined(__AVX__ ) && defined(__AVX2__ )
152+ {
153+ __m256i vec_b32 = _mm256_loadu_si256 ((__m256i * )b32 );
154+ vec_b32 = _mm256_permute4x64_epi64 (vec_b32 , _MM_SHUFFLE (0 ,1 ,2 ,3 ));
155+ const __m256i bswap_mask = _mm256_setr_epi8 ( /* TODO: precompute */
156+ 7 ,6 ,5 ,4 ,3 ,2 ,1 ,0 ,
157+ 15 ,14 ,13 ,12 ,11 ,10 ,9 ,8 ,
158+ 23 ,22 ,21 ,20 ,19 ,18 ,17 ,16 ,
159+ 31 ,30 ,29 ,28 ,27 ,26 ,25 ,24 );
160+ __m256i output = _mm256_shuffle_epi8 (vec_b32 , bswap_mask );
161+ _mm256_storeu_si256 ((__m256i * )r -> d , output );
162+ }
163+ #else
146164 r -> d [0 ] = secp256k1_read_be64 (& b32 [24 ]);
147165 r -> d [1 ] = secp256k1_read_be64 (& b32 [16 ]);
148166 r -> d [2 ] = secp256k1_read_be64 (& b32 [8 ]);
149167 r -> d [3 ] = secp256k1_read_be64 (& b32 [0 ]);
168+ #endif
150169 over = secp256k1_scalar_reduce (r , secp256k1_scalar_check_overflow (r ));
151170 if (overflow ) {
152171 * overflow = over ;
@@ -157,16 +176,28 @@ static void secp256k1_scalar_set_b32(secp256k1_scalar *r, const unsigned char *b
157176
158177static void secp256k1_scalar_get_b32 (unsigned char * bin , const secp256k1_scalar * a ) {
159178 SECP256K1_SCALAR_VERIFY (a );
160-
179+ #if defined(__AVX__ ) && defined(__AVX2__ )
180+ {
181+ __m256i vec_a = _mm256_loadu_si256 ((__m256i * )a -> d );
182+ vec_a = _mm256_permute4x64_epi64 (vec_a , _MM_SHUFFLE (0 ,1 ,2 ,3 ));
183+ const __m256i bswap_mask = _mm256_setr_epi8 ( /* TODO: precompute */
184+ 7 ,6 ,5 ,4 ,3 ,2 ,1 ,0 ,
185+ 15 ,14 ,13 ,12 ,11 ,10 ,9 ,8 ,
186+ 23 ,22 ,21 ,20 ,19 ,18 ,17 ,16 ,
187+ 31 ,30 ,29 ,28 ,27 ,26 ,25 ,24 );
188+ __m256i output = _mm256_shuffle_epi8 (vec_a , bswap_mask );
189+ _mm256_storeu_si256 ((__m256i * )bin , output );
190+ }
191+ #else
161192 secp256k1_write_be64 (& bin [0 ], a -> d [3 ]);
162193 secp256k1_write_be64 (& bin [8 ], a -> d [2 ]);
163194 secp256k1_write_be64 (& bin [16 ], a -> d [1 ]);
164195 secp256k1_write_be64 (& bin [24 ], a -> d [0 ]);
196+ #endif
165197}
166198
167199SECP256K1_INLINE static int secp256k1_scalar_is_zero (const secp256k1_scalar * a ) {
168200 SECP256K1_SCALAR_VERIFY (a );
169-
170201 return (a -> d [0 ] | a -> d [1 ] | a -> d [2 ] | a -> d [3 ]) == 0 ;
171202}
172203
@@ -882,8 +913,16 @@ static void secp256k1_scalar_split_128(secp256k1_scalar *r1, secp256k1_scalar *r
882913SECP256K1_INLINE static int secp256k1_scalar_eq (const secp256k1_scalar * a , const secp256k1_scalar * b ) {
883914 SECP256K1_SCALAR_VERIFY (a );
884915 SECP256K1_SCALAR_VERIFY (b );
885-
916+ #if defined(__AVX__ ) && defined(__AVX2__ )
917+ {
918+ __m256i vec_a = _mm256_loadu_si256 ((__m256i * )a -> d );
919+ __m256i vec_b = _mm256_loadu_si256 ((__m256i * )b -> d );
920+ __m256i vec_xor = _mm256_xor_si256 (vec_a , vec_b );
921+ return _mm256_testz_si256 (vec_xor , vec_xor );
922+ }
923+ #else
886924 return ((a -> d [0 ] ^ b -> d [0 ]) | (a -> d [1 ] ^ b -> d [1 ]) | (a -> d [2 ] ^ b -> d [2 ]) | (a -> d [3 ] ^ b -> d [3 ])) == 0 ;
925+ #endif
887926}
888927
889928SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var (secp256k1_scalar * r , const secp256k1_scalar * a , const secp256k1_scalar * b , unsigned int shift ) {
@@ -899,6 +938,9 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
899938 shiftlimbs = shift >> 6 ;
900939 shiftlow = shift & 0x3F ;
901940 shifthigh = 64 - shiftlow ;
941+
942+ /* TODO: parallelize */
943+
902944 r -> d [0 ] = shift < 512 ? (l [0 + shiftlimbs ] >> shiftlow | (shift < 448 && shiftlow ? (l [1 + shiftlimbs ] << shifthigh ) : 0 )) : 0 ;
903945 r -> d [1 ] = shift < 448 ? (l [1 + shiftlimbs ] >> shiftlow | (shift < 384 && shiftlow ? (l [2 + shiftlimbs ] << shifthigh ) : 0 )) : 0 ;
904946 r -> d [2 ] = shift < 384 ? (l [2 + shiftlimbs ] >> shiftlow | (shift < 320 && shiftlow ? (l [3 + shiftlimbs ] << shifthigh ) : 0 )) : 0 ;
@@ -909,37 +951,68 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r,
909951}
910952
911953static SECP256K1_INLINE void secp256k1_scalar_cmov (secp256k1_scalar * r , const secp256k1_scalar * a , int flag ) {
954+ #if defined(__AVX__ ) && defined(__AVX2__ )
955+ /* load here to mitigate load latency */
956+ __m256i vec_r = _mm256_loadu_si256 ((__m256i * )(r -> d ));
957+ __m256i vec_a = _mm256_loadu_si256 ((__m256i * )(a -> d ));
958+ #endif
959+
912960 uint64_t mask0 , mask1 ;
913961 volatile int vflag = flag ;
914962 SECP256K1_SCALAR_VERIFY (a );
915963 SECP256K1_CHECKMEM_CHECK_VERIFY (r -> d , sizeof (r -> d ));
916964
917965 mask0 = vflag + ~((uint64_t )0 );
918966 mask1 = ~mask0 ;
967+
968+ #if defined(__AVX__ ) && defined(__AVX2__ )
969+ {
970+ __m256i vec_mask0 = _mm256_set1_epi64x (mask0 );
971+ __m256i vec_mask1 = _mm256_set1_epi64x (mask1 );
972+ vec_r = _mm256_and_si256 (vec_r , vec_mask0 );
973+ vec_a = _mm256_and_si256 (vec_a , vec_mask1 );
974+ _mm256_storeu_si256 ((__m256i * )(r -> d ), _mm256_or_si256 (vec_r , vec_a ));
975+ }
976+ #else
919977 r -> d [0 ] = (r -> d [0 ] & mask0 ) | (a -> d [0 ] & mask1 );
920978 r -> d [1 ] = (r -> d [1 ] & mask0 ) | (a -> d [1 ] & mask1 );
921979 r -> d [2 ] = (r -> d [2 ] & mask0 ) | (a -> d [2 ] & mask1 );
922980 r -> d [3 ] = (r -> d [3 ] & mask0 ) | (a -> d [3 ] & mask1 );
981+ #endif
923982
924983 SECP256K1_SCALAR_VERIFY (r );
925984}
926985
927986static void secp256k1_scalar_from_signed62 (secp256k1_scalar * r , const secp256k1_modinv64_signed62 * a ) {
928- const uint64_t a0 = a -> v [0 ], a1 = a -> v [1 ], a2 = a -> v [2 ], a3 = a -> v [3 ], a4 = a -> v [4 ];
929-
930987 /* The output from secp256k1_modinv64{_var} should be normalized to range [0,modulus), and
931988 * have limbs in [0,2^62). The modulus is < 2^256, so the top limb must be below 2^(256-62*4).
932989 */
933- VERIFY_CHECK (a0 >> 62 == 0 );
934- VERIFY_CHECK (a1 >> 62 == 0 );
935- VERIFY_CHECK (a2 >> 62 == 0 );
936- VERIFY_CHECK (a3 >> 62 == 0 );
937- VERIFY_CHECK (a4 >> 8 == 0 );
990+ VERIFY_CHECK (a -> v [0 ] >> 62 == 0 );
991+ VERIFY_CHECK (a -> v [1 ] >> 62 == 0 );
992+ VERIFY_CHECK (a -> v [2 ] >> 62 == 0 );
993+ VERIFY_CHECK (a -> v [3 ] >> 62 == 0 );
994+ VERIFY_CHECK (a -> v [4 ] >> 8 == 0 );
995+
996+ #if defined(__AVX__ ) && defined(__AVX2__ )
997+ {
998+ __m256i limbs_0123 = _mm256_loadu_si256 ((__m256i * )a -> v );
999+ __m256i limbs_1234 = _mm256_loadu_si256 ((__m256i * )(a -> v + 1 ));
1000+ const __m256i shift_lhs = _mm256_setr_epi64x (0 , 2 , 4 , 6 ); /* TODO: precompute */
1001+ const __m256i shift_rhs = _mm256_setr_epi64x (62 , 60 , 58 , 56 ); /* TODO: precompute */
1002+ __m256i lhs = _mm256_srlv_epi64 (limbs_0123 , shift_lhs );
1003+ __m256i rhs = _mm256_sllv_epi64 (limbs_1234 , shift_rhs );
1004+ _mm256_storeu_si256 ((__m256i * )(r -> d ), _mm256_or_si256 (lhs , rhs ));
1005+ }
1006+ #else
1007+ {
1008+ const uint64_t a0 = a -> v [0 ], a1 = a -> v [1 ], a2 = a -> v [2 ], a3 = a -> v [3 ], a4 = a -> v [4 ];
9381009
939- r -> d [0 ] = a0 | a1 << 62 ;
940- r -> d [1 ] = a1 >> 2 | a2 << 60 ;
941- r -> d [2 ] = a2 >> 4 | a3 << 58 ;
942- r -> d [3 ] = a3 >> 6 | a4 << 56 ;
1010+ r -> d [0 ] = a0 | a1 << 62 ;
1011+ r -> d [1 ] = a1 >> 2 | a2 << 60 ;
1012+ r -> d [2 ] = a2 >> 4 | a3 << 58 ;
1013+ r -> d [3 ] = a3 >> 6 | a4 << 56 ;
1014+ }
1015+ #endif
9431016
9441017 SECP256K1_SCALAR_VERIFY (r );
9451018}
@@ -949,10 +1022,24 @@ static void secp256k1_scalar_to_signed62(secp256k1_modinv64_signed62 *r, const s
9491022 const uint64_t a0 = a -> d [0 ], a1 = a -> d [1 ], a2 = a -> d [2 ], a3 = a -> d [3 ];
9501023 SECP256K1_SCALAR_VERIFY (a );
9511024
1025+ #if defined(__AVX__ ) && defined(__AVX2__ )
1026+ {
1027+ __m256i limbs_0012 = _mm256_setr_epi64x (a0 , a0 , a1 , a2 );
1028+ __m256i limbs_0123 = _mm256_setr_epi64x (a0 , a1 , a2 , a3 );
1029+ const __m256i shift_lhs = _mm256_setr_epi64x (0 , 62 , 60 , 58 ); /*TODO: precompute */
1030+ const __m256i shift_rhs = _mm256_setr_epi64x (64 , 2 , 4 , 6 ); /*TODO: precompute */
1031+ const __m256i mask62 = _mm256_set1_epi64x (M62 ); /*TODO: precompute */
1032+ __m256i lhs = _mm256_srlv_epi64 (limbs_0012 , shift_lhs );
1033+ __m256i rhs = _mm256_sllv_epi64 (limbs_0123 , shift_rhs );
1034+ __m256i out = _mm256_or_si256 (lhs , rhs );
1035+ _mm256_storeu_si256 ((__m256i * )r -> v , _mm256_and_si256 (out , mask62 ));
1036+ }
1037+ #else
9521038 r -> v [0 ] = a0 & M62 ;
9531039 r -> v [1 ] = (a0 >> 62 | a1 << 2 ) & M62 ;
9541040 r -> v [2 ] = (a1 >> 60 | a2 << 4 ) & M62 ;
9551041 r -> v [3 ] = (a2 >> 58 | a3 << 6 ) & M62 ;
1042+ #endif
9561043 r -> v [4 ] = a3 >> 56 ;
9571044}
9581045
0 commit comments