Skip to content

Commit 8f63e23

Browse files
authored
chacha20 - minor improvements (#456)
- remove unnecessary endian conversions - update docs - consolidate `StreamId` and `BlockPos` with type aliases - update `counter_overflow_1` test to be more descriptive when the 64-bit counter is not implemented correctly for the 4th parblock - may be unnecessary, but it provides a better error message than the other test that would fail in the same scenario fixes #453
1 parent 55f2007 commit 8f63e23

File tree

1 file changed

+94
-60
lines changed

1 file changed

+94
-60
lines changed

chacha20/src/rng.rs

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -81,33 +81,33 @@ impl Debug for Seed {
8181
}
8282
}
8383

84-
/// A wrapper for `stream_id`.
85-
///
86-
/// Can be constructed from any of the following:
87-
/// * `[u32; 3]`
88-
/// * `[u8; 12]`
89-
/// * `u128`
84+
/// A wrapper around 64 bits of data that can be constructed from any of the
85+
/// following:
86+
/// * `u64`
87+
/// * `[u32; 2]`
88+
/// * `[u8; 8]`
9089
///
91-
/// The arrays should be in little endian order.
92-
pub struct StreamId([u32; Self::LEN]);
90+
/// The arrays should be in little endian order. You should not need to use
91+
/// this directly, as the methods in this crate that use this type call
92+
/// `.into()` for you, so you only need to supply any of the above types.
93+
pub struct U32x2([u32; Self::LEN]);
9394

94-
impl StreamId {
95-
/// Amount of raw bytes backing a `StreamId` instance.
95+
impl U32x2 {
96+
/// Amount of raw bytes backing a `U32x2` instance.
9697
const BYTES: usize = size_of::<Self>();
9798

98-
/// The length of the array contained within `StreamId`.
99+
/// The length of the array contained within `U32x2`.
99100
const LEN: usize = 2;
100101
}
101102

102-
impl From<[u32; Self::LEN]> for StreamId {
103+
impl From<[u32; Self::LEN]> for U32x2 {
103104
#[inline]
104105
fn from(value: [u32; Self::LEN]) -> Self {
105-
let result = value.map(|v| v.to_le());
106-
Self(result)
106+
Self(value)
107107
}
108108
}
109109

110-
impl From<[u8; Self::BYTES]> for StreamId {
110+
impl From<[u8; Self::BYTES]> for U32x2 {
111111
#[inline]
112112
fn from(value: [u8; Self::BYTES]) -> Self {
113113
let mut result = Self(Default::default());
@@ -116,50 +116,39 @@ impl From<[u8; Self::BYTES]> for StreamId {
116116
.iter_mut()
117117
.zip(value.chunks_exact(size_of::<u32>()))
118118
{
119-
*cur = u32::from_le_bytes(chunk.try_into().unwrap()).to_le();
119+
*cur = u32::from_le_bytes(chunk.try_into().unwrap());
120120
}
121121
result
122122
}
123123
}
124124

125-
impl From<u64> for StreamId {
125+
impl From<u64> for U32x2 {
126126
#[inline]
127127
fn from(value: u64) -> Self {
128128
let result: [u8; Self::BYTES] = value.to_le_bytes()[..Self::BYTES].try_into().unwrap();
129129
result.into()
130130
}
131131
}
132132

133-
/// A wrapper for `block_pos`.
133+
/// A wrapper for `stream_id`.
134134
///
135135
/// Can be constructed from any of the following:
136136
/// * `u64`
137-
/// * `[u8; 8]`
138137
/// * `[u32; 2]`
138+
/// * `[u8; 8]`
139139
///
140140
/// The arrays should be in little endian order.
141-
pub struct BlockPos([u32; 2]);
141+
pub type StreamId = U32x2;
142142

143-
impl From<u64> for BlockPos {
144-
#[inline]
145-
fn from(value: u64) -> Self {
146-
Self([value as u32, (value >> 32) as u32])
147-
}
148-
}
149-
150-
impl From<[u8; 8]> for BlockPos {
151-
#[inline]
152-
fn from(value: [u8; 8]) -> Self {
153-
u64::from_le_bytes(value).into()
154-
}
155-
}
156-
157-
impl From<[u32; 2]> for BlockPos {
158-
#[inline]
159-
fn from(value: [u32; 2]) -> Self {
160-
Self(value)
161-
}
162-
}
143+
/// A wrapper for `block_pos`.
144+
///
145+
/// Can be constructed from any of the following:
146+
/// * `u64`
147+
/// * `[u32; 2]`
148+
/// * `[u8; 8]`
149+
///
150+
/// The arrays should be in little endian order.
151+
pub type BlockPos = U32x2;
163152

164153
/// The results buffer that zeroizes on drop when the `zeroize` feature is enabled.
165154
#[derive(Clone)]
@@ -392,9 +381,9 @@ macro_rules! impl_chacha_rng {
392381

393382
/// Set the offset from the start of the stream, in 32-bit words.
394383
///
395-
/// As with `get_word_pos`, we use a 36-bit number. When given a `u64`, we use
396-
/// the least significant 4 bits as the RNG's index, and the 32 bits before it
397-
/// as the block position.
384+
/// As with `get_word_pos`, we use a 68-bit number. Since the generator
385+
/// simply cycles at the end of its period (1 ZiB), we ignore the upper
386+
/// 60 bits.
398387
#[inline]
399388
pub fn set_word_pos(&mut self, word_offset: u128) {
400389
let index = (word_offset & 0b1111) as usize;
@@ -411,8 +400,8 @@ macro_rules! impl_chacha_rng {
411400
///
412401
/// This method takes any of the following:
413402
/// * `u64`
414-
/// * `[u8; 8]`
415403
/// * `[u32; 2]`
404+
/// * `[u8; 8]`
416405
///
417406
/// Note: the arrays should be in little endian order.
418407
#[inline]
@@ -431,15 +420,15 @@ macro_rules! impl_chacha_rng {
431420
self.core.core.0.state[12] as u64 | ((self.core.core.0.state[13] as u64) << 32)
432421
}
433422

434-
/// Set the stream number. The lower 96 bits are used and the rest are
423+
/// Set the stream number. The lower 64 bits are used and the rest are
435424
/// discarded. This method takes any of the following:
436425
/// * `u64`
437-
/// * `[u8; 8]`
438426
/// * `[u32; 2]`
427+
/// * `[u8; 8]`
439428
///
440429
/// Note: the arrays should be in little endian order.
441430
///
442-
/// This is initialized to zero; 2<sup>96</sup> unique streams of output
431+
/// This is initialized to zero; 2<sup>64</sup> unique streams of output
443432
/// are available per seed/key. In theory a 96-bit nonce can be used by
444433
/// passing the last 64-bits to this function and using the first 32-bits as
445434
/// the most significant half of the 64-bit counter, which may be set
@@ -463,13 +452,7 @@ macro_rules! impl_chacha_rng {
463452
#[inline]
464453
pub fn set_stream<S: Into<StreamId>>(&mut self, stream: S) {
465454
let stream: StreamId = stream.into();
466-
for (n, val) in self.core.core.0.state[14..BLOCK_WORDS as usize]
467-
.as_mut()
468-
.iter_mut()
469-
.zip(stream.0.iter())
470-
{
471-
*n = val.to_le();
472-
}
455+
self.core.core.0.state[14..].copy_from_slice(&stream.0);
473456
if self.core.index() != BUFFER_SIZE {
474457
self.core.generate_and_set(self.core.index());
475458
}
@@ -1200,20 +1183,71 @@ pub(crate) mod tests {
12001183
assert_ne!(&first_blocks[0..64 * 4], &result[64..]);
12011184
}
12021185

1186+
/// Counts how many bytes were incorrect, and returns:
1187+
///
1188+
/// (`index_of_first_incorrect_word`, `num_incorrect_bytes`)
1189+
fn count_incorrect_bytes(expected: &[u8], output: &[u8]) -> (Option<usize>, u32) {
1190+
assert_eq!(expected.len(), output.len());
1191+
let mut num_incorrect_bytes = 0;
1192+
let mut index_of_first_incorrect_word = None;
1193+
expected
1194+
.iter()
1195+
.enumerate()
1196+
.zip(output.iter())
1197+
.for_each(|((i, a), b)| {
1198+
if a.ne(b) {
1199+
if index_of_first_incorrect_word.is_none() {
1200+
index_of_first_incorrect_word = Some(i / 4)
1201+
}
1202+
num_incorrect_bytes += 1;
1203+
}
1204+
});
1205+
(index_of_first_incorrect_word, num_incorrect_bytes)
1206+
}
1207+
12031208
/// Test vector 8 from https://github.com/pyca/cryptography/blob/main/vectors/cryptography_vectors/ciphers/ChaCha20/counter-overflow.txt
12041209
#[test]
1205-
fn counter_overflow_1() {
1210+
fn counter_overflow_and_diagnostics() {
12061211
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
12071212
let block_pos = 4294967295;
12081213
assert_eq!(block_pos, u32::MAX as u64);
12091214
rng.set_block_pos(4294967295);
12101215

1211-
let mut output = [0u8; 64 * 3];
1212-
rng.fill_bytes(&mut output);
1213-
let expected = hex!(
1214-
"ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d3db41d3aa0d329285de6f225e6e24bd59c9a17006943d5c9b680e3873bdc683a5819469899989690c281cd17c96159af0682b5b903468a61f50228cf09622b5a46f0f6efee15c8f1b198cb49d92b990867905159440cc723916dc0012826981039ce1766aa2542b05db3bd809ab142489d5dbfe1273e7399637b4b3213768aaa"
1216+
let mut output = [0u8; 64 * 4];
1217+
rng.fill_bytes(&mut output[..64 * 3]);
1218+
let block_before_overflow = hex!(
1219+
"ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d"
12151220
);
1216-
assert_eq!(expected, output);
1221+
let first_block_after_overflow = hex!(
1222+
"3db41d3aa0d329285de6f225e6e24bd59c9a17006943d5c9b680e3873bdc683a5819469899989690c281cd17c96159af0682b5b903468a61f50228cf09622b5a"
1223+
);
1224+
let second_block_after_overflow = hex!(
1225+
"46f0f6efee15c8f1b198cb49d92b990867905159440cc723916dc0012826981039ce1766aa2542b05db3bd809ab142489d5dbfe1273e7399637b4b3213768aaa"
1226+
);
1227+
assert!(
1228+
output[..64].eq(&block_before_overflow),
1229+
"The first parblock was incorrect before overflow, indicating that ChaCha was not implemented correctly for this backend. Check the rounds() fn or the functions that it calls"
1230+
);
1231+
1232+
rng.set_block_pos(u32::MAX as u64 - 1);
1233+
let mut skipped_blocks = [0u8; 64 * 3];
1234+
rng.fill_bytes(&mut skipped_blocks);
1235+
rng.fill_bytes(&mut output[64 * 3..]);
1236+
1237+
output.chunks_exact(64).enumerate().skip(1).zip(&[first_block_after_overflow, second_block_after_overflow, second_block_after_overflow]).for_each(|((i, a), b)| {
1238+
let (index_of_first_incorrect_word, num_incorrect_bytes) = count_incorrect_bytes(a, b);
1239+
let msg = if num_incorrect_bytes == 0 {
1240+
"The block was correct and this will not be shown"
1241+
} else if num_incorrect_bytes > 32 {
1242+
"Most of the block was incorrect, indicating an issue with the counter using 32-bit addition towards the beginning of fn rounds()"
1243+
} else if num_incorrect_bytes <= 8 && matches!(index_of_first_incorrect_word, Some(12 | 13)) {
1244+
"When the state was added to the results/res buffer at the end of fn rounds, the counter was probably incremented in 32-bit fashion for this parblock"
1245+
} else {
1246+
// this is probably unreachable in the event of a failed assertion, but it depends on the seed
1247+
"Some of the block was incorrect"
1248+
};
1249+
assert!(a.eq(b), "PARBLOCK #{} uses incorrect counter addition\nDiagnostic = {}\nnum_incorrect_bytes = {}\nindex_of_first_incorrect_word = {:?}", i + 1, msg, num_incorrect_bytes, index_of_first_incorrect_word);
1250+
});
12171251
}
12181252

12191253
/// Test vector 9 from https://github.com/pyca/cryptography/blob/main/vectors/cryptography_vectors/ciphers/ChaCha20/counter-overflow.txt

0 commit comments

Comments
 (0)