11use core:: arch:: x86_64:: * ;
2- use core:: mem;
3- use subspace_core_primitives:: pot:: PotCheckpoints ;
2+ use core:: { array, mem} ;
3+ use subspace_core_primitives:: pot:: { PotCheckpoints , PotOutput } ;
4+
5+ const NUM_ROUND_KEYS : usize = 11 ;
46
57/// Create PoT proof with checkpoints
68#[ target_feature( enable = "aes" ) ]
@@ -12,40 +14,116 @@ pub(super) unsafe fn create(
1214) -> PotCheckpoints {
1315 let mut checkpoints = PotCheckpoints :: default ( ) ;
1416
15- let keys_reg = expand_key ( key) ;
16- let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
17- let mut seed_reg = _mm_loadu_si128 ( seed. as_ptr ( ) as * const __m128i ) ;
18- seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
19- for checkpoint in checkpoints. iter_mut ( ) {
20- for _ in 0 ..checkpoint_iterations {
21- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
22- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
23- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
24- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
25- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
26- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
27- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
28- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
29- seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
30- seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
31- }
17+ unsafe {
18+ let keys_reg = expand_key ( key) ;
19+ let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
20+ let mut seed_reg = _mm_loadu_si128 ( seed. as_ptr ( ) as * const __m128i ) ;
21+ seed_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
22+ for checkpoint in checkpoints. iter_mut ( ) {
23+ for _ in 0 ..checkpoint_iterations {
24+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 1 ] ) ;
25+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 2 ] ) ;
26+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 3 ] ) ;
27+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 4 ] ) ;
28+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 5 ] ) ;
29+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 6 ] ) ;
30+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 7 ] ) ;
31+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 8 ] ) ;
32+ seed_reg = _mm_aesenc_si128 ( seed_reg, keys_reg[ 9 ] ) ;
33+ seed_reg = _mm_aesenclast_si128 ( seed_reg, xor_key) ;
34+ }
3235
33- let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
34- _mm_storeu_si128 (
35- checkpoint. as_mut ( ) . as_mut_ptr ( ) as * mut __m128i ,
36- checkpoint_reg,
37- ) ;
36+ let checkpoint_reg = _mm_xor_si128 ( seed_reg, keys_reg[ 0 ] ) ;
37+ _mm_storeu_si128 ( checkpoint. as_mut_ptr ( ) as * mut __m128i , checkpoint_reg) ;
38+ }
3839 }
3940
4041 checkpoints
4142}
4243
44+ /// Verification mimics `create` function, but also has decryption half for better performance
45+ #[ target_feature( enable = "avx512f,vaes" ) ]
46+ #[ inline]
47+ pub ( super ) unsafe fn verify_sequential_avx512f (
48+ seed : & [ u8 ; 16 ] ,
49+ key : & [ u8 ; 16 ] ,
50+ checkpoints : & PotCheckpoints ,
51+ checkpoint_iterations : u32 ,
52+ ) -> bool {
53+ let checkpoints = PotOutput :: repr_from_slice ( checkpoints. as_slice ( ) ) ;
54+
55+ unsafe {
56+ let keys_reg = expand_key ( key) ;
57+ let xor_key = _mm_xor_si128 ( keys_reg[ 10 ] , keys_reg[ 0 ] ) ;
58+ let xor_key_512 = _mm512_broadcast_i32x4 ( xor_key) ;
59+
60+ // Invert keys for decryption
61+ let mut inv_keys = keys_reg;
62+ for i in 1 ..10 {
63+ inv_keys[ i] = _mm_aesimc_si128 ( keys_reg[ 10 - i] ) ;
64+ }
65+
66+ let keys_512 = array:: from_fn :: < _ , NUM_ROUND_KEYS , _ > ( |i| _mm512_broadcast_i32x4 ( keys_reg[ i] ) ) ;
67+ let inv_keys_512 =
68+ array:: from_fn :: < _ , NUM_ROUND_KEYS , _ > ( |i| _mm512_broadcast_i32x4 ( inv_keys[ i] ) ) ;
69+
70+ let mut input_0 = [ [ 0u8 ; 16 ] ; 4 ] ;
71+ input_0[ 0 ] = * seed;
72+ input_0[ 1 ..] . copy_from_slice ( & checkpoints[ ..3 ] ) ;
73+ let mut input_0 = _mm512_loadu_si512 ( input_0. as_ptr ( ) as * const __m512i ) ;
74+ let mut input_1 = _mm512_loadu_si512 ( checkpoints[ 3 ..7 ] . as_ptr ( ) as * const __m512i ) ;
75+
76+ let mut output_0 = _mm512_loadu_si512 ( checkpoints[ 0 ..4 ] . as_ptr ( ) as * const __m512i ) ;
77+ let mut output_1 = _mm512_loadu_si512 ( checkpoints[ 4 ..8 ] . as_ptr ( ) as * const __m512i ) ;
78+
79+ input_0 = _mm512_xor_si512 ( input_0, keys_512[ 0 ] ) ;
80+ input_1 = _mm512_xor_si512 ( input_1, keys_512[ 0 ] ) ;
81+
82+ output_0 = _mm512_xor_si512 ( output_0, keys_512[ 10 ] ) ;
83+ output_1 = _mm512_xor_si512 ( output_1, keys_512[ 10 ] ) ;
84+
85+ for _ in 0 ..checkpoint_iterations / 2 {
86+ for i in 1 ..10 {
87+ input_0 = _mm512_aesenc_epi128 ( input_0, keys_512[ i] ) ;
88+ input_1 = _mm512_aesenc_epi128 ( input_1, keys_512[ i] ) ;
89+
90+ output_0 = _mm512_aesdec_epi128 ( output_0, inv_keys_512[ i] ) ;
91+ output_1 = _mm512_aesdec_epi128 ( output_1, inv_keys_512[ i] ) ;
92+ }
93+
94+ input_0 = _mm512_aesenclast_epi128 ( input_0, xor_key_512) ;
95+ input_1 = _mm512_aesenclast_epi128 ( input_1, xor_key_512) ;
96+
97+ output_0 = _mm512_aesdeclast_epi128 ( output_0, xor_key_512) ;
98+ output_1 = _mm512_aesdeclast_epi128 ( output_1, xor_key_512) ;
99+ }
100+
101+ // Code below is a more efficient version of this:
102+ // input_0 = _mm512_xor_si512(input_0, keys_512[0]);
103+ // input_1 = _mm512_xor_si512(input_1, keys_512[0]);
104+ // output_0 = _mm512_xor_si512(output_0, keys_512[10]);
105+ // output_1 = _mm512_xor_si512(output_1, keys_512[10]);
106+ //
107+ // let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
108+ // let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
109+
110+ let diff_0 = _mm512_xor_si512 ( input_0, output_0) ;
111+ let diff_1 = _mm512_xor_si512 ( input_1, output_1) ;
112+
113+ let mask0 = _mm512_cmpeq_epu64_mask ( diff_0, xor_key_512) ;
114+ let mask1 = _mm512_cmpeq_epu64_mask ( diff_1, xor_key_512) ;
115+
116+ // All inputs match outputs
117+ ( mask0 & mask1) == u8:: MAX
118+ }
119+ }
120+
43121// Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom
44122// Pavlov:
45123// https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs
46124
47125/// AES-128 round keys
48- type RoundKeys = [ __m128i ; 11 ] ;
126+ type RoundKeys = [ __m128i ; NUM_ROUND_KEYS ] ;
49127
50128macro_rules! expand_round {
51129 ( $keys: expr, $pos: expr, $round: expr) => {
@@ -72,9 +150,10 @@ macro_rules! expand_round {
72150unsafe fn expand_key ( key : & [ u8 ; 16 ] ) -> RoundKeys {
73151 // SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized
74152 // with all zeroes.
75- let mut keys: RoundKeys = mem:: zeroed ( ) ;
153+ let mut keys: RoundKeys = unsafe { mem:: zeroed ( ) } ;
76154
77- let k = _mm_loadu_si128 ( key. as_ptr ( ) as * const __m128i ) ;
155+ // SAFETY: No alignment requirement in `_mm_loadu_si128`
156+ let k = unsafe { _mm_loadu_si128 ( key. as_ptr ( ) as * const __m128i ) } ;
78157 keys[ 0 ] = k;
79158
80159 expand_round ! ( keys, 1 , 0x01 ) ;
0 commit comments