@@ -20,13 +20,15 @@ use crate::util::test_utils;
2020use  bitcoin:: network:: constants:: Network ; 
2121use  bitcoin:: secp256k1:: { PublicKey ,  Secp256k1 } ; 
2222
23+ use  core:: sync:: atomic:: { AtomicU16 ,  Ordering } ; 
2324use  crate :: io; 
2425use  crate :: io_extras:: read_to_end; 
2526use  crate :: sync:: Arc ; 
2627
2728struct  MessengerNode  { 
2829	keys_manager :  Arc < test_utils:: TestKeysInterface > , 
2930	messenger :  OnionMessenger < Arc < test_utils:: TestKeysInterface > ,  Arc < test_utils:: TestKeysInterface > ,  Arc < test_utils:: TestLogger > ,  Arc < TestCustomMessageHandler > > , 
31+ 	custom_message_handler :  Arc < TestCustomMessageHandler > , 
3032	logger :  Arc < test_utils:: TestLogger > , 
3133} 
3234
@@ -54,11 +56,32 @@ impl Writeable for TestCustomMessage {
5456	} 
5557} 
5658
57- struct  TestCustomMessageHandler  { } 
59+ struct  TestCustomMessageHandler  { 
60+ 	num_messages_expected :  AtomicU16 , 
61+ } 
62+ 
63+ impl  TestCustomMessageHandler  { 
64+ 	fn  new ( )  -> Self  { 
65+ 		Self  {  num_messages_expected :  AtomicU16 :: new ( 0 )  } 
66+ 	} 
67+ } 
68+ 
69+ impl  Drop  for  TestCustomMessageHandler  { 
70+ 	fn  drop ( & mut  self )  { 
71+ 		#[ cfg( feature = "std" ) ]   { 
72+ 			if  std:: thread:: panicking ( )  { 
73+ 				return ; 
74+ 			} 
75+ 		} 
76+ 		assert_eq ! ( self . num_messages_expected. load( Ordering :: SeqCst ) ,  0 ) ; 
77+ 	} 
78+ } 
5879
5980impl  CustomOnionMessageHandler  for  TestCustomMessageHandler  { 
6081	type  CustomMessage  = TestCustomMessage ; 
61- 	fn  handle_custom_message ( & self ,  _msg :  Self :: CustomMessage )  { } 
82+ 	fn  handle_custom_message ( & self ,  _msg :  Self :: CustomMessage )  { 
83+ 		self . num_messages_expected . fetch_sub ( 1 ,  Ordering :: SeqCst ) ; 
84+ 	} 
6285	fn  read_custom_message < R :  io:: Read > ( & self ,  message_type :  u64 ,  buffer :  & mut  R )  -> Result < Option < Self :: CustomMessage > ,  DecodeError >  where  Self :  Sized  { 
6386		if  message_type == CUSTOM_MESSAGE_TYPE  { 
6487			let  buf = read_to_end ( buffer) ?; 
@@ -75,9 +98,11 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
7598		let  logger = Arc :: new ( test_utils:: TestLogger :: with_id ( format ! ( "node {}" ,  i) ) ) ; 
7699		let  seed = [ i as  u8 ;  32 ] ; 
77100		let  keys_manager = Arc :: new ( test_utils:: TestKeysInterface :: new ( & seed,  Network :: Testnet ) ) ; 
101+ 		let  custom_message_handler = Arc :: new ( TestCustomMessageHandler :: new ( ) ) ; 
78102		nodes. push ( MessengerNode  { 
79103			keys_manager :  keys_manager. clone ( ) , 
80- 			messenger :  OnionMessenger :: new ( keys_manager. clone ( ) ,  keys_manager. clone ( ) ,  logger. clone ( ) ,  Arc :: new ( TestCustomMessageHandler  { } ) ) , 
104+ 			messenger :  OnionMessenger :: new ( keys_manager. clone ( ) ,  keys_manager. clone ( ) ,  logger. clone ( ) ,  custom_message_handler. clone ( ) ) , 
105+ 			custom_message_handler, 
81106			logger, 
82107		} ) ; 
83108	} 
@@ -92,22 +117,17 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
92117	nodes
93118} 
94119
95- fn  pass_along_path ( path :  & Vec < MessengerNode > ,  expected_path_id :  Option < [ u8 ;  32 ] > )  { 
120+ fn  pass_along_path ( path :  & Vec < MessengerNode > )  { 
121+ 	path[ path. len ( )  - 1 ] . custom_message_handler . num_messages_expected . fetch_add ( 1 ,  Ordering :: SeqCst ) ; 
96122	let  mut  prev_node = & path[ 0 ] ; 
97- 	let  num_nodes = path. len ( ) ; 
98- 	for  ( idx,  node)  in  path. into_iter ( ) . skip ( 1 ) . enumerate ( )  { 
123+ 	for  node in  path. into_iter ( ) . skip ( 1 )  { 
99124		let  events = prev_node. messenger . release_pending_msgs ( ) ; 
100125		let  onion_msg =  { 
101126			let  msgs = events. get ( & node. get_node_pk ( ) ) . unwrap ( ) ; 
102127			assert_eq ! ( msgs. len( ) ,  1 ) ; 
103128			msgs[ 0 ] . clone ( ) 
104129		} ; 
105130		node. messenger . handle_onion_message ( & prev_node. get_node_pk ( ) ,  & onion_msg) ; 
106- 		if  idx == num_nodes - 1  { 
107- 			node. logger . assert_log_contains ( 
108- 				"lightning::onion_message::messenger" , 
109- 				& format ! ( "Received an onion message with path_id: {:02x?}" ,  expected_path_id) ,  1 ) ; 
110- 		} 
111131		prev_node = node; 
112132	} 
113133} 
@@ -118,7 +138,7 @@ fn one_hop() {
118138	let  test_msg = OnionMessageContents :: Custom ( TestCustomMessage  { } ) ; 
119139
120140	nodes[ 0 ] . messenger . send_onion_message ( & [ ] ,  Destination :: Node ( nodes[ 1 ] . get_node_pk ( ) ) ,  test_msg,  None ) . unwrap ( ) ; 
121- 	pass_along_path ( & nodes,   None ) ; 
141+ 	pass_along_path ( & nodes) ; 
122142} 
123143
124144#[ test]  
@@ -127,7 +147,7 @@ fn two_unblinded_hops() {
127147	let  test_msg = OnionMessageContents :: Custom ( TestCustomMessage  { } ) ; 
128148
129149	nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) ] ,  Destination :: Node ( nodes[ 2 ] . get_node_pk ( ) ) ,  test_msg,  None ) . unwrap ( ) ; 
130- 	pass_along_path ( & nodes,   None ) ; 
150+ 	pass_along_path ( & nodes) ; 
131151} 
132152
133153#[ test]  
@@ -139,7 +159,7 @@ fn two_unblinded_two_blinded() {
139159	let  blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 3 ] . get_node_pk ( ) ,  nodes[ 4 ] . get_node_pk ( ) ] ,  & * nodes[ 4 ] . keys_manager ,  & secp_ctx) . unwrap ( ) ; 
140160
141161	nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) ,  nodes[ 2 ] . get_node_pk ( ) ] ,  Destination :: BlindedPath ( blinded_path) ,  test_msg,  None ) . unwrap ( ) ; 
142- 	pass_along_path ( & nodes,   None ) ; 
162+ 	pass_along_path ( & nodes) ; 
143163} 
144164
145165#[ test]  
@@ -151,7 +171,7 @@ fn three_blinded_hops() {
151171	let  blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 1 ] . get_node_pk ( ) ,  nodes[ 2 ] . get_node_pk ( ) ,  nodes[ 3 ] . get_node_pk ( ) ] ,  & * nodes[ 3 ] . keys_manager ,  & secp_ctx) . unwrap ( ) ; 
152172
153173	nodes[ 0 ] . messenger . send_onion_message ( & [ ] ,  Destination :: BlindedPath ( blinded_path) ,  test_msg,  None ) . unwrap ( ) ; 
154- 	pass_along_path ( & nodes,   None ) ; 
174+ 	pass_along_path ( & nodes) ; 
155175} 
156176
157177#[ test]  
@@ -177,13 +197,13 @@ fn we_are_intro_node() {
177197	let  blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 0 ] . get_node_pk ( ) ,  nodes[ 1 ] . get_node_pk ( ) ,  nodes[ 2 ] . get_node_pk ( ) ] ,  & * nodes[ 2 ] . keys_manager ,  & secp_ctx) . unwrap ( ) ; 
178198
179199	nodes[ 0 ] . messenger . send_onion_message ( & [ ] ,  Destination :: BlindedPath ( blinded_path) ,  OnionMessageContents :: Custom ( test_msg. clone ( ) ) ,  None ) . unwrap ( ) ; 
180- 	pass_along_path ( & nodes,   None ) ; 
200+ 	pass_along_path ( & nodes) ; 
181201
182202	// Try with a two-hop blinded path where we are the introduction node. 
183203	let  blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 0 ] . get_node_pk ( ) ,  nodes[ 1 ] . get_node_pk ( ) ] ,  & * nodes[ 1 ] . keys_manager ,  & secp_ctx) . unwrap ( ) ; 
184204	nodes[ 0 ] . messenger . send_onion_message ( & [ ] ,  Destination :: BlindedPath ( blinded_path) ,  OnionMessageContents :: Custom ( test_msg) ,  None ) . unwrap ( ) ; 
185205	nodes. remove ( 2 ) ; 
186- 	pass_along_path ( & nodes,   None ) ; 
206+ 	pass_along_path ( & nodes) ; 
187207} 
188208
189209#[ test]  
@@ -216,7 +236,7 @@ fn reply_path() {
216236	// Destination::Node 
217237	let  reply_path = BlindedPath :: new_for_message ( & [ nodes[ 2 ] . get_node_pk ( ) ,  nodes[ 1 ] . get_node_pk ( ) ,  nodes[ 0 ] . get_node_pk ( ) ] ,  & * nodes[ 0 ] . keys_manager ,  & secp_ctx) . unwrap ( ) ; 
218238	nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) ,  nodes[ 2 ] . get_node_pk ( ) ] ,  Destination :: Node ( nodes[ 3 ] . get_node_pk ( ) ) ,  OnionMessageContents :: Custom ( test_msg. clone ( ) ) ,  Some ( reply_path) ) . unwrap ( ) ; 
219- 	pass_along_path ( & nodes,   None ) ; 
239+ 	pass_along_path ( & nodes) ; 
220240	// Make sure the last node successfully decoded the reply path. 
221241	nodes[ 3 ] . logger . assert_log_contains ( 
222242		"lightning::onion_message::messenger" , 
@@ -227,7 +247,7 @@ fn reply_path() {
227247	let  reply_path = BlindedPath :: new_for_message ( & [ nodes[ 2 ] . get_node_pk ( ) ,  nodes[ 1 ] . get_node_pk ( ) ,  nodes[ 0 ] . get_node_pk ( ) ] ,  & * nodes[ 0 ] . keys_manager ,  & secp_ctx) . unwrap ( ) ; 
228248
229249	nodes[ 0 ] . messenger . send_onion_message ( & [ ] ,  Destination :: BlindedPath ( blinded_path) ,  OnionMessageContents :: Custom ( test_msg) ,  Some ( reply_path) ) . unwrap ( ) ; 
230- 	pass_along_path ( & nodes,   None ) ; 
250+ 	pass_along_path ( & nodes) ; 
231251	nodes[ 3 ] . logger . assert_log_contains ( 
232252		"lightning::onion_message::messenger" , 
233253		& format ! ( "Received an onion message with path_id None and a reply_path" ) ,  2 ) ; 
@@ -264,3 +284,20 @@ fn peer_buffer_full() {
264284	let  err = nodes[ 0 ] . messenger . send_onion_message ( & [ ] ,  Destination :: Node ( nodes[ 1 ] . get_node_pk ( ) ) ,  OnionMessageContents :: Custom ( test_msg) ,  None ) . unwrap_err ( ) ; 
265285	assert_eq ! ( err,  SendError :: BufferFull ) ; 
266286} 
287+ 
288+ #[ test]  
289+ fn  many_hops ( )  { 
290+ 	// Check we can send over a route with many hops. This will exercise our logic for onion messages 
291+ 	// of size [`crate::onion_message::packet::BIG_PACKET_HOP_DATA_LEN`]. 
292+ 	let  num_nodes:  usize  = 25 ; 
293+ 	let  nodes = create_nodes ( num_nodes as  u8 ) ; 
294+ 	let  test_msg = OnionMessageContents :: Custom ( TestCustomMessage  { } ) ; 
295+ 
296+ 	let  mut  intermediates = vec ! [ ] ; 
297+ 	for  i in  1 ..( num_nodes-1 )  { 
298+ 		intermediates. push ( nodes[ i] . get_node_pk ( ) ) ; 
299+ 	} 
300+ 
301+ 	nodes[ 0 ] . messenger . send_onion_message ( & intermediates,  Destination :: Node ( nodes[ num_nodes-1 ] . get_node_pk ( ) ) ,  test_msg,  None ) . unwrap ( ) ; 
302+ 	pass_along_path ( & nodes) ; 
303+ } 
0 commit comments