Skip to content

Commit bba01cf

Browse files
committed
routing+routerrpc: cancelable context in SendPaymentV2
In this commit we set up the payment loop context according to user-provided parameters. The `cancelable` parameter indicates whether the user is able to interrupt the payment loop by cancelling the server stream context. We'll additionally wrap the context in a deadline if the user provided a payment timeout. We remove the timeout channel of the payment_lifecycle.go and in favor of the deadline context.
1 parent e729084 commit bba01cf

File tree

4 files changed

+163
-69
lines changed

4 files changed

+163
-69
lines changed

lnrpc/routerrpc/router_server.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,25 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
360360
return err
361361
}
362362

363+
// The payment context is influenced by two user-provided parameters,
364+
// the cancelable flag and the payment attempt timeout.
365+
// If the payment is cancelable, we will use the stream context as the
366+
// payment context. That way, if the user ends the stream, the payment
367+
// loop will be canceled.
368+
// The second context parameter is the timeout. If the user provides a
369+
// timeout, we will additionally wrap the context in a deadline. If the
370+
// user provided 'cancelable' and ends the stream before the timeout is
371+
// reached the payment will be canceled.
372+
ctx := context.Background()
373+
if req.Cancelable {
374+
ctx = stream.Context()
375+
}
376+
363377
// Send the payment asynchronously.
364-
s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
378+
s.cfg.Router.SendPaymentAsync(ctx, payment, paySession, shardTracker)
365379

366380
// Track the payment and return.
367-
return s.trackPayment(
368-
sub, payHash, stream, req.NoInflightUpdates,
369-
)
381+
return s.trackPayment(sub, payHash, stream, req.NoInflightUpdates)
370382
}
371383

372384
// EstimateRouteFee allows callers to obtain an expected value w.r.t how much it

routing/payment_lifecycle.go

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package routing
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"time"
@@ -29,7 +30,6 @@ type paymentLifecycle struct {
2930
identifier lntypes.Hash
3031
paySession PaymentSession
3132
shardTracker shards.ShardTracker
32-
timeoutChan <-chan time.Time
3333
currentHeight int32
3434

3535
// quit is closed to signal the sub goroutines of the payment lifecycle
@@ -52,7 +52,7 @@ type paymentLifecycle struct {
5252
// newPaymentLifecycle initiates a new payment lifecycle and returns it.
5353
func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
5454
identifier lntypes.Hash, paySession PaymentSession,
55-
shardTracker shards.ShardTracker, timeout time.Duration,
55+
shardTracker shards.ShardTracker,
5656
currentHeight int32) *paymentLifecycle {
5757

5858
p := &paymentLifecycle{
@@ -69,13 +69,6 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi,
6969
// Mount the result collector.
7070
p.resultCollector = p.collectResultAsync
7171

72-
// If a timeout is specified, create a timeout channel. If no timeout is
73-
// specified, the channel is left nil and will never abort the payment
74-
// loop.
75-
if timeout != 0 {
76-
p.timeoutChan = time.After(timeout)
77-
}
78-
7972
return p
8073
}
8174

@@ -167,7 +160,9 @@ func (p *paymentLifecycle) decideNextStep(
167160
}
168161

169162
// resumePayment resumes the paymentLifecycle from the current state.
170-
func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
163+
func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte,
164+
*route.Route, error) {
165+
171166
// When the payment lifecycle loop exits, we make sure to signal any
172167
// sub goroutine of the HTLC attempt to exit, then wait for them to
173168
// return.
@@ -221,18 +216,17 @@ lifecycle:
221216

222217
// We now proceed our lifecycle with the following tasks in
223218
// order,
224-
// 1. check timeout.
219+
// 1. check context.
225220
// 2. request route.
226221
// 3. create HTLC attempt.
227222
// 4. send HTLC attempt.
228223
// 5. collect HTLC attempt result.
229224
//
230-
// Before we attempt any new shard, we'll check to see if
231-
// either we've gone past the payment attempt timeout, or the
232-
// router is exiting. In either case, we'll stop this payment
233-
// attempt short. If a timeout is not applicable, timeoutChan
234-
// will be nil.
235-
if err := p.checkTimeout(); err != nil {
225+
// Before we attempt any new shard, we'll check to see if we've
226+
// gone past the payment attempt timeout, or if the context was
227+
// cancelled, or the router is exiting. In any of these cases,
228+
// we'll stop this payment attempt short.
229+
if err := p.checkContext(ctx); err != nil {
236230
return exitWithErr(err)
237231
}
238232

@@ -318,19 +312,30 @@ lifecycle:
318312
return [32]byte{}, nil, *failure
319313
}
320314

321-
// checkTimeout checks whether the payment has reached its timeout.
322-
func (p *paymentLifecycle) checkTimeout() error {
315+
// checkContext checks whether the payment context has been canceled.
316+
// Cancellation occurs manually or if the context times out.
317+
func (p *paymentLifecycle) checkContext(ctx context.Context) error {
323318
select {
324-
case <-p.timeoutChan:
325-
log.Warnf("payment attempt not completed before timeout")
319+
case <-ctx.Done():
320+
// If the context was canceled, we'll mark the payment as
321+
// failed. There are two cases to distinguish here: Either a
322+
// user-provided timeout was reached, or the context was
323+
// canceled, either to a manual cancellation or due to an
324+
// unknown error.
325+
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
326+
log.Warnf("Payment attempt not completed before "+
327+
"timeout, id=%s", p.identifier.String())
328+
} else {
329+
log.Warnf("Payment attempt context canceled, id=%s",
330+
p.identifier.String())
331+
}
326332

327333
// By marking the payment failed, depending on whether it has
328334
// inflight HTLCs or not, its status will now either be
329335
// `StatusInflight` or `StatusFailed`. In either case, no more
330336
// HTLCs will be attempted.
331-
err := p.router.cfg.Control.FailPayment(
332-
p.identifier, channeldb.FailureReasonTimeout,
333-
)
337+
reason := channeldb.FailureReasonTimeout
338+
err := p.router.cfg.Control.FailPayment(p.identifier, reason)
334339
if err != nil {
335340
return fmt.Errorf("FailPayment got %w", err)
336341
}

routing/payment_lifecycle_test.go

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package routing
22

33
import (
4+
"context"
45
"sync/atomic"
56
"testing"
67
"time"
@@ -88,7 +89,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) {
8889
// Create a test payment lifecycle with no fee limit and no timeout.
8990
p := newPaymentLifecycle(
9091
rt, noFeeLimit, paymentHash, mockPaymentSession,
91-
mockShardTracker, 0, 0,
92+
mockShardTracker, 0,
9293
)
9394

9495
// Create a mock payment which is returned from mockControlTower.
@@ -151,17 +152,17 @@ type resumePaymentResult struct {
151152
err error
152153
}
153154

154-
// sendPaymentAndAssertFailed calls `resumePayment` and asserts that an error
155-
// is returned.
156-
func sendPaymentAndAssertFailed(t *testing.T,
155+
// sendPaymentAndAssertError calls `resumePayment` and asserts that an error is
156+
// returned.
157+
func sendPaymentAndAssertError(t *testing.T, ctx context.Context,
157158
p *paymentLifecycle, errExpected error) {
158159

159160
resultChan := make(chan *resumePaymentResult, 1)
160161

161162
// We now make a call to `resumePayment` and expect it to return the
162163
// error.
163164
go func() {
164-
preimage, _, err := p.resumePayment()
165+
preimage, _, err := p.resumePayment(ctx)
165166
resultChan <- &resumePaymentResult{
166167
preimage: preimage,
167168
err: err,
@@ -189,7 +190,7 @@ func sendPaymentAndAssertSucceeded(t *testing.T,
189190
// We now make a call to `resumePayment` and expect it to return the
190191
// preimage.
191192
go func() {
192-
preimage, _, err := p.resumePayment()
193+
preimage, _, err := p.resumePayment(context.Background())
193194
resultChan <- &resumePaymentResult{
194195
preimage: preimage,
195196
err: err,
@@ -278,6 +279,10 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo {
278279
func TestCheckTimeoutTimedOut(t *testing.T) {
279280
t.Parallel()
280281

282+
deadline := time.Now().Add(time.Nanosecond)
283+
ctx, cancel := context.WithDeadline(context.Background(), deadline)
284+
defer cancel()
285+
281286
p := createTestPaymentLifecycle()
282287

283288
// Mock the control tower's `FailPayment` method.
@@ -288,14 +293,11 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
288293
// Mount the mocked control tower.
289294
p.router.cfg.Control = ct
290295

291-
// Make the timeout happens instantly.
292-
p.timeoutChan = time.After(1 * time.Nanosecond)
293-
294296
// Sleep one millisecond to make sure it timed out.
295297
time.Sleep(1 * time.Millisecond)
296298

297299
// Call the function and expect no error.
298-
err := p.checkTimeout()
300+
err := p.checkContext(ctx)
299301
require.NoError(t, err)
300302

301303
// Assert that `FailPayment` is called as expected.
@@ -313,13 +315,15 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
313315
p.router.cfg.Control = ct
314316

315317
// Make the timeout happens instantly.
316-
p.timeoutChan = time.After(1 * time.Nanosecond)
318+
deadline = time.Now().Add(time.Nanosecond)
319+
ctx, cancel = context.WithDeadline(context.Background(), deadline)
320+
defer cancel()
317321

318322
// Sleep one millisecond to make sure it timed out.
319323
time.Sleep(1 * time.Millisecond)
320324

321325
// Call the function and expect an error.
322-
err = p.checkTimeout()
326+
err = p.checkContext(ctx)
323327
require.ErrorIs(t, err, errDummy)
324328

325329
// Assert that `FailPayment` is called as expected.
@@ -331,10 +335,13 @@ func TestCheckTimeoutTimedOut(t *testing.T) {
331335
func TestCheckTimeoutOnRouterQuit(t *testing.T) {
332336
t.Parallel()
333337

338+
ctx, cancel := context.WithCancel(context.Background())
339+
defer cancel()
340+
334341
p := createTestPaymentLifecycle()
335342

336343
close(p.router.quit)
337-
err := p.checkTimeout()
344+
err := p.checkContext(ctx)
338345
require.ErrorIs(t, err, ErrRouterShuttingDown)
339346
}
340347

@@ -627,7 +634,7 @@ func TestResumePaymentFailOnFetchPayment(t *testing.T) {
627634
m.control.On("FetchPayment", p.identifier).Return(nil, errDummy)
628635

629636
// Send the payment and assert it failed.
630-
sendPaymentAndAssertFailed(t, p, errDummy)
637+
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
631638

632639
// Expected collectResultAsync to not be called.
633640
require.Zero(t, m.collectResultsCount)
@@ -656,14 +663,15 @@ func TestResumePaymentFailOnTimeout(t *testing.T) {
656663
}
657664
m.payment.On("GetState").Return(ps).Once()
658665

659-
// NOTE: GetStatus is only used to populate the logs which is
660-
// not critical so we loosen the checks on how many times it's
661-
// been called.
666+
// NOTE: GetStatus is only used to populate the logs which is not
667+
// critical, so we loosen the checks on how many times it's been called.
662668
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
663669

664670
// 3. make the timeout happens instantly and sleep one millisecond to
665671
// make sure it timed out.
666-
p.timeoutChan = time.After(1 * time.Nanosecond)
672+
deadline := time.Now().Add(time.Nanosecond)
673+
ctx, cancel := context.WithDeadline(context.Background(), deadline)
674+
defer cancel()
667675
time.Sleep(1 * time.Millisecond)
668676

669677
// 4. the payment should be failed with reason timeout.
@@ -683,7 +691,7 @@ func TestResumePaymentFailOnTimeout(t *testing.T) {
683691
m.payment.On("TerminalInfo").Return(nil, &reason)
684692

685693
// Send the payment and assert it failed with the timeout reason.
686-
sendPaymentAndAssertFailed(t, p, reason)
694+
sendPaymentAndAssertError(t, ctx, p, reason)
687695

688696
// Expected collectResultAsync to not be called.
689697
require.Zero(t, m.collectResultsCount)
@@ -721,7 +729,65 @@ func TestResumePaymentFailOnTimeoutErr(t *testing.T) {
721729
close(p.router.quit)
722730

723731
// Send the payment and assert it failed when router is shutting down.
724-
sendPaymentAndAssertFailed(t, p, ErrRouterShuttingDown)
732+
sendPaymentAndAssertError(
733+
t, context.Background(), p, ErrRouterShuttingDown,
734+
)
735+
736+
// Expected collectResultAsync to not be called.
737+
require.Zero(t, m.collectResultsCount)
738+
}
739+
740+
// TestResumePaymentFailContextCancel checks that the lifecycle fails when the
741+
// context is canceled.
742+
//
743+
// NOTE: No parallel test because it overwrites global variables.
744+
//
745+
//nolint:paralleltest
746+
func TestResumePaymentFailContextCancel(t *testing.T) {
747+
// Create a test paymentLifecycle with the initial two calls mocked.
748+
p, m := setupTestPaymentLifecycle(t)
749+
750+
// Create the cancelable payment context.
751+
ctx, cancel := context.WithCancel(context.Background())
752+
753+
paymentAmt := lnwire.MilliSatoshi(10000)
754+
755+
// We now enter the payment lifecycle loop.
756+
//
757+
// 1. calls `FetchPayment` and return the payment.
758+
m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once()
759+
760+
// 2. calls `GetState` and return the state.
761+
ps := &channeldb.MPPaymentState{
762+
RemainingAmt: paymentAmt,
763+
}
764+
m.payment.On("GetState").Return(ps).Once()
765+
766+
// NOTE: GetStatus is only used to populate the logs which is not
767+
// critical, so we loosen the checks on how many times it's been called.
768+
m.payment.On("GetStatus").Return(channeldb.StatusInFlight)
769+
770+
// 3. Cancel the context and skip the FailPayment error to trigger the
771+
// context cancellation of the payment.
772+
cancel()
773+
774+
m.control.On(
775+
"FailPayment", p.identifier, channeldb.FailureReasonTimeout,
776+
).Return(nil).Once()
777+
778+
// 5. decideNextStep now returns stepExit.
779+
m.payment.On("AllowMoreAttempts").Return(false, nil).Once().
780+
On("NeedWaitAttempts").Return(false, nil).Once()
781+
782+
// 6. Control tower deletes failed attempts.
783+
m.control.On("DeleteFailedAttempts", p.identifier).Return(nil).Once()
784+
785+
// 7. We will observe FailureReasonError if the context was cancelled.
786+
reason := channeldb.FailureReasonError
787+
m.payment.On("TerminalInfo").Return(nil, &reason)
788+
789+
// Send the payment and assert it failed with the timeout reason.
790+
sendPaymentAndAssertError(t, ctx, p, reason)
725791

726792
// Expected collectResultAsync to not be called.
727793
require.Zero(t, m.collectResultsCount)
@@ -759,7 +825,7 @@ func TestResumePaymentFailOnStepErr(t *testing.T) {
759825
m.payment.On("AllowMoreAttempts").Return(false, errDummy).Once()
760826

761827
// Send the payment and assert it failed.
762-
sendPaymentAndAssertFailed(t, p, errDummy)
828+
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
763829

764830
// Expected collectResultAsync to not be called.
765831
require.Zero(t, m.collectResultsCount)
@@ -803,7 +869,7 @@ func TestResumePaymentFailOnRequestRouteErr(t *testing.T) {
803869
).Return(nil, errDummy).Once()
804870

805871
// Send the payment and assert it failed.
806-
sendPaymentAndAssertFailed(t, p, errDummy)
872+
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
807873

808874
// Expected collectResultAsync to not be called.
809875
require.Zero(t, m.collectResultsCount)
@@ -863,7 +929,7 @@ func TestResumePaymentFailOnRegisterAttemptErr(t *testing.T) {
863929
).Return(nil, errDummy).Once()
864930

865931
// Send the payment and assert it failed.
866-
sendPaymentAndAssertFailed(t, p, errDummy)
932+
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
867933

868934
// Expected collectResultAsync to not be called.
869935
require.Zero(t, m.collectResultsCount)
@@ -955,7 +1021,7 @@ func TestResumePaymentFailOnSendAttemptErr(t *testing.T) {
9551021
).Return(nil, errDummy).Once()
9561022

9571023
// Send the payment and assert it failed.
958-
sendPaymentAndAssertFailed(t, p, errDummy)
1024+
sendPaymentAndAssertError(t, context.Background(), p, errDummy)
9591025

9601026
// Expected collectResultAsync to not be called.
9611027
require.Zero(t, m.collectResultsCount)

0 commit comments

Comments
 (0)