@@ -18,6 +18,7 @@ use lightning::util::logger::Logger;
1818use secp256k1:: PublicKey ;
1919use core:: ops:: Deref ;
2020use core:: time:: Duration ;
21+ use core:: iter:: Iterator ;
2122
2223/// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice."
2324/// See [`PhantomKeysManager`] for more information on phantom node payments.
@@ -132,6 +133,8 @@ where
132133 )
133134}
134135
136+ const MAX_CHANNEL_HINTS : usize = 3 ;
137+
135138fn _create_phantom_invoice < ES : Deref , NS : Deref , L : Deref > (
136139 amt_msat : Option < u64 > , payment_hash : Option < PaymentHash > , description : InvoiceDescription ,
137140 invoice_expiry_delta_secs : u32 , phantom_route_hints : Vec < PhantomRouteHints > , entropy_source : ES ,
@@ -202,7 +205,8 @@ where
202205 invoice = invoice. amount_milli_satoshis ( amt) ;
203206 }
204207
205- for route_hint in select_phantom_hints ( amt_msat, phantom_route_hints, logger) {
208+
209+ for route_hint in select_phantom_hints ( amt_msat, phantom_route_hints, logger) . take ( MAX_CHANNEL_HINTS ) {
206210 invoice = invoice. private_route ( route_hint) ;
207211 }
208212
@@ -229,36 +233,48 @@ where
229233///
230234/// [`PhantomKeysManager`]: lightning::sign::PhantomKeysManager
231235fn select_phantom_hints < L : Deref > ( amt_msat : Option < u64 > , phantom_route_hints : Vec < PhantomRouteHints > ,
232- logger : L ) -> Vec < RouteHint >
236+ logger : L ) -> impl Iterator < Item = RouteHint >
233237where
234238 L :: Target : Logger ,
235239{
236- let mut phantom_hints: Vec < Vec < RouteHint > > = Vec :: new ( ) ;
240+ let mut phantom_hints: Vec < _ > = Vec :: new ( ) ;
237241
238242 for PhantomRouteHints { channels, phantom_scid, real_node_pubkey } in phantom_route_hints {
239243 log_trace ! ( logger, "Generating phantom route hints for node {}" ,
240244 log_pubkey!( real_node_pubkey) ) ;
241- let mut route_hints = sort_and_filter_channels ( channels, amt_msat, & logger) ;
245+ let route_hints = sort_and_filter_channels ( channels, amt_msat, & logger) ;
242246
243247 // If we have any public channel, the route hints from `sort_and_filter_channels` will be
244248 // empty. In that case we create a RouteHint on which we will push a single hop with the
245249 // phantom route into the invoice, and let the sender find the path to the `real_node_pubkey`
246250 // node by looking at our public channels.
247- if route_hints. is_empty ( ) {
248- route_hints. push ( RouteHint ( vec ! [ ] ) )
249- }
250- for route_hint in & mut route_hints {
251- route_hint. 0 . push ( RouteHintHop {
252- src_node_id : real_node_pubkey,
253- short_channel_id : phantom_scid,
254- fees : RoutingFees {
255- base_msat : 0 ,
256- proportional_millionths : 0 ,
257- } ,
258- cltv_expiry_delta : MIN_CLTV_EXPIRY_DELTA ,
259- htlc_minimum_msat : None ,
260- htlc_maximum_msat : None , } ) ;
261- }
251+ let empty_route_hints = route_hints. len ( ) == 0 ;
252+ let mut have_pushed_empty = false ;
253+ let route_hints = route_hints
254+ . chain ( core:: iter:: from_fn ( move || {
255+ if empty_route_hints && !have_pushed_empty {
256+ // set flag of having handled the empty route_hints and ensure empty vector
257+ // returned only once
258+ have_pushed_empty = true ;
259+ Some ( RouteHint ( Vec :: new ( ) ) )
260+ } else {
261+ None
262+ }
263+ } ) )
264+ . map ( move |mut hint| {
265+ hint. 0 . push ( RouteHintHop {
266+ src_node_id : real_node_pubkey,
267+ short_channel_id : phantom_scid,
268+ fees : RoutingFees {
269+ base_msat : 0 ,
270+ proportional_millionths : 0 ,
271+ } ,
272+ cltv_expiry_delta : MIN_CLTV_EXPIRY_DELTA ,
273+ htlc_minimum_msat : None ,
274+ htlc_maximum_msat : None ,
275+ } ) ;
276+ hint
277+ } ) ;
262278
263279 phantom_hints. push ( route_hints) ;
264280 }
@@ -267,29 +283,34 @@ where
267283 // the hints across our real nodes we add one hint from each in turn until no node has any hints
268284 // left (if one node has more hints than any other, these will accumulate at the end of the
269285 // vector).
270- let mut invoice_hints : Vec < RouteHint > = Vec :: new ( ) ;
271- let mut hint_idx = 0 ;
286+ rotate_through_iterators ( phantom_hints )
287+ }
272288
273- loop {
274- let mut remaining_hints = false ;
289+ /// Draw items iteratively from multiple iterators. The items are retrieved by index and
290+ /// rotates through the iterators - first the zero index then the first index then second index, etc.
291+ fn rotate_through_iterators < T , I : Iterator < Item = T > > ( mut vecs : Vec < I > ) -> impl Iterator < Item = T > {
292+ let mut iterations = 0 ;
275293
276- for hints in phantom_hints. iter ( ) {
277- if invoice_hints. len ( ) == 3 {
278- return invoice_hints
294+ core:: iter:: from_fn ( move || {
295+ let mut exhausted_iterators = 0 ;
296+ loop {
297+ if vecs. is_empty ( ) {
298+ return None ;
279299 }
280-
281- if hint_idx < hints. len ( ) {
282- invoice_hints. push ( hints[ hint_idx] . clone ( ) ) ;
283- remaining_hints = true
300+ let next_idx = iterations % vecs. len ( ) ;
301+ iterations += 1 ;
302+ if let Some ( item) = vecs[ next_idx] . next ( ) {
303+ return Some ( item) ;
304+ }
305+ // exhausted_vectors increase when the "next_idx" vector is exhausted
306+ exhausted_iterators += 1 ;
307+ // The check for exhausted iterators gets reset to 0 after each yield of `Some()`
308+ // The loop will return None when all of the nested iterators are exhausted
309+ if exhausted_iterators == vecs. len ( ) {
310+ return None ;
284311 }
285312 }
286-
287- if !remaining_hints {
288- return invoice_hints
289- }
290-
291- hint_idx +=1 ;
292- }
313+ } )
293314}
294315
295316#[ cfg( feature = "std" ) ]
@@ -575,15 +596,34 @@ fn _create_invoice_from_channelmanager_and_duration_since_epoch_with_payment_has
575596/// * Sorted by lowest inbound capacity if an online channel with the minimum amount requested exists,
576597/// otherwise sort by highest inbound capacity to give the payment the best chance of succeeding.
577598fn sort_and_filter_channels < L : Deref > (
578- channels : Vec < ChannelDetails > , min_inbound_capacity_msat : Option < u64 > , logger : & L
579- ) -> Vec < RouteHint > where L :: Target : Logger {
599+ channels : Vec < ChannelDetails > ,
600+ min_inbound_capacity_msat : Option < u64 > ,
601+ logger : & L ,
602+ ) -> impl ExactSizeIterator < Item = RouteHint >
603+ where
604+ L :: Target : Logger ,
605+ {
580606 let mut filtered_channels: HashMap < PublicKey , ChannelDetails > = HashMap :: new ( ) ;
581607 let min_inbound_capacity = min_inbound_capacity_msat. unwrap_or ( 0 ) ;
582608 let mut min_capacity_channel_exists = false ;
583609 let mut online_channel_exists = false ;
584610 let mut online_min_capacity_channel_exists = false ;
585611 let mut has_pub_unconf_chan = false ;
586612
613+ let route_hint_from_channel = |channel : ChannelDetails | {
614+ let forwarding_info = channel. counterparty . forwarding_info . as_ref ( ) . unwrap ( ) ;
615+ RouteHint ( vec ! [ RouteHintHop {
616+ src_node_id: channel. counterparty. node_id,
617+ short_channel_id: channel. get_inbound_payment_scid( ) . unwrap( ) ,
618+ fees: RoutingFees {
619+ base_msat: forwarding_info. fee_base_msat,
620+ proportional_millionths: forwarding_info. fee_proportional_millionths,
621+ } ,
622+ cltv_expiry_delta: forwarding_info. cltv_expiry_delta,
623+ htlc_minimum_msat: channel. inbound_htlc_minimum_msat,
624+ htlc_maximum_msat: channel. inbound_htlc_maximum_msat, } ] )
625+ } ;
626+
587627 log_trace ! ( logger, "Considering {} channels for invoice route hints" , channels. len( ) ) ;
588628 for channel in channels. into_iter ( ) . filter ( |chan| chan. is_channel_ready ) {
589629 if channel. get_inbound_payment_scid ( ) . is_none ( ) || channel. counterparty . forwarding_info . is_none ( ) {
@@ -602,7 +642,7 @@ fn sort_and_filter_channels<L: Deref>(
602642 // look at the public channels instead.
603643 log_trace ! ( logger, "Not including channels in invoice route hints on account of public channel {}" ,
604644 log_bytes!( channel. channel_id) ) ;
605- return vec ! [ ]
645+ return vec ! [ ] . into_iter ( ) . take ( MAX_CHANNEL_HINTS ) . map ( route_hint_from_channel ) ;
606646 }
607647 }
608648
@@ -662,19 +702,6 @@ fn sort_and_filter_channels<L: Deref>(
662702 }
663703 }
664704
665- let route_hint_from_channel = |channel : ChannelDetails | {
666- let forwarding_info = channel. counterparty . forwarding_info . as_ref ( ) . unwrap ( ) ;
667- RouteHint ( vec ! [ RouteHintHop {
668- src_node_id: channel. counterparty. node_id,
669- short_channel_id: channel. get_inbound_payment_scid( ) . unwrap( ) ,
670- fees: RoutingFees {
671- base_msat: forwarding_info. fee_base_msat,
672- proportional_millionths: forwarding_info. fee_proportional_millionths,
673- } ,
674- cltv_expiry_delta: forwarding_info. cltv_expiry_delta,
675- htlc_minimum_msat: channel. inbound_htlc_minimum_msat,
676- htlc_maximum_msat: channel. inbound_htlc_maximum_msat, } ] )
677- } ;
678705 // If all channels are private, prefer to return route hints which have a higher capacity than
679706 // the payment value and where we're currently connected to the channel counterparty.
680707 // Even if we cannot satisfy both goals, always ensure we include *some* hints, preferring
@@ -724,7 +751,8 @@ fn sort_and_filter_channels<L: Deref>(
724751 } else {
725752 b. inbound_capacity_msat . cmp ( & a. inbound_capacity_msat )
726753 } } ) ;
727- eligible_channels. into_iter ( ) . take ( 3 ) . map ( route_hint_from_channel) . collect :: < Vec < RouteHint > > ( )
754+
755+ eligible_channels. into_iter ( ) . take ( MAX_CHANNEL_HINTS ) . map ( route_hint_from_channel)
728756}
729757
730758/// prefer_current_channel chooses a channel to use for route hints between a currently selected and candidate
@@ -777,7 +805,7 @@ mod test {
777805 use lightning:: routing:: router:: { PaymentParameters , RouteParameters } ;
778806 use lightning:: util:: test_utils;
779807 use lightning:: util:: config:: UserConfig ;
780- use crate :: utils:: create_invoice_from_channelmanager_and_duration_since_epoch;
808+ use crate :: utils:: { create_invoice_from_channelmanager_and_duration_since_epoch, rotate_through_iterators } ;
781809 use std:: collections:: HashSet ;
782810
783811 #[ test]
@@ -1886,4 +1914,111 @@ mod test {
18861914 _ => panic ! ( ) ,
18871915 }
18881916 }
1917+
1918+ #[ test]
1919+ fn test_rotate_through_iterators ( ) {
1920+ // two nested vectors
1921+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) , vec![ "a1" , "b1" ] . into_iter( ) ] ;
1922+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1923+
1924+ let expected = vec ! [ "a0" , "a1" , "b0" , "b1" , "c0" ] ;
1925+ assert_eq ! ( expected, result) ;
1926+
1927+ // test single nested vector
1928+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) ] ;
1929+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1930+
1931+ let expected = vec ! [ "a0" , "b0" , "c0" ] ;
1932+ assert_eq ! ( expected, result) ;
1933+
1934+ // test second vector with only one element
1935+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) , vec![ "a1" ] . into_iter( ) ] ;
1936+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1937+
1938+ let expected = vec ! [ "a0" , "a1" , "b0" , "c0" ] ;
1939+ assert_eq ! ( expected, result) ;
1940+
1941+ // test three nestend vectors
1942+ let a = vec ! [ vec![ "a0" ] . into_iter( ) , vec![ "a1" , "b1" , "c1" ] . into_iter( ) , vec![ "a2" ] . into_iter( ) ] ;
1943+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1944+
1945+ let expected = vec ! [ "a0" , "a1" , "a2" , "b1" , "c1" ] ;
1946+ assert_eq ! ( expected, result) ;
1947+
1948+ // test single nested vector with a single value
1949+ let a = vec ! [ vec![ "a0" ] . into_iter( ) ] ;
1950+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1951+
1952+ let expected = vec ! [ "a0" ] ;
1953+ assert_eq ! ( expected, result) ;
1954+
1955+ // test single empty nested vector
1956+ let a: Vec < std:: vec:: IntoIter < & str > > = vec ! [ vec![ ] . into_iter( ) ] ;
1957+ let result = rotate_through_iterators ( a) . collect :: < Vec < & str > > ( ) ;
1958+ let expected: Vec < & str > = vec ! [ ] ;
1959+
1960+ assert_eq ! ( expected, result) ;
1961+
1962+ // test first nested vector is empty
1963+ let a: Vec < std:: vec:: IntoIter < & str > > = vec ! [ vec![ ] . into_iter( ) , vec![ "a1" , "b1" , "c1" ] . into_iter( ) ] ;
1964+ let result = rotate_through_iterators ( a) . collect :: < Vec < & str > > ( ) ;
1965+
1966+ let expected = vec ! [ "a1" , "b1" , "c1" ] ;
1967+ assert_eq ! ( expected, result) ;
1968+
1969+ // test two empty vectors
1970+ let a: Vec < std:: vec:: IntoIter < & str > > = vec ! [ vec![ ] . into_iter( ) , vec![ ] . into_iter( ) ] ;
1971+ let result = rotate_through_iterators ( a) . collect :: < Vec < & str > > ( ) ;
1972+
1973+ let expected: Vec < & str > = vec ! [ ] ;
1974+ assert_eq ! ( expected, result) ;
1975+
1976+ // test an empty vector amongst other filled vectors
1977+ let a = vec ! [
1978+ vec![ "a0" , "b0" , "c0" ] . into_iter( ) ,
1979+ vec![ ] . into_iter( ) ,
1980+ vec![ "a1" , "b1" , "c1" ] . into_iter( ) ,
1981+ vec![ "a2" , "b2" , "c2" ] . into_iter( ) ,
1982+ ] ;
1983+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1984+
1985+ let expected = vec ! [ "a0" , "a1" , "a2" , "b0" , "b1" , "b2" , "c0" , "c1" , "c2" ] ;
1986+ assert_eq ! ( expected, result) ;
1987+
1988+ // test a filled vector between two empty vectors
1989+ let a = vec ! [ vec![ ] . into_iter( ) , vec![ "a1" , "b1" , "c1" ] . into_iter( ) , vec![ ] . into_iter( ) ] ;
1990+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1991+
1992+ let expected = vec ! [ "a1" , "b1" , "c1" ] ;
1993+ assert_eq ! ( expected, result) ;
1994+
1995+ // test an empty vector at the end of the vectors
1996+ let a = vec ! [ vec![ "a0" , "b0" , "c0" ] . into_iter( ) , vec![ ] . into_iter( ) ] ;
1997+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
1998+
1999+ let expected = vec ! [ "a0" , "b0" , "c0" ] ;
2000+ assert_eq ! ( expected, result) ;
2001+
2002+ // test multiple empty vectors amongst multiple filled vectors
2003+ let a = vec ! [
2004+ vec![ ] . into_iter( ) ,
2005+ vec![ "a1" , "b1" , "c1" ] . into_iter( ) ,
2006+ vec![ ] . into_iter( ) ,
2007+ vec![ "a3" , "b3" ] . into_iter( ) ,
2008+ vec![ ] . into_iter( ) ,
2009+ ] ;
2010+
2011+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
2012+
2013+ let expected = vec ! [ "a1" , "a3" , "b1" , "b3" , "c1" ] ;
2014+ assert_eq ! ( expected, result) ;
2015+
2016+ // test one element in the first nested vectore and two elements in the second nested
2017+ // vector
2018+ let a = vec ! [ vec![ "a0" ] . into_iter( ) , vec![ "a1" , "b1" ] . into_iter( ) ] ;
2019+ let result = rotate_through_iterators ( a) . collect :: < Vec < _ > > ( ) ;
2020+
2021+ let expected = vec ! [ "a0" , "a1" , "b1" ] ;
2022+ assert_eq ! ( expected, result) ;
2023+ }
18892024}
0 commit comments