@@ -779,67 +779,98 @@ where
779779fn calculate_amount_to_forward_per_htlc (
780780 htlcs : & [ InterceptedHTLC ] , total_amt_to_forward_msat : u64 ,
781781) -> Vec < ( InterceptId , u64 ) > {
782+ // TODO: we should eventually make sure the HTLCs are all above ChannelDetails::next_outbound_minimum_msat
782783 let total_received_msat: u64 =
783784 htlcs. iter ( ) . map ( |htlc| htlc. expected_outbound_amount_msat ) . sum ( ) ;
784785
785- let mut fee_remaining_msat = total_received_msat - total_amt_to_forward_msat;
786- let total_fee_msat = fee_remaining_msat;
786+ match total_received_msat. checked_sub ( total_amt_to_forward_msat) {
787+ Some ( total_fee_msat) => {
788+ let mut fee_remaining_msat = total_fee_msat;
787789
788- let mut per_htlc_forwards = vec ! [ ] ;
790+ let mut per_htlc_forwards = vec ! [ ] ;
789791
790- for ( index, htlc) in htlcs. iter ( ) . enumerate ( ) {
791- let proportional_fee_amt_msat =
792- total_fee_msat * htlc. expected_outbound_amount_msat / total_received_msat;
792+ for ( index, htlc) in htlcs. iter ( ) . enumerate ( ) {
793+ let proportional_fee_amt_msat =
794+ total_fee_msat * ( htlc. expected_outbound_amount_msat / total_received_msat) ;
793795
794- let mut actual_fee_amt_msat = core:: cmp:: min ( fee_remaining_msat, proportional_fee_amt_msat) ;
795- fee_remaining_msat -= actual_fee_amt_msat;
796+ let mut actual_fee_amt_msat =
797+ core:: cmp:: min ( fee_remaining_msat, proportional_fee_amt_msat) ;
798+ fee_remaining_msat -= actual_fee_amt_msat;
796799
797- if index == htlcs. len ( ) - 1 {
798- actual_fee_amt_msat += fee_remaining_msat;
799- }
800+ if index == htlcs. len ( ) - 1 {
801+ actual_fee_amt_msat += fee_remaining_msat;
802+ }
800803
801- let amount_to_forward_msat = htlc. expected_outbound_amount_msat - actual_fee_amt_msat;
804+ let amount_to_forward_msat =
805+ htlc. expected_outbound_amount_msat . saturating_sub ( actual_fee_amt_msat) ;
802806
803- per_htlc_forwards. push ( ( htlc. intercept_id , amount_to_forward_msat) )
804- }
807+ per_htlc_forwards. push ( ( htlc. intercept_id , amount_to_forward_msat) )
808+ }
805809
806- per_htlc_forwards
810+ per_htlc_forwards
811+ }
812+ None => Vec :: new ( ) ,
813+ }
807814}
808815
809816#[ cfg( test) ]
810817mod tests {
811818
812819 use super :: * ;
820+ use proptest:: prelude:: * ;
813821
814- #[ test]
815- fn test_calculate_amount_to_forward ( ) {
816- // TODO: Use proptest to generate random allocations
817- let htlcs = vec ! [
818- InterceptedHTLC {
819- intercept_id: InterceptId ( [ 0 ; 32 ] ) ,
820- expected_outbound_amount_msat: 1000 ,
821- } ,
822- InterceptedHTLC {
823- intercept_id: InterceptId ( [ 1 ; 32 ] ) ,
824- expected_outbound_amount_msat: 2000 ,
825- } ,
826- InterceptedHTLC {
827- intercept_id: InterceptId ( [ 2 ; 32 ] ) ,
828- expected_outbound_amount_msat: 3000 ,
829- } ,
830- ] ;
831-
832- let total_amt_to_forward_msat = 5000 ;
833-
834- let result = calculate_amount_to_forward_per_htlc ( & htlcs, total_amt_to_forward_msat) ;
822+ const MAX_VALUE_MSAT : u64 = 21_000_000_0000_0000_000 ;
835823
836- assert_eq ! ( result[ 0 ] . 0 , htlcs[ 0 ] . intercept_id) ;
837- assert_eq ! ( result[ 0 ] . 1 , 834 ) ;
824+ fn arb_forward_amounts ( ) -> impl Strategy < Value = ( u64 , u64 , u64 , u64 ) > {
825+ ( 1u64 ..MAX_VALUE_MSAT , 1u64 ..MAX_VALUE_MSAT , 1u64 ..MAX_VALUE_MSAT , 1u64 ..MAX_VALUE_MSAT )
826+ . prop_map ( |( a, b, c, d) | {
827+ ( a, b, c, core:: cmp:: min ( d, a. saturating_add ( b) . saturating_add ( c) ) )
828+ } )
829+ }
838830
839- assert_eq ! ( result[ 1 ] . 0 , htlcs[ 1 ] . intercept_id) ;
840- assert_eq ! ( result[ 1 ] . 1 , 1667 ) ;
831+ proptest ! {
832+ #[ test]
833+ fn test_calculate_amount_to_forward( ( o_0, o_1, o_2, total_amt_to_forward_msat) in arb_forward_amounts( ) ) {
834+ let htlcs = vec![
835+ InterceptedHTLC {
836+ intercept_id: InterceptId ( [ 0 ; 32 ] ) ,
837+ expected_outbound_amount_msat: o_0
838+ } ,
839+ InterceptedHTLC {
840+ intercept_id: InterceptId ( [ 1 ; 32 ] ) ,
841+ expected_outbound_amount_msat: o_1
842+ } ,
843+ InterceptedHTLC {
844+ intercept_id: InterceptId ( [ 2 ; 32 ] ) ,
845+ expected_outbound_amount_msat: o_2
846+ } ,
847+ ] ;
848+
849+ let result = calculate_amount_to_forward_per_htlc( & htlcs, total_amt_to_forward_msat) ;
850+ let total_received_msat = o_0 + o_1 + o_2;
851+
852+ if total_received_msat < total_amt_to_forward_msat {
853+ assert_eq!( result. len( ) , 0 ) ;
854+ } else {
855+ assert_ne!( result. len( ) , 0 ) ;
856+ assert_eq!( result[ 0 ] . 0 , htlcs[ 0 ] . intercept_id) ;
857+ assert_eq!( result[ 1 ] . 0 , htlcs[ 1 ] . intercept_id) ;
858+ assert_eq!( result[ 2 ] . 0 , htlcs[ 2 ] . intercept_id) ;
859+ assert!( result[ 0 ] . 1 <= o_0) ;
860+ assert!( result[ 1 ] . 1 <= o_1) ;
861+ assert!( result[ 2 ] . 1 <= o_2) ;
862+
863+ let result_sum = result. iter( ) . map( |( _, f) | f) . sum:: <u64 >( ) ;
864+ assert!( result_sum >= total_amt_to_forward_msat) ;
865+ let five_pct = result_sum as f32 * 0.1 ;
866+ let fair_share_0 = ( ( o_0 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max( o_0 as f32 ) ;
867+ assert!( result[ 0 ] . 1 as f32 <= fair_share_0 + five_pct) ;
868+ let fair_share_1 = ( ( o_1 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max( o_1 as f32 ) ;
869+ assert!( result[ 1 ] . 1 as f32 <= fair_share_1 + five_pct) ;
870+ let fair_share_2 = ( ( o_2 as f32 / total_received_msat as f32 ) * result_sum as f32 ) . max( o_2 as f32 ) ;
871+ assert!( result[ 2 ] . 1 as f32 <= fair_share_2 + five_pct) ;
872+ }
841873
842- assert_eq ! ( result[ 2 ] . 0 , htlcs[ 2 ] . intercept_id) ;
843- assert_eq ! ( result[ 2 ] . 1 , 2499 ) ;
874+ }
844875 }
845876}
0 commit comments