@@ -117,6 +117,10 @@ func TestInvoiceRegistry(t *testing.T) {
117117 name : "FailPartialAMPPayment" ,
118118 test : testFailPartialAMPPayment ,
119119 },
120+ {
121+ name : "CancelAMPInvoicePendingHTLCs" ,
122+ test : testCancelAMPInvoicePendingHTLCs ,
123+ },
120124 }
121125
122126 makeKeyValueDB := func (t * testing.T ) (invpkg.InvoiceDB ,
@@ -2441,3 +2445,130 @@ func testFailPartialAMPPayment(t *testing.T,
24412445 "expected HTLC to be canceled" )
24422446 }
24432447}
2448+
2449+ // testCancelAMPInvoicePendingHTLCs tests the case where an AMP invoice is
2450+ // canceled and the remaining HTLCs are also canceled so that no HTLCs are left
2451+ // in the accepted state.
2452+ func testCancelAMPInvoicePendingHTLCs (t * testing.T ,
2453+ makeDB func (t * testing.T ) (invpkg.InvoiceDB , * clock.TestClock )) {
2454+
2455+ t .Parallel ()
2456+
2457+ ctx := newTestContext (t , nil , makeDB )
2458+ ctxb := context .Background ()
2459+
2460+ const (
2461+ expiry = uint32 (testCurrentHeight + 20 )
2462+ numShards = 4
2463+ )
2464+
2465+ var (
2466+ shardAmt = testInvoiceAmount / lnwire .MilliSatoshi (numShards )
2467+ payAddr [32 ]byte
2468+ )
2469+ _ , err := rand .Read (payAddr [:])
2470+ require .NoError (t , err )
2471+
2472+ // Create an AMP invoice we are going to pay via a multi-part payment.
2473+ ampInvoice := newInvoice (t , false , true )
2474+
2475+ // An AMP invoice is referenced by the payment address.
2476+ ampInvoice .Terms .PaymentAddr = payAddr
2477+
2478+ _ , err = ctx .registry .AddInvoice (
2479+ ctxb , ampInvoice , testInvoicePaymentHash ,
2480+ )
2481+ require .NoError (t , err )
2482+
2483+ htlcPayloadSet1 := & mockPayload {
2484+ mpp : record .NewMPP (testInvoiceAmount , payAddr ),
2485+ // We are not interested in settling the AMP HTLC so we don't
2486+ // use valid shares.
2487+ amp : record .NewAMP ([32 ]byte {1 }, [32 ]byte {1 }, 1 ),
2488+ }
2489+
2490+ // Send first HTLC which pays part of the invoice.
2491+ hodlChan1 := make (chan interface {}, 1 )
2492+ resolution , err := ctx .registry .NotifyExitHopHtlc (
2493+ lntypes.Hash {1 }, shardAmt , expiry , testCurrentHeight ,
2494+ getCircuitKey (1 ), hodlChan1 , nil , htlcPayloadSet1 ,
2495+ )
2496+ require .NoError (t , err )
2497+ require .Nil (t , resolution , "did not expect direct resolution" )
2498+
2499+ htlcPayloadSet2 := & mockPayload {
2500+ mpp : record .NewMPP (testInvoiceAmount , payAddr ),
2501+ // We are not interested in settling the AMP HTLC so we don't
2502+ // use valid shares.
2503+ amp : record .NewAMP ([32 ]byte {2 }, [32 ]byte {2 }, 1 ),
2504+ }
2505+
2506+ // Send htlc 2 which should be added to the invoice as expected.
2507+ hodlChan2 := make (chan interface {}, 1 )
2508+ resolution , err = ctx .registry .NotifyExitHopHtlc (
2509+ lntypes.Hash {2 }, shardAmt , expiry , testCurrentHeight ,
2510+ getCircuitKey (2 ), hodlChan2 , nil , htlcPayloadSet2 ,
2511+ )
2512+ require .NoError (t , err )
2513+ require .Nil (t , resolution , "did not expect direct resolution" )
2514+
2515+ require .Eventuallyf (t , func () bool {
2516+ inv , err := ctx .registry .LookupInvoice (
2517+ ctxb , testInvoicePaymentHash ,
2518+ )
2519+ require .NoError (t , err )
2520+
2521+ return len (inv .Htlcs ) == 2
2522+ }, testTimeout , time .Millisecond * 100 , "HTLCs not added to invoice" )
2523+
2524+ // expire the invoice here.
2525+ ctx .clock .SetTime (testTime .Add (65 * time .Minute ))
2526+
2527+ // Expect HLTC 1 to be canceled via the MPPTimeout fail resolution.
2528+ select {
2529+ case resolution := <- hodlChan1 :
2530+ htlcResolution , _ := resolution .(invpkg.HtlcResolution )
2531+ _ , ok := htlcResolution .(* invpkg.HtlcFailResolution )
2532+ require .True (
2533+ t , ok , "expected fail resolution, got: %T" , resolution ,
2534+ )
2535+
2536+ case <- time .After (testTimeout ):
2537+ t .Fatal ("timeout waiting for HTLC resolution" )
2538+ }
2539+
2540+ // Expect HLTC 2 to be canceled via the MPPTimeout fail resolution.
2541+ select {
2542+ case resolution := <- hodlChan2 :
2543+ htlcResolution , _ := resolution .(invpkg.HtlcResolution )
2544+ _ , ok := htlcResolution .(* invpkg.HtlcFailResolution )
2545+ require .True (
2546+ t , ok , "expected fail resolution, got: %T" , resolution ,
2547+ )
2548+
2549+ case <- time .After (testTimeout ):
2550+ t .Fatal ("timeout waiting for HTLC resolution" )
2551+ }
2552+
2553+ require .Eventuallyf (t , func () bool {
2554+ inv , err := ctx .registry .LookupInvoice (
2555+ ctxb , testInvoicePaymentHash ,
2556+ )
2557+ require .NoError (t , err )
2558+
2559+ return inv .State == invpkg .ContractCanceled
2560+ }, testTimeout , time .Millisecond * 100 , "invoice not canceled" )
2561+
2562+ // Fetch the invoice again and compare the number of cancelled HTLCs.
2563+ inv , err := ctx .registry .LookupInvoice (
2564+ ctxb , testInvoicePaymentHash ,
2565+ )
2566+ require .NoError (t , err )
2567+
2568+ // Make sure all HTLCs are in the cancelled state.
2569+ require .Len (t , inv .Htlcs , 2 )
2570+ for _ , htlc := range inv .Htlcs {
2571+ require .Equal (t , invpkg .HtlcStateCanceled , htlc .State ,
2572+ "expected HTLC to be canceled" )
2573+ }
2574+ }
0 commit comments