diff --git a/.github/workflows/aes.yml b/.github/workflows/aes.yml index f9edc7cb..6ef9c2ad 100644 --- a/.github/workflows/aes.yml +++ b/.github/workflows/aes.yml @@ -16,6 +16,8 @@ defaults: env: CARGO_INCREMENTAL: 0 RUSTFLAGS: "-Dwarnings" + QEMU_FULL_VERSION: 8.2.0 + LLVM_MAJOR_VERSION: 17 jobs: # Builds for no_std platforms @@ -234,6 +236,112 @@ jobs: - run: cross test --package aes --target ${{ matrix.target }} --features hazmat - run: cross test --package aes --target ${{ matrix.target }} --all-features + # Build and cache latest QEMUs; needed for RISC-V features + qemu-build-and-cache: + runs-on: ubuntu-latest + defaults: + run: + working-directory: /home/runner + steps: + - uses: silvanshade/rustcrypto-actions/qemu-cache-build@master + with: + qemu-full-version: ${{ env.QEMU_FULL_VERSION }} + qemu-target-archs: riscv64 + + # RISC-V rv64 cross-compiled tests for AES intrinsics + riscv64: + needs: qemu-build-and-cache + strategy: + matrix: + include: + - target: riscv64gc-unknown-linux-gnu + rust: nightly-2024-01-27 # TODO(silvanshade): stable MSRV once available + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + # NOTE: Install a recent QEMU for RISC-V support + - uses: silvanshade/rustcrypto-actions/qemu-cache-install@master + with: + qemu-full-version: ${{ env.QEMU_FULL_VERSION }} + qemu-target-archs: riscv64 + # NOTE: Install a recent LLVM/GNU toolchain configured for RISC-V multiarch cross-compilation + - uses: silvanshade/rustcrypto-actions/llvm-gnu-multiarch-install@master + with: + llvm-major-version: ${{ env.LLVM_MAJOR_VERSION }} + ubuntu-archs: riscv64 + - uses: RustCrypto/actions/cargo-cache@master + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + # NOTE: Write a rust-toolchain.toml to override the default toolchain + - name: write rust-toolchain.toml + shell: bash + run: | + cd ../aes/.. + echo '[toolchain]' > rust-toolchain.toml + echo 'channel = "${{ matrix.rust }}"' >> rust-toolchain.toml + echo '' >> rust-toolchain.toml + # NOTE: Write a `.cargo/config.toml` to configure the RISC-V target for scalar tests + - name: write .cargo/config.toml (for scalar tests) + shell: bash + run: | + cd ../aes/.. + mkdir -p .cargo + echo '[target.${{ matrix.target }}]' > .cargo/config.toml + echo 'runner = "qemu-riscv64 -cpu rv64,zkne=true,zknd=true"' >> .cargo/config.toml + echo 'linker = "clang-${{ env.LLVM_MAJOR_VERSION }}"' >> .cargo/config.toml + echo 'rustflags = [' >> .cargo/config.toml + echo ' "-C", "link-arg=-fuse-ld=lld-${{ env.LLVM_MAJOR_VERSION }}",' >> .cargo/config.toml + echo ' "-C", "link-arg=-march=rv64gc_zkne_zknd",' >> .cargo/config.toml + echo ' "-C", "link-arg=--target=${{ matrix.target }}",' >> .cargo/config.toml + echo ' "-C", "target-feature=+zkne,+zknd"' >> .cargo/config.toml + echo ']' >> .cargo/config.toml + - name: riscv64 scalar tests + run: unset RUSTFLAGS && cargo test --package aes + - name: riscv64 scalar tests (all features) + run: unset RUSTFLAGS && cargo test --package aes --all-features + # NOTE: Write a `.cargo/config.toml` to configure the RISC-V target for vector tests + - name: write .cargo/config.toml (for vector tests) + shell: bash + run: | + cd ../aes/.. + mkdir -p .cargo + echo '[target.${{ matrix.target }}]' > .cargo/config.toml + echo 'runner = "qemu-riscv64 -cpu rv64,v=true,vext_spec=v1.0,zvkned=true"' >> .cargo/config.toml + echo 'linker = "clang-${{ env.LLVM_MAJOR_VERSION }}"' >> .cargo/config.toml + echo 'rustflags = [' >> .cargo/config.toml + echo ' "-C", "link-arg=-fuse-ld=lld-${{ env.LLVM_MAJOR_VERSION }}",' >> .cargo/config.toml + echo ' "-C", "link-arg=-march=rv64gc_v1p0_zvkned1p0",' >> .cargo/config.toml + echo ' "-C", "link-arg=--target=riscv64-unknown-linux-gnu",' >> .cargo/config.toml + echo ' "-C", "target-feature=+v",' >> .cargo/config.toml + echo ' "--cfg", "target_feature_zvkned"' >> .cargo/config.toml + echo ']' >> .cargo/config.toml + - name: riscv64 vector tests + run: unset RUSTFLAGS && cargo test --package aes --target ${{ matrix.target }} + - name: riscv64 vector tests (all features) + run: unset RUSTFLAGS && cargo test --package aes --target ${{ matrix.target }} --all-features + # NOTE: Write a `.cargo/config.toml` to configure the RISC-V target for scalar AND vector tests + - name: write .cargo/config.toml (for vector tests) + shell: bash + run: | + cd ../aes/.. + mkdir -p .cargo + echo '[target.${{ matrix.target }}]' > .cargo/config.toml + echo 'runner = "qemu-riscv64 -cpu rv64,zkne=true,zknd=true,v=true,vext_spec=v1.0,zvkned=true"' >> .cargo/config.toml + echo 'linker = "clang-${{ env.LLVM_MAJOR_VERSION }}"' >> .cargo/config.toml + echo 'rustflags = [' >> .cargo/config.toml + echo ' "-C", "link-arg=-fuse-ld=lld-${{ env.LLVM_MAJOR_VERSION }}",' >> .cargo/config.toml + echo ' "-C", "link-arg=-march=rv64gc_zkne_zknd_v1p0_zvkned1p0",' >> .cargo/config.toml + echo ' "-C", "link-arg=--target=riscv64-unknown-linux-gnu",' >> .cargo/config.toml + echo ' "-C", "target-feature=+zkne,+zknd,+v",' >> .cargo/config.toml + echo ' "--cfg", "target_feature_zvkned"' >> .cargo/config.toml + echo ']' >> .cargo/config.toml + - name: riscv64 vector tests + run: unset RUSTFLAGS && cargo test --package aes --target ${{ matrix.target }} + - name: riscv64 vector tests (all features) + run: unset RUSTFLAGS && cargo test --package aes --target ${{ matrix.target }} --all-features + clippy: env: RUSTFLAGS: "-Dwarnings --cfg aes_compact" diff --git a/aes/benches/mod.rs b/aes/benches/mod.rs index 579b0731..8847a04c 100644 --- a/aes/benches/mod.rs +++ b/aes/benches/mod.rs @@ -13,11 +13,27 @@ block_decryptor_bench!( aes128_decrypt_block, aes128_decrypt_blocks, ); +#[cfg(any( + not(target_arch = "riscv64"), + all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" + ) +))] block_encryptor_bench!( Key: aes::Aes192, aes192_encrypt_block, aes192_encrypt_blocks, ); +#[cfg(any( + not(target_arch = "riscv64"), + all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" + ) +))] block_decryptor_bench!( Key: aes::Aes192, aes192_decrypt_block, @@ -43,6 +59,14 @@ fn aes128_new(bh: &mut test::Bencher) { }); } +#[cfg(any( + not(target_arch = "riscv64"), + all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" + ) +))] #[bench] fn aes192_new(bh: &mut test::Bencher) { bh.iter(|| { diff --git a/aes/src/lib.rs b/aes/src/lib.rs index 0f8bab50..bf2c87f3 100644 --- a/aes/src/lib.rs +++ b/aes/src/lib.rs @@ -35,6 +35,27 @@ //! runtime. On other platforms the `aes` target feature must be enabled via //! RUSTFLAGS. //! +//! +//! ## RISC-V rv64 (scalar) {Zkne, ZKnd} extensions +//! +//! Support is available for the RISC-V rv64 scalar crypto extensions for AES. This +//! is not currently autodetected at runtime. In order to enable, you need to +//! enable the appropriate target features at compile time. For example: +//! `RUSTFLAGS=-C target-feature=+zkne,+zknd`. +//! +//! ## RISC-V rvv (vector) {Zvkned} extensions +//! +//! Support is available for the RISC-V vector crypto extensions for AES. This is +//! not currently autodetected at runtime. In order to enable, you need to enable +//! the appropriate target features at compile time. For example: +//! `RUSTFLAGS=-C target-feature=+v --cfg target_feature_zvkned`. +//! +//! NOTE: Hardware accelerated vector key-schedule routines for AES-192 are not +//! available for the RISC-V vector crypto extensions. It is still possible to +//! fall back to using the scalar key-schedule routines for AES-192 in this case +//! if the appropriate target features are enabled. For example: +//! `RUSTFLAGS=-C target-feature=+zkne,+zknd,+v --cfg target_feature_zvkned`. +//! //! ## `x86`/`x86_64` intrinsics (AES-NI) //! By default this crate uses runtime detection on `i686`/`x86_64` targets //! in order to determine if AES-NI is available, and if it is not, it will @@ -118,6 +139,14 @@ )] #![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs, rust_2018_idioms)] +#![cfg_attr( + all( + any(target_arch = "riscv32", target_arch = "riscv64"), + target_feature = "zknd", + target_feature = "zkne" + ), + feature(riscv_ext_intrinsics, stdsimd) +)] #[cfg(feature = "hazmat")] #[cfg_attr(docsrs, doc(cfg(feature = "hazmat")))] @@ -132,6 +161,13 @@ cfg_if! { mod armv8; mod autodetect; pub use autodetect::*; + // TODO(silvanshade): switch to target_feature for `zvkned` when available + } else if #[cfg(all(any(target_arch = "riscv32", target_arch = "riscv64"), target_feature = "v", target_feature_zvkned))] { + mod riscv; + pub use riscv::rvv::*; + } else if #[cfg(all(target_arch = "riscv64", target_feature = "zknd", target_feature = "zkne"))] { + mod riscv; + pub use riscv::rv64::*; } else if #[cfg(all( any(target_arch = "x86", target_arch = "x86_64"), not(aes_force_soft) diff --git a/aes/src/riscv.rs b/aes/src/riscv.rs new file mode 100644 index 00000000..42c00fac --- /dev/null +++ b/aes/src/riscv.rs @@ -0,0 +1,145 @@ +//! AES block cipher implementations for RISC-V using the Cryptography Extensions +//! +//! Supported targets: rv64 (scalar), rvv +//! +//! NOTE: rv32 (scalar) is not currently implemented, primarily due to the difficulty in obtaining a +//! suitable development environment (lack of distro support and lack of precompiled toolchains), +//! the effort required for maintaining a test environment as 32-bit becomes less supported, and the +//! overall scarcity of relevant hardware. If someone has a specific need for such an +//! implementation, please open an issue. Theoretically, the rvv implementation should work for +//! riscv32, for a hypothetical rv32 implementation satisfying the vector feature requirements. +//! +//! NOTE: These implementations are currently not enabled through auto-detection. In order to use +//! this implementation, you must enable the appropriate target-features. +//! +//! Additionally, for the vector implementation, since the `zvkned` target-feature is not yet +//! defined in Rust, you must pass `--cfg target_feature_zvkned` to the compiler (through +//! `RUSTFLAGS` or some other means). However, you still must enable the `v` target-feature. +//! +//! Examining the module structure for this implementation should give you an idea of how to specify +//! these features in your own code. +//! +//! NOTE: AES-128, AES-192, and AES-256 are supported for both the scalar and vector +//! implementations. +//! +//! However, key expansion is not vector-accelerated for the AES-192 case (because RISC-V does not +//! provide vector instructions for this case). Users concerned with vector performance are advised +//! to select AES-129 or AES-256 instead. Nevertheless, the AES-192 vector implementation will still +//! fall back to the scalar AES-192 key-schedule implementation, if the appropriate scalar +//! target-features are enabled. + +#[cfg(all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" +))] +pub(crate) mod rv64; +#[cfg(all( + any(target_arch = "riscv32", target_arch = "riscv64"), + target_feature = "v", + target_feature_zvkned +))] +pub(crate) mod rvv; + +#[cfg(test)] +mod test { + use hex_literal::hex; + + pub(crate) const AES128_KEY: [u8; 16] = hex!("2b7e151628aed2a6abf7158809cf4f3c"); + pub(crate) const AES128_EXP_KEYS: [[u8; 16]; 11] = [ + AES128_KEY, + hex!("a0fafe1788542cb123a339392a6c7605"), + hex!("f2c295f27a96b9435935807a7359f67f"), + hex!("3d80477d4716fe3e1e237e446d7a883b"), + hex!("ef44a541a8525b7fb671253bdb0bad00"), + hex!("d4d1c6f87c839d87caf2b8bc11f915bc"), + hex!("6d88a37a110b3efddbf98641ca0093fd"), + hex!("4e54f70e5f5fc9f384a64fb24ea6dc4f"), + hex!("ead27321b58dbad2312bf5607f8d292f"), + hex!("ac7766f319fadc2128d12941575c006e"), + hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"), + ]; + pub(crate) const AES128_EXP_INVKEYS: [[u8; 16]; 11] = [ + AES128_KEY, + hex!("2b3708a7f262d405bc3ebdbf4b617d62"), + hex!("cc7505eb3e17d1ee82296c51c9481133"), + hex!("7c1f13f74208c219c021ae480969bf7b"), + hex!("90884413d280860a12a128421bc89739"), + hex!("6ea30afcbc238cf6ae82a4b4b54a338d"), + hex!("6efcd876d2df54807c5df034c917c3b9"), + hex!("12c07647c01f22c7bc42d2f37555114a"), + hex!("df7d925a1f62b09da320626ed6757324"), + hex!("0c7b5a631319eafeb0398890664cfbb4"), + hex!("d014f9a8c9ee2589e13f0cc8b6630ca6"), + ]; + + pub(crate) const AES192_KEY: [u8; 24] = + hex!("8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b"); + pub(crate) const AES192_EXP_KEYS: [[u8; 16]; 13] = [ + hex!("8e73b0f7da0e6452c810f32b809079e5"), + hex!("62f8ead2522c6b7bfe0c91f72402f5a5"), + hex!("ec12068e6c827f6b0e7a95b95c56fec2"), + hex!("4db7b4bd69b5411885a74796e92538fd"), + hex!("e75fad44bb095386485af05721efb14f"), + hex!("a448f6d94d6dce24aa326360113b30e6"), + hex!("a25e7ed583b1cf9a27f939436a94f767"), + hex!("c0a69407d19da4e1ec1786eb6fa64971"), + hex!("485f703222cb8755e26d135233f0b7b3"), + hex!("40beeb282f18a2596747d26b458c553e"), + hex!("a7e1466c9411f1df821f750aad07d753"), + hex!("ca4005388fcc5006282d166abc3ce7b5"), + hex!("e98ba06f448c773c8ecc720401002202"), + ]; + pub(crate) const AES192_EXP_INVKEYS: [[u8; 16]; 13] = [ + hex!("8e73b0f7da0e6452c810f32b809079e5"), + hex!("9eb149c479d69c5dfeb4a27ceab6d7fd"), + hex!("659763e78c817087123039436be6a51e"), + hex!("41b34544ab0592b9ce92f15e421381d9"), + hex!("5023b89a3bc51d84d04b19377b4e8b8e"), + hex!("b5dc7ad0f7cffb09a7ec43939c295e17"), + hex!("c5ddb7f8be933c760b4f46a6fc80bdaf"), + hex!("5b6cfe3cc745a02bf8b9a572462a9904"), + hex!("4d65dfa2b1e5620dea899c312dcc3c1a"), + hex!("f3b42258b59ebb5cf8fb64fe491e06f3"), + hex!("a3979ac28e5ba6d8e12cc9e654b272ba"), + hex!("ac491644e55710b746c08a75c89b2cad"), + hex!("e98ba06f448c773c8ecc720401002202"), + ]; + + pub(crate) const AES256_KEY: [u8; 32] = + hex!("603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4"); + pub(crate) const AES256_EXP_KEYS: [[u8; 16]; 15] = [ + hex!("603deb1015ca71be2b73aef0857d7781"), + hex!("1f352c073b6108d72d9810a30914dff4"), + hex!("9ba354118e6925afa51a8b5f2067fcde"), + hex!("a8b09c1a93d194cdbe49846eb75d5b9a"), + hex!("d59aecb85bf3c917fee94248de8ebe96"), + hex!("b5a9328a2678a647983122292f6c79b3"), + hex!("812c81addadf48ba24360af2fab8b464"), + hex!("98c5bfc9bebd198e268c3ba709e04214"), + hex!("68007bacb2df331696e939e46c518d80"), + hex!("c814e20476a9fb8a5025c02d59c58239"), + hex!("de1369676ccc5a71fa2563959674ee15"), + hex!("5886ca5d2e2f31d77e0af1fa27cf73c3"), + hex!("749c47ab18501ddae2757e4f7401905a"), + hex!("cafaaae3e4d59b349adf6acebd10190d"), + hex!("fe4890d1e6188d0b046df344706c631e"), + ]; + pub(crate) const AES256_EXP_INVKEYS: [[u8; 16]; 15] = [ + hex!("603deb1015ca71be2b73aef0857d7781"), + hex!("8ec6bff6829ca03b9e49af7edba96125"), + hex!("42107758e9ec98f066329ea193f8858b"), + hex!("4a7459f9c8e8f9c256a156bc8d083799"), + hex!("6c3d632985d1fbd9e3e36578701be0f3"), + hex!("54fb808b9c137949cab22ff547ba186c"), + hex!("25ba3c22a06bc7fb4388a28333934270"), + hex!("d669a7334a7ade7a80c8f18fc772e9e3"), + hex!("c440b289642b757227a3d7f114309581"), + hex!("32526c367828b24cf8e043c33f92aa20"), + hex!("34ad1e4450866b367725bcc763152946"), + hex!("b668b621ce40046d36a047ae0932ed8e"), + hex!("57c96cf6074f07c0706abb07137f9241"), + hex!("ada23f4963e23b2455427c8a5c709104"), + hex!("fe4890d1e6188d0b046df344706c631e"), + ]; +} diff --git a/aes/src/riscv/rv64.rs b/aes/src/riscv/rv64.rs new file mode 100644 index 00000000..be007ad7 --- /dev/null +++ b/aes/src/riscv/rv64.rs @@ -0,0 +1,333 @@ +//! AES block cipher implementation for RISC-V 64 using Scalar Cryptography Extensions: Zkne, Zknd +//! +//! RISC-V Scalar Cryptography Extension v1.0.1: +//! https://github.com/riscv/riscv-crypto/releases/download/v1.0.1-scalar/riscv-crypto-spec-scalar-v1.0.1.pdf +//! +//! For reference, see the following other implementations: +//! +//! 1. The RISC-V Cryptography Extensions "benchmarks" reference for RISC-V 64 with Zkn{ed}: +//! https://github.com/riscv/riscv-crypto/tree/main/benchmarks/aes/zscrypto_rv64 +//! +//! 2. The OpenSSL implementation for RISC-V 64 with Zkn{ed}: +//! https://github.com/openssl/openssl/blob/master/crypto/aes/asm/aes-riscv64-zkn.pl + +mod encdec; +pub(crate) mod expand; +#[cfg(test)] +pub(crate) mod test_expand; + +use self::{ + encdec::{decrypt1, decrypt8, encrypt1, encrypt8}, + expand::{inv_expanded_keys, KeySchedule}, +}; +use crate::{Block, Block8}; +use cipher::{ + consts::{U16, U24, U32, U8}, + inout::InOut, + AlgorithmName, BlockBackend, BlockCipher, BlockCipherDecrypt, BlockCipherEncrypt, BlockClosure, + BlockSizeUser, Key, KeyInit, KeySizeUser, ParBlocksSizeUser, +}; +use core::fmt; + +type RoundKey = [u64; 2]; +type RoundKeys = [RoundKey; N]; + +macro_rules! define_aes_impl { + ( + $name:ident, + $name_enc:ident, + $name_dec:ident, + $name_back_enc:ident, + $name_back_dec:ident, + $key_size:ty, + $words:tt, + $rounds:tt, + $doc:expr $(,)? + ) => { + #[doc=$doc] + #[doc = "block cipher"] + #[derive(Clone)] + pub struct $name { + encrypt: $name_enc, + decrypt: $name_dec, + } + + impl BlockCipher for $name {} + + impl KeySizeUser for $name { + type KeySize = $key_size; + } + + impl KeyInit for $name { + #[inline] + fn new(key: &Key) -> Self { + let encrypt = $name_enc::new(key); + let decrypt = $name_dec::from(&encrypt); + Self { encrypt, decrypt } + } + } + + impl From<$name_enc> for $name { + #[inline] + fn from(encrypt: $name_enc) -> $name { + let decrypt = (&encrypt).into(); + Self { encrypt, decrypt } + } + } + + impl From<&$name_enc> for $name { + #[inline] + fn from(encrypt: &$name_enc) -> $name { + let decrypt = encrypt.into(); + let encrypt = encrypt.clone(); + Self { encrypt, decrypt } + } + } + + impl BlockSizeUser for $name { + type BlockSize = U16; + } + + impl BlockCipherEncrypt for $name { + fn encrypt_with_backend(&self, f: impl BlockClosure) { + self.encrypt.encrypt_with_backend(f) + } + } + + impl BlockCipherDecrypt for $name { + fn decrypt_with_backend(&self, f: impl BlockClosure) { + self.decrypt.decrypt_with_backend(f) + } + } + + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name), " { .. }")) + } + } + + impl AlgorithmName for $name { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name)) + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name {} + + #[doc=$doc] + #[doc = "block cipher (encrypt-only)"] + #[derive(Clone)] + pub struct $name_enc { + round_keys: RoundKeys<$rounds>, + } + + impl $name_enc { + #[inline(always)] + pub(crate) fn get_enc_backend(&self) -> $name_back_enc<'_> { + $name_back_enc(self) + } + } + + impl BlockCipher for $name_enc {} + + impl KeySizeUser for $name_enc { + type KeySize = $key_size; + } + + impl KeyInit for $name_enc { + #[inline] + fn new(key: &Key) -> Self { + Self { + round_keys: KeySchedule::<$words, $rounds>::expand_key(key.as_ref()), + } + } + } + + impl BlockSizeUser for $name_enc { + type BlockSize = U16; + } + + impl BlockCipherEncrypt for $name_enc { + fn encrypt_with_backend(&self, f: impl BlockClosure) { + f.call(&mut self.get_enc_backend()) + } + } + + impl fmt::Debug for $name_enc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name_enc), " { .. }")) + } + } + + impl AlgorithmName for $name_enc { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name_enc)) + } + } + + impl Drop for $name_enc { + #[inline] + fn drop(&mut self) { + #[cfg(feature = "zeroize")] + zeroize::Zeroize::zeroize(&mut self.round_keys); + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name_enc {} + + #[doc=$doc] + #[doc = "block cipher (decrypt-only)"] + #[derive(Clone)] + pub struct $name_dec { + round_keys: RoundKeys<$rounds>, + } + + impl $name_dec { + #[inline(always)] + pub(crate) fn get_dec_backend(&self) -> $name_back_dec<'_> { + $name_back_dec(self) + } + } + + impl BlockCipher for $name_dec {} + + impl KeySizeUser for $name_dec { + type KeySize = $key_size; + } + + impl KeyInit for $name_dec { + #[inline] + fn new(key: &Key) -> Self { + $name_enc::new(key).into() + } + } + + impl From<$name_enc> for $name_dec { + #[inline] + fn from(enc: $name_enc) -> $name_dec { + Self::from(&enc) + } + } + + impl From<&$name_enc> for $name_dec { + fn from(enc: &$name_enc) -> $name_dec { + let mut round_keys = enc.round_keys; + inv_expanded_keys(&mut round_keys); + Self { round_keys } + } + } + + impl BlockSizeUser for $name_dec { + type BlockSize = U16; + } + + impl BlockCipherDecrypt for $name_dec { + fn decrypt_with_backend(&self, f: impl BlockClosure) { + f.call(&mut self.get_dec_backend()); + } + } + + impl fmt::Debug for $name_dec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name_dec), " { .. }")) + } + } + + impl AlgorithmName for $name_dec { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name_dec)) + } + } + + impl Drop for $name_dec { + #[inline] + fn drop(&mut self) { + #[cfg(feature = "zeroize")] + zeroize::Zeroize::zeroize(&mut self.round_keys); + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name_dec {} + + pub(crate) struct $name_back_enc<'a>(&'a $name_enc); + + impl<'a> BlockSizeUser for $name_back_enc<'a> { + type BlockSize = U16; + } + + impl<'a> ParBlocksSizeUser for $name_back_enc<'a> { + type ParBlocksSize = U8; + } + + impl<'a> BlockBackend for $name_back_enc<'a> { + #[inline(always)] + fn proc_block(&mut self, block: InOut<'_, '_, Block>) { + encrypt1(&self.0.round_keys, block); + } + + #[inline(always)] + fn proc_par_blocks(&mut self, blocks: InOut<'_, '_, Block8>) { + encrypt8(&self.0.round_keys, blocks) + } + } + + pub(crate) struct $name_back_dec<'a>(&'a $name_dec); + + impl<'a> BlockSizeUser for $name_back_dec<'a> { + type BlockSize = U16; + } + + impl<'a> ParBlocksSizeUser for $name_back_dec<'a> { + type ParBlocksSize = U8; + } + + impl<'a> BlockBackend for $name_back_dec<'a> { + #[inline(always)] + fn proc_block(&mut self, block: InOut<'_, '_, Block>) { + decrypt1(&self.0.round_keys, block); + } + + #[inline(always)] + fn proc_par_blocks(&mut self, blocks: InOut<'_, '_, Block8>) { + decrypt8(&self.0.round_keys, blocks) + } + } + }; +} + +define_aes_impl!( + Aes128, + Aes128Enc, + Aes128Dec, + Aes128BackEnc, + Aes128BackDec, + U16, + 2, + 11, + "AES-128", +); +define_aes_impl!( + Aes192, + Aes192Enc, + Aes192Dec, + Aes192BackEnc, + Aes192BackDec, + U24, + 3, + 13, + "AES-192", +); +define_aes_impl!( + Aes256, + Aes256Enc, + Aes256Dec, + Aes256BackEnc, + Aes256BackDec, + U32, + 4, + 15, + "AES-256", +); diff --git a/aes/src/riscv/rv64/encdec.rs b/aes/src/riscv/rv64/encdec.rs new file mode 100644 index 00000000..a68cf7f8 --- /dev/null +++ b/aes/src/riscv/rv64/encdec.rs @@ -0,0 +1,208 @@ +//! AES encryption support + +use crate::{ + riscv::rv64::{RoundKey, RoundKeys}, + Block, Block8, +}; +use cipher::inout::InOut; + +#[inline(always)] +pub(super) fn encrypt1(keys: &RoundKeys, mut block1: InOut<'_, '_, Block>) { + let rounds = N - 1; + let mut state1 = utils::CipherState1::load1(block1.get_in()); + for i in 0..rounds / 2 - 1 { + state1.enc1_two_more(keys[2 * i + 0], keys[2 * i + 1]); + } + state1.enc1_two_last(keys[rounds - 2], keys[rounds - 1]); + state1.xor1(&keys[rounds]); + state1.save1(block1.get_out()); +} + +#[inline(always)] +pub(super) fn encrypt8(keys: &RoundKeys, mut block8: InOut<'_, '_, Block8>) { + let rounds = N - 1; + let mut state8 = utils::CipherState8::load8(block8.get_in()); + for i in 0..rounds / 2 - 1 { + state8.enc8_two_more(keys[2 * i + 0], keys[2 * i + 1]); + } + state8.enc8_two_last(keys[rounds - 2], keys[rounds - 1]); + state8.xor8(&keys[rounds]); + state8.save8(block8.get_out()); +} + +#[inline(always)] +pub(super) fn decrypt1(keys: &RoundKeys, mut block1: InOut<'_, '_, Block>) { + let rounds = N - 1; + let mut state1 = utils::CipherState1::load1(block1.get_in()); + state1.xor1(&keys[rounds]); + for i in (1..rounds / 2).rev() { + state1.dec1_two_more(keys[2 * i + 0], keys[2 * i + 1]); + } + state1.dec1_two_last(keys[0], keys[1]); + state1.save1(block1.get_out()); +} + +#[inline(always)] +pub(super) fn decrypt8(keys: &RoundKeys, mut block8: InOut<'_, '_, Block8>) { + let rounds = N - 1; + let mut state8 = utils::CipherState8::load8(block8.get_in()); + state8.xor8(&keys[rounds]); + for i in (1..rounds / 2).rev() { + state8.dec8_two_more(keys[2 * i + 0], keys[2 * i + 1]); + } + state8.dec8_two_last(keys[0], keys[1]); + state8.save8(block8.get_out()); +} + +mod utils { + use super::*; + use core::arch::riscv64::*; + + pub(super) struct CipherState1 { + data: [u64; 2], + } + + impl CipherState1 { + #[inline(always)] + pub(super) fn load1(block: &Block) -> Self { + let ptr = block.as_ptr().cast::(); + let s0 = unsafe { ptr.add(0).read_unaligned() }; + let s1 = unsafe { ptr.add(1).read_unaligned() }; + Self { data: [s0, s1] } + } + + #[inline(always)] + pub(super) fn save1(self, block: &mut Block) { + let b0 = self.data[0].to_ne_bytes(); + let b1 = self.data[1].to_ne_bytes(); + block[00..08].copy_from_slice(&b0); + block[08..16].copy_from_slice(&b1); + } + + #[inline(always)] + pub(super) fn xor1(&mut self, key: &RoundKey) { + self.data[0] ^= key[0]; + self.data[1] ^= key[1]; + } + + #[inline(always)] + pub(super) fn enc1_two_more(&mut self, k0: RoundKey, k1: RoundKey) { + let mut n0; + let mut n1; + self.data[0] ^= k0[0]; + self.data[1] ^= k0[1]; + n0 = unsafe { aes64esm(self.data[0], self.data[1]) }; + n1 = unsafe { aes64esm(self.data[1], self.data[0]) }; + n0 ^= k1[0]; + n1 ^= k1[1]; + self.data[0] = unsafe { aes64esm(n0, n1) }; + self.data[1] = unsafe { aes64esm(n1, n0) }; + } + + #[inline(always)] + pub(super) fn enc1_two_last(&mut self, k0: RoundKey, k1: RoundKey) { + let mut n0; + let mut n1; + self.data[0] ^= k0[0]; + self.data[1] ^= k0[1]; + n0 = unsafe { aes64esm(self.data[0], self.data[1]) }; + n1 = unsafe { aes64esm(self.data[1], self.data[0]) }; + n0 ^= k1[0]; + n1 ^= k1[1]; + self.data[0] = unsafe { aes64es(n0, n1) }; + self.data[1] = unsafe { aes64es(n1, n0) }; + } + + #[inline(always)] + pub(super) fn dec1_two_more(&mut self, k0: RoundKey, k1: RoundKey) { + let mut n0; + let mut n1; + n0 = unsafe { aes64dsm(self.data[0], self.data[1]) }; + n1 = unsafe { aes64dsm(self.data[1], self.data[0]) }; + self.data[0] = n0 ^ k1[0]; + self.data[1] = n1 ^ k1[1]; + n0 = unsafe { aes64dsm(self.data[0], self.data[1]) }; + n1 = unsafe { aes64dsm(self.data[1], self.data[0]) }; + self.data[0] = n0 ^ k0[0]; + self.data[1] = n1 ^ k0[1]; + } + + #[inline(always)] + pub(super) fn dec1_two_last(&mut self, k0: RoundKey, k1: RoundKey) { + let mut n0; + let mut n1; + n0 = unsafe { aes64dsm(self.data[0], self.data[1]) }; + n1 = unsafe { aes64dsm(self.data[1], self.data[0]) }; + self.data[0] = n0 ^ k1[0]; + self.data[1] = n1 ^ k1[1]; + n0 = unsafe { aes64ds(self.data[0], self.data[1]) }; + n1 = unsafe { aes64ds(self.data[1], self.data[0]) }; + self.data[0] = n0 ^ k0[0]; + self.data[1] = n1 ^ k0[1]; + } + } + + pub(super) struct CipherState8 { + data: [CipherState1; 8], + } + + impl CipherState8 { + #[inline(always)] + pub(super) fn load8(blocks: &Block8) -> Self { + Self { + data: [ + CipherState1::load1(&blocks[0]), + CipherState1::load1(&blocks[1]), + CipherState1::load1(&blocks[2]), + CipherState1::load1(&blocks[3]), + CipherState1::load1(&blocks[4]), + CipherState1::load1(&blocks[5]), + CipherState1::load1(&blocks[6]), + CipherState1::load1(&blocks[7]), + ], + } + } + + #[inline(always)] + pub(super) fn save8(self, blocks: &mut Block8) { + for (i, state) in self.data.into_iter().enumerate() { + state.save1(&mut blocks[i]); + } + } + + #[inline(always)] + pub(super) fn xor8(&mut self, key: &RoundKey) { + for state in &mut self.data { + state.xor1(key); + } + } + + #[inline(always)] + pub(super) fn enc8_two_more(&mut self, k0: RoundKey, k1: RoundKey) { + for state in &mut self.data { + state.enc1_two_more(k0, k1); + } + } + + #[inline(always)] + pub(super) fn enc8_two_last(&mut self, k0: RoundKey, k1: RoundKey) { + for state in &mut self.data { + state.enc1_two_last(k0, k1); + } + } + + #[inline(always)] + pub(super) fn dec8_two_more(&mut self, k0: RoundKey, k1: RoundKey) { + for state in &mut self.data { + state.dec1_two_more(k0, k1); + } + } + + #[inline(always)] + pub(super) fn dec8_two_last(&mut self, k0: RoundKey, k1: RoundKey) { + for state in &mut self.data { + state.dec1_two_last(k0, k1); + } + } + } +} diff --git a/aes/src/riscv/rv64/expand.rs b/aes/src/riscv/rv64/expand.rs new file mode 100644 index 00000000..b5d9fd39 --- /dev/null +++ b/aes/src/riscv/rv64/expand.rs @@ -0,0 +1,216 @@ +use crate::riscv::rv64::{RoundKey, RoundKeys}; +use core::{ + arch::riscv64::*, + mem::{transmute, MaybeUninit}, + ptr::addr_of_mut, +}; + +// TODO(silvanshade): `COLUMNS` should be an associated constant once support for that is stable. +pub(crate) struct KeySchedule { + cols: [u64; COLUMNS], + keys: [MaybeUninit; ROUNDS], +} + +// COLUMNS: 4 x 32-bit words = 2 x 64-bit words +impl KeySchedule<{ 4 / 2 }, { 1 + 10 }> { + #[inline(always)] + fn load(ckey: &[u8; 16]) -> Self { + let ckey = ckey.as_ptr().cast::(); + let mut cols: [MaybeUninit; 2] = unsafe { MaybeUninit::uninit().assume_init() }; + unsafe { cols[0].write(ckey.add(0).read_unaligned()) }; + unsafe { cols[1].write(ckey.add(1).read_unaligned()) }; + let mut schedule = Self { + // SAFETY: `data` is fully initialized. + cols: unsafe { transmute(cols) }, + keys: unsafe { MaybeUninit::uninit().assume_init() }, + }; + schedule.save_one_keys(0); + schedule + } + + #[inline(always)] + fn save_one_keys(&mut self, i: u8) { + let i = usize::from(i); + let keys = self.keys[i].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[0]).write(self.cols[0]) }; + unsafe { addr_of_mut!((*keys)[1]).write(self.cols[1]) }; + } + + #[inline(always)] + fn one_key_rounds(&mut self) { + let s = unsafe { aes64ks1i(self.cols[1], RNUM) }; + self.cols[0] = unsafe { aes64ks2(s, self.cols[0]) }; + self.cols[1] = unsafe { aes64ks2(self.cols[0], self.cols[1]) }; + self.save_one_keys(RNUM + 1) + } + + #[inline(always)] + pub(crate) fn expand_key(ckey: &[u8; 16]) -> RoundKeys<11> { + let mut schedule = Self::load(ckey); + schedule.one_key_rounds::<0>(); + schedule.one_key_rounds::<1>(); + schedule.one_key_rounds::<2>(); + schedule.one_key_rounds::<3>(); + schedule.one_key_rounds::<4>(); + schedule.one_key_rounds::<5>(); + schedule.one_key_rounds::<6>(); + schedule.one_key_rounds::<7>(); + schedule.one_key_rounds::<8>(); + schedule.one_key_rounds::<9>(); + // SAFETY: `state.expanded_keys` is fully initialized. + unsafe { transmute(schedule.keys) } + } +} + +// COLUMNS: 6 x 32-bit words = 3 x 64-bit words +impl KeySchedule<{ 6 / 2 }, { 1 + 12 }> { + #[inline(always)] + fn load(ckey: &[u8; 24]) -> Self { + let ckey = ckey.as_ptr().cast::(); + let mut cols: [MaybeUninit; 3] = unsafe { MaybeUninit::uninit().assume_init() }; + unsafe { cols[0].write(ckey.add(0).read_unaligned()) }; + unsafe { cols[1].write(ckey.add(1).read_unaligned()) }; + unsafe { cols[2].write(ckey.add(2).read_unaligned()) }; + let mut schedule = Self { + // SAFETY: `data` is fully initialized. + cols: unsafe { transmute(cols) }, + keys: unsafe { MaybeUninit::uninit().assume_init() }, + }; + schedule.save_one_and_one_half_keys(0); + schedule + } + + #[inline(always)] + fn save_one_keys(&mut self, i: u8) { + let n = usize::from(i) * 3 / 2; + let k = usize::from(i) % 2; + let keys = self.keys[n + 0].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[0 + k]).write(self.cols[0]) }; + let keys = self.keys[n + k].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[1 - k]).write(self.cols[1]) }; + } + + #[inline(always)] + fn save_one_and_one_half_keys(&mut self, i: u8) { + let n = usize::from(i) * 3 / 2; + let k = usize::from(i) % 2; + let keys = self.keys[n + 0].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[0 + k]).write(self.cols[0]) }; + let keys = self.keys[n + k].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[1 - k]).write(self.cols[1]) }; + let keys = self.keys[n + 1].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[0 + k]).write(self.cols[2]) }; + } + + #[inline(always)] + fn one_key_rounds(&mut self) { + let s = unsafe { aes64ks1i(self.cols[2], RNUM) }; + self.cols[0] = unsafe { aes64ks2(s, self.cols[0]) }; + self.cols[1] = unsafe { aes64ks2(self.cols[0], self.cols[1]) }; + self.save_one_keys(RNUM + 1) + } + + #[inline(always)] + fn one_and_one_half_key_rounds(&mut self) { + let s = unsafe { aes64ks1i(self.cols[2], RNUM) }; + self.cols[0] = unsafe { aes64ks2(s, self.cols[0]) }; + self.cols[1] = unsafe { aes64ks2(self.cols[0], self.cols[1]) }; + self.cols[2] = unsafe { aes64ks2(self.cols[1], self.cols[2]) }; + self.save_one_and_one_half_keys(RNUM + 1) + } + + #[inline(always)] + pub(crate) fn expand_key(ckey: &[u8; 24]) -> RoundKeys<13> { + let mut schedule = Self::load(ckey); + schedule.one_and_one_half_key_rounds::<0>(); + schedule.one_and_one_half_key_rounds::<1>(); + schedule.one_and_one_half_key_rounds::<2>(); + schedule.one_and_one_half_key_rounds::<3>(); + schedule.one_and_one_half_key_rounds::<4>(); + schedule.one_and_one_half_key_rounds::<5>(); + schedule.one_and_one_half_key_rounds::<6>(); + schedule.one_key_rounds::<7>(); + // SAFETY: `state.expanded_keys` is fully initialized. + unsafe { transmute(schedule.keys) } + } +} + +// COLUMNS: 8 x 32-bit words = 4 x 64-bit words +impl KeySchedule<{ 8 / 2 }, { 1 + 14 }> { + #[inline(always)] + fn load(ckey: &[u8; 32]) -> Self { + let ckey = ckey.as_ptr().cast::(); + let mut cols: [MaybeUninit; 4] = unsafe { MaybeUninit::uninit().assume_init() }; + unsafe { cols[0].write(ckey.add(0).read_unaligned()) }; + unsafe { cols[1].write(ckey.add(1).read_unaligned()) }; + unsafe { cols[2].write(ckey.add(2).read_unaligned()) }; + unsafe { cols[3].write(ckey.add(3).read_unaligned()) }; + let mut schedule = Self { + // SAFETY: `data` is fully initialized. + cols: unsafe { transmute(cols) }, + keys: unsafe { MaybeUninit::uninit().assume_init() }, + }; + schedule.save_two_keys(0); + schedule + } + + #[inline(always)] + fn save_one_keys(&mut self, i: u8) { + let i = usize::from(i); + let keys = self.keys[2 * i + 0].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[0]).write(self.cols[0]) }; + unsafe { addr_of_mut!((*keys)[1]).write(self.cols[1]) }; + } + + #[inline(always)] + fn save_two_keys(&mut self, i: u8) { + let i = usize::from(i); + let keys = self.keys[2 * i + 0].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[0]).write(self.cols[0]) }; + unsafe { addr_of_mut!((*keys)[1]).write(self.cols[1]) }; + let keys = self.keys[2 * i + 1].as_mut_ptr(); + unsafe { addr_of_mut!((*keys)[0]).write(self.cols[2]) }; + unsafe { addr_of_mut!((*keys)[1]).write(self.cols[3]) }; + } + + #[inline(always)] + fn two_key_rounds(&mut self) { + let s = unsafe { aes64ks1i(self.cols[3], RNUM) }; + self.cols[0] = unsafe { aes64ks2(s, self.cols[0]) }; + self.cols[1] = unsafe { aes64ks2(self.cols[0], self.cols[1]) }; + let s = unsafe { aes64ks1i(self.cols[1], 0xA) }; + self.cols[2] = unsafe { aes64ks2(s, self.cols[2]) }; + self.cols[3] = unsafe { aes64ks2(self.cols[2], self.cols[3]) }; + self.save_two_keys(RNUM + 1); + } + + #[inline(always)] + fn one_key_rounds(&mut self) { + let s = unsafe { aes64ks1i(self.cols[3], RNUM) }; + self.cols[0] = unsafe { aes64ks2(s, self.cols[0]) }; + self.cols[1] = unsafe { aes64ks2(self.cols[0], self.cols[1]) }; + self.save_one_keys(RNUM + 1); + } + + #[inline(always)] + pub(crate) fn expand_key(user_key: &[u8; 32]) -> RoundKeys<15> { + let mut schedule = Self::load(user_key); + schedule.two_key_rounds::<0>(); + schedule.two_key_rounds::<1>(); + schedule.two_key_rounds::<2>(); + schedule.two_key_rounds::<3>(); + schedule.two_key_rounds::<4>(); + schedule.two_key_rounds::<5>(); + schedule.one_key_rounds::<6>(); + // SAFETY: `state.expanded_keys` is fully initialized. + unsafe { transmute(schedule.keys) } + } +} + +#[inline(always)] +pub(super) fn inv_expanded_keys(keys: &mut RoundKeys) { + for i in 1..N - 1 { + keys[i][0] = unsafe { aes64im(keys[i][0]) }; + keys[i][1] = unsafe { aes64im(keys[i][1]) }; + } +} diff --git a/aes/src/riscv/rv64/test_expand.rs b/aes/src/riscv/rv64/test_expand.rs new file mode 100644 index 00000000..9afd4a67 --- /dev/null +++ b/aes/src/riscv/rv64/test_expand.rs @@ -0,0 +1,65 @@ +use crate::riscv::rv64::{ + expand::{inv_expanded_keys, KeySchedule}, + RoundKey, RoundKeys, +}; +use crate::riscv::test::*; + +fn load_expanded_keys(input: [[u8; 16]; N]) -> RoundKeys { + let mut output = [RoundKey::from(<[u64; 2]>::default()); N]; + for (src, dst) in input.iter().zip(output.iter_mut()) { + let ptr = src.as_ptr().cast::(); + dst[0] = unsafe { ptr.add(0).read_unaligned() }; + dst[1] = unsafe { ptr.add(1).read_unaligned() }; + } + output +} + +pub(crate) fn store_expanded_keys(input: RoundKeys) -> [[u8; 16]; N] { + let mut output = [[0u8; 16]; N]; + for (src, dst) in input.iter().zip(output.iter_mut()) { + let b0 = src[0].to_ne_bytes(); + let b1 = src[1].to_ne_bytes(); + dst[00..08].copy_from_slice(&b0); + dst[08..16].copy_from_slice(&b1); + } + output +} + +#[test] +fn aes128_key_expansion() { + let ek = KeySchedule::<2, 11>::expand_key(&AES128_KEY); + assert_eq!(store_expanded_keys(ek), AES128_EXP_KEYS); +} + +#[test] +fn aes128_key_expansion_inv() { + let mut ek = load_expanded_keys(AES128_EXP_KEYS); + inv_expanded_keys(&mut ek); + assert_eq!(store_expanded_keys(ek), AES128_EXP_INVKEYS); +} + +#[test] +fn aes192_key_expansion() { + let ek = KeySchedule::<3, 13>::expand_key(&AES192_KEY); + assert_eq!(store_expanded_keys(ek), AES192_EXP_KEYS); +} + +#[test] +fn aes192_key_expansion_inv() { + let mut ek = load_expanded_keys(AES192_EXP_KEYS); + inv_expanded_keys(&mut ek); + assert_eq!(store_expanded_keys(ek), AES192_EXP_INVKEYS); +} + +#[test] +fn aes256_key_expansion() { + let ek = KeySchedule::<4, 15>::expand_key(&AES256_KEY); + assert_eq!(store_expanded_keys(ek), AES256_EXP_KEYS); +} + +#[test] +fn aes256_key_expansion_inv() { + let mut ek = load_expanded_keys(AES256_EXP_KEYS); + inv_expanded_keys(&mut ek); + assert_eq!(store_expanded_keys(ek), AES256_EXP_INVKEYS); +} diff --git a/aes/src/riscv/rvv.rs b/aes/src/riscv/rvv.rs new file mode 100644 index 00000000..09b4a9f2 --- /dev/null +++ b/aes/src/riscv/rvv.rs @@ -0,0 +1,344 @@ +//! AES block cipher implementation using the RISC-V Vector Cryptography Extensions: Zvkned +//! +//! RISC-V Vector Cryptography Extension v1.0.0: +//! https://github.com/riscv/riscv-crypto/releases/download/v1.0.0/riscv-crypto-spec-vector.pdf +//! +//! For reference, see the following other implementations: +//! +//! 1. The RISC-V Cryptography Extensions vector code samples AES-CBC proof of concept with Zvkned: +//! https://github.com/riscv/riscv-crypto/blob/main/doc/vector/code-samples/zvkned.s +//! +//! 2. The OpenSSL implementation for RISC-V 64 with Zvkned: +//! https://github.com/openssl/openssl/blob/master/crypto/aes/asm/aes-riscv64-zvkned.pl + +mod encdec; +mod expand; +#[cfg(test)] +mod test_expand; + +// TODO(silvanshade): +// - register allocation +// - use larger parallel block size +// - interleave key-schedule for parallel blocks (allows for larger LMUL) +// - use larger LMUL for parallel blocks + +use crate::{Block, Block8}; +use cipher::{ + consts::{U16, U24, U32, U8}, + inout::InOut, + AlgorithmName, BlockBackend, BlockCipher, BlockCipherDecrypt, BlockCipherEncrypt, BlockClosure, + BlockSizeUser, Key, KeyInit, KeySizeUser, ParBlocksSizeUser, +}; +use core::fmt; + +type RoundKey = [u32; 4]; +type RoundKeys = [RoundKey; N]; + +macro_rules! define_aes_impl { + ( + $module:ident, + $name:ident, + $name_enc:ident, + $name_dec:ident, + $name_back_enc:ident, + $name_back_dec:ident, + $key_size:ty, + $words:tt, + $rounds:tt, + $doc:expr $(,)? + ) => { + #[doc=$doc] + #[doc = "block cipher"] + #[derive(Clone)] + pub struct $name { + encrypt: $name_enc, + decrypt: $name_dec, + } + + impl BlockCipher for $name {} + + impl KeySizeUser for $name { + type KeySize = $key_size; + } + + impl KeyInit for $name { + #[inline] + fn new(key: &Key) -> Self { + let encrypt = $name_enc::new(key); + let decrypt = $name_dec::from(&encrypt); + Self { encrypt, decrypt } + } + } + + impl From<$name_enc> for $name { + #[inline] + fn from(encrypt: $name_enc) -> $name { + let decrypt = (&encrypt).into(); + Self { encrypt, decrypt } + } + } + + impl From<&$name_enc> for $name { + #[inline] + fn from(encrypt: &$name_enc) -> $name { + let decrypt = encrypt.into(); + let encrypt = encrypt.clone(); + Self { encrypt, decrypt } + } + } + + impl BlockSizeUser for $name { + type BlockSize = U16; + } + + impl BlockCipherEncrypt for $name { + fn encrypt_with_backend(&self, f: impl BlockClosure) { + self.encrypt.encrypt_with_backend(f) + } + } + + impl BlockCipherDecrypt for $name { + fn decrypt_with_backend(&self, f: impl BlockClosure) { + self.decrypt.decrypt_with_backend(f) + } + } + + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name), " { .. }")) + } + } + + impl AlgorithmName for $name { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name)) + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name {} + + #[doc=$doc] + #[doc = "block cipher (encrypt-only)"] + #[derive(Clone)] + pub struct $name_enc { + round_keys: RoundKeys<$rounds>, + } + + impl $name_enc { + #[inline(always)] + pub(crate) fn get_enc_backend(&self) -> $name_back_enc<'_> { + $name_back_enc(self) + } + } + + impl BlockCipher for $name_enc {} + + impl KeySizeUser for $name_enc { + type KeySize = $key_size; + } + + impl KeyInit for $name_enc { + #[inline] + fn new(key: &Key) -> Self { + Self { + round_keys: self::expand::$module::expand_key(key.as_ref()), + } + } + } + + impl BlockSizeUser for $name_enc { + type BlockSize = U16; + } + + impl BlockCipherEncrypt for $name_enc { + fn encrypt_with_backend(&self, f: impl BlockClosure) { + f.call(&mut self.get_enc_backend()) + } + } + + impl fmt::Debug for $name_enc { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name_enc), " { .. }")) + } + } + + impl AlgorithmName for $name_enc { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name_enc)) + } + } + + impl Drop for $name_enc { + #[inline] + fn drop(&mut self) { + #[cfg(feature = "zeroize")] + zeroize::Zeroize::zeroize(&mut self.round_keys); + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name_enc {} + + #[doc=$doc] + #[doc = "block cipher (decrypt-only)"] + #[derive(Clone)] + pub struct $name_dec { + round_keys: RoundKeys<$rounds>, + } + + impl $name_dec { + #[inline(always)] + pub(crate) fn get_dec_backend(&self) -> $name_back_dec<'_> { + $name_back_dec(self) + } + } + + impl BlockCipher for $name_dec {} + + impl KeySizeUser for $name_dec { + type KeySize = $key_size; + } + + impl KeyInit for $name_dec { + #[inline] + fn new(key: &Key) -> Self { + $name_enc::new(key).into() + } + } + + impl From<$name_enc> for $name_dec { + #[inline] + fn from(enc: $name_enc) -> $name_dec { + Self::from(&enc) + } + } + + impl From<&$name_enc> for $name_dec { + fn from(enc: &$name_enc) -> $name_dec { + let round_keys = enc.round_keys; + Self { round_keys } + } + } + + impl BlockSizeUser for $name_dec { + type BlockSize = U16; + } + + impl BlockCipherDecrypt for $name_dec { + fn decrypt_with_backend(&self, f: impl BlockClosure) { + f.call(&mut self.get_dec_backend()); + } + } + + impl fmt::Debug for $name_dec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str(concat!(stringify!($name_dec), " { .. }")) + } + } + + impl AlgorithmName for $name_dec { + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(stringify!($name_dec)) + } + } + + impl Drop for $name_dec { + #[inline] + fn drop(&mut self) { + #[cfg(feature = "zeroize")] + zeroize::Zeroize::zeroize(&mut self.round_keys); + } + } + + #[cfg(feature = "zeroize")] + impl zeroize::ZeroizeOnDrop for $name_dec {} + + pub(crate) struct $name_back_enc<'a>(&'a $name_enc); + + impl<'a> BlockSizeUser for $name_back_enc<'a> { + type BlockSize = U16; + } + + impl<'a> ParBlocksSizeUser for $name_back_enc<'a> { + type ParBlocksSize = U8; + } + + impl<'a> BlockBackend for $name_back_enc<'a> { + #[inline(always)] + fn proc_block(&mut self, block: InOut<'_, '_, Block>) { + self::encdec::$module::encrypt1(&self.0.round_keys, block); + } + + #[inline(always)] + fn proc_par_blocks(&mut self, blocks: InOut<'_, '_, Block8>) { + self::encdec::$module::encrypt8(&self.0.round_keys, blocks) + } + } + + pub(crate) struct $name_back_dec<'a>(&'a $name_dec); + + impl<'a> BlockSizeUser for $name_back_dec<'a> { + type BlockSize = U16; + } + + impl<'a> ParBlocksSizeUser for $name_back_dec<'a> { + type ParBlocksSize = U8; + } + + impl<'a> BlockBackend for $name_back_dec<'a> { + #[inline(always)] + fn proc_block(&mut self, block: InOut<'_, '_, Block>) { + self::encdec::$module::decrypt1(&self.0.round_keys, block); + } + + #[inline(always)] + fn proc_par_blocks(&mut self, blocks: InOut<'_, '_, Block8>) { + self::encdec::$module::decrypt8(&self.0.round_keys, blocks) + } + } + }; +} + +define_aes_impl!( + aes128, + Aes128, + Aes128Enc, + Aes128Dec, + Aes128BackEnc, + Aes128BackDec, + U16, + 2, + 11, + "AES-128", +); +// NOTE: AES-192 is only implemented if scalar-crypto is enabled. +#[cfg(all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" +))] +define_aes_impl!( + aes192, + Aes192, + Aes192Enc, + Aes192Dec, + Aes192BackEnc, + Aes192BackDec, + U24, + 3, + 13, + "AES-192", +); +define_aes_impl!( + aes256, + Aes256, + Aes256Enc, + Aes256Dec, + Aes256BackEnc, + Aes256BackDec, + U32, + 4, + 15, + "AES-256", +); diff --git a/aes/src/riscv/rvv/encdec.rs b/aes/src/riscv/rvv/encdec.rs new file mode 100644 index 00000000..abb566f0 --- /dev/null +++ b/aes/src/riscv/rvv/encdec.rs @@ -0,0 +1,3 @@ +pub(super) mod aes128; +pub(super) mod aes192; +pub(super) mod aes256; diff --git a/aes/src/riscv/rvv/encdec/aes128.rs b/aes/src/riscv/rvv/encdec/aes128.rs new file mode 100644 index 00000000..4d7a24f8 --- /dev/null +++ b/aes/src/riscv/rvv/encdec/aes128.rs @@ -0,0 +1,209 @@ +use crate::{riscv::rvv::RoundKeys, Block, Block8}; +use cipher::inout::InOut; +use core::arch::global_asm; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + // INPUTS: + // a0: uint8_t * dst + // a1: uint8_t *const src + // a2: size_t len + // a3: uint32_t *const key + // SAFETY: + // - a0, a1 must be valid pointers to memory regions of at least len bytes + // - a3 must be valid pointers to memory regions of at least 176 bytes + // - on exit: a1 and a3 remain unchanged + // - on exit: a0 is overwritten with cipher-data + ".balign 4", + ".attribute arch, \"rv64gcv1p0_zkne_zknd_zvkned1p0\"", + ".global aes_riscv_rv64_vector_encdec_aes128_encrypt", + ".type aes_riscv_rv64_vector_encdec_aes128_encrypt, @function", + "aes_riscv_rv64_vector_encdec_aes128_encrypt:", + "andi t0, a2, -16", // t0 = len (round to multiple of 16) + "beqz t0, 2f", // if len == 0, exit + "srli t3, t0, 2", // t3 = len / 4 + + "vsetivli zero, 4, e32, m1, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v10, (a3)", "addi a3, a3, 16", // load round 00 key + "vle32.v v11, (a3)", "addi a3, a3, 16", // load round 01 key + "vle32.v v12, (a3)", "addi a3, a3, 16", // load round 02 key + "vle32.v v13, (a3)", "addi a3, a3, 16", // load round 03 key + "vle32.v v14, (a3)", "addi a3, a3, 16", // load round 04 key + "vle32.v v15, (a3)", "addi a3, a3, 16", // load round 05 key + "vle32.v v16, (a3)", "addi a3, a3, 16", // load round 06 key + "vle32.v v17, (a3)", "addi a3, a3, 16", // load round 07 key + "vle32.v v18, (a3)", "addi a3, a3, 16", // load round 08 key + "vle32.v v19, (a3)", "addi a3, a3, 16", // load round 09 key + "vle32.v v20, (a3)", // load round 10 key + "1:", + "vsetvli t2, t3, e32, m1, ta, ma", // configure RVV for vector shape: len x 32b x 1 + // t2 = vl4 <= len + + "vle32.v v1, (a1)", // load vl bytes of plain-data + "vaesz.vs v1, v10", // perform AES-128 round 00 encryption + "vaesem.vs v1, v11", // perform AES-128 round 01 encryption + "vaesem.vs v1, v12", // perform AES-128 round 02 encryption + "vaesem.vs v1, v13", // perform AES-128 round 03 encryption + "vaesem.vs v1, v14", // perform AES-128 round 04 encryption + "vaesem.vs v1, v15", // perform AES-128 round 05 encryption + "vaesem.vs v1, v16", // perform AES-128 round 06 encryption + "vaesem.vs v1, v17", // perform AES-128 round 07 encryption + "vaesem.vs v1, v18", // perform AES-128 round 08 encryption + "vaesem.vs v1, v19", // perform AES-128 round 09 encryption + "vaesef.vs v1, v20", // perform AES-128 round 10 encryption + "vse32.v v1, (a0)", // save vl bytes of cipher-data + + "sub t3, t3, t2", // len -= vl4 // vl (measuring 4-byte units) + + "slli t2, t2, 2", // vl16 = vl4 * 4 // vl (measuring 16-byte units) + "add a1, a1, t2", // src += vl16 // src += vl4 * 4 + "add a0, a0, t2", // dst += vl16 // dst += vl4 * 4 + + "bnez t3, 1b", // if len != 0, loop + "2:", + "ret", +} +extern "C" { + fn aes_riscv_rv64_vector_encdec_aes128_encrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u32, + ); +} + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + // INPUTS: + // a0: uint8_t * dst + // a1: uint8_t *const src + // a2: size_t len + // a3: uint32_t *const key + // SAFETY: + // - a0, a1 must be valid pointers to memory regions of at least len bytes + // - a3 must be valid pointers to memory regions of at least 176 bytes + // - on exit: a1, a3 are unchanged + // - on exit: a0 is overwritten with plain-data + ".balign 4", + ".attribute arch, \"rv64gcv1p0_zkne_zknd_zvkned1p0\"", + ".global aes_riscv_rv64_vector_encdec_aes128_decrypt", + ".type aes_riscv_rv64_vector_encdec_aes128_decrypt, @function", + "aes_riscv_rv64_vector_encdec_aes128_decrypt:", + "andi t0, a2, -16", // t0 = len (round to multiple of 16) + "beqz t0, 2f", // if len == 0, exit + "srli t3, t0, 2", // a2 = len / 4 + + "vsetivli zero, 4, e32, m1, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v10, (a3)", "addi a3, a3, 16", // load round 00 key + "vle32.v v11, (a3)", "addi a3, a3, 16", // load round 01 key + "vle32.v v12, (a3)", "addi a3, a3, 16", // load round 02 key + "vle32.v v13, (a3)", "addi a3, a3, 16", // load round 03 key + "vle32.v v14, (a3)", "addi a3, a3, 16", // load round 04 key + "vle32.v v15, (a3)", "addi a3, a3, 16", // load round 05 key + "vle32.v v16, (a3)", "addi a3, a3, 16", // load round 06 key + "vle32.v v17, (a3)", "addi a3, a3, 16", // load round 07 key + "vle32.v v18, (a3)", "addi a3, a3, 16", // load round 08 key + "vle32.v v19, (a3)", "addi a3, a3, 16", // load round 09 key + "vle32.v v20, (a3)", // load round 10 key + "1:", + "vsetvli t2, t3, e32, m1, ta, ma", // configure RVV for vector shape: len x 32b x 1 + // t2 = vl4 <= len + + "vle32.v v0, (a1)", // load vl4 bytes of cipher-data + "vaesz.vs v0, v20", // perform AES-128 round 10 decryption + "vaesdm.vs v0, v19", // perform AES-128 round 09 decryption + "vaesdm.vs v0, v18", // perform AES-128 round 08 decryption + "vaesdm.vs v0, v17", // perform AES-128 round 07 decryption + "vaesdm.vs v0, v16", // perform AES-128 round 06 decryption + "vaesdm.vs v0, v15", // perform AES-128 round 05 decryption + "vaesdm.vs v0, v14", // perform AES-128 round 05 decryption + "vaesdm.vs v0, v13", // perform AES-128 round 03 decryption + "vaesdm.vs v0, v12", // perform AES-128 round 02 decryption + "vaesdm.vs v0, v11", // perform AES-128 round 01 decryption + "vaesdf.vs v0, v10", // perform AES-128 round 00 decryption + "vse32.v v0, (a0)", // save vl4 bytes of plain-data + + "sub t3, t3, t2", // len -= vl4 // vl (measuring 4-byte units) + + "slli t2, t2, 2", // vl16 = vl4 * 4 // vl (measuring 16-byte units) + "add a1, a1, t2", // src += vl16 // src += vl4 * 4 + "add a0, a0, t2", // dst += vl16 // dst += vl4 * 4 + + "bnez t3, 1b", // if len != 0, loop + "2:", + "ret", +} +extern "C" { + fn aes_riscv_rv64_vector_encdec_aes128_decrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u32, + ); +} + +#[inline(always)] +fn encrypt_vla(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_riscv_rv64_vector_encdec_aes128_encrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn encrypt1(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn encrypt8(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 8) +} + +#[inline(always)] +fn decrypt_vla(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_riscv_rv64_vector_encdec_aes128_decrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn decrypt1(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn decrypt8(keys: &RoundKeys<11>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 8) +} diff --git a/aes/src/riscv/rvv/encdec/aes192.rs b/aes/src/riscv/rvv/encdec/aes192.rs new file mode 100644 index 00000000..9ea70b50 --- /dev/null +++ b/aes/src/riscv/rvv/encdec/aes192.rs @@ -0,0 +1,216 @@ +use crate::riscv::rvv::RoundKeys; +use crate::{Block, Block8}; +use cipher::inout::InOut; +use core::arch::global_asm; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm!{ + // INPUTS: + // a0: uint8_t exp[208] + // a1: uint32_t key[4] + // SAFETY: + // - a0 must be valid pointers to memory regions of at least 208 bytes + // - a1 must be valid pointers to memory regions of at least 16 bytes + // - on exit: a0 is overwritten with expanded round keys + // - on exit: a1 is unchanged + ".attribute arch, \"rv64gcv1p0_zvkned1p0\"", + ".balign 4", + ".global aes_riscv_rv64_vector_encdec_aes192_encrypt", + ".type aes_riscv_rv64_vector_encdec_aes192_encrypt, @function", + "aes_riscv_rv64_vector_encdec_aes192_encrypt:", + "andi t0, a2, -16", // t0 = len (round to multiple of 16) + "beqz t0, 2f", // if len == 0, exit + "srli t3, t0, 2", // t3 = len / 4 + + "vsetivli zero, 4, e32, m1, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v10, (a3)", "addi a3, a3, 16", // load round 00 key + "vle32.v v11, (a3)", "addi a3, a3, 16", // load round 01 key + "vle32.v v12, (a3)", "addi a3, a3, 16", // load round 02 key + "vle32.v v13, (a3)", "addi a3, a3, 16", // load round 03 key + "vle32.v v14, (a3)", "addi a3, a3, 16", // load round 04 key + "vle32.v v15, (a3)", "addi a3, a3, 16", // load round 05 key + "vle32.v v16, (a3)", "addi a3, a3, 16", // load round 06 key + "vle32.v v17, (a3)", "addi a3, a3, 16", // load round 07 key + "vle32.v v18, (a3)", "addi a3, a3, 16", // load round 08 key + "vle32.v v19, (a3)", "addi a3, a3, 16", // load round 09 key + "vle32.v v20, (a3)", "addi a3, a3, 16", // load round 10 key + "vle32.v v21, (a3)", "addi a3, a3, 16", // load round 11 key + "vle32.v v22, (a3)", // load round 12 key + "1:", + "vsetvli t2, t3, e32, m1, ta, ma", // configure RVV for vector shape: len x 32b x 1 + // t2 = vl4 <= len + + "vle32.v v1, (a1)", // load vl bytes of plain-data + "vaesz.vs v1, v10", // perform AES-192 round 00 encryption + "vaesem.vs v1, v11", // perform AES-192 round 01 encryption + "vaesem.vs v1, v12", // perform AES-192 round 02 encryption + "vaesem.vs v1, v13", // perform AES-192 round 03 encryption + "vaesem.vs v1, v14", // perform AES-192 round 04 encryption + "vaesem.vs v1, v15", // perform AES-192 round 05 encryption + "vaesem.vs v1, v16", // perform AES-192 round 06 encryption + "vaesem.vs v1, v17", // perform AES-192 round 07 encryption + "vaesem.vs v1, v18", // perform AES-192 round 08 encryption + "vaesem.vs v1, v19", // perform AES-192 round 09 encryption + "vaesem.vs v1, v20", // perform AES-192 round 10 encryption + "vaesem.vs v1, v21", // perform AES-192 round 11 encryption + "vaesef.vs v1, v22", // perform AES-192 round 12 encryption + "vse32.v v1, (a0)", // save vl bytes of cipher-data + + "sub t3, t3, t2", // len -= vl4 // vl (measuring 4-byte units) + + "slli t2, t2, 2", // vl16 = vl4 * 4 // vl (measuring 16-byte units) + "add a1, a1, t2", // src += vl16 // src += vl4 * 4 + "add a0, a0, t2", // dst += vl16 // dst += vl4 * 4 + + "bnez t3, 1b", // if len != 0, loop + "2:", + "ret", +} +extern "C" { + fn aes_riscv_rv64_vector_encdec_aes192_encrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u32, + ); +} + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + // INPUTS: + // a0: uint8_t * dst + // a1: uint8_t *const src + // a2: size_t len + // a3: uint32_t *const key + // SAFETY: + // - a0, a1 must be valid pointers to memory regions of at least len bytes + // - a3 must be valid pointers to memory regions of at least 208 bytes + // - on exit: a1, a3 are unchanged + // - on exit: a0 is overwritten with plain-data + ".balign 4", + ".attribute arch, \"rv64gcv1p0_zkne_zknd_zvkned1p0\"", + ".global aes_riscv_rv64_vector_encdec_aes192_decrypt", + ".type aes_riscv_rv64_vector_encdec_aes192_decrypt, @function", + "aes_riscv_rv64_vector_encdec_aes192_decrypt:", + "andi t0, a2, -16", // t0 = len (round to multiple of 16) + "beqz t0, 2f", // if len == 0, exit + "srli t3, t0, 2", // a2 = len / 4 + + "vsetivli zero, 4, e32, m1, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v10, (a3)", "addi a3, a3, 16", // load round 00 key + "vle32.v v11, (a3)", "addi a3, a3, 16", // load round 01 key + "vle32.v v12, (a3)", "addi a3, a3, 16", // load round 02 key + "vle32.v v13, (a3)", "addi a3, a3, 16", // load round 03 key + "vle32.v v14, (a3)", "addi a3, a3, 16", // load round 04 key + "vle32.v v15, (a3)", "addi a3, a3, 16", // load round 05 key + "vle32.v v16, (a3)", "addi a3, a3, 16", // load round 06 key + "vle32.v v17, (a3)", "addi a3, a3, 16", // load round 07 key + "vle32.v v18, (a3)", "addi a3, a3, 16", // load round 08 key + "vle32.v v19, (a3)", "addi a3, a3, 16", // load round 09 key + "vle32.v v20, (a3)", "addi a3, a3, 16", // load round 10 key + "vle32.v v21, (a3)", "addi a3, a3, 16", // load round 11 key + "vle32.v v22, (a3)", // load round 12 key + "1:", + "vsetvli t2, t3, e32, m1, ta, ma", // configure RVV for vector shape: len x 32b x 1 + // t2 = vl4 <= len + + "vle32.v v0, (a1)", // load vl4 bytes of cipher-data + "vaesz.vs v0, v22", // perform AES-192 round 12 decryption + "vaesdm.vs v0, v21", // perform AES-192 round 11 decryption + "vaesdm.vs v0, v20", // perform AES-192 round 10 decryption + "vaesdm.vs v0, v19", // perform AES-192 round 09 decryption + "vaesdm.vs v0, v18", // perform AES-192 round 08 decryption + "vaesdm.vs v0, v17", // perform AES-192 round 07 decryption + "vaesdm.vs v0, v16", // perform AES-192 round 06 decryption + "vaesdm.vs v0, v15", // perform AES-192 round 05 decryption + "vaesdm.vs v0, v14", // perform AES-192 round 05 decryption + "vaesdm.vs v0, v13", // perform AES-192 round 03 decryption + "vaesdm.vs v0, v12", // perform AES-192 round 02 decryption + "vaesdm.vs v0, v11", // perform AES-192 round 01 decryption + "vaesdf.vs v0, v10", // perform AES-192 round 00 decryption + "vse32.v v0, (a0)", // save vl4 bytes of plain-data + + "sub t3, t3, t2", // len -= vl4 // vl (measuring 4-byte units) + + "slli t2, t2, 2", // vl16 = vl4 * 4 // vl (measuring 16-byte units) + "add a1, a1, t2", // src += vl16 // src += vl4 * 4 + "add a0, a0, t2", // dst += vl16 // dst += vl4 * 4 + + "bnez t3, 1b", // if len != 0, loop + "2:", + "ret", +} +extern "C" { + fn aes_riscv_rv64_vector_encdec_aes192_decrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u32, + ); +} + +#[inline(always)] +fn encrypt_vla(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_riscv_rv64_vector_encdec_aes192_encrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn encrypt1(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn encrypt8(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 8) +} + +#[inline(always)] +fn decrypt_vla(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_riscv_rv64_vector_encdec_aes192_decrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn decrypt1(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn decrypt8(keys: &RoundKeys<13>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 8) +} diff --git a/aes/src/riscv/rvv/encdec/aes256.rs b/aes/src/riscv/rvv/encdec/aes256.rs new file mode 100644 index 00000000..1e91a044 --- /dev/null +++ b/aes/src/riscv/rvv/encdec/aes256.rs @@ -0,0 +1,224 @@ +use crate::riscv::rvv::RoundKeys; +use crate::{Block, Block8}; +use cipher::inout::InOut; +use core::arch::global_asm; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm!{ + // INPUTS: + // a0: uint8_t exp[240] + // a1: uint32_t key[4] + // SAFETY: + // - a0 must be valid pointers to memory regions of at least 240 bytes + // - a1 must be valid pointers to memory regions of at least 16 bytes + // - on exit: a0 is overwritten with expanded round keys + // - on exit: a1 is unchanged + ".attribute arch, \"rv64gcv1p0_zvkned1p0\"", + ".balign 4", + ".global aes_riscv_rv64_vector_encdec_aes256_encrypt", + ".type aes_riscv_rv64_vector_encdec_aes256_encrypt, @function", + "aes_riscv_rv64_vector_encdec_aes256_encrypt:", + "andi t0, a2, -16", // t0 = len (round to multiple of 16) + "beqz t0, 2f", // if len == 0, exit + "srli t3, t0, 2", // t3 = len / 4 + + "vsetivli zero, 4, e32, m1, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v10, (a3)", "addi a3, a3, 16", // load round 00 key + "vle32.v v11, (a3)", "addi a3, a3, 16", // load round 01 key + "vle32.v v12, (a3)", "addi a3, a3, 16", // load round 02 key + "vle32.v v13, (a3)", "addi a3, a3, 16", // load round 03 key + "vle32.v v14, (a3)", "addi a3, a3, 16", // load round 04 key + "vle32.v v15, (a3)", "addi a3, a3, 16", // load round 05 key + "vle32.v v16, (a3)", "addi a3, a3, 16", // load round 06 key + "vle32.v v17, (a3)", "addi a3, a3, 16", // load round 07 key + "vle32.v v18, (a3)", "addi a3, a3, 16", // load round 08 key + "vle32.v v19, (a3)", "addi a3, a3, 16", // load round 09 key + "vle32.v v20, (a3)", "addi a3, a3, 16", // load round 10 key + "vle32.v v21, (a3)", "addi a3, a3, 16", // load round 11 key + "vle32.v v22, (a3)", "addi a3, a3, 16", // load round 12 key + "vle32.v v23, (a3)", "addi a3, a3, 16", // load round 13 key + "vle32.v v24, (a3)", // load round 14 key + "1:", + "vsetvli t2, t3, e32, m1, ta, ma", // configure RVV for vector shape: len x 32b x 1 + // t2 = vl4 <= len + + "vle32.v v1, (a1)", // load vl bytes of plain-data + "vaesz.vs v1, v10", // perform AES-256 round 00 encryption + "vaesem.vs v1, v11", // perform AES-256 round 01 encryption + "vaesem.vs v1, v12", // perform AES-256 round 02 encryption + "vaesem.vs v1, v13", // perform AES-256 round 03 encryption + "vaesem.vs v1, v14", // perform AES-256 round 04 encryption + "vaesem.vs v1, v15", // perform AES-256 round 05 encryption + "vaesem.vs v1, v16", // perform AES-256 round 06 encryption + "vaesem.vs v1, v17", // perform AES-256 round 07 encryption + "vaesem.vs v1, v18", // perform AES-256 round 08 encryption + "vaesem.vs v1, v19", // perform AES-256 round 09 encryption + "vaesem.vs v1, v20", // perform AES-256 round 10 encryption + "vaesem.vs v1, v21", // perform AES-256 round 11 encryption + "vaesem.vs v1, v22", // perform AES-256 round 12 encryption + "vaesem.vs v1, v23", // perform AES-256 round 13 encryption + "vaesef.vs v1, v24", // perform AES-256 round 14 encryption + "vse32.v v1, (a0)", // save vl bytes of cipher-data + + "sub t3, t3, t2", // len -= vl4 // vl (measuring 4-byte units) + + "slli t2, t2, 2", // vl16 = vl4 * 4 // vl (measuring 16-byte units) + "add a1, a1, t2", // src += vl16 // src += vl4 * 4 + "add a0, a0, t2", // dst += vl16 // dst += vl4 * 4 + + "bnez t3, 1b", // if len != 0, loop + "2:", + "ret", +} +extern "C" { + fn aes_riscv_rv64_vector_encdec_aes256_encrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u32, + ); +} + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + // INPUTS: + // a0: uint8_t * dst + // a1: uint8_t *const src + // a2: size_t len + // a3: uint32_t *const key + // SAFETY: + // - a0, a1 must be valid pointers to memory regions of at least len bytes + // - a3 must be valid pointers to memory regions of at least 240 bytes + // - on exit: a1, a3 are unchanged + // - on exit: a0 is overwritten with plain-data + ".balign 4", + ".attribute arch, \"rv64gcv1p0_zkne_zknd_zvkned1p0\"", + ".global aes_riscv_rv64_vector_encdec_aes256_decrypt", + ".type aes_riscv_rv64_vector_encdec_aes256_decrypt, @function", + "aes_riscv_rv64_vector_encdec_aes256_decrypt:", + "andi t0, a2, -16", // t0 = len (round to multiple of 16) + "beqz t0, 2f", // if len == 0, exit + "srli t3, t0, 2", // a2 = len / 4 + + "vsetivli zero, 4, e32, m1, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v10, (a3)", "addi a3, a3, 16", // load round 00 key + "vle32.v v11, (a3)", "addi a3, a3, 16", // load round 01 key + "vle32.v v12, (a3)", "addi a3, a3, 16", // load round 02 key + "vle32.v v13, (a3)", "addi a3, a3, 16", // load round 03 key + "vle32.v v14, (a3)", "addi a3, a3, 16", // load round 04 key + "vle32.v v15, (a3)", "addi a3, a3, 16", // load round 05 key + "vle32.v v16, (a3)", "addi a3, a3, 16", // load round 06 key + "vle32.v v17, (a3)", "addi a3, a3, 16", // load round 07 key + "vle32.v v18, (a3)", "addi a3, a3, 16", // load round 08 key + "vle32.v v19, (a3)", "addi a3, a3, 16", // load round 09 key + "vle32.v v20, (a3)", "addi a3, a3, 16", // load round 10 key + "vle32.v v21, (a3)", "addi a3, a3, 16", // load round 11 key + "vle32.v v22, (a3)", "addi a3, a3, 16", // load round 12 key + "vle32.v v23, (a3)", "addi a3, a3, 16", // load round 13 key + "vle32.v v24, (a3)", // load round 14 key + "1:", + "vsetvli t2, t3, e32, m1, ta, ma", // configure RVV for vector shape: len x 32b x 1 + // t2 = vl4 <= len + + "vle32.v v0, (a1)", // load vl4 bytes of cipher-data + "vaesz.vs v0, v24", // perform AES-256 round 14 decryption + "vaesdm.vs v0, v23", // perform AES-256 round 13 decryption + "vaesdm.vs v0, v22", // perform AES-256 round 12 decryption + "vaesdm.vs v0, v21", // perform AES-256 round 11 decryption + "vaesdm.vs v0, v20", // perform AES-256 round 10 decryption + "vaesdm.vs v0, v19", // perform AES-256 round 09 decryption + "vaesdm.vs v0, v18", // perform AES-256 round 08 decryption + "vaesdm.vs v0, v17", // perform AES-256 round 07 decryption + "vaesdm.vs v0, v16", // perform AES-256 round 06 decryption + "vaesdm.vs v0, v15", // perform AES-256 round 05 decryption + "vaesdm.vs v0, v14", // perform AES-256 round 05 decryption + "vaesdm.vs v0, v13", // perform AES-256 round 03 decryption + "vaesdm.vs v0, v12", // perform AES-256 round 02 decryption + "vaesdm.vs v0, v11", // perform AES-256 round 01 decryption + "vaesdf.vs v0, v10", // perform AES-256 round 00 decryption + "vse32.v v0, (a0)", // save vl4 bytes of plain-data + + "sub t3, t3, t2", // len -= vl4 // vl (measuring 4-byte units) + + "slli t2, t2, 2", // vl16 = vl4 * 4 // vl (measuring 16-byte units) + "add a1, a1, t2", // src += vl16 // src += vl4 * 4 + "add a0, a0, t2", // dst += vl16 // dst += vl4 * 4 + + "bnez t3, 1b", // if len != 0, loop + "2:", + "ret", +} +extern "C" { + fn aes_riscv_rv64_vector_encdec_aes256_decrypt( + dst: *mut u8, + src: *const u8, + len: usize, + key: *const u32, + ); +} + +#[inline(always)] +fn encrypt_vla(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_riscv_rv64_vector_encdec_aes256_encrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn encrypt1(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn encrypt8(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + encrypt_vla(keys, data, 8) +} + +#[inline(always)] +fn decrypt_vla(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>, blocks: usize) { + let dst = data.get_out().as_mut_ptr(); + let src = data.get_in().as_ptr(); + let len = blocks * 16; + let key = keys.as_ptr().cast::(); + unsafe { aes_riscv_rv64_vector_encdec_aes256_decrypt(dst, src, len, key) }; +} + +#[inline(always)] +pub(crate) fn decrypt1(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 1) +} + +#[inline(always)] +pub(crate) fn decrypt8(keys: &RoundKeys<15>, mut data: InOut<'_, '_, Block8>) { + let data = unsafe { + InOut::from_raw( + data.get_in().as_ptr().cast::(), + data.get_out().as_mut_ptr().cast::(), + ) + }; + decrypt_vla(keys, data, 8) +} diff --git a/aes/src/riscv/rvv/expand.rs b/aes/src/riscv/rvv/expand.rs new file mode 100644 index 00000000..d056609a --- /dev/null +++ b/aes/src/riscv/rvv/expand.rs @@ -0,0 +1,11 @@ +use super::{RoundKey, RoundKeys}; + +pub(super) mod aes128; +// NOTE: AES-192 is only implemented if scalar-crypto is enabled. +#[cfg(all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" +))] +pub(super) mod aes192; +pub(super) mod aes256; diff --git a/aes/src/riscv/rvv/expand/aes128.rs b/aes/src/riscv/rvv/expand/aes128.rs new file mode 100644 index 00000000..9484d4a1 --- /dev/null +++ b/aes/src/riscv/rvv/expand/aes128.rs @@ -0,0 +1,56 @@ +use super::{RoundKey, RoundKeys}; +use core::{ + arch::global_asm, + mem::{transmute, MaybeUninit}, +}; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] + global_asm! { + // INPUTS: + // a0: uint8_t exp[176] + // a1: uint32_t key[4] + // SAFETY: + // - a0 must be valid pointers to memory regions of at least 176 bytes + // - a1 must be valid pointers to memory regions of at least 16 bytes + // - on exit: a0 is overwritten with expanded round keys + // - on exit: a1 is unchanged + ".attribute arch, \"rv64gcv1p0_zvkned1p0\"", + ".balign 4", + ".global aes_riscv_rv64_vector_expand_aes128_expand_key", + ".type aes_riscv_rv64_vector_expand_aes128_expand_key, @function", + "aes_riscv_rv64_vector_expand_aes128_expand_key:", + "vsetivli zero, 4, e32, m1, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v4, (a1)", // load user-key + "vse32.v v4, (a0)", // save round 00 key (user-key) + + "vaeskf1.vi v4, v4, 1", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 01 key + "vaeskf1.vi v4, v4, 2", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 02 key + "vaeskf1.vi v4, v4, 3", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 03 key + "vaeskf1.vi v4, v4, 4", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 04 key + "vaeskf1.vi v4, v4, 5", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 05 key + "vaeskf1.vi v4, v4, 6", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 06 key + "vaeskf1.vi v4, v4, 7", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 07 key + "vaeskf1.vi v4, v4, 8", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 08 key + "vaeskf1.vi v4, v4, 9", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 09 key + "vaeskf1.vi v4, v4, 10", "addi a0, a0, 16", "vse32.v v4, (a0)", // expand and save round 10 key + + "ret", + } +extern "C" { + fn aes_riscv_rv64_vector_expand_aes128_expand_key(dst: *mut u32, src: *const u8); +} + +#[inline(always)] +pub fn expand_key(key: &[u8; 16]) -> RoundKeys<11> { + let mut exp: [MaybeUninit; 11] = unsafe { MaybeUninit::uninit().assume_init() }; + unsafe { + let exp = exp.as_mut_ptr().cast::(); + let key = key.as_ptr(); + aes_riscv_rv64_vector_expand_aes128_expand_key(exp, key); + }; + // SAFETY: All positions have been initialized. + let out: RoundKeys<11> = unsafe { transmute(exp) }; + out +} diff --git a/aes/src/riscv/rvv/expand/aes192.rs b/aes/src/riscv/rvv/expand/aes192.rs new file mode 100644 index 00000000..712b90ff --- /dev/null +++ b/aes/src/riscv/rvv/expand/aes192.rs @@ -0,0 +1,9 @@ +use super::RoundKeys; +use core::mem::transmute; + +#[inline(always)] +pub fn expand_key(key: &[u8; 24]) -> RoundKeys<13> { + let output = crate::riscv::rv64::expand::KeySchedule::<3, 13>::expand_key(key); + // SAFETY: Size is same and [u32] layout is downcast aligned for [u64]. + unsafe { transmute(output) } +} diff --git a/aes/src/riscv/rvv/expand/aes256.rs b/aes/src/riscv/rvv/expand/aes256.rs new file mode 100644 index 00000000..426b2a09 --- /dev/null +++ b/aes/src/riscv/rvv/expand/aes256.rs @@ -0,0 +1,69 @@ +use crate::riscv::rvv::expand::{RoundKey, RoundKeys}; +use core::{ + arch::global_asm, + mem::{transmute, MaybeUninit}, +}; + +// TODO(silvanshade): switch to intrinsics when available +#[rustfmt::skip] +global_asm! { + // INPUTS: + // a0: uint8_t exp[240] + // a1: uint32_t key[8] + // SAFETY: + // - a0 must be valid pointers to memory regions of at least 240 bytes + // - a1 must be valid pointers to memory regions of at least 32 bytes + // - on exit: a0 is overwritten with expanded round keys + // - on exit: a1 is unchanged + ".attribute arch, \"rv64gcv1p0_zvkned1p0\"", + ".balign 4", + ".global aes_riscv_rv64_vector_expand_aes256_expand_key", + ".type aes_riscv_rv64_vector_expand_aes256_expand_key, @function", + "aes_riscv_rv64_vector_expand_aes256_expand_key:", + "vsetivli zero, 4, e32, m4, ta, ma", // configure RVV for vector shape: 4 x 32b x 1 + + "vle32.v v4, (a1)", // load 1st 16-bytes of user-key [128:000] + "addi a1, a1, 16", + "vle32.v v8, (a1)", // load 2nd 16-bytes of user-key [256:128] + + "vse32.v v4, (a0)", "addi a0, a0, 16", // save round 00 key (user-key [128:000]) + "vse32.v v8, (a0)", "addi a0, a0, 16", // save round 01 key (user-key [256:128]) + + "vaeskf2.vi v4, v8, 2", "vse32.v v4, (a0)", "addi a0, a0, 16", // expand and save round 02 key + "vaeskf2.vi v8, v4, 3", "vse32.v v8, (a0)", "addi a0, a0, 16", // expand and save round 03 key + + "vaeskf2.vi v4, v8, 4", "vse32.v v4, (a0)", "addi a0, a0, 16", // expand and save round 04 key + "vaeskf2.vi v8, v4, 5", "vse32.v v8, (a0)", "addi a0, a0, 16", // expand and save round 05 key + + "vaeskf2.vi v4, v8, 6", "vse32.v v4, (a0)", "addi a0, a0, 16", // expand and save round 06 key + "vaeskf2.vi v8, v4, 7", "vse32.v v8, (a0)", "addi a0, a0, 16", // expand and save round 07 key + + "vaeskf2.vi v4, v8, 8", "vse32.v v4, (a0)", "addi a0, a0, 16", // expand and save round 08 key + "vaeskf2.vi v8, v4, 9", "vse32.v v8, (a0)", "addi a0, a0, 16", // expand and save round 09 key + + "vaeskf2.vi v4, v8, 10", "vse32.v v4, (a0)", "addi a0, a0, 16", // expand and save round 10 key + "vaeskf2.vi v8, v4, 11", "vse32.v v8, (a0)", "addi a0, a0, 16", // expand and save round 11 key + + "vaeskf2.vi v4, v8, 12", "vse32.v v4, (a0)", "addi a0, a0, 16", // expand and save round 12 key + "vaeskf2.vi v8, v4, 13", "vse32.v v8, (a0)", "addi a0, a0, 16", // expand and save round 13 key + + "vaeskf2.vi v4, v8, 14", "vse32.v v4, (a0)", // expand and save round 14 key + + "ret", +} +extern "C" { + fn aes_riscv_rv64_vector_expand_aes256_expand_key(dst: *mut u32, src: *const u8); +} + +#[inline(always)] +pub fn expand_key(key: &[u8; 32]) -> RoundKeys<15> { + let mut exp: [MaybeUninit; 15] = unsafe { MaybeUninit::uninit().assume_init() }; + unsafe { + let exp = exp.as_mut_ptr().cast::(); + let key = key.as_ptr(); + aes_riscv_rv64_vector_expand_aes256_expand_key(exp, key); + }; + // SAFETY: All positions have been initialized. + let out: RoundKeys<15> = unsafe { transmute(exp) }; + out +} diff --git a/aes/src/riscv/rvv/test_expand.rs b/aes/src/riscv/rvv/test_expand.rs new file mode 100644 index 00000000..e7451f8c --- /dev/null +++ b/aes/src/riscv/rvv/test_expand.rs @@ -0,0 +1,45 @@ +use super::RoundKeys; +use crate::riscv::test::*; + +fn store_expanded_keys(input: RoundKeys) -> [[u8; 16]; N] { + let mut output = [[0u8; 16]; N]; + for (src, dst) in input.iter().zip(output.iter_mut()) { + let b0 = src[0].to_ne_bytes(); + let b1 = src[1].to_ne_bytes(); + let b2 = src[2].to_ne_bytes(); + let b3 = src[3].to_ne_bytes(); + dst[00..04].copy_from_slice(&b0); + dst[04..08].copy_from_slice(&b1); + dst[08..12].copy_from_slice(&b2); + dst[12..16].copy_from_slice(&b3); + } + output +} + +// NOTE: Unlike RISC-V scalar crypto instructions, RISC-V vector crypto instructions implicitly +// perform key inversion as part of the cipher coding instructions. There are no distinct vector +// instructions for key inversion. Hence, no definition of `inv_expanded_keys` used below. + +#[test] +fn aes128_key_expansion() { + let ek = super::expand::aes128::expand_key(&AES128_KEY); + assert_eq!(store_expanded_keys(ek), AES128_EXP_KEYS); +} + +// NOTE: AES-192 is only implemented if scalar-crypto is enabled. +#[cfg(all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" +))] +#[test] +fn aes192_key_expansion() { + let ek = super::expand::aes192::expand_key(&AES192_KEY); + assert_eq!(store_expanded_keys(ek), AES192_EXP_KEYS); +} + +#[test] +fn aes256_key_expansion() { + let ek = super::expand::aes256::expand_key(&AES256_KEY); + assert_eq!(store_expanded_keys(ek), AES256_EXP_KEYS); +} diff --git a/aes/tests/mod.rs b/aes/tests/mod.rs index 4164e4f2..d665cbc2 100644 --- a/aes/tests/mod.rs +++ b/aes/tests/mod.rs @@ -2,5 +2,13 @@ //! https://www.cosic.esat.kuleuven.be/nessie/testvectors/ cipher::block_cipher_test!(aes128_test, "aes128", aes::Aes128); +#[cfg(any( + not(target_arch = "riscv64"), + all( + target_arch = "riscv64", + target_feature = "zknd", + target_feature = "zkne" + ) +))] cipher::block_cipher_test!(aes192_test, "aes192", aes::Aes192); cipher::block_cipher_test!(aes256_test, "aes256", aes::Aes256);