Skip to content

Commit 4508aee

Browse files
committed
Faster PoT verification for CPUs that support AVX512F+VAES
1 parent 310ba30 commit 4508aee

File tree

3 files changed

+162
-27
lines changed

3 files changed

+162
-27
lines changed

crates/subspace-proof-of-time/src/aes.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,25 @@ pub(crate) fn verify_sequential(
5151
) -> bool {
5252
assert_eq!(checkpoint_iterations % 2, 0);
5353

54+
#[cfg(target_arch = "x86_64")]
55+
{
56+
cpufeatures::new!(has_aes, "avx512f", "vaes");
57+
if has_aes::get() {
58+
return unsafe {
59+
x86_64::verify_sequential_avx512f(&seed, &key, checkpoints, checkpoint_iterations)
60+
};
61+
}
62+
}
63+
64+
verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations)
65+
}
66+
67+
fn verify_sequential_generic(
68+
seed: PotSeed,
69+
key: PotKey,
70+
checkpoints: &PotCheckpoints,
71+
checkpoint_iterations: u32,
72+
) -> bool {
5473
let key = Array::from(*key);
5574
let cipher = Aes128::new(&key);
5675

@@ -113,6 +132,12 @@ mod tests {
113132
&checkpoints,
114133
checkpoint_iterations,
115134
));
135+
assert!(verify_sequential_generic(
136+
seed,
137+
key,
138+
&checkpoints,
139+
checkpoint_iterations,
140+
));
116141

117142
// Decryption of invalid cipher text fails.
118143
let mut checkpoints_1 = checkpoints;
@@ -123,6 +148,12 @@ mod tests {
123148
&checkpoints_1,
124149
checkpoint_iterations,
125150
));
151+
assert!(!verify_sequential_generic(
152+
seed,
153+
key,
154+
&checkpoints_1,
155+
checkpoint_iterations,
156+
));
126157

127158
// Decryption with wrong number of iterations fails.
128159
assert!(!verify_sequential(
@@ -131,12 +162,24 @@ mod tests {
131162
&checkpoints,
132163
checkpoint_iterations + 2,
133164
));
165+
assert!(!verify_sequential_generic(
166+
seed,
167+
key,
168+
&checkpoints,
169+
checkpoint_iterations + 2,
170+
));
134171
assert!(!verify_sequential(
135172
seed,
136173
key,
137174
&checkpoints,
138175
checkpoint_iterations - 2,
139176
));
177+
assert!(!verify_sequential_generic(
178+
seed,
179+
key,
180+
&checkpoints,
181+
checkpoint_iterations - 2,
182+
));
140183

141184
// Decryption with wrong seed fails.
142185
assert!(!verify_sequential(
@@ -145,6 +188,12 @@ mod tests {
145188
&checkpoints,
146189
checkpoint_iterations,
147190
));
191+
assert!(!verify_sequential_generic(
192+
PotSeed::from(SEED_1),
193+
key,
194+
&checkpoints,
195+
checkpoint_iterations,
196+
));
148197

149198
// Decryption with wrong key fails.
150199
assert!(!verify_sequential(
@@ -153,5 +202,11 @@ mod tests {
153202
&checkpoints,
154203
checkpoint_iterations,
155204
));
205+
assert!(!verify_sequential_generic(
206+
seed,
207+
PotKey::from(KEY_1),
208+
&checkpoints,
209+
checkpoint_iterations,
210+
));
156211
}
157212
}

crates/subspace-proof-of-time/src/aes/x86_64.rs

Lines changed: 106 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use core::arch::x86_64::*;
2-
use core::mem;
3-
use subspace_core_primitives::pot::PotCheckpoints;
2+
use core::{array, mem};
3+
use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
4+
5+
const NUM_ROUND_KEYS: usize = 11;
46

57
/// Create PoT proof with checkpoints
68
#[target_feature(enable = "aes")]
@@ -12,40 +14,116 @@ pub(super) unsafe fn create(
1214
) -> PotCheckpoints {
1315
let mut checkpoints = PotCheckpoints::default();
1416

15-
let keys_reg = expand_key(key);
16-
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
17-
let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i);
18-
seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
19-
for checkpoint in checkpoints.iter_mut() {
20-
for _ in 0..checkpoint_iterations {
21-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]);
22-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]);
23-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]);
24-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]);
25-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]);
26-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]);
27-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]);
28-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]);
29-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]);
30-
seed_reg = _mm_aesenclast_si128(seed_reg, xor_key);
31-
}
17+
unsafe {
18+
let keys_reg = expand_key(key);
19+
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
20+
let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i);
21+
seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
22+
for checkpoint in checkpoints.iter_mut() {
23+
for _ in 0..checkpoint_iterations {
24+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]);
25+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]);
26+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]);
27+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]);
28+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]);
29+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]);
30+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]);
31+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]);
32+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]);
33+
seed_reg = _mm_aesenclast_si128(seed_reg, xor_key);
34+
}
3235

33-
let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
34-
_mm_storeu_si128(
35-
checkpoint.as_mut().as_mut_ptr() as *mut __m128i,
36-
checkpoint_reg,
37-
);
36+
let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
37+
_mm_storeu_si128(checkpoint.as_mut_ptr() as *mut __m128i, checkpoint_reg);
38+
}
3839
}
3940

4041
checkpoints
4142
}
4243

44+
/// Verification mimics `create` function, but also has decryption half for better performance
45+
#[target_feature(enable = "avx512f,vaes")]
46+
#[inline]
47+
pub(super) unsafe fn verify_sequential_avx512f(
48+
seed: &[u8; 16],
49+
key: &[u8; 16],
50+
checkpoints: &PotCheckpoints,
51+
checkpoint_iterations: u32,
52+
) -> bool {
53+
let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
54+
55+
unsafe {
56+
let keys_reg = expand_key(key);
57+
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
58+
let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
59+
60+
// Invert keys for decryption
61+
let mut inv_keys = keys_reg;
62+
for i in 1..10 {
63+
inv_keys[i] = _mm_aesimc_si128(keys_reg[10 - i]);
64+
}
65+
66+
let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys_reg[i]));
67+
let inv_keys_512 =
68+
array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
69+
70+
let mut input_0 = [[0u8; 16]; 4];
71+
input_0[0] = *seed;
72+
input_0[1..].copy_from_slice(&checkpoints[..3]);
73+
let mut input_0 = _mm512_loadu_si512(input_0.as_ptr() as *const __m512i);
74+
let mut input_1 = _mm512_loadu_si512(checkpoints[3..7].as_ptr() as *const __m512i);
75+
76+
let mut output_0 = _mm512_loadu_si512(checkpoints[0..4].as_ptr() as *const __m512i);
77+
let mut output_1 = _mm512_loadu_si512(checkpoints[4..8].as_ptr() as *const __m512i);
78+
79+
input_0 = _mm512_xor_si512(input_0, keys_512[0]);
80+
input_1 = _mm512_xor_si512(input_1, keys_512[0]);
81+
82+
output_0 = _mm512_xor_si512(output_0, keys_512[10]);
83+
output_1 = _mm512_xor_si512(output_1, keys_512[10]);
84+
85+
for _ in 0..checkpoint_iterations / 2 {
86+
for i in 1..10 {
87+
input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]);
88+
input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]);
89+
90+
output_0 = _mm512_aesdec_epi128(output_0, inv_keys_512[i]);
91+
output_1 = _mm512_aesdec_epi128(output_1, inv_keys_512[i]);
92+
}
93+
94+
input_0 = _mm512_aesenclast_epi128(input_0, xor_key_512);
95+
input_1 = _mm512_aesenclast_epi128(input_1, xor_key_512);
96+
97+
output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512);
98+
output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512);
99+
}
100+
101+
// Code below is a more efficient version of this:
102+
// input_0 = _mm512_xor_si512(input_0, keys_512[0]);
103+
// input_1 = _mm512_xor_si512(input_1, keys_512[0]);
104+
// output_0 = _mm512_xor_si512(output_0, keys_512[10]);
105+
// output_1 = _mm512_xor_si512(output_1, keys_512[10]);
106+
//
107+
// let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
108+
// let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
109+
110+
let diff_0 = _mm512_xor_si512(input_0, output_0);
111+
let diff_1 = _mm512_xor_si512(input_1, output_1);
112+
113+
let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
114+
let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
115+
116+
// All inputs match outputs
117+
(mask0 & mask1) == u8::MAX
118+
}
119+
}
120+
43121
// Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom
44122
// Pavlov:
45123
// https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs
46124

47125
/// AES-128 round keys
48-
type RoundKeys = [__m128i; 11];
126+
type RoundKeys = [__m128i; NUM_ROUND_KEYS];
49127

50128
macro_rules! expand_round {
51129
($keys:expr, $pos:expr, $round:expr) => {
@@ -72,9 +150,10 @@ macro_rules! expand_round {
72150
unsafe fn expand_key(key: &[u8; 16]) -> RoundKeys {
73151
// SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized
74152
// with all zeroes.
75-
let mut keys: RoundKeys = mem::zeroed();
153+
let mut keys: RoundKeys = unsafe { mem::zeroed() };
76154

77-
let k = _mm_loadu_si128(key.as_ptr() as *const __m128i);
155+
// SAFETY: No alignment requirement in `_mm_loadu_si128`
156+
let k = unsafe { _mm_loadu_si128(key.as_ptr() as *const __m128i) };
78157
keys[0] = k;
79158

80159
expand_round!(keys, 1, 0x01);

crates/subspace-proof-of-time/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Proof of time implementation.
22
3+
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
34
#![no_std]
45

56
mod aes;

0 commit comments

Comments
 (0)