Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chacha20/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ rand_core-compatible RNGs based on those ciphers.
[dependencies]
cfg-if = "1"
cipher = { version = "0.5.0-rc.3", optional = true, features = ["stream-wrapper"] }
rand_core = { version = "0.10.0-rc-3", optional = true, default-features = false }
# rand_core = { version = "0.10.0-rc-3", optional = true, default-features = false }
rand_core = { path = "../../rand_core", optional = true, default-features = false}

# `zeroize` is an explicit dependency because this crate may be used without the `cipher` crate
zeroize = { version = "1.8.1", optional = true, default-features = false }
Expand Down
106 changes: 87 additions & 19 deletions chacha20/src/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ pub type BlockPos = U32x2;
const BUFFER_SIZE: usize = 64;

// NB. this must remain consistent with some currently hard-coded numbers in this module
const BUF_BLOCKS: u8 = BUFFER_SIZE as u8 >> 4;
const BUF_BLOCKS: u8 = BUFFER_SIZE as u8 / BLOCK_WORDS;

impl<R: Rounds, V: Variant> ChaChaCore<R, V> {
/// Generates 4 blocks in parallel with avx2 & neon, but merely fills
Expand Down Expand Up @@ -335,29 +335,37 @@ macro_rules! impl_chacha_rng {
pub fn get_word_pos(&self) -> u128 {
let mut block_counter = (u64::from(self.core.core.0.state[13]) << 32)
| u64::from(self.core.core.0.state[12]);
block_counter = block_counter.wrapping_sub(BUF_BLOCKS as u64);
if self.core.word_offset() != 0 {
block_counter = block_counter.wrapping_sub(4);
}
let word_pos =
block_counter as u128 * BLOCK_WORDS as u128 + self.core.index() as u128;
block_counter as u128 * BLOCK_WORDS as u128 + self.core.word_offset() as u128;
// eliminate bits above the 68th bit
word_pos & ((1 << 68) - 1)
}

/// Set the offset from the start of the stream, in 32-bit words.
/// Set the offset from the start of the stream, in 32-bit words. **This
/// value will be erased when calling `set_stream()`, so call
/// `set_stream()` before calling `set_word_pos()`** if you intend on
/// using both of them together.
///
/// As with `get_word_pos`, we use a 68-bit number. Since the generator
/// simply cycles at the end of its period (1 ZiB), we ignore the upper
/// 60 bits.
#[inline]
pub fn set_word_pos(&mut self, word_offset: u128) {
let index = (word_offset & 0b1111) as usize;
let counter = word_offset >> 4;
let index = (word_offset % BLOCK_WORDS as u128) as usize;
let counter = word_offset / BLOCK_WORDS as u128;
//self.set_block_pos(counter as u64);
self.core.core.0.state[12] = counter as u32;
self.core.core.0.state[13] = (counter >> 32) as u32;
self.core.generate_and_set(index);
self.core.reset_and_skip(index);
}

/// Set the block pos and reset the RNG's index.
/// Sets the block pos and resets the RNG's index. **This value will be
/// erased when calling `set_stream()`, so call `set_stream()` before
/// calling `set_block_pos()`** if you intend on using both of them
/// together.
///
/// The word pos will be equal to `block_pos * 16 words per block`.
///
Expand All @@ -370,7 +378,7 @@ macro_rules! impl_chacha_rng {
#[inline]
#[allow(unused)]
pub fn set_block_pos<B: Into<BlockPos>>(&mut self, block_pos: B) {
self.core.reset();
self.core.reset_and_skip(0);
let block_pos = block_pos.into().0;
self.core.core.0.state[12] = block_pos[0];
self.core.core.0.state[13] = block_pos[1]
Expand All @@ -380,11 +388,19 @@ macro_rules! impl_chacha_rng {
#[inline]
#[allow(unused)]
pub fn get_block_pos(&self) -> u64 {
self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32)
let counter = self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32);
if self.core.word_offset() != 0 {
counter - 4 + self.core.word_offset() as u64 / 16
} else {
counter
}
}

/// Set the stream number. The lower 64 bits are used and the rest are
/// discarded. This method takes any of the following:
/// Sets the stream number, resetting the `index` and `block_pos` to 0,
/// effectively setting the `word_pos` to 0 as well. Consider storing
/// the `word_pos` prior to calling this method.
///
/// This method takes any of the following:
/// * `u64`
/// * `[u32; 2]`
/// * `[u8; 8]`
Expand All @@ -405,20 +421,23 @@ macro_rules! impl_chacha_rng {
/// let mut rng = ChaCha20Rng::from_seed(seed);
///
/// // set state[12] to 0, state[13] to 1, state[14] to 2, state[15] to 3
/// rng.set_block_pos([0u32, 1u32]);
/// rng.set_stream([2u32, 3u32]);
/// rng.set_block_pos([0u32, 1u32]);
///
/// // confirm that state is set correctly
/// assert_eq!(rng.get_block_pos(), 1 << 32);
/// assert_eq!(rng.get_stream(), (3 << 32) + 2);
///
/// // restoring `word_pos`/`index` after calling `set_stream`:
/// let word_pos = rng.get_word_pos();
/// rng.set_stream(4);
/// rng.set_word_pos(word_pos);
/// ```
#[inline]
pub fn set_stream<S: Into<StreamId>>(&mut self, stream: S) {
let stream: StreamId = stream.into();
self.core.core.0.state[14..].copy_from_slice(&stream.0);
if self.core.index() != BUFFER_SIZE {
self.core.generate_and_set(self.core.index());
}
self.set_block_pos(0);
}

/// Get the stream number.
Expand Down Expand Up @@ -864,6 +883,11 @@ pub(crate) mod tests {
}
rng2.set_stream(51); // switch part way through block
for _ in 7..16 {
assert_ne!(rng1.next_u64(), rng2.next_u64());
}
rng1.set_stream(51);
rng2.set_stream(51);
for _ in 0..16 {
assert_eq!(rng1.next_u64(), rng2.next_u64());
}
}
Expand Down Expand Up @@ -892,7 +916,7 @@ pub(crate) mod tests {
fn test_chacha_word_pos_zero() {
let mut rng = ChaChaRng::from_seed(Default::default());
assert_eq!(rng.core.core.0.state[12], 0);
assert_eq!(rng.core.index(), 64);
assert_eq!(rng.core.word_offset(), 0);
assert_eq!(rng.get_word_pos(), 0);
rng.set_word_pos(0);
assert_eq!(rng.get_word_pos(), 0);
Expand Down Expand Up @@ -1015,15 +1039,59 @@ pub(crate) mod tests {
#[test]
fn stream_id_endianness() {
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
assert_eq!(rng.get_word_pos(), 0);
rng.set_stream([3, 3333]);
assert_eq!(rng.get_word_pos(), 0);
let expected = 1152671828;
assert_eq!(rng.next_u32(), expected);
let mut word_pos = rng.get_word_pos();

assert_eq!(word_pos, 1);

rng.set_stream(1234567);
assert_eq!(rng.get_word_pos(), 0);
let mut block = [0u32; 16];
for word in 0..block.len() {
block[word] = rng.next_u32();
}
assert_eq!(rng.get_word_pos(), 16);
assert_eq!(rng.core.word_offset(), 16);
assert_eq!(rng.get_block_pos(), 1);
rng.set_stream(1234567);
let mut block_2 = [0u32; 16];
for word in 0..block_2.len() {
block_2[word] = rng.next_u32();
}
assert_eq!(rng.get_word_pos(), 16);
assert_eq!(rng.core.word_offset(), 16);
assert_eq!(rng.get_block_pos(), 1);
assert_eq!(block, block_2);
rng.set_stream(1234567);
assert_eq!(rng.get_block_pos(), 0);
assert_eq!(rng.get_word_pos(), 0);
let _ = rng.next_u32();

word_pos = rng.get_word_pos();
assert_eq!(word_pos, 1);
let test = rng.next_u32();
let expected = 3110319182;
assert_eq!(rng.next_u32(), expected);
rng.set_word_pos(65); // old set_stream added 64 to the word_pos
assert!(rng.next_u32() == expected);
rng.set_word_pos(word_pos);
assert_eq!(rng.next_u32(), test);


word_pos = rng.get_word_pos();
assert_eq!(word_pos, 2);
rng.set_stream([1, 2, 3, 4, 5, 6, 7, 8]);
rng.next_u32();
rng.next_u32();
let test = rng.next_u32();
rng.set_word_pos(130); // old set_stream added another 64 to the word_pos
let expected = 3790367479;
assert_eq!(rng.next_u32(), expected);
rng.set_word_pos(word_pos);
assert_eq!(rng.next_u32(), test);
}

/// If this test fails, the backend may not be
Expand All @@ -1043,7 +1111,7 @@ pub(crate) mod tests {
let mut result = [0u8; 64 * 5];
rng.fill_bytes(&mut result);
assert_eq!(first_blocks_end_word_pos, rng.get_word_pos());
assert_eq!(first_blocks_end_block_counter, rng.get_block_pos() - 3);
assert_eq!(first_blocks_end_block_counter, rng.get_block_pos());

if first_blocks[0..64 * 4].ne(&result[64..]) {
for (i, (a, b)) in first_blocks.iter().zip(result.iter().skip(64)).enumerate() {
Expand Down
Loading