diff --git a/Cargo.lock b/Cargo.lock index 7e7a691a..c2f67d57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,6 +99,7 @@ dependencies = [ "aead", "ascon", "hex-literal", + "inout", "subtle", "zeroize", ] diff --git a/ascon-aead/Cargo.toml b/ascon-aead/Cargo.toml index a20e14df..5120782d 100644 --- a/ascon-aead/Cargo.toml +++ b/ascon-aead/Cargo.toml @@ -19,6 +19,7 @@ aead = { version = "0.6.0-rc.0", default-features = false } subtle = { version = "2", default-features = false } zeroize = { version = "1.6", optional = true, default-features = false, features = ["derive"] } ascon = "0.4" +inout = { version = "0.2.0-rc.4", default-features = false } [dev-dependencies] hex-literal = "0.4" diff --git a/ascon-aead/src/asconcore.rs b/ascon-aead/src/asconcore.rs index 2cd828bf..142c411a 100644 --- a/ascon-aead/src/asconcore.rs +++ b/ascon-aead/src/asconcore.rs @@ -7,6 +7,7 @@ use aead::{ consts::U16, }; use ascon::State; +use inout::InOutBuf; use subtle::ConstantTimeEq; /// Produce mask for padding. @@ -170,63 +171,68 @@ impl<'a, P: Parameters> AsconCore<'a, P> { self.state[4] ^= 0x8000000000000000; } - fn process_encrypt_inplace(&mut self, message: &mut [u8]) { - let mut blocks = message.chunks_exact_mut(16); - for block in blocks.by_ref() { + fn process_encrypt_inout(&mut self, message: InOutBuf<'_, '_, u8>) { + let (blocks, mut last_block) = message.into_chunks::(); + + for mut block in blocks { // process full block of message - self.state[0] ^= u64_from_bytes(&block[..8]); - block[..8].copy_from_slice(&u64::to_le_bytes(self.state[0])); - self.state[1] ^= u64_from_bytes(&block[8..16]); - block[8..16].copy_from_slice(&u64::to_le_bytes(self.state[1])); + self.state[0] ^= u64_from_bytes(&block.get_in()[..8]); + block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0])); + self.state[1] ^= u64_from_bytes(&block.get_in()[8..16]); + block.get_out()[8..16].copy_from_slice(&u64::to_le_bytes(self.state[1])); self.permute_state(); } // process partial block if it exists - let mut last_block = blocks.into_remainder(); let sidx = if last_block.len() >= 8 { - self.state[0] ^= u64_from_bytes(&last_block[..8]); - last_block[..8].copy_from_slice(&u64::to_le_bytes(self.state[0])); - last_block = &mut last_block[8..]; + self.state[0] ^= u64_from_bytes(&last_block.get_in()[..8]); + last_block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0])); + (_, last_block) = last_block.split_at(8); 1 } else { 0 }; self.state[sidx] ^= pad(last_block.len()); if !last_block.is_empty() { - self.state[sidx] ^= u64_from_bytes_partial(last_block); - last_block.copy_from_slice(&u64::to_le_bytes(self.state[sidx])[0..last_block.len()]); + self.state[sidx] ^= u64_from_bytes_partial(last_block.get_in()); + let last_block_len = last_block.len(); + last_block + .get_out() + .copy_from_slice(&u64::to_le_bytes(self.state[sidx])[0..last_block_len]); } } - fn process_decrypt_inplace(&mut self, ciphertext: &mut [u8]) { - let mut blocks = ciphertext.chunks_exact_mut(16); - for block in blocks.by_ref() { + fn process_decrypt_inout(&mut self, ciphertext: InOutBuf<'_, '_, u8>) { + let (blocks, mut last_block) = ciphertext.into_chunks::(); + for mut block in blocks { // process full block of ciphertext - let cx = u64_from_bytes(&block[..8]); - block[..8].copy_from_slice(&u64::to_le_bytes(self.state[0] ^ cx)); + let cx = u64_from_bytes(&block.get_in()[..8]); + block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0] ^ cx)); self.state[0] = cx; - let cx = u64_from_bytes(&block[8..16]); - block[8..16].copy_from_slice(&u64::to_le_bytes(self.state[1] ^ cx)); + let cx = u64_from_bytes(&block.get_in()[8..16]); + block.get_out()[8..16].copy_from_slice(&u64::to_le_bytes(self.state[1] ^ cx)); self.state[1] = cx; self.permute_state(); } // process partial block if it exists - let mut last_block = blocks.into_remainder(); let sidx = if last_block.len() >= 8 { - let cx = u64_from_bytes(&last_block[..8]); - last_block[..8].copy_from_slice(&u64::to_le_bytes(self.state[0] ^ cx)); + let cx = u64_from_bytes(&last_block.get_in()[..8]); + last_block.get_out()[..8].copy_from_slice(&u64::to_le_bytes(self.state[0] ^ cx)); self.state[0] = cx; - last_block = &mut last_block[8..]; + (_, last_block) = last_block.split_at(8); 1 } else { 0 }; self.state[sidx] ^= pad(last_block.len()); if !last_block.is_empty() { - let cx = u64_from_bytes_partial(last_block); + let cx = u64_from_bytes_partial(last_block.get_in()); self.state[sidx] ^= cx; - last_block.copy_from_slice(&u64::to_le_bytes(self.state[sidx])[0..last_block.len()]); + let last_block_len = last_block.len(); + last_block + .get_out() + .copy_from_slice(&u64::to_le_bytes(self.state[sidx])[0..last_block_len]); self.state[sidx] = clear(self.state[sidx], last_block.len()) ^ cx; } } @@ -248,7 +254,7 @@ impl<'a, P: Parameters> AsconCore<'a, P> { associated_data: &[u8], ) -> Array { self.process_associated_data(associated_data); - self.process_encrypt_inplace(message); + self.process_encrypt_inout(message.into()); Array::from(self.process_final()) } @@ -259,7 +265,7 @@ impl<'a, P: Parameters> AsconCore<'a, P> { expected_tag: &Array, ) -> Result<(), Error> { self.process_associated_data(associated_data); - self.process_decrypt_inplace(ciphertext); + self.process_decrypt_inout(ciphertext.into()); let tag = self.process_final(); if bool::from(tag.ct_eq(expected_tag)) {