@@ -169,6 +169,18 @@ impl PeerChannelEncryptor {
169169 res. extend_from_slice ( & tag) ;
170170 }
171171
172+ fn decrypt_in_place_with_ad ( inout : & mut [ u8 ] , n : u64 , key : & [ u8 ; 32 ] , h : & [ u8 ] ) -> Result < ( ) , LightningError > {
173+ let mut nonce = [ 0 ; 12 ] ;
174+ nonce[ 4 ..] . copy_from_slice ( & n. to_le_bytes ( ) [ ..] ) ;
175+
176+ let mut chacha = ChaCha20Poly1305RFC :: new ( key, & nonce, h) ;
177+ let ( inout, tag) = inout. split_at_mut ( inout. len ( ) - 16 ) ;
178+ if chacha. decrypt_in_place ( inout, tag) . is_err ( ) {
179+ return Err ( LightningError { err : "Bad MAC" . to_owned ( ) , action : msgs:: ErrorAction :: DisconnectPeer { msg : None } } ) ;
180+ }
181+ Ok ( ( ) )
182+ }
183+
172184 #[ inline]
173185 fn decrypt_with_ad ( res : & mut [ u8 ] , n : u64 , key : & [ u8 ; 32 ] , h : & [ u8 ] , cyphertext : & [ u8 ] ) -> Result < ( ) , LightningError > {
174186 let mut nonce = [ 0 ; 12 ] ;
@@ -505,21 +517,20 @@ impl PeerChannelEncryptor {
505517 }
506518 }
507519
508- /// Decrypts the given message.
520+ /// Decrypts the given message up to msg.len() - 16. Bytes after msg.len() - 16 will be left
521+ /// undefined (as they contain the Poly1305 tag bytes).
522+ ///
509523 /// panics if msg.len() > 65535 + 16
510- pub fn decrypt_message ( & mut self , msg : & [ u8 ] ) -> Result < Vec < u8 > , LightningError > {
524+ pub fn decrypt_message ( & mut self , msg : & mut [ u8 ] ) -> Result < ( ) , LightningError > {
511525 if msg. len ( ) > LN_MAX_MSG_LEN + 16 {
512526 panic ! ( "Attempted to decrypt message longer than 65535 + 16 bytes!" ) ;
513527 }
514528
515529 match self . noise_state {
516530 NoiseState :: Finished { sk : _, sn : _, sck : _, ref rk, ref mut rn, rck : _ } => {
517- let mut res = Vec :: with_capacity ( msg. len ( ) - 16 ) ;
518- res. resize ( msg. len ( ) - 16 , 0 ) ;
519- Self :: decrypt_with_ad ( & mut res[ ..] , * rn, rk, & [ 0 ; 0 ] , msg) ?;
531+ Self :: decrypt_in_place_with_ad ( & mut msg[ ..] , * rn, rk, & [ 0 ; 0 ] ) ?;
520532 * rn += 1 ;
521-
522- Ok ( res)
533+ Ok ( ( ) )
523534 } ,
524535 _ => panic ! ( "Tried to decrypt a message prior to noise handshake completion" ) ,
525536 }
@@ -764,12 +775,11 @@ mod tests {
764775
765776 for i in 0 ..1005 {
766777 let msg = [ 0x68 , 0x65 , 0x6c , 0x6c , 0x6f ] ;
767- let res = outbound_peer. encrypt_buffer ( & msg) ;
778+ let mut res = outbound_peer. encrypt_buffer ( & msg) ;
768779 assert_eq ! ( res. len( ) , 5 + 2 * 16 + 2 ) ;
769780
770781 let len_header = res[ 0 ..2 +16 ] . to_vec ( ) ;
771782 assert_eq ! ( inbound_peer. decrypt_length_header( & len_header[ ..] ) . unwrap( ) as usize , msg. len( ) ) ;
772- assert_eq ! ( inbound_peer. decrypt_message( & res[ 2 +16 ..] ) . unwrap( ) [ ..] , msg[ ..] ) ;
773783
774784 if i == 0 {
775785 assert_eq ! ( res, hex:: decode( "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cbcf25d2f214cf9ea1d95" ) . unwrap( ) ) ;
@@ -784,6 +794,9 @@ mod tests {
784794 } else if i == 1001 {
785795 assert_eq ! ( res, hex:: decode( "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e268338b1a16cf4ef2d36" ) . unwrap( ) ) ;
786796 }
797+
798+ inbound_peer. decrypt_message ( & mut res[ 2 +16 ..] ) . unwrap ( ) ;
799+ assert_eq ! ( res[ 2 + 16 ..res. len( ) - 16 ] , msg[ ..] ) ;
787800 }
788801 }
789802
@@ -807,7 +820,7 @@ mod tests {
807820 let mut inbound_peer = get_inbound_peer_for_test_vectors ( ) ;
808821
809822 // MSG should not exceed LN_MAX_MSG_LEN + 16
810- let msg = [ 4u8 ; LN_MAX_MSG_LEN + 17 ] ;
811- inbound_peer. decrypt_message ( & msg) . unwrap ( ) ;
823+ let mut msg = [ 4u8 ; LN_MAX_MSG_LEN + 17 ] ;
824+ inbound_peer. decrypt_message ( & mut msg) . unwrap ( ) ;
812825 }
813826}
0 commit comments