diff --git a/Cargo.lock b/Cargo.lock index aa21ea4c..fe00bbf6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -288,8 +288,7 @@ dependencies = [ [[package]] name = "rand_core" version = "0.10.0-rc-3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f66ee92bc15280519ef199a274fe0cafff4245d31bc39aaa31c011ad56cb1f05" +source = "git+https://github.com/rust-random/rand_core?branch=push-xzmlvzwurnrl#8935760efbe9a2323ebb37cce60efc7c837c3bfe" [[package]] name = "rand_xorshift" diff --git a/chacha20/Cargo.toml b/chacha20/Cargo.toml index b72ac425..bc15feb3 100644 --- a/chacha20/Cargo.toml +++ b/chacha20/Cargo.toml @@ -21,7 +21,8 @@ rand_core-compatible RNGs based on those ciphers. [dependencies] cfg-if = "1" cipher = { version = "0.5.0-rc.3", optional = true, features = ["stream-wrapper"] } -rand_core = { version = "0.10.0-rc-3", optional = true, default-features = false } +# rand_core = { version = "0.10.0-rc-3", optional = true, default-features = false } +rand_core = { git = "https://github.com/rust-random/rand_core", branch = "push-xzmlvzwurnrl", optional = true, default-features = false} # `zeroize` is an explicit dependency because this crate may be used without the `cipher` crate zeroize = { version = "1.8.1", optional = true, default-features = false } diff --git a/chacha20/src/rng.rs b/chacha20/src/rng.rs index eade9265..952e6dac 100644 --- a/chacha20/src/rng.rs +++ b/chacha20/src/rng.rs @@ -149,7 +149,7 @@ pub type BlockPos = U32x2; const BUFFER_SIZE: usize = 64; // NB. this must remain consistent with some currently hard-coded numbers in this module -const BUF_BLOCKS: u8 = BUFFER_SIZE as u8 >> 4; +const BUF_BLOCKS: u8 = BUFFER_SIZE as u8 / BLOCK_WORDS; impl ChaChaCore { /// Generates 4 blocks in parallel with avx2 & neon, but merely fills @@ -335,29 +335,37 @@ macro_rules! impl_chacha_rng { pub fn get_word_pos(&self) -> u128 { let mut block_counter = (u64::from(self.core.core.0.state[13]) << 32) | u64::from(self.core.core.0.state[12]); - block_counter = block_counter.wrapping_sub(BUF_BLOCKS as u64); + if self.core.word_offset() != 0 { + block_counter = block_counter.wrapping_sub(BUF_BLOCKS as u64); + } let word_pos = - block_counter as u128 * BLOCK_WORDS as u128 + self.core.index() as u128; + block_counter as u128 * BLOCK_WORDS as u128 + self.core.word_offset() as u128; // eliminate bits above the 68th bit word_pos & ((1 << 68) - 1) } - /// Set the offset from the start of the stream, in 32-bit words. + /// Set the offset from the start of the stream, in 32-bit words. **This + /// value will be erased when calling `set_stream()`, so call + /// `set_stream()` before calling `set_word_pos()`** if you intend on + /// using both of them together. /// /// As with `get_word_pos`, we use a 68-bit number. Since the generator /// simply cycles at the end of its period (1 ZiB), we ignore the upper /// 60 bits. #[inline] pub fn set_word_pos(&mut self, word_offset: u128) { - let index = (word_offset & 0b1111) as usize; - let counter = word_offset >> 4; + let index = (word_offset % BLOCK_WORDS as u128) as usize; + let counter = word_offset / BLOCK_WORDS as u128; //self.set_block_pos(counter as u64); self.core.core.0.state[12] = counter as u32; self.core.core.0.state[13] = (counter >> 32) as u32; - self.core.generate_and_set(index); + self.core.reset_and_skip(index); } - /// Set the block pos and reset the RNG's index. + /// Sets the block pos and resets the RNG's index. **This value will be + /// erased when calling `set_stream()`, so call `set_stream()` before + /// calling `set_block_pos()`** if you intend on using both of them + /// together. /// /// The word pos will be equal to `block_pos * 16 words per block`. /// @@ -370,7 +378,7 @@ macro_rules! impl_chacha_rng { #[inline] #[allow(unused)] pub fn set_block_pos>(&mut self, block_pos: B) { - self.core.reset(); + self.core.reset_and_skip(0); let block_pos = block_pos.into().0; self.core.core.0.state[12] = block_pos[0]; self.core.core.0.state[13] = block_pos[1] @@ -380,11 +388,20 @@ macro_rules! impl_chacha_rng { #[inline] #[allow(unused)] pub fn get_block_pos(&self) -> u64 { - self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32) + let counter = + self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32); + if self.core.word_offset() != 0 { + counter - BUF_BLOCKS as u64 + self.core.word_offset() as u64 / 16 + } else { + counter + } } - /// Set the stream number. The lower 64 bits are used and the rest are - /// discarded. This method takes any of the following: + /// Sets the stream number, resetting the `index` and `block_pos` to 0, + /// effectively setting the `word_pos` to 0 as well. Consider storing + /// the `word_pos` prior to calling this method. + /// + /// This method takes any of the following: /// * `u64` /// * `[u32; 2]` /// * `[u8; 8]` @@ -405,20 +422,23 @@ macro_rules! impl_chacha_rng { /// let mut rng = ChaCha20Rng::from_seed(seed); /// /// // set state[12] to 0, state[13] to 1, state[14] to 2, state[15] to 3 - /// rng.set_block_pos([0u32, 1u32]); /// rng.set_stream([2u32, 3u32]); + /// rng.set_block_pos([0u32, 1u32]); /// /// // confirm that state is set correctly /// assert_eq!(rng.get_block_pos(), 1 << 32); /// assert_eq!(rng.get_stream(), (3 << 32) + 2); + /// + /// // restoring `word_pos`/`index` after calling `set_stream`: + /// let word_pos = rng.get_word_pos(); + /// rng.set_stream(4); + /// rng.set_word_pos(word_pos); /// ``` #[inline] pub fn set_stream>(&mut self, stream: S) { let stream: StreamId = stream.into(); self.core.core.0.state[14..].copy_from_slice(&stream.0); - if self.core.index() != BUFFER_SIZE { - self.core.generate_and_set(self.core.index()); - } + self.set_block_pos(0); } /// Get the stream number. @@ -864,6 +884,11 @@ pub(crate) mod tests { } rng2.set_stream(51); // switch part way through block for _ in 7..16 { + assert_ne!(rng1.next_u64(), rng2.next_u64()); + } + rng1.set_stream(51); + rng2.set_stream(51); + for _ in 0..16 { assert_eq!(rng1.next_u64(), rng2.next_u64()); } } @@ -892,7 +917,7 @@ pub(crate) mod tests { fn test_chacha_word_pos_zero() { let mut rng = ChaChaRng::from_seed(Default::default()); assert_eq!(rng.core.core.0.state[12], 0); - assert_eq!(rng.core.index(), 64); + assert_eq!(rng.core.word_offset(), 0); assert_eq!(rng.get_word_pos(), 0); rng.set_word_pos(0); assert_eq!(rng.get_word_pos(), 0); @@ -1015,15 +1040,58 @@ pub(crate) mod tests { #[test] fn stream_id_endianness() { let mut rng = ChaCha20Rng::from_seed([0u8; 32]); + assert_eq!(rng.get_word_pos(), 0); rng.set_stream([3, 3333]); + assert_eq!(rng.get_word_pos(), 0); let expected = 1152671828; assert_eq!(rng.next_u32(), expected); + let mut word_pos = rng.get_word_pos(); + + assert_eq!(word_pos, 1); + + rng.set_stream(1234567); + assert_eq!(rng.get_word_pos(), 0); + let mut block = [0u32; 16]; + for word in 0..block.len() { + block[word] = rng.next_u32(); + } + assert_eq!(rng.get_word_pos(), 16); + assert_eq!(rng.core.word_offset(), 16); + assert_eq!(rng.get_block_pos(), 1); rng.set_stream(1234567); + let mut block_2 = [0u32; 16]; + for word in 0..block_2.len() { + block_2[word] = rng.next_u32(); + } + assert_eq!(rng.get_word_pos(), 16); + assert_eq!(rng.core.word_offset(), 16); + assert_eq!(rng.get_block_pos(), 1); + assert_eq!(block, block_2); + rng.set_stream(1234567); + assert_eq!(rng.get_block_pos(), 0); + assert_eq!(rng.get_word_pos(), 0); + let _ = rng.next_u32(); + + word_pos = rng.get_word_pos(); + assert_eq!(word_pos, 1); + let test = rng.next_u32(); let expected = 3110319182; - assert_eq!(rng.next_u32(), expected); + rng.set_word_pos(65); // old set_stream added 64 to the word_pos + assert!(rng.next_u32() == expected); + rng.set_word_pos(word_pos); + assert_eq!(rng.next_u32(), test); + + word_pos = rng.get_word_pos(); + assert_eq!(word_pos, 2); rng.set_stream([1, 2, 3, 4, 5, 6, 7, 8]); + rng.next_u32(); + rng.next_u32(); + let test = rng.next_u32(); + rng.set_word_pos(130); // old set_stream added another 64 to the word_pos let expected = 3790367479; assert_eq!(rng.next_u32(), expected); + rng.set_word_pos(word_pos); + assert_eq!(rng.next_u32(), test); } /// If this test fails, the backend may not be @@ -1043,7 +1111,7 @@ pub(crate) mod tests { let mut result = [0u8; 64 * 5]; rng.fill_bytes(&mut result); assert_eq!(first_blocks_end_word_pos, rng.get_word_pos()); - assert_eq!(first_blocks_end_block_counter, rng.get_block_pos() - 3); + assert_eq!(first_blocks_end_block_counter, rng.get_block_pos()); if first_blocks[0..64 * 4].ne(&result[64..]) { for (i, (a, b)) in first_blocks.iter().zip(result.iter().skip(64)).enumerate() {