@@ -81,33 +81,33 @@ impl Debug for Seed {
8181 }
8282}
8383
84- /// A wrapper for `stream_id`.
85- ///
86- /// Can be constructed from any of the following:
87- /// * `[u32; 3]`
88- /// * `[u8; 12]`
89- /// * `u128`
84+ /// A wrapper around 64 bits of data that can be constructed from any of the
85+ /// following:
86+ /// * `u64`
87+ /// * `[u32; 2]`
88+ /// * `[u8; 8]`
9089///
91- /// The arrays should be in little endian order.
92- pub struct StreamId ( [ u32 ; Self :: LEN ] ) ;
90+ /// The arrays should be in little endian order. You should not need to use
91+ /// this directly, as the methods in this crate that use this type call
92+ /// `.into()` for you, so you only need to supply any of the above types.
93+ pub struct U32x2 ( [ u32 ; Self :: LEN ] ) ;
9394
94- impl StreamId {
95- /// Amount of raw bytes backing a `StreamId ` instance.
95+ impl U32x2 {
96+ /// Amount of raw bytes backing a `U32x2 ` instance.
9697 const BYTES : usize = size_of :: < Self > ( ) ;
9798
98- /// The length of the array contained within `StreamId `.
99+ /// The length of the array contained within `U32x2 `.
99100 const LEN : usize = 2 ;
100101}
101102
102- impl From < [ u32 ; Self :: LEN ] > for StreamId {
103+ impl From < [ u32 ; Self :: LEN ] > for U32x2 {
103104 #[ inline]
104105 fn from ( value : [ u32 ; Self :: LEN ] ) -> Self {
105- let result = value. map ( |v| v. to_le ( ) ) ;
106- Self ( result)
106+ Self ( value)
107107 }
108108}
109109
110- impl From < [ u8 ; Self :: BYTES ] > for StreamId {
110+ impl From < [ u8 ; Self :: BYTES ] > for U32x2 {
111111 #[ inline]
112112 fn from ( value : [ u8 ; Self :: BYTES ] ) -> Self {
113113 let mut result = Self ( Default :: default ( ) ) ;
@@ -116,50 +116,39 @@ impl From<[u8; Self::BYTES]> for StreamId {
116116 . iter_mut ( )
117117 . zip ( value. chunks_exact ( size_of :: < u32 > ( ) ) )
118118 {
119- * cur = u32:: from_le_bytes ( chunk. try_into ( ) . unwrap ( ) ) . to_le ( ) ;
119+ * cur = u32:: from_le_bytes ( chunk. try_into ( ) . unwrap ( ) ) ;
120120 }
121121 result
122122 }
123123}
124124
125- impl From < u64 > for StreamId {
125+ impl From < u64 > for U32x2 {
126126 #[ inline]
127127 fn from ( value : u64 ) -> Self {
128128 let result: [ u8 ; Self :: BYTES ] = value. to_le_bytes ( ) [ ..Self :: BYTES ] . try_into ( ) . unwrap ( ) ;
129129 result. into ( )
130130 }
131131}
132132
133- /// A wrapper for `block_pos `.
133+ /// A wrapper for `stream_id `.
134134///
135135/// Can be constructed from any of the following:
136136/// * `u64`
137- /// * `[u8; 8]`
138137/// * `[u32; 2]`
138+ /// * `[u8; 8]`
139139///
140140/// The arrays should be in little endian order.
141- pub struct BlockPos ( [ u32 ; 2 ] ) ;
141+ pub type StreamId = U32x2 ;
142142
143- impl From < u64 > for BlockPos {
144- #[ inline]
145- fn from ( value : u64 ) -> Self {
146- Self ( [ value as u32 , ( value >> 32 ) as u32 ] )
147- }
148- }
149-
150- impl From < [ u8 ; 8 ] > for BlockPos {
151- #[ inline]
152- fn from ( value : [ u8 ; 8 ] ) -> Self {
153- u64:: from_le_bytes ( value) . into ( )
154- }
155- }
156-
157- impl From < [ u32 ; 2 ] > for BlockPos {
158- #[ inline]
159- fn from ( value : [ u32 ; 2 ] ) -> Self {
160- Self ( value)
161- }
162- }
143+ /// A wrapper for `block_pos`.
144+ ///
145+ /// Can be constructed from any of the following:
146+ /// * `u64`
147+ /// * `[u32; 2]`
148+ /// * `[u8; 8]`
149+ ///
150+ /// The arrays should be in little endian order.
151+ pub type BlockPos = U32x2 ;
163152
164153/// The results buffer that zeroizes on drop when the `zeroize` feature is enabled.
165154#[ derive( Clone ) ]
@@ -392,9 +381,9 @@ macro_rules! impl_chacha_rng {
392381
393382 /// Set the offset from the start of the stream, in 32-bit words.
394383 ///
395- /// As with `get_word_pos`, we use a 36 -bit number. When given a `u64`, we use
396- /// the least significant 4 bits as the RNG's index, and the 32 bits before it
397- /// as the block position .
384+ /// As with `get_word_pos`, we use a 68 -bit number. Since the generator
385+ /// simply cycles at the end of its period (1 ZiB), we ignore the upper
386+ /// 60 bits .
398387 #[ inline]
399388 pub fn set_word_pos( & mut self , word_offset: u128 ) {
400389 let index = ( word_offset & 0b1111 ) as usize ;
@@ -411,8 +400,8 @@ macro_rules! impl_chacha_rng {
411400 ///
412401 /// This method takes any of the following:
413402 /// * `u64`
414- /// * `[u8; 8]`
415403 /// * `[u32; 2]`
404+ /// * `[u8; 8]`
416405 ///
417406 /// Note: the arrays should be in little endian order.
418407 #[ inline]
@@ -431,15 +420,15 @@ macro_rules! impl_chacha_rng {
431420 self . core. core. 0 . state[ 12 ] as u64 | ( ( self . core. core. 0 . state[ 13 ] as u64 ) << 32 )
432421 }
433422
434- /// Set the stream number. The lower 96 bits are used and the rest are
423+ /// Set the stream number. The lower 64 bits are used and the rest are
435424 /// discarded. This method takes any of the following:
436425 /// * `u64`
437- /// * `[u8; 8]`
438426 /// * `[u32; 2]`
427+ /// * `[u8; 8]`
439428 ///
440429 /// Note: the arrays should be in little endian order.
441430 ///
442- /// This is initialized to zero; 2<sup>96 </sup> unique streams of output
431+ /// This is initialized to zero; 2<sup>64 </sup> unique streams of output
443432 /// are available per seed/key. In theory a 96-bit nonce can be used by
444433 /// passing the last 64-bits to this function and using the first 32-bits as
445434 /// the most significant half of the 64-bit counter, which may be set
@@ -463,13 +452,7 @@ macro_rules! impl_chacha_rng {
463452 #[ inline]
464453 pub fn set_stream<S : Into <StreamId >>( & mut self , stream: S ) {
465454 let stream: StreamId = stream. into( ) ;
466- for ( n, val) in self . core. core. 0 . state[ 14 ..BLOCK_WORDS as usize ]
467- . as_mut( )
468- . iter_mut( )
469- . zip( stream. 0 . iter( ) )
470- {
471- * n = val. to_le( ) ;
472- }
455+ self . core. core. 0 . state[ 14 ..] . copy_from_slice( & stream. 0 ) ;
473456 if self . core. index( ) != BUFFER_SIZE {
474457 self . core. generate_and_set( self . core. index( ) ) ;
475458 }
@@ -1200,20 +1183,71 @@ pub(crate) mod tests {
12001183 assert_ne ! ( & first_blocks[ 0 ..64 * 4 ] , & result[ 64 ..] ) ;
12011184 }
12021185
1186+ /// Counts how many bytes were incorrect, and returns:
1187+ ///
1188+ /// (`index_of_first_incorrect_word`, `num_incorrect_bytes`)
1189+ fn count_incorrect_bytes ( expected : & [ u8 ] , output : & [ u8 ] ) -> ( Option < usize > , u32 ) {
1190+ assert_eq ! ( expected. len( ) , output. len( ) ) ;
1191+ let mut num_incorrect_bytes = 0 ;
1192+ let mut index_of_first_incorrect_word = None ;
1193+ expected
1194+ . iter ( )
1195+ . enumerate ( )
1196+ . zip ( output. iter ( ) )
1197+ . for_each ( |( ( i, a) , b) | {
1198+ if a. ne ( b) {
1199+ if index_of_first_incorrect_word. is_none ( ) {
1200+ index_of_first_incorrect_word = Some ( i / 4 )
1201+ }
1202+ num_incorrect_bytes += 1 ;
1203+ }
1204+ } ) ;
1205+ ( index_of_first_incorrect_word, num_incorrect_bytes)
1206+ }
1207+
12031208 /// Test vector 8 from https://github.com/pyca/cryptography/blob/main/vectors/cryptography_vectors/ciphers/ChaCha20/counter-overflow.txt
12041209 #[ test]
1205- fn counter_overflow_1 ( ) {
1210+ fn counter_overflow_and_diagnostics ( ) {
12061211 let mut rng = ChaCha20Rng :: from_seed ( [ 0u8 ; 32 ] ) ;
12071212 let block_pos = 4294967295 ;
12081213 assert_eq ! ( block_pos, u32 :: MAX as u64 ) ;
12091214 rng. set_block_pos ( 4294967295 ) ;
12101215
1211- let mut output = [ 0u8 ; 64 * 3 ] ;
1212- rng. fill_bytes ( & mut output) ;
1213- let expected = hex ! (
1214- "ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d3db41d3aa0d329285de6f225e6e24bd59c9a17006943d5c9b680e3873bdc683a5819469899989690c281cd17c96159af0682b5b903468a61f50228cf09622b5a46f0f6efee15c8f1b198cb49d92b990867905159440cc723916dc0012826981039ce1766aa2542b05db3bd809ab142489d5dbfe1273e7399637b4b3213768aaa "
1216+ let mut output = [ 0u8 ; 64 * 4 ] ;
1217+ rng. fill_bytes ( & mut output[ .. 64 * 3 ] ) ;
1218+ let block_before_overflow = hex ! (
1219+ "ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d "
12151220 ) ;
1216- assert_eq ! ( expected, output) ;
1221+ let first_block_after_overflow = hex ! (
1222+ "3db41d3aa0d329285de6f225e6e24bd59c9a17006943d5c9b680e3873bdc683a5819469899989690c281cd17c96159af0682b5b903468a61f50228cf09622b5a"
1223+ ) ;
1224+ let second_block_after_overflow = hex ! (
1225+ "46f0f6efee15c8f1b198cb49d92b990867905159440cc723916dc0012826981039ce1766aa2542b05db3bd809ab142489d5dbfe1273e7399637b4b3213768aaa"
1226+ ) ;
1227+ assert ! (
1228+ output[ ..64 ] . eq( & block_before_overflow) ,
1229+ "The first parblock was incorrect before overflow, indicating that ChaCha was not implemented correctly for this backend. Check the rounds() fn or the functions that it calls"
1230+ ) ;
1231+
1232+ rng. set_block_pos ( u32:: MAX as u64 - 1 ) ;
1233+ let mut skipped_blocks = [ 0u8 ; 64 * 3 ] ;
1234+ rng. fill_bytes ( & mut skipped_blocks) ;
1235+ rng. fill_bytes ( & mut output[ 64 * 3 ..] ) ;
1236+
1237+ output. chunks_exact ( 64 ) . enumerate ( ) . skip ( 1 ) . zip ( & [ first_block_after_overflow, second_block_after_overflow, second_block_after_overflow] ) . for_each ( |( ( i, a) , b) | {
1238+ let ( index_of_first_incorrect_word, num_incorrect_bytes) = count_incorrect_bytes ( a, b) ;
1239+ let msg = if num_incorrect_bytes == 0 {
1240+ "The block was correct and this will not be shown"
1241+ } else if num_incorrect_bytes > 32 {
1242+ "Most of the block was incorrect, indicating an issue with the counter using 32-bit addition towards the beginning of fn rounds()"
1243+ } else if num_incorrect_bytes <= 8 && matches ! ( index_of_first_incorrect_word, Some ( 12 | 13 ) ) {
1244+ "When the state was added to the results/res buffer at the end of fn rounds, the counter was probably incremented in 32-bit fashion for this parblock"
1245+ } else {
1246+ // this is probably unreachable in the event of a failed assertion, but it depends on the seed
1247+ "Some of the block was incorrect"
1248+ } ;
1249+ assert ! ( a. eq( b) , "PARBLOCK #{} uses incorrect counter addition\n Diagnostic = {}\n num_incorrect_bytes = {}\n index_of_first_incorrect_word = {:?}" , i + 1 , msg, num_incorrect_bytes, index_of_first_incorrect_word) ;
1250+ } ) ;
12171251 }
12181252
12191253 /// Test vector 9 from https://github.com/pyca/cryptography/blob/main/vectors/cryptography_vectors/ciphers/ChaCha20/counter-overflow.txt
0 commit comments