Skip to content

Commit bb4d9af

Browse files
committed
Reduce unsafe in subspace-proof-of-time
1 parent 4508aee commit bb4d9af

File tree

3 files changed

+99
-103
lines changed

3 files changed

+99
-103
lines changed

crates/subspace-proof-of-time/src/aes.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub(crate) fn create(seed: PotSeed, key: PotKey, checkpoint_iterations: u32) ->
1515
{
1616
cpufeatures::new!(has_aes, "aes");
1717
if has_aes::get() {
18+
// SAFETY: Checked `aes` feature
1819
return unsafe { x86_64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) };
1920
}
2021
}
@@ -53,8 +54,9 @@ pub(crate) fn verify_sequential(
5354

5455
#[cfg(target_arch = "x86_64")]
5556
{
56-
cpufeatures::new!(has_aes, "avx512f", "vaes");
57-
if has_aes::get() {
57+
cpufeatures::new!(has_avx512f_vaes, "avx512f", "vaes");
58+
if has_avx512f_vaes::get() {
59+
// SAFETY: Checked `avx512f` and `vaes` features
5860
return unsafe {
5961
x86_64::verify_sequential_avx512f(&seed, &key, checkpoints, checkpoint_iterations)
6062
};
Lines changed: 94 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,40 @@
11
use core::arch::x86_64::*;
2-
use core::{array, mem};
2+
use core::array;
3+
use core::simd::{u8x16, u8x64};
34
use subspace_core_primitives::pot::{PotCheckpoints, PotOutput};
45

56
const NUM_ROUND_KEYS: usize = 11;
67

78
/// Create PoT proof with checkpoints
89
#[target_feature(enable = "aes")]
910
#[inline]
10-
pub(super) unsafe fn create(
11+
pub(super) fn create(
1112
seed: &[u8; 16],
1213
key: &[u8; 16],
1314
checkpoint_iterations: u32,
1415
) -> PotCheckpoints {
1516
let mut checkpoints = PotCheckpoints::default();
1617

17-
unsafe {
18-
let keys_reg = expand_key(key);
19-
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
20-
let mut seed_reg = _mm_loadu_si128(seed.as_ptr() as *const __m128i);
21-
seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
22-
for checkpoint in checkpoints.iter_mut() {
23-
for _ in 0..checkpoint_iterations {
24-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]);
25-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]);
26-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]);
27-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]);
28-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]);
29-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]);
30-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]);
31-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]);
32-
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]);
33-
seed_reg = _mm_aesenclast_si128(seed_reg, xor_key);
34-
}
35-
36-
let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
37-
_mm_storeu_si128(checkpoint.as_mut_ptr() as *mut __m128i, checkpoint_reg);
18+
let keys_reg = expand_key(key);
19+
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
20+
let mut seed_reg = __m128i::from(u8x16::from_array(*seed));
21+
seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
22+
for checkpoint in checkpoints.iter_mut() {
23+
for _ in 0..checkpoint_iterations {
24+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]);
25+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]);
26+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]);
27+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]);
28+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]);
29+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]);
30+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]);
31+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]);
32+
seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]);
33+
seed_reg = _mm_aesenclast_si128(seed_reg, xor_key);
3834
}
35+
36+
let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]);
37+
**checkpoint = u8x16::from(checkpoint_reg).to_array();
3938
}
4039

4140
checkpoints
@@ -44,45 +43,48 @@ pub(super) unsafe fn create(
4443
/// Verification mimics `create` function, but also has decryption half for better performance
4544
#[target_feature(enable = "avx512f,vaes")]
4645
#[inline]
47-
pub(super) unsafe fn verify_sequential_avx512f(
46+
pub(super) fn verify_sequential_avx512f(
4847
seed: &[u8; 16],
4948
key: &[u8; 16],
5049
checkpoints: &PotCheckpoints,
5150
checkpoint_iterations: u32,
5251
) -> bool {
5352
let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice());
5453

55-
unsafe {
56-
let keys_reg = expand_key(key);
57-
let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]);
58-
let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
54+
let keys = expand_key(key);
55+
let xor_key = _mm_xor_si128(keys[10], keys[0]);
56+
let xor_key_512 = _mm512_broadcast_i32x4(xor_key);
5957

60-
// Invert keys for decryption
61-
let mut inv_keys = keys_reg;
62-
for i in 1..10 {
63-
inv_keys[i] = _mm_aesimc_si128(keys_reg[10 - i]);
64-
}
58+
// Invert keys for decryption, the first and last element is not used below, hence they are
59+
// copied as is from encryption keys (otherwise the first and last element would need to be
60+
// swapped)
61+
let mut inv_keys = keys;
62+
for i in 1..10 {
63+
inv_keys[i] = _mm_aesimc_si128(keys[10 - i]);
64+
}
6565

66-
let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys_reg[i]));
67-
let inv_keys_512 =
68-
array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
66+
let keys_512 = array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(keys[i]));
67+
let inv_keys_512 =
68+
array::from_fn::<_, NUM_ROUND_KEYS, _>(|i| _mm512_broadcast_i32x4(inv_keys[i]));
6969

70-
let mut input_0 = [[0u8; 16]; 4];
71-
input_0[0] = *seed;
72-
input_0[1..].copy_from_slice(&checkpoints[..3]);
73-
let mut input_0 = _mm512_loadu_si512(input_0.as_ptr() as *const __m512i);
74-
let mut input_1 = _mm512_loadu_si512(checkpoints[3..7].as_ptr() as *const __m512i);
70+
let mut input_0 = [[0u8; 16]; 4];
71+
input_0[0] = *seed;
72+
input_0[1..].copy_from_slice(&checkpoints[..3]);
73+
let mut input_0 = __m512i::from(u8x64::from_slice(input_0.as_flattened()));
74+
let mut input_1 = __m512i::from(u8x64::from_slice(checkpoints[3..7].as_flattened()));
7575

76-
let mut output_0 = _mm512_loadu_si512(checkpoints[0..4].as_ptr() as *const __m512i);
77-
let mut output_1 = _mm512_loadu_si512(checkpoints[4..8].as_ptr() as *const __m512i);
76+
let mut output_0 = __m512i::from(u8x64::from_slice(checkpoints[0..4].as_flattened()));
77+
let mut output_1 = __m512i::from(u8x64::from_slice(checkpoints[4..8].as_flattened()));
7878

79-
input_0 = _mm512_xor_si512(input_0, keys_512[0]);
80-
input_1 = _mm512_xor_si512(input_1, keys_512[0]);
79+
input_0 = _mm512_xor_si512(input_0, keys_512[0]);
80+
input_1 = _mm512_xor_si512(input_1, keys_512[0]);
8181

82-
output_0 = _mm512_xor_si512(output_0, keys_512[10]);
83-
output_1 = _mm512_xor_si512(output_1, keys_512[10]);
82+
output_0 = _mm512_xor_si512(output_0, keys_512[10]);
83+
output_1 = _mm512_xor_si512(output_1, keys_512[10]);
8484

85-
for _ in 0..checkpoint_iterations / 2 {
85+
for _ in 0..checkpoint_iterations / 2 {
86+
// TODO: Shouldn't be unsafe: https://github.com/rust-lang/rust/issues/141718
87+
unsafe {
8688
for i in 1..10 {
8789
input_0 = _mm512_aesenc_epi128(input_0, keys_512[i]);
8890
input_1 = _mm512_aesenc_epi128(input_1, keys_512[i]);
@@ -97,75 +99,66 @@ pub(super) unsafe fn verify_sequential_avx512f(
9799
output_0 = _mm512_aesdeclast_epi128(output_0, xor_key_512);
98100
output_1 = _mm512_aesdeclast_epi128(output_1, xor_key_512);
99101
}
102+
}
100103

101-
// Code below is a more efficient version of this:
102-
// input_0 = _mm512_xor_si512(input_0, keys_512[0]);
103-
// input_1 = _mm512_xor_si512(input_1, keys_512[0]);
104-
// output_0 = _mm512_xor_si512(output_0, keys_512[10]);
105-
// output_1 = _mm512_xor_si512(output_1, keys_512[10]);
106-
//
107-
// let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
108-
// let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
104+
// Code below is a more efficient version of this:
105+
// input_0 = _mm512_xor_si512(input_0, keys_512[0]);
106+
// input_1 = _mm512_xor_si512(input_1, keys_512[0]);
107+
// output_0 = _mm512_xor_si512(output_0, keys_512[10]);
108+
// output_1 = _mm512_xor_si512(output_1, keys_512[10]);
109+
//
110+
// let mask0 = _mm512_cmpeq_epu64_mask(input_0, output_0);
111+
// let mask1 = _mm512_cmpeq_epu64_mask(input_1, output_1);
109112

110-
let diff_0 = _mm512_xor_si512(input_0, output_0);
111-
let diff_1 = _mm512_xor_si512(input_1, output_1);
113+
let diff_0 = _mm512_xor_si512(input_0, output_0);
114+
let diff_1 = _mm512_xor_si512(input_1, output_1);
112115

113-
let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
114-
let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
116+
let mask0 = _mm512_cmpeq_epu64_mask(diff_0, xor_key_512);
117+
let mask1 = _mm512_cmpeq_epu64_mask(diff_1, xor_key_512);
115118

116-
// All inputs match outputs
117-
(mask0 & mask1) == u8::MAX
118-
}
119+
// All inputs match outputs
120+
(mask0 & mask1) == u8::MAX
119121
}
120122

121-
// Below code copied with minor changes from following place under MIT/Apache-2.0 license by Artyom
122-
// Pavlov:
123-
// https://github.com/RustCrypto/block-ciphers/blob/9413fcadd28d53854954498c0589b747d8e4ade2/aes/src/ni/aes128.rs
123+
// Below code copied with minor changes from the following place under MIT/Apache-2.0 license by
124+
// Artyom Pavlov:
125+
// https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f/aes/src/ni/expand.rs
124126

125-
/// AES-128 round keys
126-
type RoundKeys = [__m128i; NUM_ROUND_KEYS];
127-
128-
macro_rules! expand_round {
129-
($keys:expr, $pos:expr, $round:expr) => {
130-
let mut t1 = $keys[$pos - 1];
127+
#[target_feature(enable = "aes")]
128+
fn expand_key(key: &[u8; 16]) -> [__m128i; NUM_ROUND_KEYS] {
129+
#[target_feature(enable = "aes")]
130+
fn expand_round<const RK: i32>(keys: &mut [__m128i; NUM_ROUND_KEYS], pos: usize) {
131+
let mut t1 = keys[pos - 1];
131132
let mut t2;
132133
let mut t3;
133134

134-
t2 = _mm_aeskeygenassist_si128(t1, $round);
135-
t2 = _mm_shuffle_epi32(t2, 0xff);
136-
t3 = _mm_slli_si128(t1, 0x4);
135+
t2 = _mm_aeskeygenassist_si128::<RK>(t1);
136+
t2 = _mm_shuffle_epi32::<0xff>(t2);
137+
t3 = _mm_slli_si128::<0x4>(t1);
137138
t1 = _mm_xor_si128(t1, t3);
138-
t3 = _mm_slli_si128(t3, 0x4);
139+
t3 = _mm_slli_si128::<0x4>(t3);
139140
t1 = _mm_xor_si128(t1, t3);
140-
t3 = _mm_slli_si128(t3, 0x4);
141+
t3 = _mm_slli_si128::<0x4>(t3);
141142
t1 = _mm_xor_si128(t1, t3);
142143
t1 = _mm_xor_si128(t1, t2);
143144

144-
$keys[$pos] = t1;
145-
};
146-
}
145+
keys[pos] = t1;
146+
}
147147

148-
#[target_feature(enable = "aes")]
149-
#[inline]
150-
unsafe fn expand_key(key: &[u8; 16]) -> RoundKeys {
151-
// SAFETY: `RoundKeys` is a `[__m128i; 11]` which can be initialized
152-
// with all zeroes.
153-
let mut keys: RoundKeys = unsafe { mem::zeroed() };
154-
155-
// SAFETY: No alignment requirement in `_mm_loadu_si128`
156-
let k = unsafe { _mm_loadu_si128(key.as_ptr() as *const __m128i) };
157-
keys[0] = k;
158-
159-
expand_round!(keys, 1, 0x01);
160-
expand_round!(keys, 2, 0x02);
161-
expand_round!(keys, 3, 0x04);
162-
expand_round!(keys, 4, 0x08);
163-
expand_round!(keys, 5, 0x10);
164-
expand_round!(keys, 6, 0x20);
165-
expand_round!(keys, 7, 0x40);
166-
expand_round!(keys, 8, 0x80);
167-
expand_round!(keys, 9, 0x1B);
168-
expand_round!(keys, 10, 0x36);
148+
let mut keys = [_mm_setzero_si128(); NUM_ROUND_KEYS];
149+
keys[0] = __m128i::from(u8x16::from(*key));
150+
151+
let kr = &mut keys;
152+
expand_round::<0x01>(kr, 1);
153+
expand_round::<0x02>(kr, 2);
154+
expand_round::<0x04>(kr, 3);
155+
expand_round::<0x08>(kr, 4);
156+
expand_round::<0x10>(kr, 5);
157+
expand_round::<0x20>(kr, 6);
158+
expand_round::<0x40>(kr, 7);
159+
expand_round::<0x80>(kr, 8);
160+
expand_round::<0x1B>(kr, 9);
161+
expand_round::<0x36>(kr, 10);
169162

170163
keys
171164
}

crates/subspace-proof-of-time/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Proof of time implementation.
22
33
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
4+
#![feature(portable_simd)]
45
#![no_std]
56

67
mod aes;

0 commit comments

Comments
 (0)