@@ -21,11 +21,12 @@ pub use aead::{
2121use aead:: { PostfixTagged , array:: ArraySize } ;
2222use cipher:: {
2323 BlockCipherDecrypt , BlockCipherEncrypt , BlockSizeUser ,
24- consts:: { U12 , U16 } ,
25- typenum:: Unsigned ,
24+ consts:: { U2 , U12 , U16 } ,
25+ typenum:: Prod ,
2626} ;
2727use core:: marker:: PhantomData ;
2828use dbl:: Dbl ;
29+ use inout:: { InOut , InOutBuf } ;
2930use subtle:: ConstantTimeEq ;
3031
3132/// Number of L values to be precomputed. Precomputing m values, allows
@@ -55,7 +56,9 @@ pub type Nonce<NonceSize> = Array<u8, NonceSize>;
5556/// OCB3 tag
5657pub type Tag < TagSize > = Array < u8 , TagSize > ;
5758
58- pub ( crate ) type Block = Array < u8 , U16 > ;
59+ type BlockSize = U16 ;
60+ pub ( crate ) type Block = Array < u8 , BlockSize > ;
61+ type DoubleBlock = Array < u8 , Prod < BlockSize , U2 > > ;
5962
6063mod sealed {
6164 use aead:: array:: {
@@ -210,34 +213,36 @@ where
210213 associated_data : & [ u8 ] ,
211214 buffer : & mut [ u8 ] ,
212215 ) -> aead:: Result < aead:: Tag < Self > > {
216+ let buffer = InOutBuf :: from ( buffer) ;
213217 if ( buffer. len ( ) > P_MAX ) || ( associated_data. len ( ) > A_MAX ) {
214218 unimplemented ! ( )
215219 }
216220
217221 // First, try to process many blocks at once.
218- let ( processed_bytes , mut offset_i, mut checksum_i) = self . wide_encrypt ( nonce, buffer) ;
222+ let ( tail , index , mut offset_i, mut checksum_i) = self . wide_encrypt ( nonce, buffer) ;
219223
220- let mut i = ( processed_bytes / 16 ) + 1 ;
224+ let mut i = index ;
221225
222226 // Then, process the remaining blocks.
223- for p_i in Block :: slice_as_chunks_mut ( & mut buffer[ processed_bytes..] ) . 0 {
227+ let ( blocks, mut tail) : ( InOutBuf < ' _ , ' _ , Block > , _ ) = tail. into_chunks ( ) ;
228+
229+ for p_i in blocks {
224230 // offset_i = offset_{i-1} xor L_{ntz(i)}
225231 inplace_xor ( & mut offset_i, & self . ll [ ntz ( i) ] ) ;
226232 // checksum_i = checksum_{i-1} xor p_i
227- inplace_xor ( & mut checksum_i, p_i) ;
233+ inplace_xor ( & mut checksum_i, p_i. get_in ( ) ) ;
228234 // c_i = offset_i xor ENCIPHER(K, p_i xor offset_i)
229- let c_i = p_i;
230- inplace_xor ( c_i, & offset_i) ;
231- self . cipher . encrypt_block ( c_i) ;
232- inplace_xor ( c_i, & offset_i) ;
235+ let mut c_i = p_i;
236+ c_i. xor_in2out ( & offset_i) ;
237+ self . cipher . encrypt_block ( c_i. get_out ( ) ) ;
238+ inplace_xor ( c_i. get_out ( ) , & offset_i) ;
233239
234240 i += 1 ;
235241 }
236242
237243 // Process any partial blocks.
238- if ( buffer. len ( ) % 16 ) != 0 {
239- let processed_bytes = ( i - 1 ) * 16 ;
240- let remaining_bytes = buffer. len ( ) - processed_bytes;
244+ if !tail. is_empty ( ) {
245+ let remaining_bytes = tail. len ( ) ;
241246
242247 // offset_* = offset_m xor L_*
243248 inplace_xor ( & mut offset_i, & self . ll_star ) ;
@@ -247,15 +252,13 @@ where
247252 self . cipher . encrypt_block ( & mut pad) ;
248253 // checksum_* = checksum_m xor (P_* || 1 || zeros(127-bitlen(P_*)))
249254 let checksum_rhs = & mut [ 0u8 ; 16 ] ;
250- checksum_rhs[ ..remaining_bytes] . copy_from_slice ( & buffer [ processed_bytes.. ] ) ;
255+ checksum_rhs[ ..remaining_bytes] . copy_from_slice ( tail . get_in ( ) ) ;
251256 checksum_rhs[ remaining_bytes] = 0b1000_0000 ;
252257 inplace_xor ( & mut checksum_i, checksum_rhs. as_ref ( ) ) ;
253258 // C_* = P_* xor Pad[1..bitlen(P_*)]
254- let p_star = & mut buffer [ processed_bytes.. ] ;
259+ let p_star = tail . get_out ( ) ;
255260 let pad = & mut pad[ ..p_star. len ( ) ] ;
256- for ( aa, bb) in p_star. iter_mut ( ) . zip ( pad) {
257- * aa ^= * bb;
258- }
261+ tail. xor_in2out ( pad) ;
259262 }
260263
261264 let tag = self . compute_tag ( associated_data, & mut checksum_i, & offset_i) ;
@@ -295,32 +298,32 @@ where
295298 if ( buffer. len ( ) > C_MAX ) || ( associated_data. len ( ) > A_MAX ) {
296299 unimplemented ! ( )
297300 }
301+ let buffer = InOutBuf :: from ( buffer) ;
298302
299303 // First, try to process many blocks at once.
300- let ( processed_bytes , mut offset_i, mut checksum_i) = self . wide_decrypt ( nonce, buffer) ;
304+ let ( tail , index , mut offset_i, mut checksum_i) = self . wide_decrypt ( nonce, buffer) ;
301305
302- let mut i = ( processed_bytes / 16 ) + 1 ;
306+ let mut i = index ;
303307
304308 // Then, process the remaining blocks.
305- let ( blocks, _remaining ) = Block :: slice_as_chunks_mut ( & mut buffer [ processed_bytes.. ] ) ;
309+ let ( blocks, mut tail ) : ( InOutBuf < ' _ , ' _ , Block > , _ ) = tail . into_chunks ( ) ;
306310 for c_i in blocks {
307311 // offset_i = offset_{i-1} xor L_{ntz(i)}
308312 inplace_xor ( & mut offset_i, & self . ll [ ntz ( i) ] ) ;
309313 // p_i = offset_i xor DECIPHER(K, c_i xor offset_i)
310- let p_i = c_i;
311- inplace_xor ( p_i, & offset_i) ;
312- self . cipher . decrypt_block ( p_i) ;
313- inplace_xor ( p_i, & offset_i) ;
314+ let mut p_i = c_i;
315+ p_i. xor_in2out ( & offset_i) ;
316+ self . cipher . decrypt_block ( p_i. get_out ( ) ) ;
317+ inplace_xor ( p_i. get_out ( ) , & offset_i) ;
314318 // checksum_i = checksum_{i-1} xor p_i
315- inplace_xor ( & mut checksum_i, p_i) ;
319+ inplace_xor ( & mut checksum_i, p_i. get_out ( ) ) ;
316320
317321 i += 1 ;
318322 }
319323
320324 // Process any partial blocks.
321- if ( buffer. len ( ) % 16 ) != 0 {
322- let processed_bytes = ( i - 1 ) * 16 ;
323- let remaining_bytes = buffer. len ( ) - processed_bytes;
325+ if !tail. is_empty ( ) {
326+ let remaining_bytes = tail. len ( ) ;
324327
325328 // offset_* = offset_m xor L_*
326329 inplace_xor ( & mut offset_i, & self . ll_star ) ;
@@ -329,14 +332,12 @@ where
329332 inplace_xor ( & mut pad, & offset_i) ;
330333 self . cipher . encrypt_block ( & mut pad) ;
331334 // P_* = C_* xor Pad[1..bitlen(C_*)]
332- let c_star = & mut buffer [ processed_bytes.. ] ;
335+ let c_star = tail . get_in ( ) ;
333336 let pad = & mut pad[ ..c_star. len ( ) ] ;
334- for ( aa, bb) in c_star. iter_mut ( ) . zip ( pad) {
335- * aa ^= * bb;
336- }
337+ tail. xor_in2out ( pad) ;
337338 // checksum_* = checksum_m xor (P_* || 1 || zeros(127-bitlen(P_*)))
338339 let checksum_rhs = & mut [ 0u8 ; 16 ] ;
339- checksum_rhs[ ..remaining_bytes] . copy_from_slice ( & buffer [ processed_bytes.. ] ) ;
340+ checksum_rhs[ ..remaining_bytes] . copy_from_slice ( tail . get_out ( ) ) ;
340341 checksum_rhs[ remaining_bytes] = 0b1000_0000 ;
341342 inplace_xor ( & mut checksum_i, checksum_rhs. as_ref ( ) ) ;
342343 }
@@ -347,81 +348,85 @@ where
347348 /// Encrypts plaintext in groups of two.
348349 ///
349350 /// Adapted from https://www.cs.ucdavis.edu/~rogaway/ocb/news/code/ocb.c
350- fn wide_encrypt ( & self , nonce : & Nonce < NonceSize > , buffer : & mut [ u8 ] ) -> ( usize , Block , Block ) {
351+ fn wide_encrypt < ' i , ' o > (
352+ & self ,
353+ nonce : & Nonce < NonceSize > ,
354+ buffer : InOutBuf < ' i , ' o , u8 > ,
355+ ) -> ( InOutBuf < ' i , ' o , u8 > , usize , Block , Block ) {
351356 const WIDTH : usize = 2 ;
352357
353358 let mut i = 1 ;
354359
355360 let mut offset_i = [ Block :: default ( ) ; WIDTH ] ;
356- offset_i[ offset_i . len ( ) - 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
361+ offset_i[ 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
357362 let mut checksum_i = Block :: default ( ) ;
358- for wide_blocks in buffer. chunks_exact_mut ( <Block as AssocArraySize >:: Size :: USIZE * WIDTH ) {
359- let p_i = split_into_two_blocks ( wide_blocks) ;
360363
364+ let ( wide_blocks, tail) : ( InOutBuf < ' _ , ' _ , DoubleBlock > , _ ) = buffer. into_chunks ( ) ;
365+ for wide_block in wide_blocks. into_iter ( ) {
366+ let mut p_i = split_into_two_blocks ( wide_block) ;
361367 // checksum_i = checksum_{i-1} xor p_i
362368 for p_ij in & p_i {
363- inplace_xor ( & mut checksum_i, p_ij) ;
369+ inplace_xor ( & mut checksum_i, p_ij. get_in ( ) ) ;
364370 }
365371
366372 // offset_i = offset_{i-1} xor L_{ntz(i)}
367- offset_i[ 0 ] = offset_i[ offset_i . len ( ) - 1 ] ;
373+ offset_i[ 0 ] = offset_i[ 1 ] ;
368374 inplace_xor ( & mut offset_i[ 0 ] , & self . ll [ ntz ( i) ] ) ;
369- for j in 1 ..p_i. len ( ) {
370- offset_i[ j] = offset_i[ j - 1 ] ;
371- inplace_xor ( & mut offset_i[ j] , & self . ll [ ntz ( i + j) ] ) ;
372- }
375+ offset_i[ 1 ] = offset_i[ 0 ] ;
376+ inplace_xor ( & mut offset_i[ 1 ] , & self . ll [ ntz ( i + 1 ) ] ) ;
373377
374378 // c_i = offset_i xor ENCIPHER(K, p_i xor offset_i)
375379 for j in 0 ..p_i. len ( ) {
376- inplace_xor ( p_i[ j] , & offset_i[ j] ) ;
377- self . cipher . encrypt_block ( p_i[ j] ) ;
378- inplace_xor ( p_i[ j] , & offset_i[ j] )
380+ p_i[ j] . xor_in2out ( & offset_i[ j] ) ;
381+ self . cipher . encrypt_block ( p_i[ j] . get_out ( ) ) ;
382+ inplace_xor ( p_i[ j] . get_out ( ) , & offset_i[ j] ) ;
379383 }
380384
381385 i += WIDTH ;
382386 }
383387
384- let processed_bytes = ( buffer. len ( ) / ( WIDTH * 16 ) ) * ( WIDTH * 16 ) ;
385-
386- ( processed_bytes, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
388+ ( tail, i, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
387389 }
388390
389391 /// Decrypts plaintext in groups of two.
390392 ///
391393 /// Adapted from https://www.cs.ucdavis.edu/~rogaway/ocb/news/code/ocb.c
392- fn wide_decrypt ( & self , nonce : & Nonce < NonceSize > , buffer : & mut [ u8 ] ) -> ( usize , Block , Block ) {
394+ fn wide_decrypt < ' i , ' o > (
395+ & self ,
396+ nonce : & Nonce < NonceSize > ,
397+ buffer : InOutBuf < ' i , ' o , u8 > ,
398+ ) -> ( InOutBuf < ' i , ' o , u8 > , usize , Block , Block ) {
393399 const WIDTH : usize = 2 ;
394400
395401 let mut i = 1 ;
396402
397403 let mut offset_i = [ Block :: default ( ) ; WIDTH ] ;
398- offset_i[ offset_i . len ( ) - 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
404+ offset_i[ 1 ] = initial_offset ( & self . cipher , nonce, TagSize :: to_u32 ( ) ) ;
399405 let mut checksum_i = Block :: default ( ) ;
400- for wide_blocks in buffer. chunks_exact_mut ( 16 * WIDTH ) {
401- let c_i = split_into_two_blocks ( wide_blocks) ;
406+
407+ let ( wide_blocks, tail) : ( InOutBuf < ' _ , ' _ , DoubleBlock > , _ ) = buffer. into_chunks ( ) ;
408+ for wide_block in wide_blocks. into_iter ( ) {
409+ let mut c_i = split_into_two_blocks ( wide_block) ;
402410
403411 // offset_i = offset_{i-1} xor L_{ntz(i)}
404- offset_i[ 0 ] = offset_i[ offset_i . len ( ) - 1 ] ;
412+ offset_i[ 0 ] = offset_i[ 1 ] ;
405413 inplace_xor ( & mut offset_i[ 0 ] , & self . ll [ ntz ( i) ] ) ;
406- for j in 1 ..c_i. len ( ) {
407- offset_i[ j] = offset_i[ j - 1 ] ;
408- inplace_xor ( & mut offset_i[ j] , & self . ll [ ntz ( i + j) ] ) ;
409- }
414+ offset_i[ 1 ] = offset_i[ 0 ] ;
415+ inplace_xor ( & mut offset_i[ 1 ] , & self . ll [ ntz ( i + 1 ) ] ) ;
410416
411417 // p_i = offset_i xor DECIPHER(K, c_i xor offset_i)
412418 // checksum_i = checksum_{i-1} xor p_i
413419 for j in 0 ..c_i. len ( ) {
414- inplace_xor ( c_i[ j] , & offset_i[ j] ) ;
415- self . cipher . decrypt_block ( c_i[ j] ) ;
416- inplace_xor ( c_i[ j] , & offset_i[ j] ) ;
417- inplace_xor ( & mut checksum_i, c_i[ j] ) ;
420+ c_i[ j] . xor_in2out ( & offset_i[ j] ) ;
421+ self . cipher . decrypt_block ( c_i[ j] . get_out ( ) ) ;
422+ inplace_xor ( c_i[ j] . get_out ( ) , & offset_i[ j] ) ;
423+ inplace_xor ( & mut checksum_i, c_i[ j] . get_out ( ) ) ;
418424 }
419425
420426 i += WIDTH ;
421427 }
422428
423- let processed_bytes = ( buffer. len ( ) / ( WIDTH * 16 ) ) * ( WIDTH * 16 ) ;
424- ( processed_bytes, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
429+ ( tail, i, offset_i[ offset_i. len ( ) - 1 ] , checksum_i)
425430 }
426431
427432 /// Computes HASH function defined in https://www.rfc-editor.org/rfc/rfc7253.html#section-4.1
@@ -580,11 +585,10 @@ pub(crate) fn ntz(n: usize) -> usize {
580585}
581586
582587#[ inline]
583- pub ( crate ) fn split_into_two_blocks ( two_blocks : & mut [ u8 ] ) -> [ & mut Block ; 2 ] {
584- const BLOCK_SIZE : usize = 16 ;
585- debug_assert_eq ! ( two_blocks. len( ) , BLOCK_SIZE * 2 ) ;
586- let ( b0, b1) = two_blocks. split_at_mut ( BLOCK_SIZE ) ;
587- [ b0. try_into ( ) . unwrap ( ) , b1. try_into ( ) . unwrap ( ) ]
588+ pub ( crate ) fn split_into_two_blocks < ' i , ' o > (
589+ two_blocks : InOut < ' i , ' o , DoubleBlock > ,
590+ ) -> [ InOut < ' i , ' o , Block > ; 2 ] {
591+ Array :: < InOut < ' i , ' o , Block > , U2 > :: from ( two_blocks) . into ( )
588592}
589593
590594#[ cfg( test) ]
0 commit comments