11use core:: arch:: x86_64:: * ;
2- use core:: { array, mem} ;
2+ use core:: array;
3+ use core:: simd:: { u8x16, u8x64} ;
34use subspace_core_primitives:: pot:: { PotCheckpoints , PotOutput } ;
45
56const NUM_ROUND_KEYS : usize = 11 ;
67
78/// Create PoT proof with checkpoints
89#[ target_feature( enable = "aes" ) ]
910#[ inline]
10- pub ( super ) unsafe fn create (
11+ pub ( super ) fn create (
1112 seed : & [ u8 ; 16 ] ,
1213 key : & [ u8 ; 16 ] ,
1314 checkpoint_iterations : u32 ,
1415) -> PotCheckpoints {
1516 let mut checkpoints = PotCheckpoints :: default ( ) ;
1617
17- unsafe {
18- let keys_reg = expand_key ( key) ;
19- let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
20- let mut seed_reg = _mm_loadu_si128 ( seed. as_ptr ( ) as * const __m128i ) ;
21- seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
22- for checkpoint in checkpoints. iter_mut ( ) {
23- for _ in 0 ..checkpoint_iterations {
24- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
25- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
26- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
27- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
28- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
29- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
30- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
31- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
32- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
33- seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
34- }
35-
36- let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
37- _mm_storeu_si128 ( checkpoint. as_mut_ptr ( ) as * mut __m128i , checkpoint_reg) ;
18+ let keys_reg = expand_key ( key) ;
19+ let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
20+ let mut seed_reg = __m128i:: from ( u8x16:: from_array ( * seed) ) ;
21+ seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
22+ for checkpoint in checkpoints. iter_mut ( ) {
23+ for _ in 0 ..checkpoint_iterations {
24+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
25+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
26+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
27+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
28+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
29+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
30+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
31+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
32+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
33+ seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
3834 }
35+
36+ let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
37+ * * checkpoint = u8x16:: from ( checkpoint_reg) . to_array ( ) ;
3938 }
4039
4140 checkpoints
@@ -44,45 +43,48 @@ pub(super) unsafe fn create(
4443/// Verification mimics `create` function, but also has decryption half for better performance
4544#[ target_feature( enable = "avx512f,vaes" ) ]
4645#[ inline]
47- pub ( super ) unsafe fn verify_sequential_avx512f (
46+ pub ( super ) fn verify_sequential_avx512f (
4847 seed : & [ u8 ; 16 ] ,
4948 key : & [ u8 ; 16 ] ,
5049 checkpoints : & PotCheckpoints ,
5150 checkpoint_iterations : u32 ,
5251) -> bool {
5352 let checkpoints = PotOutput :: repr_from_slice ( checkpoints. as_slice ( ) ) ;
5453
55- unsafe {
56- let keys_reg = expand_key ( key) ;
57- let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
58- let xor_key_512 = _mm512_broadcast_i32x4 ( xor_key) ;
54+ let keys = expand_key ( key) ;
55+ let xor_key = _mm_xor_si128 ( keys[ 10 ] , keys[ 0 ] ) ;
56+ let xor_key_512 = _mm512_broadcast_i32x4 ( xor_key) ;
5957
60- // Invert keys for decryption
61- let mut inv_keys = keys_reg;
62- for i in 1 ..10 {
63- inv_keys[ i] = _mm_aesimc_si128 ( keys_reg[ 10 - i] ) ;
64- }
58+ // Invert keys for decryption, the first and last element is not used below, hence they are
59+ // copied as is from encryption keys (otherwise the first and last element would need to be
60+ // swapped)
61+ let mut inv_keys = keys;
62+ for i in 1 ..10 {
63+ inv_keys[ i] = _mm_aesimc_si128 ( keys[ 10 - i] ) ;
64+ }
6565
66- let keys_512 = array:: from_fn :: < _ , NUM_ROUND_KEYS , _ > ( |i| _mm512_broadcast_i32x4 ( keys_reg [ i] ) ) ;
67- let inv_keys_512 =
68- array:: from_fn :: < _ , NUM_ROUND_KEYS , _ > ( |i| _mm512_broadcast_i32x4 ( inv_keys[ i] ) ) ;
66+ let keys_512 = array:: from_fn :: < _ , NUM_ROUND_KEYS , _ > ( |i| _mm512_broadcast_i32x4 ( keys [ i] ) ) ;
67+ let inv_keys_512 =
68+ array:: from_fn :: < _ , NUM_ROUND_KEYS , _ > ( |i| _mm512_broadcast_i32x4 ( inv_keys[ i] ) ) ;
6969
70- let mut input_0 = [ [ 0u8 ; 16 ] ; 4 ] ;
71- input_0[ 0 ] = * seed;
72- input_0[ 1 ..] . copy_from_slice ( & checkpoints[ ..3 ] ) ;
73- let mut input_0 = _mm512_loadu_si512 ( input_0. as_ptr ( ) as * const __m512i ) ;
74- let mut input_1 = _mm512_loadu_si512 ( checkpoints[ 3 ..7 ] . as_ptr ( ) as * const __m512i ) ;
70+ let mut input_0 = [ [ 0u8 ; 16 ] ; 4 ] ;
71+ input_0[ 0 ] = * seed;
72+ input_0[ 1 ..] . copy_from_slice ( & checkpoints[ ..3 ] ) ;
73+ let mut input_0 = __m512i :: from ( u8x64 :: from_slice ( input_0. as_flattened ( ) ) ) ;
74+ let mut input_1 = __m512i :: from ( u8x64 :: from_slice ( checkpoints[ 3 ..7 ] . as_flattened ( ) ) ) ;
7575
76- let mut output_0 = _mm512_loadu_si512 ( checkpoints[ 0 ..4 ] . as_ptr ( ) as * const __m512i ) ;
77- let mut output_1 = _mm512_loadu_si512 ( checkpoints[ 4 ..8 ] . as_ptr ( ) as * const __m512i ) ;
76+ let mut output_0 = __m512i :: from ( u8x64 :: from_slice ( checkpoints[ 0 ..4 ] . as_flattened ( ) ) ) ;
77+ let mut output_1 = __m512i :: from ( u8x64 :: from_slice ( checkpoints[ 4 ..8 ] . as_flattened ( ) ) ) ;
7878
79- input_0 = _mm512_xor_si512 ( input_0, keys_512[ 0 ] ) ;
80- input_1 = _mm512_xor_si512 ( input_1, keys_512[ 0 ] ) ;
79+ input_0 = _mm512_xor_si512 ( input_0, keys_512[ 0 ] ) ;
80+ input_1 = _mm512_xor_si512 ( input_1, keys_512[ 0 ] ) ;
8181
82- output_0 = _mm512_xor_si512 ( output_0, keys_512[ 10 ] ) ;
83- output_1 = _mm512_xor_si512 ( output_1, keys_512[ 10 ] ) ;
82+ output_0 = _mm512_xor_si512 ( output_0, keys_512[ 10 ] ) ;
83+ output_1 = _mm512_xor_si512 ( output_1, keys_512[ 10 ] ) ;
8484
85- for _ in 0 ..checkpoint_iterations / 2 {
85+ for _ in 0 ..checkpoint_iterations / 2 {
86+ // TODO: Shouldn't be unsafe: https://github.com/rust-lang/rust/issues/141718
87+ unsafe {
8688 for i in 1 ..10 {
8789 input_0 = _mm512_aesenc_epi128 ( input_0, keys_512[ i] ) ;
8890 input_1 = _mm512_aesenc_epi128 ( input_1, keys_512[ i] ) ;
@@ -97,75 +99,66 @@ pub(super) unsafe fn verify_sequential_avx512f(
9799 output_0 = _mm512_aesdeclast_epi128 ( output_0, xor_key_512) ;
98100 output_1 = _mm512_aesdeclast_epi128 ( output_1, xor_key_512) ;
99101 }
102+ }
100103
101- // Code below is a more efficient version of this:
102- // input_0 = _mm512_xor_si512(input_0, keys_512[0]);
103- // input_1 = _mm512_xor_si512(input_1, keys_512[0]);
104- // output_0 = _mm512_xor_si512(output_0, keys_512[10]);
105- // output_1 = _mm512_xor_si512(output_1, keys_512[10]);
106- //
107- // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
108- // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
104+ // Code below is a more efficient version of this:
105+ // input_0 = _mm512_xor_si512(input_0, keys_512[0]);
106+ // input_1 = _mm512_xor_si512(input_1, keys_512[0]);
107+ // output_0 = _mm512_xor_si512(output_0, keys_512[10]);
108+ // output_1 = _mm512_xor_si512(output_1, keys_512[10]);
109+ //
110+ // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
111+ // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
109112
110- let diff_0 = _mm512_xor_si512 ( input_0, output_0) ;
111- let diff_1 = _mm512_xor_si512 ( input_1, output_1) ;
113+ let diff_0 = _mm512_xor_si512 ( input_0, output_0) ;
114+ let diff_1 = _mm512_xor_si512 ( input_1, output_1) ;
112115
113- let mask0 = _mm512_cmpeq_epu64_mask ( diff_0, xor_key_512) ;
114- let mask1 = _mm512_cmpeq_epu64_mask ( diff_1, xor_key_512) ;
116+ let mask0 = _mm512_cmpeq_epu64_mask ( diff_0, xor_key_512) ;
117+ let mask1 = _mm512_cmpeq_epu64_mask ( diff_1, xor_key_512) ;
115118
116- // All inputs match outputs
117- ( mask0 & mask1) == u8:: MAX
118- }
119+ // All inputs match outputs
120+ ( mask0 & mask1) == u8:: MAX
119121}
120122
121- // Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom
122- // Pavlov:
123- // https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2 /aes/src/ni/aes128 .rs
123+ // Below code copied with minor changes from the following place under MIT/Apache-2.0 license by
124+ // Artyom Pavlov:
125+ // https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f /aes/src/ni/expand .rs
124126
125- /// AES-128 round keys
126- type RoundKeys = [ __m128i ; NUM_ROUND_KEYS ] ;
127-
128- macro_rules! expand_round {
129- ( $keys: expr, $pos: expr, $round: expr) => {
130- let mut t1 = $keys[ $pos - 1 ] ;
127+ #[ target_feature( enable = "aes" ) ]
128+ fn expand_key ( key : & [ u8 ; 16 ] ) -> [ __m128i ; NUM_ROUND_KEYS ] {
129+ #[ target_feature( enable = "aes" ) ]
130+ fn expand_round < const RK : i32 > ( keys : & mut [ __m128i ; NUM_ROUND_KEYS ] , pos : usize ) {
131+ let mut t1 = keys[ pos - 1 ] ;
131132 let mut t2;
132133 let mut t3;
133134
134- t2 = _mm_aeskeygenassist_si128( t1, $round ) ;
135- t2 = _mm_shuffle_epi32( t2, 0xff ) ;
136- t3 = _mm_slli_si128( t1, 0x4 ) ;
135+ t2 = _mm_aeskeygenassist_si128 :: < RK > ( t1) ;
136+ t2 = _mm_shuffle_epi32 :: < 0xff > ( t2) ;
137+ t3 = _mm_slli_si128 :: < 0x4 > ( t1) ;
137138 t1 = _mm_xor_si128 ( t1, t3) ;
138- t3 = _mm_slli_si128( t3, 0x4 ) ;
139+ t3 = _mm_slli_si128 :: < 0x4 > ( t3) ;
139140 t1 = _mm_xor_si128 ( t1, t3) ;
140- t3 = _mm_slli_si128( t3, 0x4 ) ;
141+ t3 = _mm_slli_si128 :: < 0x4 > ( t3) ;
141142 t1 = _mm_xor_si128 ( t1, t3) ;
142143 t1 = _mm_xor_si128 ( t1, t2) ;
143144
144- $keys[ $pos] = t1;
145- } ;
146- }
145+ keys[ pos] = t1;
146+ }
147147
148- #[ target_feature( enable = "aes" ) ]
149- #[ inline]
150- unsafe fn expand_key ( key : & [ u8 ; 16 ] ) -> RoundKeys {
151- // SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized
152- // with all zeroes.
153- let mut keys: RoundKeys = unsafe { mem:: zeroed ( ) } ;
154-
155- // SAFETY: No alignment requirement in `_mm_loadu_si128`
156- let k = unsafe { _mm_loadu_si128 ( key. as_ptr ( ) as * const __m128i ) } ;
157- keys[ 0 ] = k;
158-
159- expand_round ! ( keys, 1 , 0x01 ) ;
160- expand_round ! ( keys, 2 , 0x02 ) ;
161- expand_round ! ( keys, 3 , 0x04 ) ;
162- expand_round ! ( keys, 4 , 0x08 ) ;
163- expand_round ! ( keys, 5 , 0x10 ) ;
164- expand_round ! ( keys, 6 , 0x20 ) ;
165- expand_round ! ( keys, 7 , 0x40 ) ;
166- expand_round ! ( keys, 8 , 0x80 ) ;
167- expand_round ! ( keys, 9 , 0x1B ) ;
168- expand_round ! ( keys, 10 , 0x36 ) ;
148+ let mut keys = [ _mm_setzero_si128 ( ) ; NUM_ROUND_KEYS ] ;
149+ keys[ 0 ] = __m128i:: from ( u8x16:: from ( * key) ) ;
150+
151+ let kr = & mut keys;
152+ expand_round :: < 0x01 > ( kr, 1 ) ;
153+ expand_round :: < 0x02 > ( kr, 2 ) ;
154+ expand_round :: < 0x04 > ( kr, 3 ) ;
155+ expand_round :: < 0x08 > ( kr, 4 ) ;
156+ expand_round :: < 0x10 > ( kr, 5 ) ;
157+ expand_round :: < 0x20 > ( kr, 6 ) ;
158+ expand_round :: < 0x40 > ( kr, 7 ) ;
159+ expand_round :: < 0x80 > ( kr, 8 ) ;
160+ expand_round :: < 0x1B > ( kr, 9 ) ;
161+ expand_round :: < 0x36 > ( kr, 10 ) ;
169162
170163 keys
171164}
0 commit comments