@@ -20,13 +20,15 @@ use crate::util::test_utils;
2020use bitcoin:: network:: constants:: Network ;
2121use bitcoin:: secp256k1:: { PublicKey , Secp256k1 } ;
2222
23+ use core:: sync:: atomic:: { AtomicU16 , Ordering } ;
2324use crate :: io;
2425use crate :: io_extras:: read_to_end;
2526use crate :: sync:: Arc ;
2627
2728struct MessengerNode {
2829 keys_manager : Arc < test_utils:: TestKeysInterface > ,
2930 messenger : OnionMessenger < Arc < test_utils:: TestKeysInterface > , Arc < test_utils:: TestKeysInterface > , Arc < test_utils:: TestLogger > , Arc < TestCustomMessageHandler > > ,
31+ custom_message_handler : Arc < TestCustomMessageHandler > ,
3032 logger : Arc < test_utils:: TestLogger > ,
3133}
3234
@@ -54,11 +56,32 @@ impl Writeable for TestCustomMessage {
5456 }
5557}
5658
57- struct TestCustomMessageHandler { }
59+ struct TestCustomMessageHandler {
60+ num_messages_expected : AtomicU16 ,
61+ }
62+
63+ impl TestCustomMessageHandler {
64+ fn new ( ) -> Self {
65+ Self { num_messages_expected : AtomicU16 :: new ( 0 ) }
66+ }
67+ }
68+
69+ impl Drop for TestCustomMessageHandler {
70+ fn drop ( & mut self ) {
71+ #[ cfg( feature = "std" ) ] {
72+ if std:: thread:: panicking ( ) {
73+ return ;
74+ }
75+ }
76+ assert_eq ! ( self . num_messages_expected. load( Ordering :: SeqCst ) , 0 ) ;
77+ }
78+ }
5879
5980impl CustomOnionMessageHandler for TestCustomMessageHandler {
6081 type CustomMessage = TestCustomMessage ;
61- fn handle_custom_message ( & self , _msg : Self :: CustomMessage ) { }
82+ fn handle_custom_message ( & self , _msg : Self :: CustomMessage ) {
83+ self . num_messages_expected . fetch_sub ( 1 , Ordering :: SeqCst ) ;
84+ }
6285 fn read_custom_message < R : io:: Read > ( & self , message_type : u64 , buffer : & mut R ) -> Result < Option < Self :: CustomMessage > , DecodeError > where Self : Sized {
6386 if message_type == CUSTOM_MESSAGE_TYPE {
6487 let buf = read_to_end ( buffer) ?;
@@ -75,9 +98,11 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
7598 let logger = Arc :: new ( test_utils:: TestLogger :: with_id ( format ! ( "node {}" , i) ) ) ;
7699 let seed = [ i as u8 ; 32 ] ;
77100 let keys_manager = Arc :: new ( test_utils:: TestKeysInterface :: new ( & seed, Network :: Testnet ) ) ;
101+ let custom_message_handler = Arc :: new ( TestCustomMessageHandler :: new ( ) ) ;
78102 nodes. push ( MessengerNode {
79103 keys_manager : keys_manager. clone ( ) ,
80- messenger : OnionMessenger :: new ( keys_manager. clone ( ) , keys_manager. clone ( ) , logger. clone ( ) , Arc :: new ( TestCustomMessageHandler { } ) ) ,
104+ messenger : OnionMessenger :: new ( keys_manager. clone ( ) , keys_manager. clone ( ) , logger. clone ( ) , custom_message_handler. clone ( ) ) ,
105+ custom_message_handler,
81106 logger,
82107 } ) ;
83108 }
@@ -92,22 +117,17 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
92117 nodes
93118}
94119
95- fn pass_along_path ( path : & Vec < MessengerNode > , expected_path_id : Option < [ u8 ; 32 ] > ) {
120+ fn pass_along_path ( path : & Vec < MessengerNode > ) {
121+ path[ path. len ( ) - 1 ] . custom_message_handler . num_messages_expected . fetch_add ( 1 , Ordering :: SeqCst ) ;
96122 let mut prev_node = & path[ 0 ] ;
97- let num_nodes = path. len ( ) ;
98- for ( idx, node) in path. into_iter ( ) . skip ( 1 ) . enumerate ( ) {
123+ for node in path. into_iter ( ) . skip ( 1 ) {
99124 let events = prev_node. messenger . release_pending_msgs ( ) ;
100125 let onion_msg = {
101126 let msgs = events. get ( & node. get_node_pk ( ) ) . unwrap ( ) ;
102127 assert_eq ! ( msgs. len( ) , 1 ) ;
103128 msgs[ 0 ] . clone ( )
104129 } ;
105130 node. messenger . handle_onion_message ( & prev_node. get_node_pk ( ) , & onion_msg) ;
106- if idx == num_nodes - 1 {
107- node. logger . assert_log_contains (
108- "lightning::onion_message::messenger" ,
109- & format ! ( "Received an onion message with path_id: {:02x?}" , expected_path_id) , 1 ) ;
110- }
111131 prev_node = node;
112132 }
113133}
@@ -118,7 +138,7 @@ fn one_hop() {
118138 let test_msg = OnionMessageContents :: Custom ( TestCustomMessage { } ) ;
119139
120140 nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: Node ( nodes[ 1 ] . get_node_pk ( ) ) , test_msg, None ) . unwrap ( ) ;
121- pass_along_path ( & nodes, None ) ;
141+ pass_along_path ( & nodes) ;
122142}
123143
124144#[ test]
@@ -127,7 +147,7 @@ fn two_unblinded_hops() {
127147 let test_msg = OnionMessageContents :: Custom ( TestCustomMessage { } ) ;
128148
129149 nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) ] , Destination :: Node ( nodes[ 2 ] . get_node_pk ( ) ) , test_msg, None ) . unwrap ( ) ;
130- pass_along_path ( & nodes, None ) ;
150+ pass_along_path ( & nodes) ;
131151}
132152
133153#[ test]
@@ -139,7 +159,7 @@ fn two_unblinded_two_blinded() {
139159 let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 3 ] . get_node_pk ( ) , nodes[ 4 ] . get_node_pk ( ) ] , & * nodes[ 4 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
140160
141161 nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) ] , Destination :: BlindedPath ( blinded_path) , test_msg, None ) . unwrap ( ) ;
142- pass_along_path ( & nodes, None ) ;
162+ pass_along_path ( & nodes) ;
143163}
144164
145165#[ test]
@@ -151,7 +171,7 @@ fn three_blinded_hops() {
151171 let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) , nodes[ 3 ] . get_node_pk ( ) ] , & * nodes[ 3 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
152172
153173 nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , test_msg, None ) . unwrap ( ) ;
154- pass_along_path ( & nodes, None ) ;
174+ pass_along_path ( & nodes) ;
155175}
156176
157177#[ test]
@@ -177,13 +197,13 @@ fn we_are_intro_node() {
177197 let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 0 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) ] , & * nodes[ 2 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
178198
179199 nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , OnionMessageContents :: Custom ( test_msg. clone ( ) ) , None ) . unwrap ( ) ;
180- pass_along_path ( & nodes, None ) ;
200+ pass_along_path ( & nodes) ;
181201
182202 // Try with a two-hop blinded path where we are the introduction node.
183203 let blinded_path = BlindedPath :: new_for_message ( & [ nodes[ 0 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) ] , & * nodes[ 1 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
184204 nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , OnionMessageContents :: Custom ( test_msg) , None ) . unwrap ( ) ;
185205 nodes. remove ( 2 ) ;
186- pass_along_path ( & nodes, None ) ;
206+ pass_along_path ( & nodes) ;
187207}
188208
189209#[ test]
@@ -216,7 +236,7 @@ fn reply_path() {
216236 // Destination::Node
217237 let reply_path = BlindedPath :: new_for_message ( & [ nodes[ 2 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) , nodes[ 0 ] . get_node_pk ( ) ] , & * nodes[ 0 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
218238 nodes[ 0 ] . messenger . send_onion_message ( & [ nodes[ 1 ] . get_node_pk ( ) , nodes[ 2 ] . get_node_pk ( ) ] , Destination :: Node ( nodes[ 3 ] . get_node_pk ( ) ) , OnionMessageContents :: Custom ( test_msg. clone ( ) ) , Some ( reply_path) ) . unwrap ( ) ;
219- pass_along_path ( & nodes, None ) ;
239+ pass_along_path ( & nodes) ;
220240 // Make sure the last node successfully decoded the reply path.
221241 nodes[ 3 ] . logger . assert_log_contains (
222242 "lightning::onion_message::messenger" ,
@@ -227,7 +247,7 @@ fn reply_path() {
227247 let reply_path = BlindedPath :: new_for_message ( & [ nodes[ 2 ] . get_node_pk ( ) , nodes[ 1 ] . get_node_pk ( ) , nodes[ 0 ] . get_node_pk ( ) ] , & * nodes[ 0 ] . keys_manager , & secp_ctx) . unwrap ( ) ;
228248
229249 nodes[ 0 ] . messenger . send_onion_message ( & [ ] , Destination :: BlindedPath ( blinded_path) , OnionMessageContents :: Custom ( test_msg) , Some ( reply_path) ) . unwrap ( ) ;
230- pass_along_path ( & nodes, None ) ;
250+ pass_along_path ( & nodes) ;
231251 nodes[ 3 ] . logger . assert_log_contains (
232252 "lightning::onion_message::messenger" ,
233253 & format ! ( "Received an onion message with path_id None and a reply_path" ) , 2 ) ;
0 commit comments