@@ -34,6 +34,10 @@ use core::ops::Deref;
3434/// and [BOLT-1](https://github.com/lightning/bolts/blob/master/01-messaging.md#lightning-message-format): 
3535pub  const  LN_MAX_MSG_LEN :  usize  = :: core:: u16:: MAX  as  usize ;  // Must be equal to 65535 
3636
37+ /// The (rough) size buffer to pre-allocate when encoding a message. Messages should reliably be 
38+ /// smaller than this size by at least 32 bytes or so. 
39+ pub  const  MSG_BUF_ALLOC_SIZE :  usize  = 2048 ; 
40+ 
3741// Sha256("Noise_XK_secp256k1_ChaChaPoly_SHA256") 
3842const  NOISE_CK :  [ u8 ;  32 ]  = [ 0x26 ,  0x40 ,  0xf5 ,  0x2e ,  0xeb ,  0xcd ,  0x9e ,  0x88 ,  0x29 ,  0x58 ,  0x95 ,  0x1c ,  0x79 ,  0x42 ,  0x50 ,  0xee ,  0xdb ,  0x28 ,  0x00 ,  0x2c ,  0x05 ,  0xd7 ,  0xdc ,  0x2e ,  0xa0 ,  0xf1 ,  0x95 ,  0x40 ,  0x60 ,  0x42 ,  0xca ,  0xf1 ] ; 
3943// Sha256(NOISE_CK || "lightning") 
@@ -165,6 +169,18 @@ impl PeerChannelEncryptor {
165169		res. extend_from_slice ( & tag) ; 
166170	} 
167171
172+ 	fn  decrypt_in_place_with_ad ( inout :  & mut  [ u8 ] ,  n :  u64 ,  key :  & [ u8 ;  32 ] ,  h :  & [ u8 ] )  -> Result < ( ) ,  LightningError >  { 
173+ 		let  mut  nonce = [ 0 ;  12 ] ; 
174+ 		nonce[ 4 ..] . copy_from_slice ( & n. to_le_bytes ( ) [ ..] ) ; 
175+ 
176+ 		let  mut  chacha = ChaCha20Poly1305RFC :: new ( key,  & nonce,  h) ; 
177+ 		let  ( inout,  tag)  = inout. split_at_mut ( inout. len ( )  - 16 ) ; 
178+ 		if  chacha. check_decrypt_in_place ( inout,  tag) . is_err ( )  { 
179+ 			return  Err ( LightningError { err :  "Bad MAC" . to_owned ( ) ,  action :  msgs:: ErrorAction :: DisconnectPeer {  msg :  None  } } ) ; 
180+ 		} 
181+ 		Ok ( ( ) ) 
182+ 	} 
183+ 
168184	#[ inline]  
169185	fn  decrypt_with_ad ( res :  & mut [ u8 ] ,  n :  u64 ,  key :  & [ u8 ;  32 ] ,  h :  & [ u8 ] ,  cyphertext :  & [ u8 ] )  -> Result < ( ) ,  LightningError >  { 
170186		let  mut  nonce = [ 0 ;  12 ] ; 
@@ -411,16 +427,20 @@ impl PeerChannelEncryptor {
411427		Ok ( self . their_node_id . unwrap ( ) . clone ( ) ) 
412428	} 
413429
414- 	/// Encrypts the given pre-serialized message, returning the encrypted version. 
415- /// panics if msg.len() > 65535 or Noise handshake has not finished. 
416- pub  fn  encrypt_buffer ( & mut  self ,  msg :  & [ u8 ] )  -> Vec < u8 >  { 
417- 		if  msg. len ( )  > LN_MAX_MSG_LEN  { 
430+ 	/// Builds sendable bytes for a message. 
431+ /// 
432+ /// `msgbuf` must begin with 16 + 2 dummy/0 bytes, which will be filled with the encrypted 
433+ /// message length and its MAC. It should then be followed by the message bytes themselves 
434+ /// (including the two byte message type). 
435+ /// 
436+ /// For effeciency, the [`Vec::capacity`] should be at least 16 bytes larger than the 
437+ /// [`Vec::len`], to avoid reallocating for the message MAC, which will be appended to the vec. 
438+ fn  encrypt_message_with_header_0s ( & mut  self ,  msgbuf :  & mut  Vec < u8 > )  { 
439+ 		let  msg_len = msgbuf. len ( )  - 16  - 2 ; 
440+ 		if  msg_len > LN_MAX_MSG_LEN  { 
418441			panic ! ( "Attempted to encrypt message longer than 65535 bytes!" ) ; 
419442		} 
420443
421- 		let  mut  res = Vec :: with_capacity ( msg. len ( )  + 16 * 2  + 2 ) ; 
422- 		res. resize ( msg. len ( )  + 16 * 2  + 2 ,  0 ) ; 
423- 
424444		match  self . noise_state  { 
425445			NoiseState :: Finished  {  ref  mut  sk,  ref  mut  sn,  ref  mut  sck,  rk :  _,  rn :  _,  rck :  _ }  => { 
426446				if  * sn >= 1000  { 
@@ -430,16 +450,21 @@ impl PeerChannelEncryptor {
430450					* sn = 0 ; 
431451				} 
432452
433- 				Self :: encrypt_with_ad ( & mut  res [ 0 ..16 +2 ] ,  * sn,  sk,  & [ 0 ;  0 ] ,  & ( msg . len ( )  as  u16 ) . to_be_bytes ( ) ) ; 
453+ 				Self :: encrypt_with_ad ( & mut  msgbuf [ 0 ..16 +2 ] ,  * sn,  sk,  & [ 0 ;  0 ] ,  & ( msg_len  as  u16 ) . to_be_bytes ( ) ) ; 
434454				* sn += 1 ; 
435455
436- 				Self :: encrypt_with_ad ( & mut  res [ 16 +2 .. ] ,  * sn,  sk,  & [ 0 ;  0 ] ,  msg ) ; 
456+ 				Self :: encrypt_in_place_with_ad ( msgbuf ,   16 +2 ,  * sn,  sk,  & [ 0 ;  0 ] ) ; 
437457				* sn += 1 ; 
438458			} , 
439459			_ => panic ! ( "Tried to encrypt a message prior to noise handshake completion" ) , 
440460		} 
461+ 	} 
441462
442- 		res
463+ 	/// Encrypts the given pre-serialized message, returning the encrypted version. 
464+ /// panics if msg.len() > 65535 or Noise handshake has not finished. 
465+ pub  fn  encrypt_buffer ( & mut  self ,  mut  msg :  MessageBuf )  -> Vec < u8 >  { 
466+ 		self . encrypt_message_with_header_0s ( & mut  msg. 0 ) ; 
467+ 		msg. 0 
443468	} 
444469
445470	/// Encrypts the given message, returning the encrypted version. 
@@ -448,33 +473,11 @@ impl PeerChannelEncryptor {
448473pub  fn  encrypt_message < M :  wire:: Type > ( & mut  self ,  message :  & M )  -> Vec < u8 >  { 
449474		// Allocate a buffer with 2KB, fitting most common messages. Reserve the first 16+2 bytes 
450475		// for the 2-byte message type prefix and its MAC. 
451- 		let  mut  res = VecWriter ( Vec :: with_capacity ( 2048 ) ) ; 
476+ 		let  mut  res = VecWriter ( Vec :: with_capacity ( MSG_BUF_ALLOC_SIZE ) ) ; 
452477		res. 0 . resize ( 16  + 2 ,  0 ) ; 
453478		wire:: write ( message,  & mut  res) . expect ( "In-memory messages must never fail to serialize" ) ; 
454479
455- 		let  msg_len = res. 0 . len ( )  - 16  - 2 ; 
456- 		if  msg_len > LN_MAX_MSG_LEN  { 
457- 			panic ! ( "Attempted to encrypt message longer than 65535 bytes!" ) ; 
458- 		} 
459- 
460- 		match  self . noise_state  { 
461- 			NoiseState :: Finished  {  ref  mut  sk,  ref  mut  sn,  ref  mut  sck,  rk :  _,  rn :  _,  rck :  _ }  => { 
462- 				if  * sn >= 1000  { 
463- 					let  ( new_sck,  new_sk)  = hkdf_extract_expand_twice ( sck,  sk) ; 
464- 					* sck = new_sck; 
465- 					* sk = new_sk; 
466- 					* sn = 0 ; 
467- 				} 
468- 
469- 				Self :: encrypt_with_ad ( & mut  res. 0 [ 0 ..16 +2 ] ,  * sn,  sk,  & [ 0 ;  0 ] ,  & ( msg_len as  u16 ) . to_be_bytes ( ) ) ; 
470- 				* sn += 1 ; 
471- 
472- 				Self :: encrypt_in_place_with_ad ( & mut  res. 0 ,  16 +2 ,  * sn,  sk,  & [ 0 ;  0 ] ) ; 
473- 				* sn += 1 ; 
474- 			} , 
475- 			_ => panic ! ( "Tried to encrypt a message prior to noise handshake completion" ) , 
476- 		} 
477- 
480+ 		self . encrypt_message_with_header_0s ( & mut  res. 0 ) ; 
478481		res. 0 
479482	} 
480483
@@ -501,21 +504,20 @@ impl PeerChannelEncryptor {
501504		} 
502505	} 
503506
504- 	/// Decrypts the given message. 
507+ 	/// Decrypts the given message up to msg.len() - 16. Bytes after msg.len() - 16 will be left 
508+ /// undefined (as they contain the Poly1305 tag bytes). 
509+ /// 
505510/// panics if msg.len() > 65535 + 16 
506- pub  fn  decrypt_message ( & mut  self ,  msg :  & [ u8 ] )  -> Result < Vec < u8 > ,  LightningError >  { 
511+ pub  fn  decrypt_message ( & mut  self ,  msg :  & mut   [ u8 ] )  -> Result < ( ) ,  LightningError >  { 
507512		if  msg. len ( )  > LN_MAX_MSG_LEN  + 16  { 
508513			panic ! ( "Attempted to decrypt message longer than 65535 + 16 bytes!" ) ; 
509514		} 
510515
511516		match  self . noise_state  { 
512517			NoiseState :: Finished  {  sk :  _,  sn :  _,  sck :  _,  ref  rk,  ref  mut  rn,  rck :  _ }  => { 
513- 				let  mut  res = Vec :: with_capacity ( msg. len ( )  - 16 ) ; 
514- 				res. resize ( msg. len ( )  - 16 ,  0 ) ; 
515- 				Self :: decrypt_with_ad ( & mut  res[ ..] ,  * rn,  rk,  & [ 0 ;  0 ] ,  msg) ?; 
518+ 				Self :: decrypt_in_place_with_ad ( & mut  msg[ ..] ,  * rn,  rk,  & [ 0 ;  0 ] ) ?; 
516519				* rn += 1 ; 
517- 
518- 				Ok ( res) 
520+ 				Ok ( ( ) ) 
519521			} , 
520522			_ => panic ! ( "Tried to decrypt a message prior to noise handshake completion" ) , 
521523		} 
@@ -542,9 +544,30 @@ impl PeerChannelEncryptor {
542544	} 
543545} 
544546
547+ /// A buffer which stores an encoded message (including the two message-type bytes) with some 
548+ /// padding to allow for future encryption/MACing. 
549+ pub  struct  MessageBuf ( Vec < u8 > ) ; 
550+ impl  MessageBuf  { 
551+ 	/// Creates a new buffer from an encoded message (i.e. the two message-type bytes followed by 
552+ /// the message contents). 
553+ /// 
554+ /// Panics if the message is longer than 2^16. 
555+ pub  fn  from_encoded ( encoded_msg :  & [ u8 ] )  -> Self  { 
556+ 		if  encoded_msg. len ( )  > LN_MAX_MSG_LEN  { 
557+ 			panic ! ( "Attempted to encrypt message longer than 65535 bytes!" ) ; 
558+ 		} 
559+ 		// In addition to the message (continaing the two message type bytes), we also have to add 
560+ 		// the message length header (and its MAC) and the message MAC. 
561+ 		let  mut  res = Vec :: with_capacity ( encoded_msg. len ( )  + 16 * 2  + 2 ) ; 
562+ 		res. resize ( encoded_msg. len ( )  + 16  + 2 ,  0 ) ; 
563+ 		res[ 16  + 2 ..] . copy_from_slice ( & encoded_msg) ; 
564+ 		Self ( res) 
565+ 	} 
566+ } 
567+ 
545568#[ cfg( test) ]  
546569mod  tests { 
547- 	use  super :: LN_MAX_MSG_LEN ; 
570+ 	use  super :: { MessageBuf ,   LN_MAX_MSG_LEN } ; 
548571
549572	use  bitcoin:: secp256k1:: { PublicKey ,  SecretKey } ; 
550573	use  bitcoin:: secp256k1:: Secp256k1 ; 
@@ -760,12 +783,11 @@ mod tests {
760783
761784		for  i in  0 ..1005  { 
762785			let  msg = [ 0x68 ,  0x65 ,  0x6c ,  0x6c ,  0x6f ] ; 
763- 			let  res = outbound_peer. encrypt_buffer ( & msg) ; 
786+ 			let  mut   res = outbound_peer. encrypt_buffer ( MessageBuf :: from_encoded ( & msg) ) ; 
764787			assert_eq ! ( res. len( ) ,  5  + 2 * 16  + 2 ) ; 
765788
766789			let  len_header = res[ 0 ..2 +16 ] . to_vec ( ) ; 
767790			assert_eq ! ( inbound_peer. decrypt_length_header( & len_header[ ..] ) . unwrap( )  as  usize ,  msg. len( ) ) ; 
768- 			assert_eq ! ( inbound_peer. decrypt_message( & res[ 2 +16 ..] ) . unwrap( ) [ ..] ,  msg[ ..] ) ; 
769791
770792			if  i == 0  { 
771793				assert_eq ! ( res,  hex:: decode( "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cbcf25d2f214cf9ea1d95" ) . unwrap( ) ) ; 
@@ -780,6 +802,9 @@ mod tests {
780802			}  else  if  i == 1001  { 
781803				assert_eq ! ( res,  hex:: decode( "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e268338b1a16cf4ef2d36" ) . unwrap( ) ) ; 
782804			} 
805+ 
806+ 			inbound_peer. decrypt_message ( & mut  res[ 2 +16 ..] ) . unwrap ( ) ; 
807+ 			assert_eq ! ( res[ 2  + 16 ..res. len( )  - 16 ] ,  msg[ ..] ) ; 
783808		} 
784809	} 
785810
@@ -794,7 +819,7 @@ mod tests {
794819	fn  max_message_len_encryption ( )  { 
795820		let  mut  outbound_peer = get_outbound_peer_for_initiator_test_vectors ( ) ; 
796821		let  msg = [ 4u8 ;  LN_MAX_MSG_LEN  + 1 ] ; 
797- 		outbound_peer. encrypt_buffer ( & msg) ; 
822+ 		outbound_peer. encrypt_buffer ( MessageBuf :: from_encoded ( & msg) ) ; 
798823	} 
799824
800825	#[ test]  
@@ -803,7 +828,7 @@ mod tests {
803828		let  mut  inbound_peer = get_inbound_peer_for_test_vectors ( ) ; 
804829
805830		// MSG should not exceed LN_MAX_MSG_LEN + 16 
806- 		let  msg = [ 4u8 ;  LN_MAX_MSG_LEN  + 17 ] ; 
807- 		inbound_peer. decrypt_message ( & msg) . unwrap ( ) ; 
831+ 		let  mut   msg = [ 4u8 ;  LN_MAX_MSG_LEN  + 17 ] ; 
832+ 		inbound_peer. decrypt_message ( & mut   msg) . unwrap ( ) ; 
808833	} 
809834} 
0 commit comments