diff --git a/crates/subspace-proof-of-time/Cargo.toml b/crates/subspace-proof-of-time/Cargo.toml index e24657d128..b912714d20 100644 --- a/crates/subspace-proof-of-time/Cargo.toml +++ b/crates/subspace-proof-of-time/Cargo.toml @@ -19,7 +19,7 @@ aes.workspace = true subspace-core-primitives.workspace = true thiserror.workspace = true -[target.'cfg(target_arch = "x86_64")'.dependencies] +[target.'cfg(any(target_arch = "aarch64", target_arch = "x86_64"))'.dependencies] cpufeatures = { workspace = true } [dev-dependencies] diff --git a/crates/subspace-proof-of-time/src/aes.rs b/crates/subspace-proof-of-time/src/aes.rs index fba14c1a21..89843286ac 100644 --- a/crates/subspace-proof-of-time/src/aes.rs +++ b/crates/subspace-proof-of-time/src/aes.rs @@ -1,5 +1,7 @@ //! AES related functionality. +#[cfg(target_arch = "aarch64")] +mod aarch64; #[cfg(target_arch = "x86_64")] mod x86_64; @@ -19,6 +21,14 @@ pub(crate) fn create(seed: PotSeed, key: PotKey, checkpoint_iterations: u32) -> return unsafe { x86_64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) }; } } + #[cfg(target_arch = "aarch64")] + { + cpufeatures::new!(has_aes, "aes"); + if has_aes::get() { + // SAFETY: Checked `aes` feature + return unsafe { aarch64::create(seed.as_ref(), key.as_ref(), checkpoint_iterations) }; + } + } create_generic(seed, key, checkpoint_iterations) } @@ -83,6 +93,16 @@ pub(crate) fn verify_sequential( }; } } + #[cfg(target_arch = "aarch64")] + { + cpufeatures::new!(has_aes, "aes"); + if has_aes::get() { + // SAFETY: Checked `aes` feature + return unsafe { + aarch64::verify_sequential_aes(&seed, &key, checkpoints, checkpoint_iterations) + }; + } + } verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations) } @@ -143,9 +163,8 @@ mod tests { checkpoint_iterations: u32, ) -> bool { let sequential = verify_sequential(seed, key, checkpoints, checkpoint_iterations); - let sequential_generic = - verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations); - assert_eq!(sequential, sequential_generic); + let generic = verify_sequential_generic(seed, key, checkpoints, checkpoint_iterations); + assert_eq!(sequential, generic); #[cfg(target_arch = "x86_64")] { @@ -180,7 +199,7 @@ mod tests { cpufeatures::new!(has_aes_sse41, "aes", "sse4.1"); if has_aes_sse41::get() { // SAFETY: Checked `aes` and `sse4.1` features - let aes = unsafe { + let aes_sse41 = unsafe { x86_64::verify_sequential_aes_sse41( &seed, &key, @@ -188,6 +207,17 @@ mod tests { checkpoint_iterations, ) }; + assert_eq!(sequential, aes_sse41); + } + } + #[cfg(target_arch = "aarch64")] + { + cpufeatures::new!(has_aes, "aes"); + if has_aes::get() { + // SAFETY: Checked `aes` feature + let aes = unsafe { + aarch64::verify_sequential_aes(&seed, &key, checkpoints, checkpoint_iterations) + }; assert_eq!(sequential, aes); } } diff --git a/crates/subspace-proof-of-time/src/aes/aarch64.rs b/crates/subspace-proof-of-time/src/aes/aarch64.rs new file mode 100644 index 0000000000..4c443755e5 --- /dev/null +++ b/crates/subspace-proof-of-time/src/aes/aarch64.rs @@ -0,0 +1,170 @@ +use core::arch::aarch64::*; +use core::simd::u8x16; +use core::slice; +use subspace_core_primitives::pot::{PotCheckpoints, PotOutput}; + +const NUM_ROUND_KEYS: usize = 11; + +/// Create PoT proof with checkpoints +#[target_feature(enable = "aes")] +#[inline] +pub(super) fn create( + seed: &[u8; 16], + key: &[u8; 16], + checkpoint_iterations: u32, +) -> PotCheckpoints { + let mut checkpoints = PotCheckpoints::default(); + + let keys = expand_key(key); + let xor_key = veorq_u8(keys[10], keys[0]); + let mut seed = uint8x16_t::from(u8x16::from(*seed)); + seed = veorq_u8(seed, keys[10]); + for checkpoint in checkpoints.iter_mut() { + for _ in 0..checkpoint_iterations { + seed = vaesmcq_u8(vaeseq_u8(seed, xor_key)); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[1])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[2])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[3])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[4])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[5])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[6])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[7])); + seed = vaesmcq_u8(vaeseq_u8(seed, keys[8])); + seed = vaeseq_u8(seed, keys[9]); + } + + let checkpoint_reg = veorq_u8(seed, keys[10]); + **checkpoint = u8x16::from(checkpoint_reg).to_array(); + } + + checkpoints +} + +/// Verification mimics `create` function, but also has decryption half for better performance +#[target_feature(enable = "aes")] +#[inline] +pub(super) fn verify_sequential_aes( + seed: &[u8; 16], + key: &[u8; 16], + checkpoints: &PotCheckpoints, + checkpoint_iterations: u32, +) -> bool { + let checkpoints = PotOutput::repr_from_slice(checkpoints.as_slice()); + + let keys = expand_key(key); + let xor_key = veorq_u8(keys[10], keys[0]); + + // Invert keys for decryption, the first and last element is not used below, hence they are + // copied as is from encryption keys (otherwise the first and last element would need to be + // swapped) + let mut inv_keys = keys; + for i in 1..10 { + inv_keys[i] = vaesimcq_u8(keys[10 - i]); + } + + let mut inputs: [uint8x16_t; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ + uint8x16_t::from(u8x16::from(*seed)), + uint8x16_t::from(u8x16::from(checkpoints[0])), + uint8x16_t::from(u8x16::from(checkpoints[1])), + uint8x16_t::from(u8x16::from(checkpoints[2])), + uint8x16_t::from(u8x16::from(checkpoints[3])), + uint8x16_t::from(u8x16::from(checkpoints[4])), + uint8x16_t::from(u8x16::from(checkpoints[5])), + uint8x16_t::from(u8x16::from(checkpoints[6])), + ]; + + let mut outputs: [uint8x16_t; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ + uint8x16_t::from(u8x16::from(checkpoints[0])), + uint8x16_t::from(u8x16::from(checkpoints[1])), + uint8x16_t::from(u8x16::from(checkpoints[2])), + uint8x16_t::from(u8x16::from(checkpoints[3])), + uint8x16_t::from(u8x16::from(checkpoints[4])), + uint8x16_t::from(u8x16::from(checkpoints[5])), + uint8x16_t::from(u8x16::from(checkpoints[6])), + uint8x16_t::from(u8x16::from(checkpoints[7])), + ]; + + inputs = inputs.map(|input| veorq_u8(input, keys[10])); + outputs = outputs.map(|output| veorq_u8(output, keys[0])); + + for _ in 0..checkpoint_iterations / 2 { + inputs = inputs.map(|input| vaesmcq_u8(vaeseq_u8(input, xor_key))); + outputs = outputs.map(|output| vaesimcq_u8(vaesdq_u8(output, xor_key))); + + for i in 1..9 { + inputs = inputs.map(|input| vaesmcq_u8(vaeseq_u8(input, keys[i]))); + outputs = outputs.map(|output| vaesimcq_u8(vaesdq_u8(output, inv_keys[i]))); + } + + inputs = inputs.map(|input| vaeseq_u8(input, keys[9])); + outputs = outputs.map(|output| vaesdq_u8(output, inv_keys[9])); + } + + inputs.into_iter().zip(outputs).all(|(input, output)| { + let diff = veorq_u8(input, output); + let cmp = vceqq_u8(diff, xor_key); + vminvq_u8(cmp) == u8::MAX + }) +} + +// Below code copied with minor changes from the following place under MIT/Apache-2.0 license by +// Artyom Pavlov: +// https://github.com/RustCrypto/block-ciphers/blob/fbb68f40b122909d92e40ee8a50112b6e5d0af8f/aes/src/armv8/expand.rs + +/// There are 4 AES words in a block. +const BLOCK_WORDS: usize = 4; + +/// The AES (nee Rijndael) notion of a word is always 32-bits, or 4-bytes. +const WORD_SIZE: usize = 4; + +/// AES round constants. +const ROUND_CONSTS: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36]; + +/// AES key expansion. +#[target_feature(enable = "aes")] +fn expand_key(key: &[u8; 16]) -> [uint8x16_t; NUM_ROUND_KEYS] { + let mut expanded_keys = [uint8x16_t::from(u8x16::default()); NUM_ROUND_KEYS]; + + // Sanity check, as this is required in order for the subsequent conversion to be sound. + const _: () = assert!(align_of::() >= align_of::()); + let columns = unsafe { + slice::from_raw_parts_mut( + expanded_keys.as_mut_ptr().cast::(), + NUM_ROUND_KEYS * BLOCK_WORDS, + ) + }; + + for (i, chunk) in key.array_chunks::().enumerate() { + columns[i] = u32::from_ne_bytes(*chunk); + } + + // From "The Rijndael Block Cipher" Section 4.1: + // > The number of columns of the Cipher Key is denoted by `Nk` and is + // > equal to the key length divided by 32 [bits]. + let nk = 16 / WORD_SIZE; + + for i in nk..NUM_ROUND_KEYS * BLOCK_WORDS { + let mut word = columns[i - 1]; + + if i % nk == 0 { + word = sub_word(word).rotate_right(8) ^ ROUND_CONSTS[i / nk - 1]; + } else if nk > 6 && i % nk == 4 { + word = sub_word(word); + } + + columns[i] = columns[i - nk] ^ word; + } + + expanded_keys +} + +/// Sub bytes for a single AES word: used for key expansion +#[target_feature(enable = "aes")] +fn sub_word(input: u32) -> u32 { + let input = vreinterpretq_u8_u32(vdupq_n_u32(input)); + + // AES single round encryption (with a "round" key of all zeros) + let sub_input = vaeseq_u8(input, vdupq_n_u8(0)); + + vgetq_lane_u32::<0>(vreinterpretq_u32_u8(sub_input)) +} diff --git a/crates/subspace-proof-of-time/src/aes/x86_64.rs b/crates/subspace-proof-of-time/src/aes/x86_64.rs index 7bf70b99af..fec1517875 100644 --- a/crates/subspace-proof-of-time/src/aes/x86_64.rs +++ b/crates/subspace-proof-of-time/src/aes/x86_64.rs @@ -15,25 +15,25 @@ pub(super) fn create( ) -> PotCheckpoints { let mut checkpoints = PotCheckpoints::default(); - let keys_reg = expand_key(key); - let xor_key = _mm_xor_si128(keys_reg[10], keys_reg[0]); - let mut seed_reg = __m128i::from(u8x16::from_array(*seed)); - seed_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + let keys = expand_key(key); + let xor_key = _mm_xor_si128(keys[10], keys[0]); + let mut seed = __m128i::from(u8x16::from_array(*seed)); + seed = _mm_xor_si128(seed, keys[0]); for checkpoint in checkpoints.iter_mut() { for _ in 0..checkpoint_iterations { - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[1]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[2]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[3]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[4]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[5]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[6]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[7]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[8]); - seed_reg = _mm_aesenc_si128(seed_reg, keys_reg[9]); - seed_reg = _mm_aesenclast_si128(seed_reg, xor_key); + seed = _mm_aesenc_si128(seed, keys[1]); + seed = _mm_aesenc_si128(seed, keys[2]); + seed = _mm_aesenc_si128(seed, keys[3]); + seed = _mm_aesenc_si128(seed, keys[4]); + seed = _mm_aesenc_si128(seed, keys[5]); + seed = _mm_aesenc_si128(seed, keys[6]); + seed = _mm_aesenc_si128(seed, keys[7]); + seed = _mm_aesenc_si128(seed, keys[8]); + seed = _mm_aesenc_si128(seed, keys[9]); + seed = _mm_aesenclast_si128(seed, xor_key); } - let checkpoint_reg = _mm_xor_si128(seed_reg, keys_reg[0]); + let checkpoint_reg = _mm_xor_si128(seed, keys[0]); **checkpoint = u8x16::from(checkpoint_reg).to_array(); } @@ -62,7 +62,7 @@ pub(super) fn verify_sequential_aes_sse41( inv_keys[i] = _mm_aesimc_si128(keys[10 - i]); } - let mut inputs: [__m128i; 8] = [ + let mut inputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ __m128i::from(u8x16::from(*seed)), __m128i::from(u8x16::from(checkpoints[0])), __m128i::from(u8x16::from(checkpoints[1])), @@ -73,7 +73,7 @@ pub(super) fn verify_sequential_aes_sse41( __m128i::from(u8x16::from(checkpoints[6])), ]; - let mut outputs: [__m128i; 8] = [ + let mut outputs: [__m128i; PotCheckpoints::NUM_CHECKPOINTS.get() as usize] = [ __m128i::from(u8x16::from(checkpoints[0])), __m128i::from(u8x16::from(checkpoints[1])), __m128i::from(u8x16::from(checkpoints[2])), diff --git a/crates/subspace-proof-of-time/src/lib.rs b/crates/subspace-proof-of-time/src/lib.rs index 0f3ca68a6b..386922094d 100644 --- a/crates/subspace-proof-of-time/src/lib.rs +++ b/crates/subspace-proof-of-time/src/lib.rs @@ -1,5 +1,6 @@ //! Proof of time implementation. +#![cfg_attr(target_arch = "aarch64", feature(array_chunks))] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![feature(portable_simd)] #![no_std]