Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 94 additions & 60 deletions chacha20/src/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,33 @@ impl Debug for Seed {
}
}

/// A wrapper for `stream_id`.
///
/// Can be constructed from any of the following:
/// * `[u32; 3]`
/// * `[u8; 12]`
/// * `u128`
/// A wrapper around 64 bits of data that can be constructed from any of the
/// following:
/// * `u64`
/// * `[u32; 2]`
/// * `[u8; 8]`
///
/// The arrays should be in little endian order.
pub struct StreamId([u32; Self::LEN]);
/// The arrays should be in little endian order. You should not need to use
/// this directly, as the methods in this crate that use this type call
/// `.into()` for you, so you only need to supply any of the above types.
pub struct U32x2([u32; Self::LEN]);
Copy link
Member

@tarcieri tarcieri Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a bit of an odd name with something with a specific purpose that's only used in a single part of the API, particularly since it's public. I'll admit "stream ID" also seemed like an odd name for a nonce to me, but seemed OK in the context of the RNG API, particularly given the legacy of rand_chacha.

I guess this is because it now has a dual role as both a nonce and a block position identifier?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes; this was discussed in #453. No comments on the name.


impl StreamId {
/// Amount of raw bytes backing a `StreamId` instance.
impl U32x2 {
/// Amount of raw bytes backing a `U32x2` instance.
const BYTES: usize = size_of::<Self>();

/// The length of the array contained within `StreamId`.
/// The length of the array contained within `U32x2`.
const LEN: usize = 2;
}

impl From<[u32; Self::LEN]> for StreamId {
impl From<[u32; Self::LEN]> for U32x2 {
#[inline]
fn from(value: [u32; Self::LEN]) -> Self {
let result = value.map(|v| v.to_le());
Self(result)
Self(value)
}
}

impl From<[u8; Self::BYTES]> for StreamId {
impl From<[u8; Self::BYTES]> for U32x2 {
#[inline]
fn from(value: [u8; Self::BYTES]) -> Self {
let mut result = Self(Default::default());
Expand All @@ -116,50 +116,39 @@ impl From<[u8; Self::BYTES]> for StreamId {
.iter_mut()
.zip(value.chunks_exact(size_of::<u32>()))
{
*cur = u32::from_le_bytes(chunk.try_into().unwrap()).to_le();
*cur = u32::from_le_bytes(chunk.try_into().unwrap());
}
result
}
}

impl From<u64> for StreamId {
impl From<u64> for U32x2 {
#[inline]
fn from(value: u64) -> Self {
let result: [u8; Self::BYTES] = value.to_le_bytes()[..Self::BYTES].try_into().unwrap();
result.into()
}
}

/// A wrapper for `block_pos`.
/// A wrapper for `stream_id`.
///
/// Can be constructed from any of the following:
/// * `u64`
/// * `[u8; 8]`
/// * `[u32; 2]`
/// * `[u8; 8]`
///
/// The arrays should be in little endian order.
pub struct BlockPos([u32; 2]);
pub type StreamId = U32x2;

impl From<u64> for BlockPos {
#[inline]
fn from(value: u64) -> Self {
Self([value as u32, (value >> 32) as u32])
}
}

impl From<[u8; 8]> for BlockPos {
#[inline]
fn from(value: [u8; 8]) -> Self {
u64::from_le_bytes(value).into()
}
}

impl From<[u32; 2]> for BlockPos {
#[inline]
fn from(value: [u32; 2]) -> Self {
Self(value)
}
}
/// A wrapper for `block_pos`.
///
/// Can be constructed from any of the following:
/// * `u64`
/// * `[u32; 2]`
/// * `[u8; 8]`
///
/// The arrays should be in little endian order.
pub type BlockPos = U32x2;

/// The results buffer that zeroizes on drop when the `zeroize` feature is enabled.
#[derive(Clone)]
Expand Down Expand Up @@ -392,9 +381,9 @@ macro_rules! impl_chacha_rng {

/// Set the offset from the start of the stream, in 32-bit words.
///
/// As with `get_word_pos`, we use a 36-bit number. When given a `u64`, we use
/// the least significant 4 bits as the RNG's index, and the 32 bits before it
/// as the block position.
/// As with `get_word_pos`, we use a 68-bit number. Since the generator
/// simply cycles at the end of its period (1 ZiB), we ignore the upper
/// 60 bits.
#[inline]
pub fn set_word_pos(&mut self, word_offset: u128) {
let index = (word_offset & 0b1111) as usize;
Expand All @@ -411,8 +400,8 @@ macro_rules! impl_chacha_rng {
///
/// This method takes any of the following:
/// * `u64`
/// * `[u8; 8]`
/// * `[u32; 2]`
/// * `[u8; 8]`
///
/// Note: the arrays should be in little endian order.
#[inline]
Expand All @@ -431,15 +420,15 @@ macro_rules! impl_chacha_rng {
self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32)
}

/// Set the stream number. The lower 96 bits are used and the rest are
/// Set the stream number. The lower 64 bits are used and the rest are
/// discarded. This method takes any of the following:
/// * `u64`
/// * `[u8; 8]`
/// * `[u32; 2]`
/// * `[u8; 8]`
///
/// Note: the arrays should be in little endian order.
///
/// This is initialized to zero; 2<sup>96</sup> unique streams of output
/// This is initialized to zero; 2<sup>64</sup> unique streams of output
/// are available per seed/key. In theory a 96-bit nonce can be used by
/// passing the last 64-bits to this function and using the first 32-bits as
/// the most significant half of the 64-bit counter, which may be set
Expand All @@ -463,13 +452,7 @@ macro_rules! impl_chacha_rng {
#[inline]
pub fn set_stream<S: Into<StreamId>>(&mut self, stream: S) {
let stream: StreamId = stream.into();
for (n, val) in self.core.core.0.state[14..BLOCK_WORDS as usize]
.as_mut()
.iter_mut()
.zip(stream.0.iter())
{
*n = val.to_le();
}
self.core.core.0.state[14..].copy_from_slice(&stream.0);
if self.core.index() != BUFFER_SIZE {
self.core.generate_and_set(self.core.index());
}
Expand Down Expand Up @@ -1200,20 +1183,71 @@ pub(crate) mod tests {
assert_ne!(&first_blocks[0..64 * 4], &result[64..]);
}

/// Counts how many bytes were incorrect, and returns:
///
/// (`index_of_first_incorrect_word`, `num_incorrect_bytes`)
fn count_incorrect_bytes(expected: &[u8], output: &[u8]) -> (Option<usize>, u32) {
assert_eq!(expected.len(), output.len());
let mut num_incorrect_bytes = 0;
let mut index_of_first_incorrect_word = None;
expected
.iter()
.enumerate()
.zip(output.iter())
.for_each(|((i, a), b)| {
if a.ne(b) {
if index_of_first_incorrect_word.is_none() {
index_of_first_incorrect_word = Some(i / 4)
}
num_incorrect_bytes += 1;
}
});
(index_of_first_incorrect_word, num_incorrect_bytes)
}

/// Test vector 8 from https://github.com/pyca/cryptography/blob/main/vectors/cryptography_vectors/ciphers/ChaCha20/counter-overflow.txt
#[test]
fn counter_overflow_1() {
fn counter_overflow_and_diagnostics() {
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let block_pos = 4294967295;
assert_eq!(block_pos, u32::MAX as u64);
rng.set_block_pos(4294967295);

let mut output = [0u8; 64 * 3];
rng.fill_bytes(&mut output);
let expected = hex!(
"ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d3db41d3aa0d329285de6f225e6e24bd59c9a17006943d5c9b680e3873bdc683a5819469899989690c281cd17c96159af0682b5b903468a61f50228cf09622b5a46f0f6efee15c8f1b198cb49d92b990867905159440cc723916dc0012826981039ce1766aa2542b05db3bd809ab142489d5dbfe1273e7399637b4b3213768aaa"
let mut output = [0u8; 64 * 4];
rng.fill_bytes(&mut output[..64 * 3]);
let block_before_overflow = hex!(
"ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d"
);
assert_eq!(expected, output);
let first_block_after_overflow = hex!(
"3db41d3aa0d329285de6f225e6e24bd59c9a17006943d5c9b680e3873bdc683a5819469899989690c281cd17c96159af0682b5b903468a61f50228cf09622b5a"
);
let second_block_after_overflow = hex!(
"46f0f6efee15c8f1b198cb49d92b990867905159440cc723916dc0012826981039ce1766aa2542b05db3bd809ab142489d5dbfe1273e7399637b4b3213768aaa"
);
assert!(
output[..64].eq(&block_before_overflow),
"The first parblock was incorrect before overflow, indicating that ChaCha was not implemented correctly for this backend. Check the rounds() fn or the functions that it calls"
);

rng.set_block_pos(u32::MAX as u64 - 1);
let mut skipped_blocks = [0u8; 64 * 3];
rng.fill_bytes(&mut skipped_blocks);
rng.fill_bytes(&mut output[64 * 3..]);

output.chunks_exact(64).enumerate().skip(1).zip(&[first_block_after_overflow, second_block_after_overflow, second_block_after_overflow]).for_each(|((i, a), b)| {
let (index_of_first_incorrect_word, num_incorrect_bytes) = count_incorrect_bytes(a, b);
let msg = if num_incorrect_bytes == 0 {
"The block was correct and this will not be shown"
} else if num_incorrect_bytes > 32 {
"Most of the block was incorrect, indicating an issue with the counter using 32-bit addition towards the beginning of fn rounds()"
} else if num_incorrect_bytes <= 8 && matches!(index_of_first_incorrect_word, Some(12 | 13)) {
"When the state was added to the results/res buffer at the end of fn rounds, the counter was probably incremented in 32-bit fashion for this parblock"
} else {
// this is probably unreachable in the event of a failed assertion, but it depends on the seed
"Some of the block was incorrect"
};
assert!(a.eq(b), "PARBLOCK #{} uses incorrect counter addition\nDiagnostic = {}\nnum_incorrect_bytes = {}\nindex_of_first_incorrect_word = {:?}", i + 1, msg, num_incorrect_bytes, index_of_first_incorrect_word);
});
}

/// Test vector 9 from https://github.com/pyca/cryptography/blob/main/vectors/cryptography_vectors/ciphers/ChaCha20/counter-overflow.txt
Expand Down