diff --git a/docs/release-notes/release-notes-0.21.0.md b/docs/release-notes/release-notes-0.21.0.md index bb02b9c276c..4ccfc5ce6df 100644 --- a/docs/release-notes/release-notes-0.21.0.md +++ b/docs/release-notes/release-notes-0.21.0.md @@ -54,6 +54,17 @@ refacotor the payment related LND code to make it more modular. * Implement the SQL backend for the [payments database](https://github.com/lightningnetwork/lnd/pull/9147) + * Implement query methods for the [payments db + SQL Backend](https://github.com/lightningnetwork/lnd/pull/10287) + * Implement insert methods for the [payments db + SQL Backend](https://github.com/lightningnetwork/lnd/pull/10291) + * Finalize SQL payments implementation [enabling unit and itests + for SQL backend](https://github.com/lightningnetwork/lnd/pull/10292) + * [Thread context through payment + db functions Part 1](https://github.com/lightningnetwork/lnd/pull/10307) + * [Thread context through payment + db functions Part 2](https://github.com/lightningnetwork/lnd/pull/10308) + ## Code Health diff --git a/itest/lnd_payment_test.go b/itest/lnd_payment_test.go index 37aff052264..f683cd44a80 100644 --- a/itest/lnd_payment_test.go +++ b/itest/lnd_payment_test.go @@ -504,61 +504,86 @@ func testListPayments(ht *lntest.HarnessTest) { expected bool } - // Create test cases to check the timestamp filters. - createCases := func(createTimeSeconds uint64) []testCase { + // Create test cases with proper rounding for start and end dates. + createCases := func(startTimeSeconds, + endTimeSeconds uint64) []testCase { + return []testCase{ { // Use a start date same as the creation date - // should return us the item. + // (truncated) should return us the item. name: "exact start date", - startDate: createTimeSeconds, + startDate: startTimeSeconds, expected: true, }, { // Use an earlier start date should return us // the item. name: "earlier start date", - startDate: createTimeSeconds - 1, + startDate: startTimeSeconds - 1, expected: true, }, { // Use a future start date should return us // nothing. name: "future start date", - startDate: createTimeSeconds + 1, + startDate: startTimeSeconds + 1, expected: false, }, { // Use an end date same as the creation date - // should return us the item. + // (ceiling) should return us the item. name: "exact end date", - endDate: createTimeSeconds, + endDate: endTimeSeconds, expected: true, }, { // Use an end date in the future should return // us the item. name: "future end date", - endDate: createTimeSeconds + 1, + endDate: endTimeSeconds + 1, expected: true, }, { // Use an earlier end date should return us // nothing. - name: "earlier end date", - endDate: createTimeSeconds - 1, + name: "earlier end date", + // The native sql backend has a higher + // precision than the kv backend, the native sql + // backend uses microseconds, the kv backend + // when filtering uses seconds so we need to + // subtract 2 seconds to ensure the payment is + // not included. + // We could also truncate before inserting + // into the sql db but I rather relax this test + // here. + endDate: endTimeSeconds - 2, expected: false, }, } } - // Get the payment creation time in seconds. - paymentCreateSeconds := uint64( - p.CreationTimeNs / time.Second.Nanoseconds(), + // Get the payment creation time in seconds, using different approaches + // for start and end date comparisons to avoid rounding issues. + creationTime := time.Unix(0, p.CreationTimeNs) + + // For start date comparisons: use truncation (floor) to include + // payments from the beginning of that second. + paymentCreateSecondsStart := uint64( + creationTime.Truncate(time.Second).Unix(), + ) + + // For end date comparisons: use ceiling to include payments up to the + // end of that second. + paymentCreateSecondsEnd := uint64( + (p.CreationTimeNs + time.Second.Nanoseconds() - 1) / + time.Second.Nanoseconds(), ) // Create test cases from the payment creation time. - testCases := createCases(paymentCreateSeconds) + testCases := createCases( + paymentCreateSecondsStart, paymentCreateSecondsEnd, + ) // We now check the timestamp filters in `ListPayments`. for _, tc := range testCases { @@ -578,7 +603,9 @@ func testListPayments(ht *lntest.HarnessTest) { } // Create test cases from the invoice creation time. - testCases = createCases(uint64(invoice.CreationDate)) + testCases = createCases( + uint64(invoice.CreationDate), uint64(invoice.CreationDate), + ) // We now do the same check for `ListInvoices`. for _, tc := range testCases { diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 377d0be3ea0..04faa7840b5 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1745,6 +1745,8 @@ func (r *RouterBackend) MarshallPayment(payment *paymentsdb.MPPayment) ( // If any of the htlcs have settled, extract a valid // preimage. if htlc.Settle != nil { + // For AMP payments all hashes will be different so we + // will depict the last htlc preimage. preimage = htlc.Settle.Preimage fee += htlc.Route.TotalFees() } diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 1dbc19e47f9..5db1fb9584a 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -936,11 +936,11 @@ func (s *Server) SendToRouteV2(ctx context.Context, // db. if req.SkipTempErr { attempt, err = s.cfg.Router.SendToRouteSkipTempErr( - hash, route, firstHopRecords, + ctx, hash, route, firstHopRecords, ) } else { attempt, err = s.cfg.Router.SendToRoute( - hash, route, firstHopRecords, + ctx, hash, route, firstHopRecords, ) } if attempt != nil { diff --git a/payments/db/errors.go b/payments/db/errors.go index 40e37d95c6a..424b63c7bd3 100644 --- a/payments/db/errors.go +++ b/payments/db/errors.go @@ -130,4 +130,8 @@ var ( // NOTE: Only used for the kv backend. ErrNoSequenceNrIndex = errors.New("payment sequence number index " + "does not exist") + + // errMaxPaymentsReached is used internally to signal that the maximum + // number of payments has been reached during a paginated query. + errMaxPaymentsReached = errors.New("max payments reached") ) diff --git a/payments/db/interface.go b/payments/db/interface.go index c41dc371f89..d788b7a2211 100644 --- a/payments/db/interface.go +++ b/payments/db/interface.go @@ -21,21 +21,24 @@ type PaymentReader interface { // FetchPayment fetches the payment corresponding to the given payment // hash. - FetchPayment(paymentHash lntypes.Hash) (*MPPayment, error) + FetchPayment(ctx context.Context, + paymentHash lntypes.Hash) (*MPPayment, error) // FetchInFlightPayments returns all payments with status InFlight. - FetchInFlightPayments() ([]*MPPayment, error) + FetchInFlightPayments(ctx context.Context) ([]*MPPayment, error) } // PaymentWriter represents the interface to write operations to the payments // database. type PaymentWriter interface { // DeletePayment deletes a payment from the DB given its payment hash. - DeletePayment(paymentHash lntypes.Hash, failedAttemptsOnly bool) error + DeletePayment(ctx context.Context, paymentHash lntypes.Hash, + failedAttemptsOnly bool) error // DeletePayments deletes all payments from the DB given the specified // flags. - DeletePayments(failedOnly, failedAttemptsOnly bool) (int, error) + DeletePayments(ctx context.Context, failedOnly, + failedAttemptsOnly bool) (int, error) PaymentControl } @@ -58,10 +61,11 @@ type PaymentControl interface { // exists in the database before creating a new payment. However, it // should allow the user making a subsequent payment if the payment is // in a Failed state. - InitPayment(lntypes.Hash, *PaymentCreationInfo) error + InitPayment(context.Context, lntypes.Hash, *PaymentCreationInfo) error // RegisterAttempt atomically records the provided HTLCAttemptInfo. - RegisterAttempt(lntypes.Hash, *HTLCAttemptInfo) (*MPPayment, error) + RegisterAttempt(context.Context, lntypes.Hash, + *HTLCAttemptInfo) (*MPPayment, error) // SettleAttempt marks the given attempt settled with the preimage. If // this is a multi shard payment, this might implicitly mean the @@ -71,10 +75,12 @@ type PaymentControl interface { // error to prevent us from making duplicate payments to the same // payment hash. The provided preimage is atomically saved to the DB // for record keeping. - SettleAttempt(lntypes.Hash, uint64, *HTLCSettleInfo) (*MPPayment, error) + SettleAttempt(context.Context, lntypes.Hash, uint64, + *HTLCSettleInfo) (*MPPayment, error) // FailAttempt marks the given payment attempt failed. - FailAttempt(lntypes.Hash, uint64, *HTLCFailInfo) (*MPPayment, error) + FailAttempt(context.Context, lntypes.Hash, uint64, + *HTLCFailInfo) (*MPPayment, error) // Fail transitions a payment into the Failed state, and records // the ultimate reason the payment failed. Note that this should only @@ -82,12 +88,12 @@ type PaymentControl interface { // invoking this method, InitPayment should return nil on its next call // for this payment hash, allowing the user to make a subsequent // payment. - Fail(lntypes.Hash, FailureReason) (*MPPayment, error) + Fail(context.Context, lntypes.Hash, FailureReason) (*MPPayment, error) // DeleteFailedAttempts removes all failed HTLCs from the db. It should // be called for a given payment whenever all inflight htlcs are // completed, and the payment has reached a final terminal state. - DeleteFailedAttempts(lntypes.Hash) error + DeleteFailedAttempts(context.Context, lntypes.Hash) error } // DBMPPayment is an interface that represents the payment state during a diff --git a/payments/db/kv_store.go b/payments/db/kv_store.go index e5915a00347..f38f3235b67 100644 --- a/payments/db/kv_store.go +++ b/payments/db/kv_store.go @@ -186,7 +186,7 @@ func initKVStore(db kvdb.Backend) error { // making sure it does not already exist as an in-flight payment. When this // method returns successfully, the payment is guaranteed to be in the InFlight // state. -func (p *KVStore) InitPayment(paymentHash lntypes.Hash, +func (p *KVStore) InitPayment(_ context.Context, paymentHash lntypes.Hash, info *PaymentCreationInfo) error { // Obtain a new sequence number for this payment. This is used @@ -290,10 +290,12 @@ func (p *KVStore) InitPayment(paymentHash lntypes.Hash, // DeleteFailedAttempts deletes all failed htlcs for a payment if configured // by the KVStore db. -func (p *KVStore) DeleteFailedAttempts(hash lntypes.Hash) error { +func (p *KVStore) DeleteFailedAttempts(ctx context.Context, + hash lntypes.Hash) error { + if !p.keepFailedPaymentAttempts { const failedHtlcsOnly = true - err := p.DeletePayment(hash, failedHtlcsOnly) + err := p.DeletePayment(ctx, hash, failedHtlcsOnly) if err != nil { return err } @@ -357,7 +359,7 @@ func deserializePaymentIndex(r io.Reader) (lntypes.Hash, error) { // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. -func (p *KVStore) RegisterAttempt(paymentHash lntypes.Hash, +func (p *KVStore) RegisterAttempt(_ context.Context, paymentHash lntypes.Hash, attempt *HTLCAttemptInfo) (*MPPayment, error) { // Serialize the information before opening the db transaction. @@ -428,7 +430,7 @@ func (p *KVStore) RegisterAttempt(paymentHash lntypes.Hash, // After invoking this method, InitPayment should always return an error to // prevent us from making duplicate payments to the same payment hash. The // provided preimage is atomically saved to the DB for record keeping. -func (p *KVStore) SettleAttempt(hash lntypes.Hash, +func (p *KVStore) SettleAttempt(_ context.Context, hash lntypes.Hash, attemptID uint64, settleInfo *HTLCSettleInfo) (*MPPayment, error) { var b bytes.Buffer @@ -441,7 +443,7 @@ func (p *KVStore) SettleAttempt(hash lntypes.Hash, } // FailAttempt marks the given payment attempt failed. -func (p *KVStore) FailAttempt(hash lntypes.Hash, +func (p *KVStore) FailAttempt(_ context.Context, hash lntypes.Hash, attemptID uint64, failInfo *HTLCFailInfo) (*MPPayment, error) { var b bytes.Buffer @@ -526,7 +528,7 @@ func (p *KVStore) updateHtlcKey(paymentHash lntypes.Hash, // payment failed. After invoking this method, InitPayment should return nil on // its next call for this payment hash, allowing the switch to make a // subsequent payment. -func (p *KVStore) Fail(paymentHash lntypes.Hash, +func (p *KVStore) Fail(_ context.Context, paymentHash lntypes.Hash, reason FailureReason) (*MPPayment, error) { var ( @@ -583,8 +585,8 @@ func (p *KVStore) Fail(paymentHash lntypes.Hash, } // FetchPayment returns information about a payment from the database. -func (p *KVStore) FetchPayment(paymentHash lntypes.Hash) ( - *MPPayment, error) { +func (p *KVStore) FetchPayment(_ context.Context, + paymentHash lntypes.Hash) (*MPPayment, error) { var payment *MPPayment err := kvdb.View(p.db, func(tx kvdb.RTx) error { @@ -739,7 +741,9 @@ func fetchPaymentStatus(bucket kvdb.RBucket) (PaymentStatus, error) { } // FetchInFlightPayments returns all payments with status InFlight. -func (p *KVStore) FetchInFlightPayments() ([]*MPPayment, error) { +func (p *KVStore) FetchInFlightPayments(_ context.Context) ([]*MPPayment, + error) { + var ( inFlights []*MPPayment start = time.Now() @@ -1273,7 +1277,7 @@ func fetchPaymentWithSequenceNumber(tx kvdb.RTx, paymentHash lntypes.Hash, // DeletePayment deletes a payment from the DB given its payment hash. If // failedHtlcsOnly is set, only failed HTLC attempts of the payment will be // deleted. -func (p *KVStore) DeletePayment(paymentHash lntypes.Hash, +func (p *KVStore) DeletePayment(_ context.Context, paymentHash lntypes.Hash, failedHtlcsOnly bool) error { return kvdb.Update(p.db, func(tx kvdb.RwTx) error { @@ -1370,7 +1374,7 @@ func (p *KVStore) DeletePayment(paymentHash lntypes.Hash, // failedHtlcsOnly is set, the payment itself won't be deleted, only failed HTLC // attempts. The method returns the number of deleted payments, which is always // 0 if failedHtlcsOnly is set. -func (p *KVStore) DeletePayments(failedOnly, +func (p *KVStore) DeletePayments(_ context.Context, failedOnly, failedHtlcsOnly bool) (int, error) { var numPayments int diff --git a/payments/db/kv_store_test.go b/payments/db/kv_store_test.go index 73df8ead4ec..b969d29d888 100644 --- a/payments/db/kv_store_test.go +++ b/payments/db/kv_store_test.go @@ -2,6 +2,7 @@ package paymentsdb import ( "bytes" + "crypto/sha256" "math" "reflect" "testing" @@ -24,6 +25,8 @@ import ( func TestKVStoreDeleteNonInFlight(t *testing.T) { t.Parallel() + ctx := t.Context() + paymentDB := NewKVTestDB(t) // Create a sequence number for duplicate payments that will not collide @@ -61,18 +64,23 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { var numSuccess, numInflight int for _, p := range payments { - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash( + t, 0, genSessionKey(t), rhash, + ) + require.NoError(t, err) // Sends base htlc message which initiate StatusInFlight. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } _, err = paymentDB.RegisterAttempt( - info.PaymentIdentifier, attempt, + ctx, info.PaymentIdentifier, attempt, ) if err != nil { t.Fatalf("unable to send htlc message: %v", err) @@ -87,7 +95,7 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { // Fail the payment attempt. htlcFailure := HTLCFailUnreadable _, err := paymentDB.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCFailInfo{ Reason: htlcFailure, }, @@ -99,7 +107,7 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { // Fail the payment, which should moved it to Failed. failReason := FailureReasonNoRoute _, err = paymentDB.Fail( - info.PaymentIdentifier, failReason, + ctx, info.PaymentIdentifier, failReason, ) if err != nil { t.Fatalf("unable to fail payment hash: %v", err) @@ -120,7 +128,7 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { case p.success: // Verifies that status was changed to StatusSucceeded. _, err := paymentDB.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCSettleInfo{ Preimage: preimg, }, @@ -169,7 +177,7 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { } // Delete all failed payments. - numPayments, err := paymentDB.DeletePayments(true, false) + numPayments, err := paymentDB.DeletePayments(ctx, true, false) require.NoError(t, err) require.EqualValues(t, 1, numPayments) @@ -205,7 +213,7 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { } // Now delete all payments except in-flight. - numPayments, err = paymentDB.DeletePayments(false, false) + numPayments, err = paymentDB.DeletePayments(ctx, false, false) require.NoError(t, err) require.EqualValues(t, 2, numPayments) @@ -287,6 +295,8 @@ func fetchPaymentIndexEntry(t *testing.T, p *KVStore, func assertPaymentIndex(t *testing.T, p DB, expectedHash lntypes.Hash) { t.Helper() + ctx := t.Context() + // Only the kv implementation uses the index so we exit early if the // payment db is not a kv implementation. This helps us to reuse the // same test for both implementations. @@ -297,7 +307,7 @@ func assertPaymentIndex(t *testing.T, p DB, expectedHash lntypes.Hash) { // Lookup the payment so that we have its sequence number and check // that is has correctly been indexed in the payment indexes bucket. - pmt, err := kvPaymentDB.FetchPayment(expectedHash) + pmt, err := kvPaymentDB.FetchPayment(ctx, expectedHash) require.NoError(t, err) hash, err := fetchPaymentIndexEntry(t, kvPaymentDB, pmt.SequenceNum) @@ -473,35 +483,37 @@ func deletePayment(t *testing.T, db kvdb.Backend, paymentHash lntypes.Hash, func TestFetchPaymentWithSequenceNumber(t *testing.T) { paymentDB := NewKVTestDB(t) + ctx := t.Context() + // Generate a test payment which does not have duplicates. - noDuplicates, _, _, err := genInfo(t) + noDuplicates, _, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. err = paymentDB.InitPayment( - noDuplicates.PaymentIdentifier, noDuplicates, + ctx, noDuplicates.PaymentIdentifier, noDuplicates, ) require.NoError(t, err) // Fetch the payment so we can get its sequence nr. noDuplicatesPayment, err := paymentDB.FetchPayment( - noDuplicates.PaymentIdentifier, + ctx, noDuplicates.PaymentIdentifier, ) require.NoError(t, err) // Generate a test payment which we will add duplicates to. - hasDuplicates, _, preimg, err := genInfo(t) + hasDuplicates, preimg, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. err = paymentDB.InitPayment( - hasDuplicates.PaymentIdentifier, hasDuplicates, + ctx, hasDuplicates.PaymentIdentifier, hasDuplicates, ) require.NoError(t, err) // Fetch the payment so we can get its sequence nr. hasDuplicatesPayment, err := paymentDB.FetchPayment( - hasDuplicates.PaymentIdentifier, + ctx, hasDuplicates.PaymentIdentifier, ) require.NoError(t, err) @@ -648,7 +660,7 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) // Generate fake information for the duplicate payment. - info, _, _, err := genInfo(t) + info, _, err := genInfo(t) require.NoError(t, err) // Write the payment info to disk under the creation info key. This code @@ -680,17 +692,19 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) } -// TestQueryPayments tests retrieval of payments with forwards and reversed -// queries. -// -// TODO(ziggie): Make this test db agnostic. -func TestQueryPayments(t *testing.T) { - // Define table driven test for QueryPayments. +// TestKVStoreQueryPaymentsDuplicates tests the KV store's legacy duplicate +// payment handling. This tests the specific case where duplicate payments +// are stored in a nested bucket within the parent payment bucket. +func TestKVStoreQueryPaymentsDuplicates(t *testing.T) { + t.Parallel() + // Test payments have sequence indices [1, 3, 4, 5, 6, 7]. // Note that the payment with index 7 has the same payment hash as 6, // and is stored in a nested bucket within payment 6 rather than being - // its own entry in the payments bucket. We do this to test retrieval - // of legacy payments. + // its own entry in the payments bucket. This tests retrieval of legacy + // duplicate payments which is KV-store specific. + // These test cases focus on validating that duplicate payments (seq 7, + // nested under payment 6) are correctly returned in queries. tests := []struct { name string query Query @@ -702,31 +716,20 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs []uint64 }{ { - name: "IndexOffset at the end of the payments range", + name: "query includes duplicate payment in forward " + + "order", query: Query{ - IndexOffset: 7, - MaxPayments: 7, + IndexOffset: 5, + MaxPayments: 3, Reversed: false, IncludeIncomplete: true, }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "query in forwards order, start at beginning", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, + firstIndex: 6, + lastIndex: 7, + expectedSeqNrs: []uint64{6, 7}, }, { - name: "query in forwards order, start at end, overflow", + name: "query duplicate payment at end", query: Query{ IndexOffset: 6, MaxPayments: 2, @@ -738,44 +741,7 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{7}, }, { - name: "start at offset index outside of payments", - query: Query{ - IndexOffset: 20, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "overflow in forwards order", - query: Query{ - IndexOffset: 4, - MaxPayments: math.MaxUint64, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 5, - lastIndex: 7, - expectedSeqNrs: []uint64{5, 6, 7}, - }, - { - name: "start at offset index outside of payments, " + - "reversed order", - query: Query{ - IndexOffset: 9, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 6, - lastIndex: 7, - expectedSeqNrs: []uint64{6, 7}, - }, - { - name: "query in reverse order, start at end", + name: "query includes duplicate in reverse order", query: Query{ IndexOffset: 0, MaxPayments: 2, @@ -787,36 +753,11 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{6, 7}, }, { - name: "query in reverse order, starting in middle", - query: Query{ - IndexOffset: 4, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "query in reverse order, starting in middle, " + - "with underflow", - query: Query{ - IndexOffset: 4, - MaxPayments: 5, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "all payments in reverse, order maintained", + name: "query all payments includes duplicate", query: Query{ IndexOffset: 0, - MaxPayments: 7, - Reversed: true, + MaxPayments: math.MaxUint64, + Reversed: false, IncludeIncomplete: true, }, firstIndex: 1, @@ -824,7 +765,7 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{1, 3, 4, 5, 6, 7}, }, { - name: "exclude incomplete payments", + name: "exclude incomplete includes duplicate", query: Query{ IndexOffset: 0, MaxPayments: 7, @@ -835,96 +776,6 @@ func TestQueryPayments(t *testing.T) { lastIndex: 7, expectedSeqNrs: []uint64{7}, }, - { - name: "query payments at index gap", - query: Query{ - IndexOffset: 1, - MaxPayments: 7, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 7, - expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, - }, - { - name: "query payments reverse before index gap", - query: Query{ - IndexOffset: 3, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments reverse on index gap", - query: Query{ - IndexOffset: 2, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments forward on index gap", - query: Query{ - IndexOffset: 2, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 4, - expectedSeqNrs: []uint64{3, 4}, - }, - { - name: "query in forwards order, with start creation " + - "time", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 5, - }, - firstIndex: 5, - lastIndex: 6, - expectedSeqNrs: []uint64{5, 6}, - }, - { - name: "query in forwards order, with start creation " + - "time at end, overflow", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 7, - }, - firstIndex: 7, - lastIndex: 7, - expectedSeqNrs: []uint64{7}, - }, - { - name: "query with start and end creation time", - query: Query{ - IndexOffset: 9, - MaxPayments: math.MaxUint64, - Reversed: true, - IncludeIncomplete: true, - CreationDateStart: 3, - CreationDateEnd: 5, - }, - firstIndex: 3, - lastIndex: 5, - expectedSeqNrs: []uint64{3, 4, 5}, - }, } for _, tt := range tests { @@ -956,7 +807,7 @@ func TestQueryPayments(t *testing.T) { for i := 0; i < nonDuplicatePayments; i++ { // Generate a test payment. - info, _, preimg, err := genInfo(t) + info, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to create test "+ "payment: %v", err) @@ -967,14 +818,14 @@ func TestQueryPayments(t *testing.T) { // Create a new payment entry in the database. err = paymentDB.InitPayment( - info.PaymentIdentifier, info, + ctx, info.PaymentIdentifier, info, ) require.NoError(t, err) // Immediately delete the payment with index 2. if i == 1 { pmt, err := paymentDB.FetchPayment( - info.PaymentIdentifier, + ctx, info.PaymentIdentifier, ) require.NoError(t, err) @@ -991,7 +842,7 @@ func TestQueryPayments(t *testing.T) { // duplicate payments will always be succeeded. if i == (nonDuplicatePayments - 1) { pmt, err := paymentDB.FetchPayment( - info.PaymentIdentifier, + ctx, info.PaymentIdentifier, ) require.NoError(t, err) diff --git a/payments/db/payment_test.go b/payments/db/payment_test.go index 534a1b1e5ee..d511bfa6fcc 100644 --- a/payments/db/payment_test.go +++ b/payments/db/payment_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "math" "reflect" "testing" "time" @@ -58,7 +59,10 @@ var ( ChannelID: 12345, OutgoingTimeLock: 111, AmtToForward: 555, - LegacyPayload: true, + + // Only tlv payloads are now supported in LND therefore we set + // LegacyPayload to false. + LegacyPayload: false, } testRoute = route.Route{ @@ -113,29 +117,38 @@ type payment struct { func createTestPayments(t *testing.T, p DB, payments []*payment) { t.Helper() + ctx := t.Context() + attemptID := uint64(0) for i := 0; i < len(payments); i++ { - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Set the payment id accordingly in the payments slice. payments[i].id = info.PaymentIdentifier - attempt.AttemptID = attemptID + attempt, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + attemptID++ // Init the payment. - err = p.InitPayment(info.PaymentIdentifier, info) + err = p.InitPayment(ctx, info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") // Register and fail the first attempt for all payments. - _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) + _, err = p.RegisterAttempt(ctx, info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to send htlc message") htlcFailure := HTLCFailUnreadable _, err = p.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCFailInfo{ Reason: htlcFailure, }, @@ -148,10 +161,13 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) { // Depending on the test case, fail or succeed the next // attempt. - attempt.AttemptID = attemptID + attempt, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) attemptID++ - _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) + _, err = p.RegisterAttempt(ctx, info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to send htlc message") switch payments[i].status { @@ -159,7 +175,7 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) { case StatusFailed: htlcFailure := HTLCFailUnreadable _, err = p.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCFailInfo{ Reason: htlcFailure, }, @@ -167,14 +183,15 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) { require.NoError(t, err, "unable to fail htlc") failReason := FailureReasonNoRoute - _, err = p.Fail(info.PaymentIdentifier, - failReason) + _, err = p.Fail( + ctx, info.PaymentIdentifier, failReason, + ) require.NoError(t, err, "unable to fail payment hash") // Settle the attempt case StatusSucceeded: _, err := p.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCSettleInfo{ Preimage: preimg, }, @@ -213,7 +230,9 @@ func assertPaymentInfo(t *testing.T, p DB, hash lntypes.Hash, t.Helper() - payment, err := p.FetchPayment(hash) + ctx := t.Context() + + payment, err := p.FetchPayment(ctx, hash) if err != nil { t.Fatal(err) } @@ -281,7 +300,9 @@ func assertDBPaymentstatus(t *testing.T, p DB, hash lntypes.Hash, t.Helper() - payment, err := p.FetchPayment(hash) + ctx := t.Context() + + payment, err := p.FetchPayment(ctx, hash) if errors.Is(err, ErrPaymentNotInitiated) { return } @@ -334,7 +355,7 @@ func assertDBPayments(t *testing.T, paymentDB DB, payments []*payment) { } // genPreimage generates a random preimage. -func genPreimage(t *testing.T) ([32]byte, error) { +func genPreimage(t *testing.T) (lntypes.Preimage, error) { t.Helper() var preimage [32]byte @@ -345,31 +366,85 @@ func genPreimage(t *testing.T) ([32]byte, error) { return preimage, nil } -// genInfo generates a payment creation info, an attempt info and a preimage. -func genInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo, - lntypes.Preimage, error) { +// genSessionKey generates a new random private key for use as a session key. +func genSessionKey(t *testing.T) *btcec.PrivateKey { + t.Helper() - preimage, err := genPreimage(t) - if err != nil { - return nil, nil, preimage, fmt.Errorf("unable to "+ - "generate preimage: %v", err) + key, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return key +} + +// genPaymentCreationInfo generates a payment creation info. +func genPaymentCreationInfo(t *testing.T, + paymentHash lntypes.Hash) *PaymentCreationInfo { + + t.Helper() + + // Add constant first hop custom records for testing for testing + // purposes. + firstHopCustomRecords := lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 1: []byte("test_record_1"), + lnwire.MinCustomRecordsTlvType + 2: []byte("test_record_2"), + lnwire.MinCustomRecordsTlvType + 3: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, + }, } + return &PaymentCreationInfo{ + PaymentIdentifier: paymentHash, + Value: testRoute.ReceiverAmt(), + CreationTime: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + FirstHopCustomRecords: firstHopCustomRecords, + } +} + +// genPreimageAndHash generates a random preimage and its corresponding hash. +func genPreimageAndHash(t *testing.T) (lntypes.Preimage, lntypes.Hash, error) { + t.Helper() + + preimage, err := genPreimage(t) + require.NoError(t, err) + rhash := sha256.Sum256(preimage[:]) var hash lntypes.Hash copy(hash[:], rhash[:]) + return preimage, hash, nil +} + +// genAttemptWithPreimage generates an HTLC attempt and returns both the +// attempt and preimage. +func genAttemptWithHash(t *testing.T, attemptID uint64, + sessionKey *btcec.PrivateKey, hash lntypes.Hash) (*HTLCAttemptInfo, + error) { + + t.Helper() + attempt, err := NewHtlcAttempt( - 0, priv, *testRoute.Copy(), time.Time{}, &hash, + attemptID, sessionKey, *testRoute.Copy(), time.Time{}, + &hash, ) - require.NoError(t, err) + if err != nil { + return nil, err + } - return &PaymentCreationInfo{ - PaymentIdentifier: rhash, - Value: testRoute.ReceiverAmt(), - CreationTime: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte("hola"), - }, &attempt.HTLCAttemptInfo, preimage, nil + return &attempt.HTLCAttemptInfo, nil +} + +// genInfo generates a payment creation info and the corresponding preimage. +func genInfo(t *testing.T) (*PaymentCreationInfo, lntypes.Preimage, error) { + preimage, _, err := genPreimageAndHash(t) + if err != nil { + return nil, preimage, err + } + + rhash := sha256.Sum256(preimage[:]) + creationInfo := genPaymentCreationInfo(t, rhash) + + return creationInfo, preimage, nil } // TestDeleteFailedAttempts checks that DeleteFailedAttempts properly removes @@ -420,7 +495,9 @@ func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { // Calling DeleteFailedAttempts on a failed payment should delete all // HTLCs. - require.NoError(t, paymentDB.DeleteFailedAttempts(payments[0].id)) + require.NoError(t, paymentDB.DeleteFailedAttempts( + t.Context(), payments[0].id, + )) // Expect all HTLCs to be deleted if the config is set to delete them. if !keepFailedPaymentAttempts { @@ -435,11 +512,15 @@ func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { // operation are performed in general therefore we do NOT expect an // error in this case. if keepFailedPaymentAttempts { - require.NoError( - t, paymentDB.DeleteFailedAttempts(payments[1].id), + err := paymentDB.DeleteFailedAttempts( + t.Context(), payments[1].id, ) + require.NoError(t, err) } else { - require.Error(t, paymentDB.DeleteFailedAttempts(payments[1].id)) + err := paymentDB.DeleteFailedAttempts( + t.Context(), payments[1].id, + ) + require.Error(t, err) } // Since DeleteFailedAttempts returned an error, we should expect the @@ -447,7 +528,9 @@ func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { assertDBPayments(t, paymentDB, payments) // Cleaning up a successful payment should remove failed htlcs. - require.NoError(t, paymentDB.DeleteFailedAttempts(payments[2].id)) + require.NoError(t, paymentDB.DeleteFailedAttempts( + t.Context(), payments[2].id, + )) // Expect all HTLCs except for the settled one to be deleted if the // config is set to delete them. @@ -464,13 +547,17 @@ func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { // payments, if the control tower is configured to keep failed // HTLCs. require.NoError( - t, paymentDB.DeleteFailedAttempts(lntypes.ZeroHash), + t, paymentDB.DeleteFailedAttempts( + t.Context(), lntypes.ZeroHash, + ), ) } else { // Attempting to cleanup a non-existent payment returns an // error. require.Error( - t, paymentDB.DeleteFailedAttempts(lntypes.ZeroHash), + t, paymentDB.DeleteFailedAttempts( + t.Context(), lntypes.ZeroHash, + ), ) } } @@ -479,13 +566,25 @@ func testDeleteFailedAttempts(t *testing.T, keepFailedPaymentAttempts bool) { func TestMPPRecordValidation(t *testing.T) { t.Parallel() + ctx := t.Context() + paymentDB := NewTestDB(t) - info, attempt, _, err := genInfo(t) + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + + attemptID := uint64(0) + + attempt, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) require.NoError(t, err, "unable to generate htlc message") // Init the payment. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") // Create three unique attempts we'll use for the test, and @@ -498,50 +597,80 @@ func TestMPPRecordValidation(t *testing.T) { info.Value, [32]byte{1}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + _, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to send htlc message") // Now try to register a non-MPP attempt, which should fail. - b := *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = nil - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + attemptID++ + attempt2, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + attempt2.Route.FinalHop().MPP = nil + + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, attempt2, + ) require.ErrorIs(t, err, ErrMPPayment) // Try to register attempt one with a different payment address. - b.Route.FinalHop().MPP = record.NewMPP( + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value, [32]byte{2}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, attempt2, + ) require.ErrorIs(t, err, ErrMPPPaymentAddrMismatch) // Try registering one with a different total amount. - b.Route.FinalHop().MPP = record.NewMPP( + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value/2, [32]byte{1}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, attempt2, + ) require.ErrorIs(t, err, ErrMPPTotalAmountMismatch) // Create and init a new payment. This time we'll check that we cannot // register an MPP attempt if we already registered a non-MPP one. - info, attempt, _, err = genInfo(t) + preimg, err = genPreimage(t) + require.NoError(t, err) + + rhash = sha256.Sum256(preimg[:]) + info = genPaymentCreationInfo(t, rhash) + + attemptID++ + attempt, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + require.NoError(t, err, "unable to generate htlc message") - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") attempt.Route.FinalHop().MPP = nil - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, attempt, + ) require.NoError(t, err, "unable to send htlc message") // Attempt to register an MPP attempt, which should fail. - b = *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = record.NewMPP( + attemptID++ + attempt2, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value, [32]byte{1}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, attempt2, + ) require.ErrorIs(t, err, ErrNonMPPayment) } @@ -550,6 +679,8 @@ func TestMPPRecordValidation(t *testing.T) { func TestDeleteSinglePayment(t *testing.T) { t.Parallel() + ctx := t.Context() + paymentDB := NewTestDB(t) // Register four payments: @@ -581,7 +712,9 @@ func TestDeleteSinglePayment(t *testing.T) { assertDBPayments(t, paymentDB, payments) // Delete HTLC attempts for first payment only. - require.NoError(t, paymentDB.DeletePayment(payments[0].id, true)) + require.NoError(t, paymentDB.DeletePayment( + ctx, payments[0].id, true, + )) // The first payment is the only altered one as its failed HTLC should // have been removed but is still present as payment. @@ -589,19 +722,25 @@ func TestDeleteSinglePayment(t *testing.T) { assertDBPayments(t, paymentDB, payments) // Delete the first payment completely. - require.NoError(t, paymentDB.DeletePayment(payments[0].id, false)) + require.NoError(t, paymentDB.DeletePayment( + ctx, payments[0].id, false, + )) // The first payment should have been deleted. assertDBPayments(t, paymentDB, payments[1:]) // Now delete the second payment completely. - require.NoError(t, paymentDB.DeletePayment(payments[1].id, false)) + require.NoError(t, paymentDB.DeletePayment( + ctx, payments[1].id, false, + )) // The Second payment should have been deleted. assertDBPayments(t, paymentDB, payments[2:]) // Delete failed HTLC attempts for the third payment. - require.NoError(t, paymentDB.DeletePayment(payments[2].id, true)) + require.NoError(t, paymentDB.DeletePayment( + ctx, payments[2].id, true, + )) // Only the successful HTLC attempt should be left for the third // payment. @@ -609,21 +748,27 @@ func TestDeleteSinglePayment(t *testing.T) { assertDBPayments(t, paymentDB, payments[2:]) // Now delete the third payment completely. - require.NoError(t, paymentDB.DeletePayment(payments[2].id, false)) + require.NoError(t, paymentDB.DeletePayment( + ctx, payments[2].id, false, + )) // Only the last payment should be left. assertDBPayments(t, paymentDB, payments[3:]) // Deleting HTLC attempts from InFlight payments should not work and an // error returned. - require.Error(t, paymentDB.DeletePayment(payments[3].id, true)) + require.Error(t, paymentDB.DeletePayment( + ctx, payments[3].id, true, + )) // The payment is InFlight and therefore should not have been altered. assertDBPayments(t, paymentDB, payments[3:]) // Finally deleting the InFlight payment should also not work and an // error returned. - require.Error(t, paymentDB.DeletePayment(payments[3].id, false)) + require.Error(t, paymentDB.DeletePayment( + ctx, payments[3].id, false, + )) // The payment is InFlight and therefore should not have been altered. assertDBPayments(t, paymentDB, payments[3:]) @@ -1197,11 +1342,15 @@ func TestSuccessesWithoutInFlight(t *testing.T) { paymentDB := NewTestDB(t) - info, _, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Attempt to complete the payment should fail. _, err = paymentDB.SettleAttempt( + t.Context(), info.PaymentIdentifier, 0, &HTLCSettleInfo{ Preimage: preimg, @@ -1217,12 +1366,15 @@ func TestFailsWithoutInFlight(t *testing.T) { paymentDB := NewTestDB(t) - info, _, _, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Calling Fail should return an error. _, err = paymentDB.Fail( - info.PaymentIdentifier, FailureReasonNoRoute, + t.Context(), info.PaymentIdentifier, FailureReasonNoRoute, ) require.ErrorIs(t, err, ErrPaymentNotInitiated) } @@ -1232,6 +1384,8 @@ func TestFailsWithoutInFlight(t *testing.T) { func TestDeletePayments(t *testing.T) { t.Parallel() + ctx := t.Context() + paymentDB := NewTestDB(t) // Register three payments: @@ -1252,7 +1406,7 @@ func TestDeletePayments(t *testing.T) { assertDBPayments(t, paymentDB, payments) // Delete HTLC attempts for failed payments only. - numPayments, err := paymentDB.DeletePayments(true, true) + numPayments, err := paymentDB.DeletePayments(ctx, true, true) require.NoError(t, err) require.EqualValues(t, 0, numPayments) @@ -1261,7 +1415,7 @@ func TestDeletePayments(t *testing.T) { assertDBPayments(t, paymentDB, payments) // Delete failed attempts for all payments. - numPayments, err = paymentDB.DeletePayments(false, true) + numPayments, err = paymentDB.DeletePayments(ctx, false, true) require.NoError(t, err) require.EqualValues(t, 0, numPayments) @@ -1271,14 +1425,14 @@ func TestDeletePayments(t *testing.T) { assertDBPayments(t, paymentDB, payments) // Now delete all failed payments. - numPayments, err = paymentDB.DeletePayments(true, false) + numPayments, err = paymentDB.DeletePayments(ctx, true, false) require.NoError(t, err) require.EqualValues(t, 1, numPayments) assertDBPayments(t, paymentDB, payments[1:]) // Finally delete all completed payments. - numPayments, err = paymentDB.DeletePayments(false, false) + numPayments, err = paymentDB.DeletePayments(ctx, false, false) require.NoError(t, err) require.EqualValues(t, 1, numPayments) @@ -1290,14 +1444,21 @@ func TestDeletePayments(t *testing.T) { func TestSwitchDoubleSend(t *testing.T) { t.Parallel() + ctx := t.Context() + paymentDB := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash(t, 0, genSessionKey(t), rhash) + require.NoError(t, err) // Sends base htlc message which initiate base status and move it to // StatusInFlight and verifies that it was changed. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) @@ -1311,11 +1472,11 @@ func TestSwitchDoubleSend(t *testing.T) { // Try to initiate double sending of htlc message with the same // payment hash, should result in error indicating that payment has // already been sent. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) require.ErrorIs(t, err, ErrPaymentExists) // Record an attempt. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + _, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to send htlc message") assertDBPaymentstatus( t, paymentDB, info.PaymentIdentifier, StatusInFlight, @@ -1329,7 +1490,7 @@ func TestSwitchDoubleSend(t *testing.T) { ) // Sends base htlc message which initiate StatusInFlight. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) if !errors.Is(err, ErrPaymentInFlight) { t.Fatalf("payment control wrong behaviour: " + "double sending must trigger ErrPaymentInFlight error") @@ -1337,7 +1498,7 @@ func TestSwitchDoubleSend(t *testing.T) { // After settling, the error should be ErrAlreadyPaid. _, err = paymentDB.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCSettleInfo{ Preimage: preimg, }, @@ -1352,7 +1513,7 @@ func TestSwitchDoubleSend(t *testing.T) { t, paymentDB, info.PaymentIdentifier, info, nil, htlc, ) - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) if !errors.Is(err, ErrAlreadyPaid) { t.Fatalf("unable to send htlc message: %v", err) } @@ -1363,13 +1524,20 @@ func TestSwitchDoubleSend(t *testing.T) { func TestSwitchFail(t *testing.T) { t.Parallel() + ctx := t.Context() + paymentDB := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash(t, 0, genSessionKey(t), rhash) + require.NoError(t, err) // Sends base htlc message which initiate StatusInFlight. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) @@ -1382,7 +1550,7 @@ func TestSwitchFail(t *testing.T) { // Fail the payment, which should moved it to Failed. failReason := FailureReasonNoRoute - _, err = paymentDB.Fail(info.PaymentIdentifier, failReason) + _, err = paymentDB.Fail(ctx, info.PaymentIdentifier, failReason) require.NoError(t, err, "unable to fail payment hash") // Verify the status is indeed Failed. @@ -1396,12 +1564,12 @@ func TestSwitchFail(t *testing.T) { // Lookup the payment so we can get its old sequence number before it is // overwritten. - payment, err := paymentDB.FetchPayment(info.PaymentIdentifier) + payment, err := paymentDB.FetchPayment(ctx, info.PaymentIdentifier) require.NoError(t, err) // Sends the htlc again, which should succeed since the prior payment // failed. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) require.NoError(t, err, "unable to send htlc message") // Check that our index has been updated, and the old index has been @@ -1419,12 +1587,12 @@ func TestSwitchFail(t *testing.T) { // Record a new attempt. In this test scenario, the attempt fails. // However, this is not communicated to control tower in the current // implementation. It only registers the initiation of the attempt. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + _, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to register attempt") htlcReason := HTLCFailUnreadable _, err = paymentDB.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCFailInfo{ Reason: htlcReason, }, @@ -1444,8 +1612,12 @@ func TestSwitchFail(t *testing.T) { assertPaymentInfo(t, paymentDB, info.PaymentIdentifier, info, nil, htlc) // Record another attempt. - attempt.AttemptID = 1 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) + attempt, err = genAttemptWithHash( + t, 1, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + _, err = paymentDB.RegisterAttempt(ctx, info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to send htlc message") assertDBPaymentstatus( t, paymentDB, info.PaymentIdentifier, StatusInFlight, @@ -1462,7 +1634,7 @@ func TestSwitchFail(t *testing.T) { // Settle the attempt and verify that status was changed to // StatusSucceeded. payment, err = paymentDB.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, + ctx, info.PaymentIdentifier, attempt.AttemptID, &HTLCSettleInfo{ Preimage: preimg, }, @@ -1492,7 +1664,7 @@ func TestSwitchFail(t *testing.T) { // Attempt a final payment, which should now fail since the prior // payment succeed. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) if !errors.Is(err, ErrAlreadyPaid) { t.Fatalf("unable to send htlc message: %v", err) } @@ -1503,6 +1675,8 @@ func TestSwitchFail(t *testing.T) { func TestMultiShard(t *testing.T) { t.Parallel() + ctx := t.Context() + // We will register three HTLC attempts, and always fail the second // one. We'll generate all combinations of settling/failing the first // and third HTLC, and assert that the payment status end up as we @@ -1522,16 +1696,15 @@ func TestMultiShard(t *testing.T) { runSubTest := func(t *testing.T, test testCase) { paymentDB := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Init the payment, moving it to the StatusInFlight state. - err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } + err = paymentDB.InitPayment(ctx, info.PaymentIdentifier, info) + require.NoError(t, err) assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) assertDBPaymentstatus( @@ -1546,19 +1719,23 @@ func TestMultiShard(t *testing.T) { // attempts's value to one third of the payment amount, and // populate the MPP options. shardAmt := info.Value / 3 - attempt.Route.FinalHop().AmtToForward = shardAmt - attempt.Route.FinalHop().MPP = record.NewMPP( - info.Value, [32]byte{1}, - ) var attempts []*HTLCAttemptInfo for i := uint64(0); i < 3; i++ { - a := *attempt - a.AttemptID = i - attempts = append(attempts, &a) + a, err := genAttemptWithHash( + t, i, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + a.Route.FinalHop().AmtToForward = shardAmt + a.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + attempts = append(attempts, a) _, err = paymentDB.RegisterAttempt( - info.PaymentIdentifier, &a, + ctx, info.PaymentIdentifier, a, ) if err != nil { t.Fatalf("unable to send htlc message: %v", err) @@ -1569,7 +1746,7 @@ func TestMultiShard(t *testing.T) { ) htlc := &htlcStatus{ - HTLCAttemptInfo: &a, + HTLCAttemptInfo: a, } assertPaymentInfo( t, paymentDB, info.PaymentIdentifier, info, nil, @@ -1580,16 +1757,26 @@ func TestMultiShard(t *testing.T) { // For a fourth attempt, check that attempting to // register it will fail since the total sent amount // will be too large. - b := *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + b, err := genAttemptWithHash( + t, 3, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + b.Route.FinalHop().AmtToForward = shardAmt + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, b, + ) require.ErrorIs(t, err, ErrValueExceedsAmt) // Fail the second attempt. a := attempts[1] htlcFail := HTLCFailUnreadable _, err = paymentDB.FailAttempt( - info.PaymentIdentifier, a.AttemptID, + ctx, info.PaymentIdentifier, a.AttemptID, &HTLCFailInfo{ Reason: htlcFail, }, @@ -1620,7 +1807,7 @@ func TestMultiShard(t *testing.T) { var firstFailReason *FailureReason if test.settleFirst { _, err := paymentDB.SettleAttempt( - info.PaymentIdentifier, a.AttemptID, + ctx, info.PaymentIdentifier, a.AttemptID, &HTLCSettleInfo{ Preimage: preimg, }, @@ -1638,7 +1825,7 @@ func TestMultiShard(t *testing.T) { ) } else { _, err := paymentDB.FailAttempt( - info.PaymentIdentifier, a.AttemptID, + ctx, info.PaymentIdentifier, a.AttemptID, &HTLCFailInfo{ Reason: htlcFail, }, @@ -1659,7 +1846,7 @@ func TestMultiShard(t *testing.T) { // a terminal state. failReason := FailureReasonNoRoute _, err = paymentDB.Fail( - info.PaymentIdentifier, failReason, + ctx, info.PaymentIdentifier, failReason, ) if err != nil { t.Fatalf("unable to fail payment hash: %v", err) @@ -1679,9 +1866,19 @@ func TestMultiShard(t *testing.T) { // Try to register yet another attempt. This should fail now // that the payment has reached a terminal condition. - b = *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + b, err = genAttemptWithHash( + t, 3, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + b.Route.FinalHop().AmtToForward = shardAmt + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, b, + ) if test.settleFirst { require.ErrorIs( t, err, ErrPaymentPendingSettled, @@ -1704,7 +1901,7 @@ func TestMultiShard(t *testing.T) { if test.settleLast { // Settle the last outstanding attempt. _, err = paymentDB.SettleAttempt( - info.PaymentIdentifier, a.AttemptID, + ctx, info.PaymentIdentifier, a.AttemptID, &HTLCSettleInfo{ Preimage: preimg, }, @@ -1719,7 +1916,7 @@ func TestMultiShard(t *testing.T) { } else { // Fail the attempt. _, err := paymentDB.FailAttempt( - info.PaymentIdentifier, a.AttemptID, + ctx, info.PaymentIdentifier, a.AttemptID, &HTLCFailInfo{ Reason: htlcFail, }, @@ -1742,7 +1939,7 @@ func TestMultiShard(t *testing.T) { // syncing. failReason := FailureReasonPaymentDetails _, err = paymentDB.Fail( - info.PaymentIdentifier, failReason, + ctx, info.PaymentIdentifier, failReason, ) require.NoError(t, err, "unable to fail") } @@ -1780,8 +1977,10 @@ func TestMultiShard(t *testing.T) { ) // Finally assert we cannot register more attempts. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.Equal(t, registerErr, err) + _, err = paymentDB.RegisterAttempt( + ctx, info.PaymentIdentifier, b, + ) + require.ErrorIs(t, err, registerErr) } for _, test := range tests { @@ -1793,3 +1992,477 @@ func TestMultiShard(t *testing.T) { }) } } + +// TestQueryPayments tests retrieval of payments with forwards and reversed +// queries. +func TestQueryPayments(t *testing.T) { + // Define table driven test for QueryPayments. + // Test payments have sequence indices [1, 3, 4, 5, 6]. + // Note that payment with index 2 is deleted to create a gap in the + // sequence numbers. + tests := []struct { + name string + query Query + firstIndex uint64 + lastIndex uint64 + + // expectedSeqNrs contains the set of sequence numbers we expect + // our query to return. + expectedSeqNrs []uint64 + }{ + { + name: "IndexOffset at the end of the payments range", + query: Query{ + IndexOffset: 6, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "query in forwards order, start at beginning", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in forwards order, start at end, overflow", + query: Query{ + IndexOffset: 5, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "start at offset index outside of payments", + query: Query{ + IndexOffset: 20, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "overflow in forwards order", + query: Query{ + IndexOffset: 4, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "start at offset index outside of payments, " + + "reversed order", + query: Query{ + IndexOffset: 9, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in reverse order, start at end", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in reverse order, starting in middle", + query: Query{ + IndexOffset: 4, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in reverse order, starting in middle, " + + "with underflow", + query: Query{ + IndexOffset: 4, + MaxPayments: 5, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "all payments in reverse, order maintained", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 6, + expectedSeqNrs: []uint64{1, 3, 4, 5, 6}, + }, + { + name: "exclude incomplete payments", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: false, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "query payments at index gap", + query: Query{ + IndexOffset: 1, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 6, + expectedSeqNrs: []uint64{3, 4, 5, 6}, + }, + { + name: "query payments reverse before index gap", + query: Query{ + IndexOffset: 3, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments reverse on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments forward on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 4, + expectedSeqNrs: []uint64{3, 4}, + }, + { + name: "query in forwards order, with start creation " + + "time", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 5, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in forwards order, with start creation " + + "time at end, overflow", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 6, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "query with start and end creation time", + query: Query{ + IndexOffset: 9, + MaxPayments: math.MaxUint64, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: 3, + CreationDateEnd: 5, + }, + firstIndex: 3, + lastIndex: 5, + expectedSeqNrs: []uint64{3, 4, 5}, + }, + { + name: "query with only end creation time", + query: Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: true, + CreationDateEnd: 4, + }, + firstIndex: 1, + lastIndex: 4, + expectedSeqNrs: []uint64{1, 3, 4}, + }, + { + name: "query reversed with creation date start", + query: Query{ + IndexOffset: 0, + MaxPayments: 3, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: 3, + }, + firstIndex: 4, + lastIndex: 6, + expectedSeqNrs: []uint64{4, 5, 6}, + }, + { + name: "count total with forward pagination", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CountTotal: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "count total with reverse pagination", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + CountTotal: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "count total with filters", + query: Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: false, + CountTotal: true, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + paymentDB := NewTestDB(t) + + // Make a preliminary query to make sure it's ok to + // query when we have no payments. + resp, err := paymentDB.QueryPayments(ctx, tt.query) + require.NoError(t, err) + require.Len(t, resp.Payments, 0) + + // Populate the database with a set of test payments. + // We create 6 payments, deleting the payment at index + // 2 so that we cover the case where sequence numbers + // are missing. + numberOfPayments := 6 + + // Store payment info for all payments so we can delete + // one after all are created. + var paymentInfos []*PaymentCreationInfo + + // First, create all payments. + for i := range numberOfPayments { + // Generate a test payment. + info, _, err := genInfo(t) + require.NoError(t, err) + + // Override creation time to allow for testing + // of CreationDateStart and CreationDateEnd. + info.CreationTime = time.Unix(int64(i+1), 0) + + paymentInfos = append(paymentInfos, info) + + // Create a new payment entry in the database. + err = paymentDB.InitPayment( + ctx, info.PaymentIdentifier, info, + ) + require.NoError(t, err) + } + + // Now delete the payment at index 1 (the second + // payment). + pmt, err := paymentDB.FetchPayment( + ctx, paymentInfos[1].PaymentIdentifier, + ) + require.NoError(t, err) + + // We delete the whole payment. + err = paymentDB.DeletePayment( + ctx, paymentInfos[1].PaymentIdentifier, false, + ) + require.NoError(t, err) + + // Verify the payment is deleted. + _, err = paymentDB.FetchPayment( + ctx, paymentInfos[1].PaymentIdentifier, + ) + require.ErrorIs( + t, err, ErrPaymentNotInitiated, + ) + + // Verify the index is removed (KV store only). + assertNoIndex( + t, paymentDB, pmt.SequenceNum, + ) + + // For the last payment, settle it so we have at least + // one completed payment for the "exclude incomplete" + // test case. + lastPaymentInfo := paymentInfos[numberOfPayments-1] + attempt, err := NewHtlcAttempt( + 1, priv, testRoute, + time.Unix(100, 0), + &lastPaymentInfo.PaymentIdentifier, + ) + require.NoError(t, err) + + _, err = paymentDB.RegisterAttempt( + ctx, lastPaymentInfo.PaymentIdentifier, + &attempt.HTLCAttemptInfo, + ) + require.NoError(t, err) + + var preimg lntypes.Preimage + copy(preimg[:], rev[:]) + + _, err = paymentDB.SettleAttempt( + ctx, lastPaymentInfo.PaymentIdentifier, + attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err) + + // Fetch all payments in the database. + resp, err = paymentDB.QueryPayments( + ctx, Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + IncludeIncomplete: true, + }, + ) + require.NoError(t, err) + + allPayments := resp.Payments + + if len(allPayments) != 5 { + t.Fatalf("Number of payments received does "+ + "not match expected one. Got %v, "+ + "want %v.", len(allPayments), 5) + } + + querySlice, err := paymentDB.QueryPayments( + ctx, tt.query, + ) + require.NoError(t, err) + + if tt.firstIndex != querySlice.FirstIndexOffset || + tt.lastIndex != querySlice.LastIndexOffset { + + t.Errorf("First or last index does not match "+ + "expected index. Want (%d, %d), "+ + "got (%d, %d).", + tt.firstIndex, tt.lastIndex, + querySlice.FirstIndexOffset, + querySlice.LastIndexOffset) + } + + if len(querySlice.Payments) != len(tt.expectedSeqNrs) { + t.Errorf("expected: %v payments, got: %v", + len(tt.expectedSeqNrs), + len(querySlice.Payments)) + } + + for i, seqNr := range tt.expectedSeqNrs { + q := querySlice.Payments[i] + if seqNr != q.SequenceNum { + t.Errorf("sequence numbers do not "+ + "match, got %v, want %v", + q.SequenceNum, seqNr) + } + } + + // Verify CountTotal is set correctly when requested. + if tt.query.CountTotal { + // We should have 5 total payments + // (6 created - 1 deleted). + expectedTotal := uint64(5) + if querySlice.TotalCount != expectedTotal { + t.Errorf("expected total count %v, "+ + "got %v", expectedTotal, + querySlice.TotalCount) + } + } else if querySlice.TotalCount != 0 { + t.Errorf("expected total count 0 when "+ + "CountTotal=false, got %v", + querySlice.TotalCount) + } + }) + } +} diff --git a/payments/db/sql_converters.go b/payments/db/sql_converters.go new file mode 100644 index 00000000000..7f6d222defb --- /dev/null +++ b/payments/db/sql_converters.go @@ -0,0 +1,277 @@ +package paymentsdb + +import ( + "bytes" + "fmt" + "sort" + "strconv" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/sqldb/sqlc" + "github.com/lightningnetwork/lnd/tlv" +) + +// dbPaymentToCreationInfo converts database payment data to +// PaymentCreationInfo. +func dbPaymentToCreationInfo(paymentIdentifier []byte, amountMsat int64, + createdAt time.Time, intentPayload []byte, + firstHopCustomRecords lnwire.CustomRecords) *PaymentCreationInfo { + + // This is the payment hash for non-AMP payments and the SetID for AMP + // payments. + var identifier lntypes.Hash + copy(identifier[:], paymentIdentifier) + + return &PaymentCreationInfo{ + PaymentIdentifier: identifier, + Value: lnwire.MilliSatoshi(amountMsat), + // The creation time is stored in the database as UTC but here + // we convert it to local time. + CreationTime: createdAt.Local(), + PaymentRequest: intentPayload, + FirstHopCustomRecords: firstHopCustomRecords, + } +} + +// dbAttemptToHTLCAttempt converts a database HTLC attempt to an HTLCAttempt. +func dbAttemptToHTLCAttempt( + dbAttempt sqlc.FetchHtlcAttemptsForPaymentRow, + hops []sqlc.FetchHopsForAttemptsRow, + hopCustomRecords map[int64][]sqlc.PaymentHopCustomRecord, + routeCustomRecords []sqlc.PaymentAttemptFirstHopCustomRecord) ( + *HTLCAttempt, error) { + + // Convert route-level first hop custom records to CustomRecords map. + var firstHopWireCustomRecords lnwire.CustomRecords + if len(routeCustomRecords) > 0 { + firstHopWireCustomRecords = make(lnwire.CustomRecords) + for _, record := range routeCustomRecords { + firstHopWireCustomRecords[uint64(record.Key)] = + record.Value + } + } + + // Build the route from the database data. + route, err := dbDataToRoute( + hops, hopCustomRecords, dbAttempt.FirstHopAmountMsat, + dbAttempt.RouteTotalTimeLock, dbAttempt.RouteTotalAmount, + dbAttempt.RouteSourceKey, firstHopWireCustomRecords, + ) + if err != nil { + return nil, fmt.Errorf("failed to convert to route: %w", + err) + } + + hash, err := lntypes.MakeHash(dbAttempt.PaymentHash) + if err != nil { + return nil, fmt.Errorf("failed to parse payment "+ + "hash: %w", err) + } + + // Create the attempt info. + var sessionKey [32]byte + copy(sessionKey[:], dbAttempt.SessionKey) + + info := HTLCAttemptInfo{ + AttemptID: uint64(dbAttempt.AttemptIndex), + sessionKey: sessionKey, + Route: *route, + AttemptTime: dbAttempt.AttemptTime, + Hash: &hash, + } + + attempt := &HTLCAttempt{ + HTLCAttemptInfo: info, + } + + // Add settlement info if present. + if dbAttempt.ResolutionType.Valid && + HTLCAttemptResolutionType(dbAttempt.ResolutionType.Int32) == + HTLCAttemptResolutionSettled { + + var preimage lntypes.Preimage + copy(preimage[:], dbAttempt.SettlePreimage) + + attempt.Settle = &HTLCSettleInfo{ + Preimage: preimage, + SettleTime: dbAttempt.ResolutionTime.Time, + } + } + + // Add failure info if present. + if dbAttempt.ResolutionType.Valid && + HTLCAttemptResolutionType(dbAttempt.ResolutionType.Int32) == + HTLCAttemptResolutionFailed { + + failure := &HTLCFailInfo{ + FailTime: dbAttempt.ResolutionTime.Time, + } + + if dbAttempt.HtlcFailReason.Valid { + failure.Reason = HTLCFailReason( + dbAttempt.HtlcFailReason.Int32, + ) + } + + if dbAttempt.FailureSourceIndex.Valid { + failure.FailureSourceIndex = uint32( + dbAttempt.FailureSourceIndex.Int32, + ) + } + + // Decode the failure message if present. + if len(dbAttempt.FailureMsg) > 0 { + msg, err := lnwire.DecodeFailureMessage( + bytes.NewReader(dbAttempt.FailureMsg), 0, + ) + if err != nil { + return nil, fmt.Errorf("failed to decode "+ + "failure message: %w", err) + } + failure.Message = msg + } + + attempt.Failure = failure + } + + return attempt, nil +} + +// dbDataToRoute converts database route data to a route.Route. +func dbDataToRoute(hops []sqlc.FetchHopsForAttemptsRow, + hopCustomRecords map[int64][]sqlc.PaymentHopCustomRecord, + firstHopAmountMsat int64, totalTimeLock int32, totalAmount int64, + sourceKey []byte, firstHopWireCustomRecords lnwire.CustomRecords) ( + *route.Route, error) { + + if len(hops) == 0 { + return nil, fmt.Errorf("no hops provided") + } + + // Sort hops by hop index. + sort.Slice(hops, func(i, j int) bool { + return hops[i].HopIndex < hops[j].HopIndex + }) + + routeHops := make([]*route.Hop, len(hops)) + + for i, hop := range hops { + pubKey, err := route.NewVertexFromBytes(hop.PubKey) + if err != nil { + return nil, fmt.Errorf("failed to parse pub key: %w", + err) + } + + var channelID uint64 + if hop.Scid != "" { + // The SCID is stored as a string representation + // of the uint64. + var err error + channelID, err = strconv.ParseUint(hop.Scid, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse "+ + "scid: %w", err) + } + } + + routeHop := &route.Hop{ + PubKeyBytes: pubKey, + ChannelID: channelID, + OutgoingTimeLock: uint32(hop.OutgoingTimeLock), + AmtToForward: lnwire.MilliSatoshi(hop.AmtToForward), + } + + // Add MPP record if present. + if len(hop.MppPaymentAddr) > 0 { + var paymentAddr [32]byte + copy(paymentAddr[:], hop.MppPaymentAddr) + routeHop.MPP = record.NewMPP( + lnwire.MilliSatoshi(hop.MppTotalMsat.Int64), + paymentAddr, + ) + } + + // Add AMP record if present. + if len(hop.AmpRootShare) > 0 { + var rootShare [32]byte + copy(rootShare[:], hop.AmpRootShare) + var setID [32]byte + copy(setID[:], hop.AmpSetID) + + routeHop.AMP = record.NewAMP( + rootShare, setID, + uint32(hop.AmpChildIndex.Int32), + ) + } + + // Add blinding point if present (only for introduction node + // in blinded route). + if len(hop.BlindingPoint) > 0 { + pubKey, err := btcec.ParsePubKey(hop.BlindingPoint) + if err != nil { + return nil, fmt.Errorf("failed to parse "+ + "blinding point: %w", err) + } + routeHop.BlindingPoint = pubKey + } + + // Add encrypted data if present (for all blinded hops). + if len(hop.EncryptedData) > 0 { + routeHop.EncryptedData = hop.EncryptedData + } + + // Add total amount if present (only for final hop in blinded + // route). + if hop.BlindedPathTotalAmt.Valid { + routeHop.TotalAmtMsat = lnwire.MilliSatoshi( + hop.BlindedPathTotalAmt.Int64, + ) + } + + // Add hop-level custom records. + if records, ok := hopCustomRecords[hop.ID]; ok { + routeHop.CustomRecords = make( + record.CustomSet, + ) + for _, rec := range records { + routeHop.CustomRecords[uint64(rec.Key)] = + rec.Value + } + } + + // Add metadata if present. + if len(hop.MetaData) > 0 { + routeHop.Metadata = hop.MetaData + } + + routeHops[i] = routeHop + } + + // Parse the source node public key. + var sourceNode route.Vertex + copy(sourceNode[:], sourceKey) + + route := &route.Route{ + TotalTimeLock: uint32(totalTimeLock), + TotalAmount: lnwire.MilliSatoshi(totalAmount), + SourcePubKey: sourceNode, + Hops: routeHops, + FirstHopWireCustomRecords: firstHopWireCustomRecords, + } + + // Set the first hop amount if it is set. + if firstHopAmountMsat != 0 { + route.FirstHopAmount = tlv.NewRecordT[tlv.TlvType0]( + tlv.NewBigSizeT(lnwire.MilliSatoshi( + firstHopAmountMsat, + )), + ) + } + + return route, nil +} diff --git a/payments/db/sql_store.go b/payments/db/sql_store.go index 12585caf64e..a7efba8afc4 100644 --- a/payments/db/sql_store.go +++ b/payments/db/sql_store.go @@ -1,14 +1,91 @@ package paymentsdb import ( + "bytes" + "context" + "database/sql" + "errors" "fmt" + "math" + "strconv" + "time" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/sqlc" +) + +// PaymentIntentType represents the type of payment intent. +type PaymentIntentType int16 + +const ( + // PaymentIntentTypeBolt11 indicates a BOLT11 invoice payment. + PaymentIntentTypeBolt11 PaymentIntentType = 0 +) + +// HTLCAttemptResolutionType represents the type of HTLC attempt resolution. +type HTLCAttemptResolutionType int32 + +const ( + // HTLCAttemptResolutionSettled indicates the HTLC attempt was settled + // successfully with a preimage. + HTLCAttemptResolutionSettled HTLCAttemptResolutionType = 1 + + // HTLCAttemptResolutionFailed indicates the HTLC attempt failed. + HTLCAttemptResolutionFailed HTLCAttemptResolutionType = 2 ) // SQLQueries is a subset of the sqlc.Querier interface that can be used to // execute queries against the SQL payments tables. +// +//nolint:ll,interfacebloat type SQLQueries interface { + /* + Payment DB read operations. + */ + FilterPayments(ctx context.Context, query sqlc.FilterPaymentsParams) ([]sqlc.FilterPaymentsRow, error) + FetchPayment(ctx context.Context, paymentIdentifier []byte) (sqlc.FetchPaymentRow, error) + FetchPaymentsByIDs(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchPaymentsByIDsRow, error) + + CountPayments(ctx context.Context) (int64, error) + + FetchHtlcAttemptsForPayment(ctx context.Context, paymentID int64) ([]sqlc.FetchHtlcAttemptsForPaymentRow, error) + FetchAllInflightAttempts(ctx context.Context) ([]sqlc.PaymentHtlcAttempt, error) + FetchHopsForAttempt(ctx context.Context, htlcAttemptIndex int64) ([]sqlc.FetchHopsForAttemptRow, error) + FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]sqlc.FetchHopsForAttemptsRow, error) + + FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentID int64) ([]sqlc.PaymentFirstHopCustomRecord, error) + FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]sqlc.PaymentAttemptFirstHopCustomRecord, error) + FetchHopLevelCustomRecords(ctx context.Context, hopIDs []int64) ([]sqlc.PaymentHopCustomRecord, error) + + /* + Payment DB write operations. + */ + InsertPaymentIntent(ctx context.Context, arg sqlc.InsertPaymentIntentParams) (int64, error) + InsertPayment(ctx context.Context, arg sqlc.InsertPaymentParams) (int64, error) + InsertPaymentFirstHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentFirstHopCustomRecordParams) error + + InsertHtlcAttempt(ctx context.Context, arg sqlc.InsertHtlcAttemptParams) (int64, error) + InsertRouteHop(ctx context.Context, arg sqlc.InsertRouteHopParams) (int64, error) + InsertRouteHopMpp(ctx context.Context, arg sqlc.InsertRouteHopMppParams) error + InsertRouteHopAmp(ctx context.Context, arg sqlc.InsertRouteHopAmpParams) error + InsertRouteHopBlinded(ctx context.Context, arg sqlc.InsertRouteHopBlindedParams) error + + InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentAttemptFirstHopCustomRecordParams) error + InsertPaymentHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentHopCustomRecordParams) error + + SettleAttempt(ctx context.Context, arg sqlc.SettleAttemptParams) error + FailAttempt(ctx context.Context, arg sqlc.FailAttemptParams) error + + FailPayment(ctx context.Context, arg sqlc.FailPaymentParams) (sql.Result, error) + + DeletePayment(ctx context.Context, paymentID int64) error + + // DeleteFailedAttempts removes all failed HTLCs from the db for a + // given payment. + DeleteFailedAttempts(ctx context.Context, paymentID int64) error } // BatchedSQLQueries is a version of the SQLQueries that's capable @@ -20,10 +97,6 @@ type BatchedSQLQueries interface { // SQLStore represents a storage backend. type SQLStore struct { - // TODO(ziggie): Remove the KVStore once all the interface functions are - // implemented. - KVStore - cfg *SQLStoreConfig db BatchedSQLQueries @@ -65,3 +138,1106 @@ func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, // A compile-time constraint to ensure SQLStore implements DB. var _ DB = (*SQLStore)(nil) + +// fetchPaymentWithCompleteData fetches a payment with all its related data +// including attempts, hops, and custom records from the database. +func (s *SQLStore) fetchPaymentWithCompleteData(ctx context.Context, + db SQLQueries, dbPayment sqlc.PaymentAndIntent) (*MPPayment, error) { + + // The query will only return BOLT 11 payment intents or intents with + // no intent type set. + paymentIntent := dbPayment.GetPaymentIntent() + paymentRequest := paymentIntent.IntentPayload + + // Fetch payment-level first hop custom records. + payment := dbPayment.GetPayment() + customRecords, err := db.FetchPaymentLevelFirstHopCustomRecords( + ctx, payment.ID, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch payment level custom "+ + "records: %w", err) + } + + // Convert to the FirstHopCustomRecords map. + var firstHopCustomRecords lnwire.CustomRecords + if len(customRecords) > 0 { + firstHopCustomRecords = make(lnwire.CustomRecords) + for _, record := range customRecords { + firstHopCustomRecords[uint64(record.Key)] = record.Value + } + } + + // Convert the basic payment info. + info := dbPaymentToCreationInfo( + payment.PaymentIdentifier, payment.AmountMsat, + payment.CreatedAt, paymentRequest, firstHopCustomRecords, + ) + + // Fetch all HTLC attempts for this payment. + attempts, err := s.fetchHTLCAttemptsForPayment( + ctx, db, payment.ID, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch HTLC attempts: %w", + err) + } + + // Set the failure reason if present. + var failureReason *FailureReason + if payment.FailReason.Valid { + reason := FailureReason(payment.FailReason.Int32) + failureReason = &reason + } + + mpPayment := &MPPayment{ + SequenceNum: uint64(payment.ID), + Info: info, + HTLCs: attempts, + FailureReason: failureReason, + } + + // The status and state will be determined by calling + // SetState after construction. + if err := mpPayment.SetState(); err != nil { + return nil, fmt.Errorf("failed to set payment state: %w", err) + } + + return mpPayment, nil +} + +// fetchHTLCAttemptsForPayment fetches all HTLC attempts for a payment and +// uses ExecuteBatchQuery to efficiently fetch hops and custom records. +func (s *SQLStore) fetchHTLCAttemptsForPayment(ctx context.Context, + db SQLQueries, paymentID int64) ([]HTLCAttempt, error) { + + // Fetch all HTLC attempts for this payment. + dbAttempts, err := db.FetchHtlcAttemptsForPayment( + ctx, paymentID, + ) + + if err != nil { + return nil, fmt.Errorf("failed to fetch HTLC attempts: %w", + err) + } + + if len(dbAttempts) == 0 { + return nil, nil + } + + // Collect all attempt indices for batch fetching. + attemptIndices := make([]int64, len(dbAttempts)) + for i, attempt := range dbAttempts { + attemptIndices[i] = attempt.AttemptIndex + } + + // Fetch all hops for all attempts using ExecuteBatchQuery. + hopsByAttempt := make(map[int64][]sqlc.FetchHopsForAttemptsRow) + err = sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, attemptIndices, + func(idx int64) int64 { return idx }, + func(ctx context.Context, indices []int64) ( + []sqlc.FetchHopsForAttemptsRow, error) { + + return db.FetchHopsForAttempts(ctx, indices) + }, + func(ctx context.Context, + hop sqlc.FetchHopsForAttemptsRow) error { + + hopsByAttempt[hop.HtlcAttemptIndex] = append( + hopsByAttempt[hop.HtlcAttemptIndex], hop, + ) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch hops for attempts: %w", + err) + } + + // Collect all hop IDs for fetching hop-level custom records. + var hopIDs []int64 + for _, hops := range hopsByAttempt { + for _, hop := range hops { + hopIDs = append(hopIDs, hop.ID) + } + } + + // Fetch all hop-level custom records using ExecuteBatchQuery. + hopCustomRecords := make(map[int64][]sqlc.PaymentHopCustomRecord) + if len(hopIDs) > 0 { + err = sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, hopIDs, + func(id int64) int64 { return id }, + func(ctx context.Context, ids []int64) ( + []sqlc.PaymentHopCustomRecord, error) { + + return db.FetchHopLevelCustomRecords(ctx, ids) + }, + func(ctx context.Context, + record sqlc.PaymentHopCustomRecord) error { + + // TODO(ziggie): Can we get rid of this? + // This has to be in place otherwise the + // comparison will not match. + if record.Value == nil { + record.Value = []byte{} + } + + hopCustomRecords[record.HopID] = append( + hopCustomRecords[record.HopID], record, + ) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch hop custom "+ + "records: %w", err) + } + } + + // Fetch route-level first hop custom records using ExecuteBatchQuery. + routeCustomRecords := make( + map[int64][]sqlc.PaymentAttemptFirstHopCustomRecord, + ) + err = sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, attemptIndices, + func(idx int64) int64 { return idx }, + func(ctx context.Context, indices []int64) ( + []sqlc.PaymentAttemptFirstHopCustomRecord, error) { + + return db.FetchRouteLevelFirstHopCustomRecords( + ctx, indices, + ) + }, + func(ctx context.Context, + record sqlc.PaymentAttemptFirstHopCustomRecord) error { + + routeCustomRecords[record.HtlcAttemptIndex] = append( + routeCustomRecords[record.HtlcAttemptIndex], + record, + ) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch route custom "+ + "records: %w", err) + } + + // Now convert all attempts to HTLCAttempt structs. + attempts := make([]HTLCAttempt, 0, len(dbAttempts)) + for _, dbAttempt := range dbAttempts { + attemptIndex := dbAttempt.AttemptIndex + attempt, err := dbAttemptToHTLCAttempt( + dbAttempt, hopsByAttempt[attemptIndex], + hopCustomRecords, + routeCustomRecords[attemptIndex], + ) + if err != nil { + return nil, fmt.Errorf("failed to convert attempt "+ + "%d: %w", attemptIndex, err) + } + attempts = append(attempts, *attempt) + } + + return attempts, nil +} + +// QueryPayments queries the payments from the database. +// +// This is part of the DB interface. +func (s *SQLStore) QueryPayments(ctx context.Context, + query Query) (Response, error) { + + if query.MaxPayments == 0 { + return Response{}, fmt.Errorf("max payments must be non-zero") + } + + var ( + allPayments []*MPPayment + totalCount int64 + initialCursor int64 + ) + + extractCursor := func( + row sqlc.FilterPaymentsRow) int64 { + + return row.Payment.ID + } + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + // We first count all payments to determine the total count + // if requested. + if query.CountTotal { + totalPayments, err := db.CountPayments(ctx) + if err != nil { + return fmt.Errorf("failed to count "+ + "payments: %w", err) + } + totalCount = totalPayments + } + + processPayment := func(ctx context.Context, + dbPayment sqlc.FilterPaymentsRow) error { + + // Fetch all the additional data for the payment. + mpPayment, err := s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + + // To keep compatibility with the old API, we only + // return non-succeeded payments if requested. + if mpPayment.Status != StatusSucceeded && + !query.IncludeIncomplete { + + return nil + } + + if uint64(len(allPayments)) >= query.MaxPayments { + return errMaxPaymentsReached + } + + allPayments = append(allPayments, mpPayment) + + return nil + } + + queryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.FilterPaymentsRow, error) { + + filterParams := sqlc.FilterPaymentsParams{ + NumLimit: limit, + Reverse: query.Reversed, + // For now there only BOLT 11 payment intents + // exist. + IntentType: sqldb.SQLInt16( + PaymentIntentTypeBolt11, + ), + } + + if query.Reversed { + filterParams.IndexOffsetLet = sqldb.SQLInt64( + lastID, + ) + } else { + filterParams.IndexOffsetGet = sqldb.SQLInt64( + lastID, + ) + } + + // Add potential date filters if specified. + if query.CreationDateStart != 0 { + filterParams.CreatedAfter = sqldb.SQLTime( + time.Unix(query.CreationDateStart, 0). + UTC(), + ) + } + if query.CreationDateEnd != 0 { + filterParams.CreatedBefore = sqldb.SQLTime( + time.Unix(query.CreationDateEnd, 0). + UTC(), + ) + } + + return db.FilterPayments(ctx, filterParams) + } + + if query.Reversed { + if query.IndexOffset == 0 { + initialCursor = int64(math.MaxInt64) + } else { + initialCursor = int64(query.IndexOffset) + } + } else { + initialCursor = int64(query.IndexOffset) + } + + return sqldb.ExecutePaginatedQuery( + ctx, s.cfg.QueryCfg, initialCursor, queryFunc, + extractCursor, processPayment, + ) + }, func() { + allPayments = nil + }) + + // We make sure we don't return an error if we reached the maximum + // number of payments. Which is the pagination limit for the query + // itself. + if err != nil && !errors.Is(err, errMaxPaymentsReached) { + return Response{}, fmt.Errorf("failed to query payments: %w", + err) + } + + // Handle case where no payments were found + if len(allPayments) == 0 { + return Response{ + Payments: allPayments, + FirstIndexOffset: 0, + LastIndexOffset: 0, + TotalCount: uint64(totalCount), + }, nil + } + + // If the query was reversed, we need to reverse the payment list + // to match the kvstore behavior and return payments in forward order. + if query.Reversed { + for i, j := 0, len(allPayments)-1; i < j; i, j = i+1, j-1 { + allPayments[i], allPayments[j] = allPayments[j], + allPayments[i] + } + } + + return Response{ + Payments: allPayments, + FirstIndexOffset: allPayments[0].SequenceNum, + LastIndexOffset: allPayments[len(allPayments)-1].SequenceNum, + TotalCount: uint64(totalCount), + }, nil +} + +// FetchPayment fetches the payment corresponding to the given payment +// hash. +// +// This is part of the DB interface. +func (s *SQLStore) FetchPayment(ctx context.Context, + paymentHash lntypes.Hash) (*MPPayment, error) { + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + dbPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to fetch payment: %w", err) + } + + if errors.Is(err, sql.ErrNoRows) { + return ErrPaymentNotInitiated + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch payment: %w", err) + } + + return mpPayment, nil +} + +// FetchInFlightPayments fetches all payments with status InFlight. +// +// TODO(ziggie): Add pagination (LIMIT)) to this function? +// +// This is part of the DB interface. +func (s *SQLStore) FetchInFlightPayments(ctx context.Context) ([]*MPPayment, + error) { + + var mpPayments []*MPPayment + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + inflightDBAttempts, err := db.FetchAllInflightAttempts(ctx) + if err != nil { + return fmt.Errorf("failed to fetch inflight "+ + "attempts: %w", err) + } + + paymentIDs := make([]int64, len(inflightDBAttempts)) + for i, attempt := range inflightDBAttempts { + paymentIDs[i] = attempt.PaymentID + } + + dbPayments, err := db.FetchPaymentsByIDs(ctx, paymentIDs) + if err != nil { + return fmt.Errorf("failed to fetch payments by IDs: %w", + err) + } + + mpPayments = make([]*MPPayment, len(dbPayments)) + for i, dbPayment := range dbPayments { + mpPayment, err := s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + mpPayments[i] = mpPayment + } + + return nil + }, func() { + mpPayments = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch inflight "+ + "attempts: %w", err) + } + + return mpPayments, nil +} + +// DeletePayment deletes a payment from the DB given its payment hash. If +// failedHtlcsOnly is set, only failed HTLC attempts of the payment will be +// deleted. +func (s *SQLStore) DeletePayment(ctx context.Context, paymentHash lntypes.Hash, + failedHtlcsOnly bool) error { + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + fetchPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + completePayment, err := s.fetchPaymentWithCompleteData( + ctx, db, fetchPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + if err := completePayment.Status.removable(); err != nil { + return fmt.Errorf("payment %v cannot be deleted: %w", + paymentHash, err) + } + + // If we are only deleting failed HTLCs, we delete them. + if failedHtlcsOnly { + return db.DeleteFailedAttempts( + ctx, fetchPayment.Payment.ID, + ) + } + + // Be careful to not use s.db here, because we are in a + // transaction, is there a way to make this more secure? + return db.DeletePayment(ctx, fetchPayment.Payment.ID) + }, func() { + }) + if err != nil { + return fmt.Errorf("failed to delete payment "+ + "(failedHtlcsOnly: %v, paymentHash: %v): %w", + failedHtlcsOnly, paymentHash, err) + } + + return nil +} + +// DeleteFailedAttempts removes all failed HTLCs from the db. It should +// be called for a given payment whenever all inflight htlcs are +// completed, and the payment has reached a final terminal state. +func (s *SQLStore) DeleteFailedAttempts(ctx context.Context, + paymentHash lntypes.Hash) error { + + // In case we are configured to keep failed payment attempts, we exit + // early. + if s.keepFailedPaymentAttempts { + return nil + } + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // We first fetch the payment to get the payment ID. + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + + completePayment, err := s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + if err := completePayment.Status.removable(); err != nil { + return fmt.Errorf("payment %v cannot be deleted: %w", + paymentHash, err) + } + + // Then we delete the failed attempts for this payment. + return db.DeleteFailedAttempts(ctx, payment.Payment.ID) + }, func() { + }) + if err != nil { + return fmt.Errorf("failed to delete failed attempts for "+ + "payment %v: %w", paymentHash, err) + } + + return nil +} + +// InitPayment initializes a payment. +// +// This is part of the DB interface. +func (s *SQLStore) InitPayment(ctx context.Context, paymentHash lntypes.Hash, + paymentCreationInfo *PaymentCreationInfo) error { + + // Create the payment in the database. + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + existingPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err == nil { + completePayment, err := s.fetchPaymentWithCompleteData( + ctx, db, existingPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + + // Check if the payment is initializable otherwise + // we'll return early. + err = completePayment.Status.initializable() + if err != nil { + return err + } + } else if !errors.Is(err, sql.ErrNoRows) { + // Some other error occurred + return fmt.Errorf("failed to check existing "+ + "payment: %w", err) + } + + // If payment exists and is failed, delete it first. + if existingPayment.Payment.ID != 0 { + err := db.DeletePayment(ctx, existingPayment.Payment.ID) + if err != nil { + return fmt.Errorf("failed to delete "+ + "payment: %w", err) + } + } + + var intentID *int64 + if len(paymentCreationInfo.PaymentRequest) > 0 { + intentIDValue, err := db.InsertPaymentIntent(ctx, + sqlc.InsertPaymentIntentParams{ + IntentType: int16( + PaymentIntentTypeBolt11, + ), + IntentPayload: paymentCreationInfo. + PaymentRequest, + }) + if err != nil { + return fmt.Errorf("failed to initialize "+ + "payment intent: %w", err) + } + intentID = &intentIDValue + } + + // Only set the intent ID if it's not nil. + var intentIDParam sql.NullInt64 + if intentID != nil { + intentIDParam = sqldb.SQLInt64(*intentID) + } + + paymentID, err := db.InsertPayment(ctx, + sqlc.InsertPaymentParams{ + IntentID: intentIDParam, + AmountMsat: int64( + paymentCreationInfo.Value, + ), + CreatedAt: paymentCreationInfo. + CreationTime.UTC(), + PaymentIdentifier: paymentHash[:], + }) + if err != nil { + return fmt.Errorf("failed to insert payment: %w", err) + } + + firstHopCustomRecords := paymentCreationInfo. + FirstHopCustomRecords + + for key, value := range firstHopCustomRecords { + err = db.InsertPaymentFirstHopCustomRecord(ctx, + sqlc.InsertPaymentFirstHopCustomRecordParams{ + PaymentID: paymentID, + Key: int64(key), + Value: value, + }) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment first hop custom "+ + "record: %w", err) + } + } + + return nil + }, func() { + }) + if err != nil { + return fmt.Errorf("failed to initialize payment: %w", err) + } + + return nil +} + +// insertRouteHops inserts all route hop data for a given set of hops. +func (s *SQLStore) insertRouteHops(ctx context.Context, db SQLQueries, + hops []*route.Hop, attemptID uint64) error { + + for i, hop := range hops { + // Insert the basic route hop data and get the generated ID + hopID, err := db.InsertRouteHop(ctx, sqlc.InsertRouteHopParams{ + HtlcAttemptIndex: int64(attemptID), + HopIndex: int32(i), + PubKey: hop.PubKeyBytes[:], + Scid: strconv.FormatUint( + hop.ChannelID, 10, + ), + OutgoingTimeLock: int32(hop.OutgoingTimeLock), + AmtToForward: int64(hop.AmtToForward), + MetaData: hop.Metadata, + }) + if err != nil { + return fmt.Errorf("failed to insert route hop: %w", err) + } + + // Insert the per-hop custom records + if len(hop.CustomRecords) > 0 { + for key, value := range hop.CustomRecords { + err = db.InsertPaymentHopCustomRecord(ctx, + sqlc.InsertPaymentHopCustomRecordParams{ + HopID: hopID, + Key: int64(key), + Value: value, + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment hop custom "+ + "records: %w", err) + } + } + } + + // Insert MPP data if present + if hop.MPP != nil { + paymentAddr := hop.MPP.PaymentAddr() + err = db.InsertRouteHopMpp(ctx, + sqlc.InsertRouteHopMppParams{ + HopID: hopID, + PaymentAddr: paymentAddr[:], + TotalMsat: int64(hop.MPP.TotalMsat()), + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop MPP: %w", err) + } + } + + // Insert AMP data if present + if hop.AMP != nil { + rootShare := hop.AMP.RootShare() + setID := hop.AMP.SetID() + err = db.InsertRouteHopAmp(ctx, + sqlc.InsertRouteHopAmpParams{ + HopID: hopID, + RootShare: rootShare[:], + SetID: setID[:], + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop AMP: %w", err) + } + } + + // Insert blinded route data if present + if hop.EncryptedData != nil || hop.BlindingPoint != nil { + var blindingPointBytes []byte + if hop.BlindingPoint != nil { + blindingPointBytes = hop.BlindingPoint. + SerializeCompressed() + } + + err = db.InsertRouteHopBlinded(ctx, + sqlc.InsertRouteHopBlindedParams{ + HopID: hopID, + EncryptedData: hop.EncryptedData, + BlindingPoint: blindingPointBytes, + BlindedPathTotalAmt: sqldb.SQLInt64( + hop.TotalAmtMsat, + ), + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop blinded: %w", err) + } + } + } + + return nil +} + +// RegisterAttempt registers an attempt for a payment. +// +// This is part of the DB interface. +func (s *SQLStore) RegisterAttempt(ctx context.Context, + paymentHash lntypes.Hash, attempt *HTLCAttemptInfo) (*MPPayment, + error) { + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // 1. First Fetch the payment and check if it is registrable. + existingPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, existingPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + if err := mpPayment.Registrable(); err != nil { + return fmt.Errorf("htlc attempt not registrable: %w", + err) + } + + // Verify the attempt is compatible with the existing payment. + if err := verifyAttempt(mpPayment, attempt); err != nil { + return fmt.Errorf("failed to verify attempt: %w", err) + } + + // Fist register the plain HTLC attempt. + // Prepare the session key. + sessionKey := attempt.SessionKey() + sessionKeyBytes := sessionKey.Serialize() + + _, err = db.InsertHtlcAttempt(ctx, sqlc.InsertHtlcAttemptParams{ + PaymentID: existingPayment.Payment.ID, + AttemptIndex: int64(attempt.AttemptID), + SessionKey: sessionKeyBytes, + AttemptTime: attempt.AttemptTime, + PaymentHash: paymentHash[:], + FirstHopAmountMsat: int64( + attempt.Route.FirstHopAmount.Val.Int(), + ), + RouteTotalTimeLock: int32(attempt.Route.TotalTimeLock), + RouteTotalAmount: int64(attempt.Route.TotalAmount), + RouteSourceKey: attempt.Route.SourcePubKey[:], + }) + if err != nil { + return fmt.Errorf("failed to insert HTLC "+ + "attempt: %w", err) + } + + // Insert the route level first hop custom records. + attemptFirstHopCustomRecords := attempt.Route. + FirstHopWireCustomRecords + + for key, value := range attemptFirstHopCustomRecords { + err = db.InsertPaymentAttemptFirstHopCustomRecord(ctx, + //nolint:ll + sqlc.InsertPaymentAttemptFirstHopCustomRecordParams{ + HtlcAttemptIndex: int64(attempt.AttemptID), + Key: int64(key), + Value: value, + }) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment attempt first hop custom "+ + "record: %w", err) + } + } + + // Insert the route hops. + err = s.insertRouteHops( + ctx, db, attempt.Route.Hops, attempt.AttemptID, + ) + if err != nil { + return fmt.Errorf("failed to insert route hops: %w", + err) + } + + // Add the attempt to the payment without fetching it from the + // DB again. + mpPayment.HTLCs = append(mpPayment.HTLCs, HTLCAttempt{ + HTLCAttemptInfo: *attempt, + }) + + if err := mpPayment.SetState(); err != nil { + return fmt.Errorf("failed to set payment state: %w", + err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to register attempt: %w", err) + } + + return mpPayment, nil +} + +// SettleAttempt marks the given attempt settled with the preimage. +func (s *SQLStore) SettleAttempt(ctx context.Context, paymentHash lntypes.Hash, + attemptID uint64, settleInfo *HTLCSettleInfo) (*MPPayment, error) { + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // Before updating the attempt, we fetch the payment to get the + // payment ID. + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to fetch payment: %w", err) + } + if errors.Is(err, sql.ErrNoRows) { + return ErrPaymentNotInitiated + } + + err = db.SettleAttempt(ctx, sqlc.SettleAttemptParams{ + AttemptIndex: int64(attemptID), + ResolutionTime: time.Now(), + ResolutionType: int32(HTLCAttemptResolutionSettled), + SettlePreimage: settleInfo.Preimage[:], + }) + if err != nil { + return fmt.Errorf("failed to settle attempt: %w", err) + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to settle attempt: %w", err) + } + + return mpPayment, nil +} + +// FailAttempt marks the given attempt failed. +func (s *SQLStore) FailAttempt(ctx context.Context, paymentHash lntypes.Hash, + attemptID uint64, failInfo *HTLCFailInfo) (*MPPayment, error) { + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // Before updating the attempt, we fetch the payment to get the + // payment ID. + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to fetch payment: %w", err) + } + if errors.Is(err, sql.ErrNoRows) { + return ErrPaymentNotInitiated + } + + var failureMsg bytes.Buffer + if failInfo.Message != nil { + err := lnwire.EncodeFailureMessage( + &failureMsg, failInfo.Message, 0, + ) + if err != nil { + return fmt.Errorf("failed to encode "+ + "failure message: %w", err) + } + } + + err = db.FailAttempt(ctx, sqlc.FailAttemptParams{ + AttemptIndex: int64(attemptID), + ResolutionTime: time.Now(), + ResolutionType: int32(HTLCAttemptResolutionFailed), + FailureSourceIndex: sqldb.SQLInt32( + failInfo.FailureSourceIndex, + ), + HtlcFailReason: sqldb.SQLInt32(failInfo.Reason), + FailureMsg: failureMsg.Bytes(), + }) + if err != nil { + return fmt.Errorf("failed to fail attempt: %w", err) + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with"+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fail attempt: %w", err) + } + + return mpPayment, nil +} + +// Fail transitions a payment into the Failed state, and records the ultimate +// reason the payment failed. Note that this should only be called when all +// active attempts are already failed. After invoking this method, InitPayment +// should return nil on its next call for this payment hash, allowing the user +// to make a subsequent payments for the same payment hash. +func (s *SQLStore) Fail(ctx context.Context, paymentHash lntypes.Hash, + reason FailureReason) (*MPPayment, error) { + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + result, err := db.FailPayment(ctx, sqlc.FailPaymentParams{ + PaymentIdentifier: paymentHash[:], + FailReason: sqldb.SQLInt32(reason), + }) + if err != nil { + return fmt.Errorf("failed to fail payment: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", + err) + } + if rowsAffected == 0 { + return ErrPaymentNotInitiated + } + + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fail payment: %w", err) + } + + return mpPayment, nil +} + +// DeletePayments deletes all payments from the DB given the specified flags. +// +// TODO(ziggie): batch and use iterator instead. +func (s *SQLStore) DeletePayments(ctx context.Context, failedOnly, + failedHtlcsOnly bool) (int, error) { + + var numPayments int + + extractCursor := func( + row sqlc.FilterPaymentsRow) int64 { + + return row.Payment.ID + } + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + processPayment := func(ctx context.Context, + dbPayment sqlc.FilterPaymentsRow) error { + + // Fetch all the additional data for the payment. + mpPayment, err := s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + + // Payments which are not final yet cannot be deleted. + // we skip them. + if err := mpPayment.Status.removable(); err != nil { + return nil + } + + // If we are only deleting failed payments, we skip + // if the payment is not failed. + if failedOnly && mpPayment.Status != StatusFailed { + return nil + } + + // If we are only deleting failed HTLCs, we delete them + // and return early. + if failedHtlcsOnly { + return db.DeleteFailedAttempts( + ctx, dbPayment.Payment.ID, + ) + } + + // Otherwise we delete the payment. + err = db.DeletePayment(ctx, dbPayment.Payment.ID) + if err != nil { + return fmt.Errorf("failed to delete "+ + "payment: %w", err) + } + + numPayments++ + + return nil + } + + queryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.FilterPaymentsRow, error) { + + filterParams := sqlc.FilterPaymentsParams{ + NumLimit: limit, + // For now there are only BOLT 11 payment + // intents. + IntentType: sqldb.SQLInt16( + PaymentIntentTypeBolt11, + ), + IndexOffsetGet: sqldb.SQLInt64( + lastID, + ), + } + + return db.FilterPayments(ctx, filterParams) + } + + return sqldb.ExecutePaginatedQuery( + ctx, s.cfg.QueryCfg, int64(-1), queryFunc, + extractCursor, processPayment, + ) + }, func() { + numPayments = 0 + }) + if err != nil { + return 0, fmt.Errorf("failed to delete payments "+ + "(failedOnly: %v, failedHtlcsOnly: %v): %w", + failedOnly, failedHtlcsOnly, err) + } + + return numPayments, nil +} diff --git a/payments/db/test_kvdb.go b/payments/db/test_kvdb.go index e0ee1738d7a..a4bbfccbd97 100644 --- a/payments/db/test_kvdb.go +++ b/payments/db/test_kvdb.go @@ -1,3 +1,5 @@ +//go:build !test_db_sqlite && !test_db_postgres + package paymentsdb import ( diff --git a/payments/db/test_postgres.go b/payments/db/test_postgres.go new file mode 100644 index 00000000000..1a20c7c0dfb --- /dev/null +++ b/payments/db/test_postgres.go @@ -0,0 +1,95 @@ +//go:build test_db_postgres && !test_db_sqlite + +package paymentsdb + +import ( + "database/sql" + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB, opts ...OptionModifier) DB { + return NewTestDBWithFixture(t, nil, opts...) +} + +// NewTestDBFixture creates a new sqldb.TestPgFixture for testing purposes. +func NewTestDBFixture(t *testing.T) *sqldb.TestPgFixture { + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + return pgFixture +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture, opts ...OptionModifier) DB { + + var querier BatchedSQLQueries + if pgFixture == nil { + querier = newBatchQuerier(t) + } else { + querier = newBatchQuerierWithFixture(t, pgFixture) + } + + store, err := NewSQLStore( + &SQLStoreConfig{ + QueryCfg: sqldb.DefaultPostgresConfig(), + }, querier, opts..., + ) + require.NoError(t, err) + + return store +} + +// newBatchQuerier creates a new BatchedSQLQueries instance for testing +// using a PostgreSQL database fixture. +func newBatchQuerier(t testing.TB) BatchedSQLQueries { + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + return newBatchQuerierWithFixture(t, pgFixture) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a PostgreSQL database fixture. +func newBatchQuerierWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture) BatchedSQLQueries { + + db := sqldb.NewTestPostgresDB(t, pgFixture).BaseDB + + return sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLQueries { + return db.WithTx(tx) + }, + ) +} + +// NewKVTestDB is a helper function that creates an BBolt database for testing +// and there is no need to convert the interface to the KVStore because for +// some unit tests we still need access to the kvdb interface. +func NewKVTestDB(t *testing.T, opts ...OptionModifier) *KVStore { + backend, backendCleanup, err := kvdb.GetTestBackend( + t.TempDir(), "kvPaymentDB", + ) + require.NoError(t, err) + + t.Cleanup(backendCleanup) + + paymentDB, err := NewKVStore(backend, opts...) + require.NoError(t, err) + + return paymentDB +} diff --git a/payments/db/test_sqlite.go b/payments/db/test_sqlite.go new file mode 100644 index 00000000000..d751bb9977f --- /dev/null +++ b/payments/db/test_sqlite.go @@ -0,0 +1,74 @@ +//go:build !test_db_postgres && test_db_sqlite + +package paymentsdb + +import ( + "database/sql" + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB, opts ...OptionModifier) DB { + return NewTestDBWithFixture(t, nil, opts...) +} + +// NewTestDBFixture is a no-op for the sqlite build. +func NewTestDBFixture(_ *testing.T) *sqldb.TestPgFixture { + return nil +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, _ *sqldb.TestPgFixture, + opts ...OptionModifier) DB { + + store, err := NewSQLStore( + &SQLStoreConfig{ + QueryCfg: sqldb.DefaultSQLiteConfig(), + }, newBatchQuerier(t), opts..., + ) + require.NoError(t, err) + return store +} + +// newBatchQuerier creates a new BatchedSQLQueries instance for testing +// using a SQLite database. +func newBatchQuerier(t testing.TB) BatchedSQLQueries { + return newBatchQuerierWithFixture(t, nil) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a SQLite database. +func newBatchQuerierWithFixture(t testing.TB, + _ *sqldb.TestPgFixture) BatchedSQLQueries { + + db := sqldb.NewTestSqliteDB(t).BaseDB + + return sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLQueries { + return db.WithTx(tx) + }, + ) +} + +// NewKVTestDB is a helper function that creates an BBolt database for testing +// and there is no need to convert the interface to the KVStore because for +// some unit tests we still need access to the kvdb interface. +func NewKVTestDB(t *testing.T, opts ...OptionModifier) *KVStore { + backend, backendCleanup, err := kvdb.GetTestBackend( + t.TempDir(), "kvPaymentDB", + ) + require.NoError(t, err) + + t.Cleanup(backendCleanup) + + paymentDB, err := NewKVStore(backend, opts...) + require.NoError(t, err) + + return paymentDB +} diff --git a/routing/control_tower.go b/routing/control_tower.go index 2b9e7dd9d28..1c246f17d9b 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -1,6 +1,7 @@ package routing import ( + "context" "sync" "github.com/lightningnetwork/lnd/lntypes" @@ -19,17 +20,19 @@ type ControlTower interface { // also notifies subscribers of the payment creation. // // NOTE: Subscribers should be notified by the new state of the payment. - InitPayment(lntypes.Hash, *paymentsdb.PaymentCreationInfo) error + InitPayment(context.Context, lntypes.Hash, + *paymentsdb.PaymentCreationInfo) error // DeleteFailedAttempts removes all failed HTLCs from the db. It should // be called for a given payment whenever all inflight htlcs are // completed, and the payment has reached a final settled state. - DeleteFailedAttempts(lntypes.Hash) error + DeleteFailedAttempts(context.Context, lntypes.Hash) error // RegisterAttempt atomically records the provided HTLCAttemptInfo. // // NOTE: Subscribers should be notified by the new state of the payment. - RegisterAttempt(lntypes.Hash, *paymentsdb.HTLCAttemptInfo) error + RegisterAttempt(context.Context, lntypes.Hash, + *paymentsdb.HTLCAttemptInfo) error // SettleAttempt marks the given attempt settled with the preimage. If // this is a multi shard payment, this might implicitly mean the the @@ -41,18 +44,19 @@ type ControlTower interface { // for record keeping. // // NOTE: Subscribers should be notified by the new state of the payment. - SettleAttempt(lntypes.Hash, uint64, *paymentsdb.HTLCSettleInfo) ( - *paymentsdb.HTLCAttempt, error) + SettleAttempt(context.Context, lntypes.Hash, uint64, + *paymentsdb.HTLCSettleInfo) (*paymentsdb.HTLCAttempt, error) // FailAttempt marks the given payment attempt failed. // // NOTE: Subscribers should be notified by the new state of the payment. - FailAttempt(lntypes.Hash, uint64, *paymentsdb.HTLCFailInfo) ( - *paymentsdb.HTLCAttempt, error) + FailAttempt(context.Context, lntypes.Hash, uint64, + *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, error) // FetchPayment fetches the payment corresponding to the given payment // hash. - FetchPayment(paymentHash lntypes.Hash) (paymentsdb.DBMPPayment, error) + FetchPayment(ctx context.Context, + paymentHash lntypes.Hash) (paymentsdb.DBMPPayment, error) // FailPayment transitions a payment into the Failed state, and records // the ultimate reason the payment failed. Note that this should only @@ -62,10 +66,12 @@ type ControlTower interface { // payment. // // NOTE: Subscribers should be notified by the new state of the payment. - FailPayment(lntypes.Hash, paymentsdb.FailureReason) error + FailPayment(context.Context, lntypes.Hash, + paymentsdb.FailureReason) error // FetchInFlightPayments returns all payments with status InFlight. - FetchInFlightPayments() ([]*paymentsdb.MPPayment, error) + FetchInFlightPayments(ctx context.Context) ([]*paymentsdb.MPPayment, + error) // SubscribePayment subscribes to updates for the payment with the given // hash. A first update with the current state of the payment is always @@ -161,10 +167,10 @@ func NewControlTower(db paymentsdb.DB) ControlTower { // making sure it does not already exist as an in-flight payment. Then this // method returns successfully, the payment is guaranteed to be in the // Initiated state. -func (p *controlTower) InitPayment(paymentHash lntypes.Hash, - info *paymentsdb.PaymentCreationInfo) error { +func (p *controlTower) InitPayment(ctx context.Context, + paymentHash lntypes.Hash, info *paymentsdb.PaymentCreationInfo) error { - err := p.db.InitPayment(paymentHash, info) + err := p.db.InitPayment(ctx, paymentHash, info) if err != nil { return err } @@ -174,7 +180,7 @@ func (p *controlTower) InitPayment(paymentHash lntypes.Hash, p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) - payment, err := p.db.FetchPayment(paymentHash) + payment, err := p.db.FetchPayment(ctx, paymentHash) if err != nil { return err } @@ -186,19 +192,21 @@ func (p *controlTower) InitPayment(paymentHash lntypes.Hash, // DeleteFailedAttempts deletes all failed htlcs if the payment was // successfully settled. -func (p *controlTower) DeleteFailedAttempts(paymentHash lntypes.Hash) error { - return p.db.DeleteFailedAttempts(paymentHash) +func (p *controlTower) DeleteFailedAttempts(ctx context.Context, + paymentHash lntypes.Hash) error { + + return p.db.DeleteFailedAttempts(ctx, paymentHash) } // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. -func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash, - attempt *paymentsdb.HTLCAttemptInfo) error { +func (p *controlTower) RegisterAttempt(ctx context.Context, + paymentHash lntypes.Hash, attempt *paymentsdb.HTLCAttemptInfo) error { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) - payment, err := p.db.RegisterAttempt(paymentHash, attempt) + payment, err := p.db.RegisterAttempt(ctx, paymentHash, attempt) if err != nil { return err } @@ -212,14 +220,17 @@ func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash, // SettleAttempt marks the given attempt settled with the preimage. If // this is a multi shard payment, this might implicitly mean the the // full payment succeeded. -func (p *controlTower) SettleAttempt(paymentHash lntypes.Hash, - attemptID uint64, settleInfo *paymentsdb.HTLCSettleInfo) ( - *paymentsdb.HTLCAttempt, error) { +func (p *controlTower) SettleAttempt(ctx context.Context, + paymentHash lntypes.Hash, attemptID uint64, + settleInfo *paymentsdb.HTLCSettleInfo) (*paymentsdb.HTLCAttempt, + error) { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) - payment, err := p.db.SettleAttempt(paymentHash, attemptID, settleInfo) + payment, err := p.db.SettleAttempt( + ctx, paymentHash, attemptID, settleInfo, + ) if err != nil { return nil, err } @@ -231,14 +242,14 @@ func (p *controlTower) SettleAttempt(paymentHash lntypes.Hash, } // FailAttempt marks the given payment attempt failed. -func (p *controlTower) FailAttempt(paymentHash lntypes.Hash, - attemptID uint64, failInfo *paymentsdb.HTLCFailInfo) ( - *paymentsdb.HTLCAttempt, error) { +func (p *controlTower) FailAttempt(ctx context.Context, + paymentHash lntypes.Hash, attemptID uint64, + failInfo *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, error) { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) - payment, err := p.db.FailAttempt(paymentHash, attemptID, failInfo) + payment, err := p.db.FailAttempt(ctx, paymentHash, attemptID, failInfo) if err != nil { return nil, err } @@ -250,10 +261,11 @@ func (p *controlTower) FailAttempt(paymentHash lntypes.Hash, } // FetchPayment fetches the payment corresponding to the given payment hash. -func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) ( +func (p *controlTower) FetchPayment(ctx context.Context, + paymentHash lntypes.Hash) ( paymentsdb.DBMPPayment, error) { - return p.db.FetchPayment(paymentHash) + return p.db.FetchPayment(ctx, paymentHash) } // FailPayment transitions a payment into the Failed state, and records the @@ -263,13 +275,13 @@ func (p *controlTower) FetchPayment(paymentHash lntypes.Hash) ( // // NOTE: This method will overwrite the failure reason if the payment is already // failed. -func (p *controlTower) FailPayment(paymentHash lntypes.Hash, - reason paymentsdb.FailureReason) error { +func (p *controlTower) FailPayment(ctx context.Context, + paymentHash lntypes.Hash, reason paymentsdb.FailureReason) error { p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) - payment, err := p.db.Fail(paymentHash, reason) + payment, err := p.db.Fail(ctx, paymentHash, reason) if err != nil { return err } @@ -281,10 +293,10 @@ func (p *controlTower) FailPayment(paymentHash lntypes.Hash, } // FetchInFlightPayments returns all payments with status InFlight. -func (p *controlTower) FetchInFlightPayments() ([]*paymentsdb.MPPayment, - error) { +func (p *controlTower) FetchInFlightPayments( + ctx context.Context) ([]*paymentsdb.MPPayment, error) { - return p.db.FetchInFlightPayments() + return p.db.FetchInFlightPayments(ctx) } // SubscribePayment subscribes to updates for the payment with the given hash. A @@ -293,12 +305,14 @@ func (p *controlTower) FetchInFlightPayments() ([]*paymentsdb.MPPayment, func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( ControlTowerSubscriber, error) { + ctx := context.TODO() + // Take lock before querying the db to prevent missing or duplicating an // update. p.paymentsMtx.Lock(paymentHash) defer p.paymentsMtx.Unlock(paymentHash) - payment, err := p.db.FetchPayment(paymentHash) + payment, err := p.db.FetchPayment(ctx, paymentHash) if err != nil { return nil, err } @@ -335,6 +349,8 @@ func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) { subscriber := newControlTowerSubscriber() + ctx := context.TODO() + // Add the subscriber to the list before fetching in-flight payments, so // no events are missed. If a payment attempt update occurs after // appending and before fetching in-flight payments, an out-of-order @@ -346,7 +362,7 @@ func (p *controlTower) SubscribeAllPayments() (ControlTowerSubscriber, error) { p.subscribersMtx.Unlock() log.Debugf("Scanning for inflight payments") - inflightPayments, err := p.db.FetchInFlightPayments() + inflightPayments, err := p.db.FetchInFlightPayments(ctx) if err != nil { return nil, err } diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index de0aacf880b..c9e8f485733 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -81,7 +81,7 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { t.Fatal(err) } - err = pControl.InitPayment(info.PaymentIdentifier, info) + err = pControl.InitPayment(t.Context(), info.PaymentIdentifier, info) if err != nil { t.Fatal(err) } @@ -92,7 +92,9 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { require.NoError(t, err, "expected subscribe to succeed, but got") // Register an attempt. - err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) + err = pControl.RegisterAttempt( + t.Context(), info.PaymentIdentifier, attempt, + ) if err != nil { t.Fatal(err) } @@ -106,7 +108,8 @@ func TestControlTowerSubscribeSuccess(t *testing.T) { Preimage: preimg, } htlcAttempt, err := pControl.SettleAttempt( - info.PaymentIdentifier, attempt.AttemptID, &settleInfo, + t.Context(), info.PaymentIdentifier, attempt.AttemptID, + &settleInfo, ) if err != nil { t.Fatal(err) @@ -212,7 +215,7 @@ func TestKVStoreSubscribeAllSuccess(t *testing.T) { info1, attempt1, preimg1, err := genInfo() require.NoError(t, err) - err = pControl.InitPayment(info1.PaymentIdentifier, info1) + err = pControl.InitPayment(t.Context(), info1.PaymentIdentifier, info1) require.NoError(t, err) // Subscription should succeed and immediately report the Initiated @@ -221,18 +224,22 @@ func TestKVStoreSubscribeAllSuccess(t *testing.T) { require.NoError(t, err, "expected subscribe to succeed, but got: %v") // Register an attempt. - err = pControl.RegisterAttempt(info1.PaymentIdentifier, attempt1) + err = pControl.RegisterAttempt( + t.Context(), info1.PaymentIdentifier, attempt1, + ) require.NoError(t, err) // Initiate a second payment after the subscription is already active. info2, attempt2, preimg2, err := genInfo() require.NoError(t, err) - err = pControl.InitPayment(info2.PaymentIdentifier, info2) + err = pControl.InitPayment(t.Context(), info2.PaymentIdentifier, info2) require.NoError(t, err) // Register an attempt on the second payment. - err = pControl.RegisterAttempt(info2.PaymentIdentifier, attempt2) + err = pControl.RegisterAttempt( + t.Context(), info2.PaymentIdentifier, attempt2, + ) require.NoError(t, err) // Mark the first payment as successful. @@ -240,7 +247,8 @@ func TestKVStoreSubscribeAllSuccess(t *testing.T) { Preimage: preimg1, } htlcAttempt1, err := pControl.SettleAttempt( - info1.PaymentIdentifier, attempt1.AttemptID, &settleInfo1, + t.Context(), info1.PaymentIdentifier, attempt1.AttemptID, + &settleInfo1, ) require.NoError(t, err) require.Equal( @@ -253,7 +261,8 @@ func TestKVStoreSubscribeAllSuccess(t *testing.T) { Preimage: preimg2, } htlcAttempt2, err := pControl.SettleAttempt( - info2.PaymentIdentifier, attempt2.AttemptID, &settleInfo2, + t.Context(), info2.PaymentIdentifier, attempt2.AttemptID, + &settleInfo2, ) require.NoError(t, err) require.Equal( @@ -337,11 +346,13 @@ func TestKVStoreSubscribeAllImmediate(t *testing.T) { info, attempt, _, err := genInfo() require.NoError(t, err) - err = pControl.InitPayment(info.PaymentIdentifier, info) + err = pControl.InitPayment(t.Context(), info.PaymentIdentifier, info) require.NoError(t, err) // Register a payment update. - err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) + err = pControl.RegisterAttempt( + t.Context(), info.PaymentIdentifier, attempt, + ) require.NoError(t, err) subscription, err := pControl.SubscribeAllPayments() @@ -392,7 +403,7 @@ func TestKVStoreUnsubscribeSuccess(t *testing.T) { info, attempt, _, err := genInfo() require.NoError(t, err) - err = pControl.InitPayment(info.PaymentIdentifier, info) + err = pControl.InitPayment(t.Context(), info.PaymentIdentifier, info) require.NoError(t, err) // Assert all subscriptions receive the update. @@ -414,7 +425,9 @@ func TestKVStoreUnsubscribeSuccess(t *testing.T) { subscription1.Close() // Register a payment update. - err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) + err = pControl.RegisterAttempt( + t.Context(), info.PaymentIdentifier, attempt, + ) require.NoError(t, err) // Assert only subscription 2 receives the update. @@ -435,7 +448,8 @@ func TestKVStoreUnsubscribeSuccess(t *testing.T) { Reason: paymentsdb.HTLCFailInternal, } _, err = pControl.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, &failInfo, + t.Context(), info.PaymentIdentifier, attempt.AttemptID, + &failInfo, ) require.NoError(t, err, "unable to fail htlc") @@ -465,7 +479,7 @@ func testKVStoreSubscribeFail(t *testing.T, registerAttempt, t.Fatal(err) } - err = pControl.InitPayment(info.PaymentIdentifier, info) + err = pControl.InitPayment(t.Context(), info.PaymentIdentifier, info) if err != nil { t.Fatal(err) } @@ -479,17 +493,18 @@ func testKVStoreSubscribeFail(t *testing.T, registerAttempt, // making any attempts at all. if registerAttempt { // Register an attempt. - err = pControl.RegisterAttempt(info.PaymentIdentifier, attempt) - if err != nil { - t.Fatal(err) - } + err = pControl.RegisterAttempt( + t.Context(), info.PaymentIdentifier, attempt, + ) + require.NoError(t, err) // Fail the payment attempt. failInfo := paymentsdb.HTLCFailInfo{ Reason: paymentsdb.HTLCFailInternal, } htlcAttempt, err := pControl.FailAttempt( - info.PaymentIdentifier, attempt.AttemptID, &failInfo, + t.Context(), info.PaymentIdentifier, attempt.AttemptID, + &failInfo, ) if err != nil { t.Fatalf("unable to fail htlc: %v", err) @@ -501,7 +516,8 @@ func testKVStoreSubscribeFail(t *testing.T, registerAttempt, // Mark the payment as failed. err = pControl.FailPayment( - info.PaymentIdentifier, paymentsdb.FailureReasonTimeout, + t.Context(), info.PaymentIdentifier, + paymentsdb.FailureReasonTimeout, ) if err != nil { t.Fatal(err) diff --git a/routing/mock_test.go b/routing/mock_test.go index 19a76ee9010..472f1261623 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "errors" "fmt" "sync" @@ -296,8 +297,8 @@ func makeMockControlTower() *mockControlTowerOld { } } -func (m *mockControlTowerOld) InitPayment(phash lntypes.Hash, - c *paymentsdb.PaymentCreationInfo) error { +func (m *mockControlTowerOld) InitPayment(_ context.Context, + phash lntypes.Hash, c *paymentsdb.PaymentCreationInfo) error { if m.init != nil { m.init <- initArgs{c} @@ -327,7 +328,9 @@ func (m *mockControlTowerOld) InitPayment(phash lntypes.Hash, return nil } -func (m *mockControlTowerOld) DeleteFailedAttempts(phash lntypes.Hash) error { +func (m *mockControlTowerOld) DeleteFailedAttempts(_ context.Context, + phash lntypes.Hash) error { + p, ok := m.payments[phash] if !ok { return paymentsdb.ErrPaymentNotInitiated @@ -353,8 +356,8 @@ func (m *mockControlTowerOld) DeleteFailedAttempts(phash lntypes.Hash) error { return nil } -func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash, - a *paymentsdb.HTLCAttemptInfo) error { +func (m *mockControlTowerOld) RegisterAttempt(_ context.Context, + phash lntypes.Hash, a *paymentsdb.HTLCAttemptInfo) error { if m.registerAttempt != nil { m.registerAttempt <- registerAttemptArgs{a} @@ -407,8 +410,8 @@ func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash, return nil } -func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash, - pid uint64, settleInfo *paymentsdb.HTLCSettleInfo) ( +func (m *mockControlTowerOld) SettleAttempt(_ context.Context, + phash lntypes.Hash, pid uint64, settleInfo *paymentsdb.HTLCSettleInfo) ( *paymentsdb.HTLCAttempt, error) { if m.settleAttempt != nil { @@ -450,8 +453,9 @@ func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash, return nil, fmt.Errorf("pid not found") } -func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64, - failInfo *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, error) { +func (m *mockControlTowerOld) FailAttempt(_ context.Context, phash lntypes.Hash, + pid uint64, failInfo *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, + error) { if m.failAttempt != nil { m.failAttempt <- failAttemptArgs{failInfo} @@ -489,7 +493,7 @@ func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64, return nil, fmt.Errorf("pid not found") } -func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash, +func (m *mockControlTowerOld) FailPayment(_ context.Context, phash lntypes.Hash, reason paymentsdb.FailureReason) error { m.Lock() @@ -509,8 +513,8 @@ func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash, return nil } -func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) ( - paymentsdb.DBMPPayment, error) { +func (m *mockControlTowerOld) FetchPayment(_ context.Context, + phash lntypes.Hash) (paymentsdb.DBMPPayment, error) { m.Lock() defer m.Unlock() @@ -545,7 +549,7 @@ func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) ( return mp, nil } -func (m *mockControlTowerOld) FetchInFlightPayments() ( +func (m *mockControlTowerOld) FetchInFlightPayments(_ context.Context) ( []*paymentsdb.MPPayment, error) { if m.fetchInFlight != nil { @@ -733,26 +737,28 @@ type mockControlTower struct { var _ ControlTower = (*mockControlTower)(nil) -func (m *mockControlTower) InitPayment(phash lntypes.Hash, +func (m *mockControlTower) InitPayment(_ context.Context, phash lntypes.Hash, c *paymentsdb.PaymentCreationInfo) error { args := m.Called(phash, c) return args.Error(0) } -func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error { +func (m *mockControlTower) DeleteFailedAttempts(_ context.Context, + phash lntypes.Hash) error { + args := m.Called(phash) return args.Error(0) } -func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, - a *paymentsdb.HTLCAttemptInfo) error { +func (m *mockControlTower) RegisterAttempt(_ context.Context, + phash lntypes.Hash, a *paymentsdb.HTLCAttemptInfo) error { args := m.Called(phash, a) return args.Error(0) } -func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, +func (m *mockControlTower) SettleAttempt(_ context.Context, phash lntypes.Hash, pid uint64, settleInfo *paymentsdb.HTLCSettleInfo) ( *paymentsdb.HTLCAttempt, error) { @@ -766,8 +772,9 @@ func (m *mockControlTower) SettleAttempt(phash lntypes.Hash, return attempt.(*paymentsdb.HTLCAttempt), args.Error(1) } -func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, - failInfo *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, error) { +func (m *mockControlTower) FailAttempt(_ context.Context, phash lntypes.Hash, + pid uint64, failInfo *paymentsdb.HTLCFailInfo) (*paymentsdb.HTLCAttempt, + error) { args := m.Called(phash, pid, failInfo) @@ -779,15 +786,15 @@ func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64, return attempt.(*paymentsdb.HTLCAttempt), args.Error(1) } -func (m *mockControlTower) FailPayment(phash lntypes.Hash, +func (m *mockControlTower) FailPayment(_ context.Context, phash lntypes.Hash, reason paymentsdb.FailureReason) error { args := m.Called(phash, reason) return args.Error(0) } -func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( - paymentsdb.DBMPPayment, error) { +func (m *mockControlTower) FetchPayment(_ context.Context, + phash lntypes.Hash) (paymentsdb.DBMPPayment, error) { args := m.Called(phash) @@ -800,7 +807,7 @@ func (m *mockControlTower) FetchPayment(phash lntypes.Hash) ( return payment, args.Error(1) } -func (m *mockControlTower) FetchInFlightPayments() ( +func (m *mockControlTower) FetchInFlightPayments(_ context.Context) ( []*paymentsdb.MPPayment, error) { args := m.Called() diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 8353cba157f..39c74d90f16 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -128,7 +128,7 @@ const ( // results is sent back. then process its result here. When there's no need to // wait for results, the method will exit with `stepExit` such that the payment // lifecycle loop will terminate. -func (p *paymentLifecycle) decideNextStep( +func (p *paymentLifecycle) decideNextStep(ctx context.Context, payment paymentsdb.DBMPPayment) (stateStep, error) { // Check whether we could make new HTLC attempts. @@ -168,7 +168,7 @@ func (p *paymentLifecycle) decideNextStep( // stepSkip and move to the next lifecycle iteration, which will // refresh the payment and wait for the next attempt result, if // any. - _, err := p.handleAttemptResult(r.attempt, r.result) + _, err := p.handleAttemptResult(ctx, r.attempt, r.result) // We would only get a DB-related error here, which will cause // us to abort the payment flow. @@ -198,7 +198,7 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, // If we had any existing attempts outstanding, we'll start by spinning // up goroutines that'll collect their results and deliver them to the // lifecycle loop below. - payment, err := p.reloadInflightAttempts() + payment, err := p.reloadInflightAttempts(ctx) if err != nil { return [32]byte{}, nil, err } @@ -239,7 +239,7 @@ lifecycle: } // We update the payment state on every iteration. - currentPayment, ps, err := p.reloadPayment() + currentPayment, ps, err := p.reloadPayment(ctx) if err != nil { return exitWithErr(err) } @@ -260,7 +260,7 @@ lifecycle: // // Now decide the next step of the current lifecycle. - step, err := p.decideNextStep(payment) + step, err := p.decideNextStep(ctx, payment) if err != nil { return exitWithErr(err) } @@ -284,7 +284,7 @@ lifecycle: } // Now request a route to be used to create our HTLC attempt. - rt, err := p.requestRoute(ps) + rt, err := p.requestRoute(ctx, ps) if err != nil { return exitWithErr(err) } @@ -303,13 +303,13 @@ lifecycle: log.Tracef("Found route: %s", lnutils.SpewLogClosure(rt.Hops)) // We found a route to try, create a new HTLC attempt to try. - attempt, err := p.registerAttempt(rt, ps.RemainingAmt) + attempt, err := p.registerAttempt(ctx, rt, ps.RemainingAmt) if err != nil { return exitWithErr(err) } // Once the attempt is created, send it to the htlcswitch. - result, err := p.sendAttempt(attempt) + result, err := p.sendAttempt(ctx, attempt) if err != nil { return exitWithErr(err) } @@ -328,7 +328,7 @@ lifecycle: // Optionally delete the failed attempts from the database. Depends on // the database options deleting attempts is not allowed so this will // just be a no-op. - err = p.router.cfg.Control.DeleteFailedAttempts(p.identifier) + err = p.router.cfg.Control.DeleteFailedAttempts(ctx, p.identifier) if err != nil { log.Errorf("Error deleting failed htlc attempts for payment "+ "%v: %v", p.identifier, err) @@ -368,7 +368,9 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error { // inflight HTLCs or not, its status will now either be // `StatusInflight` or `StatusFailed`. In either case, no more // HTLCs will be attempted. - err := p.router.cfg.Control.FailPayment(p.identifier, reason) + err := p.router.cfg.Control.FailPayment( + ctx, p.identifier, reason, + ) if err != nil { return fmt.Errorf("FailPayment got %w", err) } @@ -386,7 +388,7 @@ func (p *paymentLifecycle) checkContext(ctx context.Context) error { // requestRoute is responsible for finding a route to be used to create an HTLC // attempt. -func (p *paymentLifecycle) requestRoute( +func (p *paymentLifecycle) requestRoute(ctx context.Context, ps *paymentsdb.MPPaymentState) (*route.Route, error) { remainingFees := p.calcFeeBudget(ps.FeesPaid) @@ -430,7 +432,9 @@ func (p *paymentLifecycle) requestRoute( log.Warnf("Marking payment %v permanently failed with no route: %v", p.identifier, failureCode) - err = p.router.cfg.Control.FailPayment(p.identifier, failureCode) + err = p.router.cfg.Control.FailPayment( + ctx, p.identifier, failureCode, + ) if err != nil { return nil, fmt.Errorf("FailPayment got: %w", err) } @@ -581,7 +585,7 @@ func (p *paymentLifecycle) collectResult( // registerAttempt is responsible for creating and saving an HTLC attempt in db // by using the route info provided. The `remainingAmt` is used to decide // whether this is the last attempt. -func (p *paymentLifecycle) registerAttempt(rt *route.Route, +func (p *paymentLifecycle) registerAttempt(ctx context.Context, rt *route.Route, remainingAmt lnwire.MilliSatoshi) (*paymentsdb.HTLCAttempt, error) { // If this route will consume the last remaining amount to send @@ -601,7 +605,7 @@ func (p *paymentLifecycle) registerAttempt(rt *route.Route, // Switch for its whereabouts. The route is needed to handle the result // when it eventually comes back. err = p.router.cfg.Control.RegisterAttempt( - p.identifier, &attempt.HTLCAttemptInfo, + ctx, p.identifier, &attempt.HTLCAttemptInfo, ) return attempt, err @@ -657,7 +661,7 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, // sendAttempt attempts to send the current attempt to the switch to complete // the payment. If this attempt fails, then we'll continue on to the next // available route. -func (p *paymentLifecycle) sendAttempt( +func (p *paymentLifecycle) sendAttempt(ctx context.Context, attempt *paymentsdb.HTLCAttempt) (*attemptResult, error) { log.Debugf("Sending HTLC attempt(id=%v, total_amt=%v, first_hop_amt=%d"+ @@ -689,7 +693,7 @@ func (p *paymentLifecycle) sendAttempt( "payment=%v, err:%v", attempt.AttemptID, p.identifier, err) - return p.failAttempt(attempt.AttemptID, err) + return p.failAttempt(ctx, attempt.AttemptID, err) } htlcAdd.OnionBlob = onionBlob @@ -703,7 +707,7 @@ func (p *paymentLifecycle) sendAttempt( log.Errorf("Failed sending attempt %d for payment %v to "+ "switch: %v", attempt.AttemptID, p.identifier, err) - return p.handleSwitchErr(attempt, err) + return p.handleSwitchErr(ctx, attempt, err) } log.Debugf("Attempt %v for payment %v successfully sent to switch, "+ @@ -794,7 +798,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { // failAttemptAndPayment fails both the payment and its attempt via the // router's control tower, which marks the payment as failed in db. -func (p *paymentLifecycle) failPaymentAndAttempt( +func (p *paymentLifecycle) failPaymentAndAttempt(ctx context.Context, attemptID uint64, reason *paymentsdb.FailureReason, sendErr error) (*attemptResult, error) { @@ -806,14 +810,16 @@ func (p *paymentLifecycle) failPaymentAndAttempt( // NOTE: we must fail the payment first before failing the attempt. // Otherwise, once the attempt is marked as failed, another goroutine // might make another attempt while we are failing the payment. - err := p.router.cfg.Control.FailPayment(p.identifier, *reason) + err := p.router.cfg.Control.FailPayment( + ctx, p.identifier, *reason, + ) if err != nil { log.Errorf("Unable to fail payment: %v", err) return nil, err } // Fail the attempt. - return p.failAttempt(attemptID, sendErr) + return p.failAttempt(ctx, attemptID, sendErr) } // handleSwitchErr inspects the given error from the Switch and determines @@ -824,7 +830,8 @@ func (p *paymentLifecycle) failPaymentAndAttempt( // the error type, the error is either the final outcome of the payment or we // need to continue with an alternative route. A final outcome is indicated by // a non-nil reason value. -func (p *paymentLifecycle) handleSwitchErr(attempt *paymentsdb.HTLCAttempt, +func (p *paymentLifecycle) handleSwitchErr(ctx context.Context, + attempt *paymentsdb.HTLCAttempt, sendErr error) (*attemptResult, error) { internalErrorReason := paymentsdb.FailureReasonError @@ -851,11 +858,11 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *paymentsdb.HTLCAttempt, // Fail the attempt only if there's no reason. if reason == nil { // Fail the attempt. - return p.failAttempt(attemptID, sendErr) + return p.failAttempt(ctx, attemptID, sendErr) } // Otherwise fail both the payment and the attempt. - return p.failPaymentAndAttempt(attemptID, reason, sendErr) + return p.failPaymentAndAttempt(ctx, attemptID, reason, sendErr) } // If this attempt ID is unknown to the Switch, it means it was never @@ -866,7 +873,7 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *paymentsdb.HTLCAttempt, log.Warnf("Failing attempt=%v for payment=%v as it's not "+ "found in the Switch", attempt.AttemptID, p.identifier) - return p.failAttempt(attemptID, sendErr) + return p.failAttempt(ctx, attemptID, sendErr) } if errors.Is(sendErr, htlcswitch.ErrUnreadableFailureMessage) { @@ -888,7 +895,7 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *paymentsdb.HTLCAttempt, ok := errors.As(sendErr, &rtErr) if !ok { return p.failPaymentAndAttempt( - attemptID, &internalErrorReason, sendErr, + ctx, attemptID, &internalErrorReason, sendErr, ) } @@ -914,7 +921,7 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *paymentsdb.HTLCAttempt, ) if err != nil { return p.failPaymentAndAttempt( - attemptID, &internalErrorReason, sendErr, + ctx, attemptID, &internalErrorReason, sendErr, ) } @@ -998,7 +1005,7 @@ func (p *paymentLifecycle) handleFailureMessage(rt *route.Route, } // failAttempt calls control tower to fail the current payment attempt. -func (p *paymentLifecycle) failAttempt(attemptID uint64, +func (p *paymentLifecycle) failAttempt(ctx context.Context, attemptID uint64, sendError error) (*attemptResult, error) { log.Warnf("Attempt %v for payment %v failed: %v", attemptID, @@ -1017,7 +1024,7 @@ func (p *paymentLifecycle) failAttempt(attemptID uint64, } attempt, err := p.router.cfg.Control.FailAttempt( - p.identifier, attemptID, failInfo, + ctx, p.identifier, attemptID, failInfo, ) if err != nil { return nil, err @@ -1111,10 +1118,10 @@ func (p *paymentLifecycle) patchLegacyPaymentHash( // reloadInflightAttempts is called when the payment lifecycle is resumed after // a restart. It reloads all inflight attempts from the control tower and // collects the results of the attempts that have been sent before. -func (p *paymentLifecycle) reloadInflightAttempts() (paymentsdb.DBMPPayment, - error) { +func (p *paymentLifecycle) reloadInflightAttempts( + ctx context.Context) (paymentsdb.DBMPPayment, error) { - payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + payment, err := p.router.cfg.Control.FetchPayment(ctx, p.identifier) if err != nil { return nil, err } @@ -1136,11 +1143,12 @@ func (p *paymentLifecycle) reloadInflightAttempts() (paymentsdb.DBMPPayment, } // reloadPayment returns the latest payment found in the db (control tower). -func (p *paymentLifecycle) reloadPayment() (paymentsdb.DBMPPayment, +func (p *paymentLifecycle) reloadPayment( + ctx context.Context) (paymentsdb.DBMPPayment, *paymentsdb.MPPaymentState, error) { // Read the db to get the latest state of the payment. - payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + payment, err := p.router.cfg.Control.FetchPayment(ctx, p.identifier) if err != nil { return nil, nil, err } @@ -1157,13 +1165,14 @@ func (p *paymentLifecycle) reloadPayment() (paymentsdb.DBMPPayment, // handleAttemptResult processes the result of an HTLC attempt returned from // the htlcswitch. -func (p *paymentLifecycle) handleAttemptResult(attempt *paymentsdb.HTLCAttempt, +func (p *paymentLifecycle) handleAttemptResult(ctx context.Context, + attempt *paymentsdb.HTLCAttempt, result *htlcswitch.PaymentResult) (*attemptResult, error) { // If the result has an error, we need to further process it by failing // the attempt and maybe fail the payment. if result.Error != nil { - return p.handleSwitchErr(attempt, result.Error) + return p.handleSwitchErr(ctx, attempt, result.Error) } // We got an attempt settled result back from the switch. @@ -1181,7 +1190,7 @@ func (p *paymentLifecycle) handleAttemptResult(attempt *paymentsdb.HTLCAttempt, // In case of success we atomically store settle result to the DB and // move the shard to the settled state. htlcAttempt, err := p.router.cfg.Control.SettleAttempt( - p.identifier, attempt.AttemptID, + ctx, p.identifier, attempt.AttemptID, &paymentsdb.HTLCSettleInfo{ Preimage: result.Preimage, SettleTime: p.router.cfg.Clock.Now(), @@ -1206,7 +1215,7 @@ func (p *paymentLifecycle) handleAttemptResult(attempt *paymentsdb.HTLCAttempt, // available from the Switch, then records the attempt outcome with the control // tower. An attemptResult is returned, indicating the final outcome of this // HTLC attempt. -func (p *paymentLifecycle) collectAndHandleResult( +func (p *paymentLifecycle) collectAndHandleResult(ctx context.Context, attempt *paymentsdb.HTLCAttempt) (*attemptResult, error) { result, err := p.collectResult(attempt) @@ -1214,5 +1223,5 @@ func (p *paymentLifecycle) collectAndHandleResult( return nil, err } - return p.handleAttemptResult(attempt, result) + return p.handleAttemptResult(ctx, attempt, result) } diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 7e94315a7dc..82e2f800d8f 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -393,7 +393,7 @@ func TestRequestRouteSucceed(t *testing.T) { mock.Anything, ).Return(dummyRoute, nil) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(t.Context(), ps) require.NoError(t, err, "expect no error") require.Equal(t, dummyRoute, result, "returned route not matched") @@ -430,7 +430,7 @@ func TestRequestRouteHandleCriticalErr(t *testing.T) { mock.Anything, ).Return(nil, errDummy) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(t.Context(), ps) // Expect an error is returned since it's critical. require.ErrorIs(t, err, errDummy, "error not matched") @@ -470,7 +470,7 @@ func TestRequestRouteHandleNoRouteErr(t *testing.T) { p.identifier, paymentsdb.FailureReasonNoRoute, ).Return(nil).Once() - result, err := p.requestRoute(ps) + result, err := p.requestRoute(t.Context(), ps) // Expect no error is returned since it's not critical. require.NoError(t, err, "expected no error") @@ -513,7 +513,7 @@ func TestRequestRouteFailPaymentError(t *testing.T) { mock.Anything, ).Return(nil, errNoTlvPayload) - result, err := p.requestRoute(ps) + result, err := p.requestRoute(t.Context(), ps) // Expect an error is returned. require.ErrorIs(t, err, errDummy, "error not matched") @@ -599,7 +599,7 @@ func TestDecideNextStep(t *testing.T) { // Once the setup is finished, run the test cases. t.Run(tc.name, func(t *testing.T) { - step, err := p.decideNextStep(payment) + step, err := p.decideNextStep(t.Context(), payment) require.Equal(t, tc.expectedStep, step) require.ErrorIs(t, tc.expectedErr, err) }) @@ -628,7 +628,7 @@ func TestDecideNextStepOnRouterQuit(t *testing.T) { close(p.router.quit) // Call the method under test. - step, err := p.decideNextStep(payment) + step, err := p.decideNextStep(t.Context(), payment) // We expect stepExit and an error to be returned. require.Equal(t, stepExit, step) @@ -657,7 +657,7 @@ func TestDecideNextStepOnLifecycleQuit(t *testing.T) { close(p.quit) // Call the method under test. - step, err := p.decideNextStep(payment) + step, err := p.decideNextStep(t.Context(), payment) // We expect stepExit and an error to be returned. require.Equal(t, stepExit, step) @@ -716,7 +716,7 @@ func TestDecideNextStepHandleAttemptResultSucceed(t *testing.T) { mock.Anything).Return(attempt, nil).Once() // Call the method under test. - step, err := p.decideNextStep(payment) + step, err := p.decideNextStep(t.Context(), payment) // We expect stepSkip and no error to be returned. require.Equal(t, stepSkip, step) @@ -774,7 +774,7 @@ func TestDecideNextStepHandleAttemptResultFail(t *testing.T) { mock.Anything).Return(attempt, errDummy).Once() // Call the method under test. - step, err := p.decideNextStep(payment) + step, err := p.decideNextStep(t.Context(), payment) // We expect stepExit and the above error to be returned. require.Equal(t, stepExit, step) @@ -1467,7 +1467,7 @@ func TestCollectResultExitOnErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectAndHandleResult(attempt) + result, err := p.collectAndHandleResult(t.Context(), attempt) require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } @@ -1513,7 +1513,7 @@ func TestCollectResultExitOnResultErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectAndHandleResult(attempt) + result, err := p.collectAndHandleResult(t.Context(), attempt) require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } @@ -1539,7 +1539,7 @@ func TestCollectResultExitOnSwitchQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectAndHandleResult(attempt) + result, err := p.collectAndHandleResult(t.Context(), attempt) require.ErrorIs(t, err, htlcswitch.ErrSwitchExiting, "expected switch exit") require.Nil(t, result, "expected nil attempt") @@ -1566,7 +1566,7 @@ func TestCollectResultExitOnRouterQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectAndHandleResult(attempt) + result, err := p.collectAndHandleResult(t.Context(), attempt) require.ErrorIs(t, err, ErrRouterShuttingDown, "expected router exit") require.Nil(t, result, "expected nil attempt") } @@ -1592,7 +1592,7 @@ func TestCollectResultExitOnLifecycleQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectAndHandleResult(attempt) + result, err := p.collectAndHandleResult(t.Context(), attempt) require.ErrorIs(t, err, ErrPaymentLifecycleExiting, "expected lifecycle exit") require.Nil(t, result, "expected nil attempt") @@ -1636,7 +1636,7 @@ func TestCollectResultExitOnSettleErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectAndHandleResult(attempt) + result, err := p.collectAndHandleResult(t.Context(), attempt) require.ErrorIs(t, err, errDummy, "expected settle error") require.Nil(t, result, "expected nil attempt") } @@ -1678,7 +1678,7 @@ func TestCollectResultSuccess(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectAndHandleResult(attempt) + result, err := p.collectAndHandleResult(t.Context(), attempt) require.NoError(t, err, "expected no error") require.Equal(t, preimage, result.attempt.Settle.Preimage, "preimage mismatch") @@ -1762,7 +1762,9 @@ func TestHandleAttemptResultWithError(t *testing.T) { // Call the method under test and expect the dummy error to be // returned. - attemptResult, err := p.handleAttemptResult(attempt, result) + attemptResult, err := p.handleAttemptResult( + t.Context(), attempt, result, + ) require.ErrorIs(t, err, errDummy, "expected fail error") require.Nil(t, attemptResult, "expected nil attempt result") } @@ -1800,7 +1802,9 @@ func TestHandleAttemptResultSuccess(t *testing.T) { // Call the method under test and expect the dummy error to be // returned. - attemptResult, err := p.handleAttemptResult(attempt, result) + attemptResult, err := p.handleAttemptResult( + t.Context(), attempt, result, + ) require.NoError(t, err, "expected no error") require.Equal(t, attempt, attemptResult.attempt) } @@ -1846,7 +1850,7 @@ func TestReloadInflightAttemptsLegacy(t *testing.T) { }) // Now call the method under test. - payment, err := p.reloadInflightAttempts() + payment, err := p.reloadInflightAttempts(t.Context()) require.NoError(t, err) require.Equal(t, m.payment, payment) diff --git a/routing/router.go b/routing/router.go index 3c35b7c52cc..b81b202f974 100644 --- a/routing/router.go +++ b/routing/router.go @@ -967,6 +967,8 @@ func spewPayment(payment *LightningPayment) lnutils.LogClosure { func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( PaymentSession, shards.ShardTracker, error) { + ctx := context.TODO() + // Assemble any custom data we want to send to the first hop only. var firstHopData fn.Option[tlv.Blob] if len(payment.FirstHopCustomRecords) > 0 { @@ -1026,7 +1028,7 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( ) } - err = r.cfg.Control.InitPayment(payment.Identifier(), info) + err = r.cfg.Control.InitPayment(ctx, payment.Identifier(), info) if err != nil { return nil, nil, err } @@ -1036,21 +1038,22 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( // SendToRoute sends a payment using the provided route and fails the payment // when an error is returned from the attempt. -func (r *ChannelRouter) SendToRoute(htlcHash lntypes.Hash, rt *route.Route, +func (r *ChannelRouter) SendToRoute(ctx context.Context, htlcHash lntypes.Hash, + rt *route.Route, firstHopCustomRecords lnwire.CustomRecords) (*paymentsdb.HTLCAttempt, error) { - return r.sendToRoute(htlcHash, rt, false, firstHopCustomRecords) + return r.sendToRoute(ctx, htlcHash, rt, false, firstHopCustomRecords) } // SendToRouteSkipTempErr sends a payment using the provided route and fails // the payment ONLY when a terminal error is returned from the attempt. -func (r *ChannelRouter) SendToRouteSkipTempErr(htlcHash lntypes.Hash, - rt *route.Route, +func (r *ChannelRouter) SendToRouteSkipTempErr(ctx context.Context, + htlcHash lntypes.Hash, rt *route.Route, firstHopCustomRecords lnwire.CustomRecords) (*paymentsdb.HTLCAttempt, error) { - return r.sendToRoute(htlcHash, rt, true, firstHopCustomRecords) + return r.sendToRoute(ctx, htlcHash, rt, true, firstHopCustomRecords) } // sendToRoute attempts to send a payment with the given hash through the @@ -1059,8 +1062,8 @@ func (r *ChannelRouter) SendToRouteSkipTempErr(htlcHash lntypes.Hash, // information will contain the preimage. If an error occurs after the attempt // was initiated, both return values will be non-nil. If skipTempErr is true, // the payment won't be failed unless a terminal error has occurred. -func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, - skipTempErr bool, +func (r *ChannelRouter) sendToRoute(ctx context.Context, htlcHash lntypes.Hash, + rt *route.Route, skipTempErr bool, firstHopCustomRecords lnwire.CustomRecords) (*paymentsdb.HTLCAttempt, error) { @@ -1070,7 +1073,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, reason paymentsdb.FailureReason) error { payment, fetchErr := r.cfg.Control.FetchPayment( - paymentIdentifier, + ctx, paymentIdentifier, ) if fetchErr != nil { return fetchErr @@ -1084,7 +1087,9 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, return nil } - return r.cfg.Control.FailPayment(paymentIdentifier, reason) + return r.cfg.Control.FailPayment( + ctx, paymentIdentifier, reason, + ) } log.Debugf("SendToRoute for payment %v with skipTempErr=%v", @@ -1129,7 +1134,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, FirstHopCustomRecords: firstHopCustomRecords, } - err := r.cfg.Control.InitPayment(paymentIdentifier, info) + err := r.cfg.Control.InitPayment(ctx, paymentIdentifier, info) switch { // If this is an MPP attempt and the hash is already registered with // the database, we can go on to launch the shard. @@ -1173,7 +1178,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // NOTE: we use zero `remainingAmt` here to simulate the same effect of // setting the lastShard to be false, which is used by previous // implementation. - attempt, err := p.registerAttempt(rt, 0) + attempt, err := p.registerAttempt(ctx, rt, 0) if err != nil { return nil, err } @@ -1182,7 +1187,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // the `err` returned here has already been processed by // `handleSwitchErr`, which means if there's a terminal failure, the // payment has been failed. - result, err := p.sendAttempt(attempt) + result, err := p.sendAttempt(ctx, attempt) if err != nil { return nil, err } @@ -1210,7 +1215,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // The attempt was successfully sent, wait for the result to be // available. - result, err = p.collectAndHandleResult(attempt) + result, err = p.collectAndHandleResult(ctx, attempt) if err != nil { return nil, err } @@ -1415,9 +1420,11 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], // resumePayments fetches inflight payments and resumes their payment // lifecycles. func (r *ChannelRouter) resumePayments() error { + ctx := context.TODO() + // Get all payments that are inflight. log.Debugf("Scanning for inflight payments") - payments, err := r.cfg.Control.FetchInFlightPayments() + payments, err := r.cfg.Control.FetchInFlightPayments(ctx) if err != nil { return err } @@ -1525,6 +1532,8 @@ func (r *ChannelRouter) resumePayments() error { func (r *ChannelRouter) failStaleAttempt(a paymentsdb.HTLCAttempt, payHash lntypes.Hash) { + ctx := context.TODO() + // We can only fail inflight HTLCs so we skip the settled/failed ones. if a.Failure != nil || a.Settle != nil { return @@ -1608,7 +1617,7 @@ func (r *ChannelRouter) failStaleAttempt(a paymentsdb.HTLCAttempt, Reason: paymentsdb.HTLCFailUnknown, FailTime: r.cfg.Clock.Now(), } - _, err = r.cfg.Control.FailAttempt(payHash, a.AttemptID, failInfo) + _, err = r.cfg.Control.FailAttempt(ctx, payHash, a.AttemptID, failInfo) if err != nil { log.Errorf("Fail attempt=%v got error: %v", a.AttemptID, err) } diff --git a/routing/router_test.go b/routing/router_test.go index b811793d258..2a72ff39fc5 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -519,7 +519,7 @@ func TestChannelUpdateValidation(t *testing.T) { // Send off the payment request to the router. The specified route // should be attempted and the channel update should be received by // graph and ignored because it is missing a valid signature. - _, err = ctx.router.SendToRoute(payment, rt, nil) + _, err = ctx.router.SendToRoute(t.Context(), payment, rt, nil) require.Error(t, err, "expected route to fail with channel update") _, e1, e2, err = ctx.graph.FetchChannelEdgesByID( @@ -539,7 +539,7 @@ func TestChannelUpdateValidation(t *testing.T) { ctx.graphBuilder.setNextReject(false) // Retry the payment using the same route as before. - _, err = ctx.router.SendToRoute(payment, rt, nil) + _, err = ctx.router.SendToRoute(t.Context(), payment, rt, nil) require.Error(t, err, "expected route to fail with channel update") // This time a valid signature was supplied and the policy change should @@ -1094,7 +1094,9 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { require.Equal(t, paymentsdb.FailureReasonNoRoute, err) // Inspect the two attempts that were made before the payment failed. - p, err := ctx.router.cfg.Control.FetchPayment(*payment.paymentHash) + p, err := ctx.router.cfg.Control.FetchPayment( + t.Context(), *payment.paymentHash, + ) require.NoError(t, err) htlcs := p.GetHTLCs() @@ -1422,7 +1424,9 @@ func TestSendToRouteStructuredError(t *testing.T) { // update should be received by router and ignored // because it is missing a valid // signature. - _, err = ctx.router.SendToRoute(payment, rt, nil) + _, err = ctx.router.SendToRoute( + t.Context(), payment, rt, nil, + ) fErr, ok := err.(*htlcswitch.ForwardingError) require.True( @@ -1501,7 +1505,7 @@ func TestSendToRouteMaxHops(t *testing.T) { // Send off the payment request to the router. We expect an error back // indicating that the route is too long. var payHash lntypes.Hash - _, err = ctx.router.SendToRoute(payHash, rt, nil) + _, err = ctx.router.SendToRoute(t.Context(), payHash, rt, nil) if err != route.ErrMaxRouteHopsExceeded { t.Fatalf("expected ErrMaxRouteHopsExceeded, but got %v", err) } @@ -2216,7 +2220,9 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { ).Return(nil) // Expect a successful send to route. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt, nil) + attempt, err := router.SendToRouteSkipTempErr( + t.Context(), payHash, rt, nil, + ) require.NoError(t, err) require.Equal(t, testAttempt, attempt) @@ -2271,7 +2277,9 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { }} // Expect an error to be returned. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt, nil) + attempt, err := router.SendToRouteSkipTempErr( + t.Context(), payHash, rt, nil, + ) require.ErrorIs(t, ErrSkipTempErr, err) require.Nil(t, attempt) @@ -2351,7 +2359,9 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { ).Return(nil, nil) // Expect a failed send to route. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt, nil) + attempt, err := router.SendToRouteSkipTempErr( + t.Context(), payHash, rt, nil, + ) require.Equal(t, tempErr, err) require.Equal(t, testAttempt, attempt) @@ -2435,7 +2445,9 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { ).Return(&failureReason, nil) // Expect a failed send to route. - attempt, err := router.SendToRouteSkipTempErr(payHash, rt, nil) + attempt, err := router.SendToRouteSkipTempErr( + t.Context(), payHash, rt, nil, + ) require.Equal(t, permErr, err) require.Equal(t, testAttempt, attempt) @@ -2524,7 +2536,7 @@ func TestSendToRouteTempFailure(t *testing.T) { ).Return(nil, nil) // Expect a failed send to route. - attempt, err := router.SendToRoute(payHash, rt, nil) + attempt, err := router.SendToRoute(t.Context(), payHash, rt, nil) require.Equal(t, tempErr, err) require.Equal(t, testAttempt, attempt) diff --git a/rpcserver.go b/rpcserver.go index d3d3c518014..8698aa10a1b 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -5509,6 +5509,7 @@ func (r *rpcServer) SubscribeChannelEvents(req *lnrpc.ChannelEventSubscription, // execute sendPayment. We use this struct as a sort of bridge to enable code // re-use between SendPayment and SendToRoute. type paymentStream struct { + ctx context.Context recv func() (*rpcPaymentRequest, error) send func(*lnrpc.SendResponse) error } @@ -5524,10 +5525,13 @@ type rpcPaymentRequest struct { // through the Lightning Network. A single RPC invocation creates a persistent // bi-directional stream allowing clients to rapidly send payments through the // Lightning Network with a single persistent connection. -func (r *rpcServer) SendPayment(stream lnrpc.Lightning_SendPaymentServer) error { +func (r *rpcServer) SendPayment( + stream lnrpc.Lightning_SendPaymentServer) error { + var lock sync.Mutex return r.sendPayment(&paymentStream{ + ctx: stream.Context(), recv: func() (*rpcPaymentRequest, error) { req, err := stream.Recv() if err != nil { @@ -5552,10 +5556,13 @@ func (r *rpcServer) SendPayment(stream lnrpc.Lightning_SendPaymentServer) error // invocation creates a persistent bi-directional stream allowing clients to // rapidly send payments through the Lightning Network with a single persistent // connection. -func (r *rpcServer) SendToRoute(stream lnrpc.Lightning_SendToRouteServer) error { +func (r *rpcServer) SendToRoute( + stream lnrpc.Lightning_SendToRouteServer) error { + var lock sync.Mutex return r.sendPayment(&paymentStream{ + ctx: stream.Context(), recv: func() (*rpcPaymentRequest, error) { req, err := stream.Recv() if err != nil { @@ -5625,7 +5632,9 @@ type rpcPaymentIntent struct { // dispatch a client from the information presented by an RPC client. There are // three ways a client can specify their payment details: a payment request, // via manual details, or via a complete route. -func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPaymentIntent, error) { +func (r *rpcServer) extractPaymentIntent( + rpcPayReq *rpcPaymentRequest) (rpcPaymentIntent, error) { + payIntent := rpcPaymentIntent{} // If a route was specified, then we can use that directly. @@ -5897,7 +5906,7 @@ type paymentIntentResponse struct { // pre-built route. The first error this method returns denotes if we were // unable to save the payment. The second error returned denotes if the payment // didn't succeed. -func (r *rpcServer) dispatchPaymentIntent( +func (r *rpcServer) dispatchPaymentIntent(ctx context.Context, payIntent *rpcPaymentIntent) (*paymentIntentResponse, error) { // Construct a payment request to send to the channel router. If the @@ -5944,7 +5953,7 @@ func (r *rpcServer) dispatchPaymentIntent( } else { var attempt *paymentsdb.HTLCAttempt attempt, routerErr = r.server.chanRouter.SendToRoute( - payIntent.rHash, payIntent.route, nil, + ctx, payIntent.rHash, payIntent.route, nil, ) if routerErr == nil { @@ -6117,7 +6126,7 @@ sendLoop: }() resp, saveErr := r.dispatchPaymentIntent( - payIntent, + stream.ctx, payIntent, ) switch { @@ -6195,7 +6204,7 @@ sendLoop: func (r *rpcServer) SendPaymentSync(ctx context.Context, nextPayment *lnrpc.SendRequest) (*lnrpc.SendResponse, error) { - return r.sendPaymentSync(&rpcPaymentRequest{ + return r.sendPaymentSync(ctx, &rpcPaymentRequest{ SendRequest: nextPayment, }) } @@ -6216,12 +6225,12 @@ func (r *rpcServer) SendToRouteSync(ctx context.Context, return nil, err } - return r.sendPaymentSync(paymentRequest) + return r.sendPaymentSync(ctx, paymentRequest) } // sendPaymentSync is the synchronous variant of sendPayment. It will block and // wait until the payment has been fully completed. -func (r *rpcServer) sendPaymentSync( +func (r *rpcServer) sendPaymentSync(ctx context.Context, nextPayment *rpcPaymentRequest) (*lnrpc.SendResponse, error) { // We don't allow payments to be sent while the daemon itself is still @@ -6240,7 +6249,7 @@ func (r *rpcServer) sendPaymentSync( // With the payment validated, we'll now attempt to dispatch the // payment. - resp, saveErr := r.dispatchPaymentIntent(&payIntent) + resp, saveErr := r.dispatchPaymentIntent(ctx, &payIntent) switch { case saveErr != nil: return nil, saveErr @@ -7666,7 +7675,7 @@ func (r *rpcServer) DeletePayment(ctx context.Context, rpcsLog.Infof("[DeletePayment] payment_identifier=%v, "+ "failed_htlcs_only=%v", hash, req.FailedHtlcsOnly) - err = r.server.paymentsDB.DeletePayment(hash, req.FailedHtlcsOnly) + err = r.server.paymentsDB.DeletePayment(ctx, hash, req.FailedHtlcsOnly) if err != nil { return nil, err } @@ -7707,7 +7716,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context, req.FailedHtlcsOnly) numDeletedPayments, err := r.server.paymentsDB.DeletePayments( - req.FailedPaymentsOnly, req.FailedHtlcsOnly, + ctx, req.FailedPaymentsOnly, req.FailedHtlcsOnly, ) if err != nil { return nil, err diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go index f7bc499185a..db0a03a4210 100644 --- a/sqldb/sqlc/db_custom.go +++ b/sqldb/sqlc/db_custom.go @@ -161,3 +161,79 @@ func (r GetChannelsBySCIDRangeRow) Node1Pub() []byte { func (r GetChannelsBySCIDRangeRow) Node2Pub() []byte { return r.Node2PubKey } + +// PaymentAndIntent is an interface that provides access to a payment and its +// associated payment intent. +type PaymentAndIntent interface { + // GetPayment returns the Payment associated with this interface. + GetPayment() Payment + + // GetPaymentIntent returns the PaymentIntent associated with this payment. + GetPaymentIntent() PaymentIntent +} + +// GetPayment returns the Payment associated with this interface. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FilterPaymentsRow) GetPayment() Payment { + return r.Payment +} + +// GetPaymentIntent returns the PaymentIntent associated with this payment. +// If the payment has no intent (IntentType is NULL), this returns a zero-value +// PaymentIntent. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FilterPaymentsRow) GetPaymentIntent() PaymentIntent { + if !r.IntentType.Valid { + return PaymentIntent{} + } + return PaymentIntent{ + IntentType: r.IntentType.Int16, + IntentPayload: r.IntentPayload, + } +} + +// GetPayment returns the Payment associated with this interface. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentRow) GetPayment() Payment { + return r.Payment +} + +// GetPaymentIntent returns the PaymentIntent associated with this payment. +// If the payment has no intent (IntentType is NULL), this returns a zero-value +// PaymentIntent. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentRow) GetPaymentIntent() PaymentIntent { + if !r.IntentType.Valid { + return PaymentIntent{} + } + return PaymentIntent{ + IntentType: r.IntentType.Int16, + IntentPayload: r.IntentPayload, + } +} + +// GetPayment returns the Payment associated with this interface. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentsByIDsRow) GetPayment() Payment { + return r.Payment +} + +// GetPaymentIntent returns the PaymentIntent associated with this payment. +// If the payment has no intent (IntentType is NULL), this returns a zero-value +// PaymentIntent. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentsByIDsRow) GetPaymentIntent() PaymentIntent { + if !r.IntentType.Valid { + return PaymentIntent{} + } + return PaymentIntent{ + IntentType: r.IntentType.Int16, + IntentPayload: r.IntentPayload, + } +} diff --git a/sqldb/sqlc/migrations/000009_payments.up.sql b/sqldb/sqlc/migrations/000009_payments.up.sql index c856db8f442..0d85b497b00 100644 --- a/sqldb/sqlc/migrations/000009_payments.up.sql +++ b/sqldb/sqlc/migrations/000009_payments.up.sql @@ -32,9 +32,13 @@ CREATE TABLE IF NOT EXISTS payment_intents ( ); -- Index for efficient querying by intent type -CREATE INDEX IF NOT EXISTS idx_payment_intents_type +CREATE INDEX IF NOT EXISTS idx_payment_intents_type ON payment_intents(intent_type); +-- Unique constraint for deduplication of payment intents +CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_intents_unique +ON payment_intents(intent_type, intent_payload); + -- ───────────────────────────────────────────── -- Payments Table -- ───────────────────────────────────────────── @@ -187,7 +191,8 @@ CREATE TABLE IF NOT EXISTS payment_htlc_attempt_resolutions ( -- HTLC failure reason code htlc_fail_reason INTEGER, - -- Failure message from the failing node + -- Failure message from the failing node, this message is binary encoded + -- using the lightning wire protocol, see also lnwire/onion_error.go failure_msg BLOB, -- Ensure data integrity: settled attempts must have preimage, diff --git a/sqldb/sqlc/payments.sql.go b/sqldb/sqlc/payments.sql.go new file mode 100644 index 00000000000..babab68247d --- /dev/null +++ b/sqldb/sqlc/payments.sql.go @@ -0,0 +1,1084 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: payments.sql + +package sqlc + +import ( + "context" + "database/sql" + "strings" + "time" +) + +const countPayments = `-- name: CountPayments :one +SELECT COUNT(*) FROM payments +` + +func (q *Queries) CountPayments(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, countPayments) + var count int64 + err := row.Scan(&count) + return count, err +} + +const deleteFailedAttempts = `-- name: DeleteFailedAttempts :exec +DELETE FROM payment_htlc_attempts WHERE payment_id = $1 AND attempt_index IN ( + SELECT attempt_index FROM payment_htlc_attempt_resolutions WHERE resolution_type = 2 +) +` + +func (q *Queries) DeleteFailedAttempts(ctx context.Context, paymentID int64) error { + _, err := q.db.ExecContext(ctx, deleteFailedAttempts, paymentID) + return err +} + +const deletePayment = `-- name: DeletePayment :exec +DELETE FROM payments WHERE id = $1 +` + +func (q *Queries) DeletePayment(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deletePayment, id) + return err +} + +const failAttempt = `-- name: FailAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + failure_source_index, + htlc_fail_reason, + failure_msg +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6 +) +` + +type FailAttemptParams struct { + AttemptIndex int64 + ResolutionTime time.Time + ResolutionType int32 + FailureSourceIndex sql.NullInt32 + HtlcFailReason sql.NullInt32 + FailureMsg []byte +} + +func (q *Queries) FailAttempt(ctx context.Context, arg FailAttemptParams) error { + _, err := q.db.ExecContext(ctx, failAttempt, + arg.AttemptIndex, + arg.ResolutionTime, + arg.ResolutionType, + arg.FailureSourceIndex, + arg.HtlcFailReason, + arg.FailureMsg, + ) + return err +} + +const failPayment = `-- name: FailPayment :execresult +UPDATE payments SET fail_reason = $1 WHERE payment_identifier = $2 +` + +type FailPaymentParams struct { + FailReason sql.NullInt32 + PaymentIdentifier []byte +} + +func (q *Queries) FailPayment(ctx context.Context, arg FailPaymentParams) (sql.Result, error) { + return q.db.ExecContext(ctx, failPayment, arg.FailReason, arg.PaymentIdentifier) +} + +const fetchAllInflightAttempts = `-- name: FetchAllInflightAttempts :many +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key +FROM payment_htlc_attempts ha +WHERE NOT EXISTS ( + SELECT 1 FROM payment_htlc_attempt_resolutions hr + WHERE hr.attempt_index = ha.attempt_index +) +` + +// Fetch all inflight attempts across all payments +func (q *Queries) FetchAllInflightAttempts(ctx context.Context) ([]PaymentHtlcAttempt, error) { + rows, err := q.db.QueryContext(ctx, fetchAllInflightAttempts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentHtlcAttempt + for rows.Next() { + var i PaymentHtlcAttempt + if err := rows.Scan( + &i.ID, + &i.AttemptIndex, + &i.PaymentID, + &i.SessionKey, + &i.AttemptTime, + &i.PaymentHash, + &i.FirstHopAmountMsat, + &i.RouteTotalTimeLock, + &i.RouteTotalAmount, + &i.RouteSourceKey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHopLevelCustomRecords = `-- name: FetchHopLevelCustomRecords :many +SELECT + l.id, + l.hop_id, + l.key, + l.value +FROM payment_hop_custom_records l +WHERE l.hop_id IN (/*SLICE:hop_ids*/?) +` + +func (q *Queries) FetchHopLevelCustomRecords(ctx context.Context, hopIds []int64) ([]PaymentHopCustomRecord, error) { + query := fetchHopLevelCustomRecords + var queryParams []interface{} + if len(hopIds) > 0 { + for _, v := range hopIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:hop_ids*/?", makeQueryParams(len(queryParams), len(hopIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:hop_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentHopCustomRecord + for rows.Next() { + var i PaymentHopCustomRecord + if err := rows.Scan( + &i.ID, + &i.HopID, + &i.Key, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHopsForAttempt = `-- name: FetchHopsForAttempt :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index = $1 +ORDER BY h.hop_index ASC +` + +type FetchHopsForAttemptRow struct { + ID int64 + HtlcAttemptIndex int64 + HopIndex int32 + PubKey []byte + Scid string + OutgoingTimeLock int32 + AmtToForward int64 + MetaData []byte + MppPaymentAddr []byte + MppTotalMsat sql.NullInt64 + AmpRootShare []byte + AmpSetID []byte + AmpChildIndex sql.NullInt32 + EncryptedData []byte + BlindingPoint []byte + BlindedPathTotalAmt sql.NullInt64 +} + +func (q *Queries) FetchHopsForAttempt(ctx context.Context, htlcAttemptIndex int64) ([]FetchHopsForAttemptRow, error) { + rows, err := q.db.QueryContext(ctx, fetchHopsForAttempt, htlcAttemptIndex) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchHopsForAttemptRow + for rows.Next() { + var i FetchHopsForAttemptRow + if err := rows.Scan( + &i.ID, + &i.HtlcAttemptIndex, + &i.HopIndex, + &i.PubKey, + &i.Scid, + &i.OutgoingTimeLock, + &i.AmtToForward, + &i.MetaData, + &i.MppPaymentAddr, + &i.MppTotalMsat, + &i.AmpRootShare, + &i.AmpSetID, + &i.AmpChildIndex, + &i.EncryptedData, + &i.BlindingPoint, + &i.BlindedPathTotalAmt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHopsForAttempts = `-- name: FetchHopsForAttempts :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index IN (/*SLICE:htlc_attempt_indices*/?) +` + +type FetchHopsForAttemptsRow struct { + ID int64 + HtlcAttemptIndex int64 + HopIndex int32 + PubKey []byte + Scid string + OutgoingTimeLock int32 + AmtToForward int64 + MetaData []byte + MppPaymentAddr []byte + MppTotalMsat sql.NullInt64 + AmpRootShare []byte + AmpSetID []byte + AmpChildIndex sql.NullInt32 + EncryptedData []byte + BlindingPoint []byte + BlindedPathTotalAmt sql.NullInt64 +} + +func (q *Queries) FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]FetchHopsForAttemptsRow, error) { + query := fetchHopsForAttempts + var queryParams []interface{} + if len(htlcAttemptIndices) > 0 { + for _, v := range htlcAttemptIndices { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", makeQueryParams(len(queryParams), len(htlcAttemptIndices)), 1) + } else { + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchHopsForAttemptsRow + for rows.Next() { + var i FetchHopsForAttemptsRow + if err := rows.Scan( + &i.ID, + &i.HtlcAttemptIndex, + &i.HopIndex, + &i.PubKey, + &i.Scid, + &i.OutgoingTimeLock, + &i.AmtToForward, + &i.MetaData, + &i.MppPaymentAddr, + &i.MppTotalMsat, + &i.AmpRootShare, + &i.AmpSetID, + &i.AmpChildIndex, + &i.EncryptedData, + &i.BlindingPoint, + &i.BlindedPathTotalAmt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHtlcAttemptsForPayment = `-- name: FetchHtlcAttemptsForPayment :many +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key, + hr.resolution_type, + hr.resolution_time, + hr.failure_source_index, + hr.htlc_fail_reason, + hr.failure_msg, + hr.settle_preimage +FROM payment_htlc_attempts ha +LEFT JOIN payment_htlc_attempt_resolutions hr ON hr.attempt_index = ha.attempt_index +WHERE ha.payment_id = $1 +ORDER BY ha.attempt_time ASC +` + +type FetchHtlcAttemptsForPaymentRow struct { + ID int64 + AttemptIndex int64 + PaymentID int64 + SessionKey []byte + AttemptTime time.Time + PaymentHash []byte + FirstHopAmountMsat int64 + RouteTotalTimeLock int32 + RouteTotalAmount int64 + RouteSourceKey []byte + ResolutionType sql.NullInt32 + ResolutionTime sql.NullTime + FailureSourceIndex sql.NullInt32 + HtlcFailReason sql.NullInt32 + FailureMsg []byte + SettlePreimage []byte +} + +// This fetches all htlc attempts for a payment. +func (q *Queries) FetchHtlcAttemptsForPayment(ctx context.Context, paymentID int64) ([]FetchHtlcAttemptsForPaymentRow, error) { + rows, err := q.db.QueryContext(ctx, fetchHtlcAttemptsForPayment, paymentID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchHtlcAttemptsForPaymentRow + for rows.Next() { + var i FetchHtlcAttemptsForPaymentRow + if err := rows.Scan( + &i.ID, + &i.AttemptIndex, + &i.PaymentID, + &i.SessionKey, + &i.AttemptTime, + &i.PaymentHash, + &i.FirstHopAmountMsat, + &i.RouteTotalTimeLock, + &i.RouteTotalAmount, + &i.RouteSourceKey, + &i.ResolutionType, + &i.ResolutionTime, + &i.FailureSourceIndex, + &i.HtlcFailReason, + &i.FailureMsg, + &i.SettlePreimage, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchPayment = `-- name: FetchPayment :one +SELECT + p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.payment_identifier = $1 +` + +type FetchPaymentRow struct { + Payment Payment + IntentType sql.NullInt16 + IntentPayload []byte +} + +func (q *Queries) FetchPayment(ctx context.Context, paymentIdentifier []byte) (FetchPaymentRow, error) { + row := q.db.QueryRowContext(ctx, fetchPayment, paymentIdentifier) + var i FetchPaymentRow + err := row.Scan( + &i.Payment.ID, + &i.Payment.IntentID, + &i.Payment.AmountMsat, + &i.Payment.CreatedAt, + &i.Payment.PaymentIdentifier, + &i.Payment.FailReason, + &i.IntentType, + &i.IntentPayload, + ) + return i, err +} + +const fetchPaymentLevelFirstHopCustomRecords = `-- name: FetchPaymentLevelFirstHopCustomRecords :many +SELECT + l.id, + l.payment_id, + l.key, + l.value +FROM payment_first_hop_custom_records l +WHERE l.payment_id = $1 +` + +func (q *Queries) FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentID int64) ([]PaymentFirstHopCustomRecord, error) { + rows, err := q.db.QueryContext(ctx, fetchPaymentLevelFirstHopCustomRecords, paymentID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentFirstHopCustomRecord + for rows.Next() { + var i PaymentFirstHopCustomRecord + if err := rows.Scan( + &i.ID, + &i.PaymentID, + &i.Key, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchPaymentsByIDs = `-- name: FetchPaymentsByIDs :many +SELECT + p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.id IN (/*SLICE:payment_ids*/?) +` + +type FetchPaymentsByIDsRow struct { + Payment Payment + IntentType sql.NullInt16 + IntentPayload []byte +} + +func (q *Queries) FetchPaymentsByIDs(ctx context.Context, paymentIds []int64) ([]FetchPaymentsByIDsRow, error) { + query := fetchPaymentsByIDs + var queryParams []interface{} + if len(paymentIds) > 0 { + for _, v := range paymentIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:payment_ids*/?", makeQueryParams(len(queryParams), len(paymentIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:payment_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchPaymentsByIDsRow + for rows.Next() { + var i FetchPaymentsByIDsRow + if err := rows.Scan( + &i.Payment.ID, + &i.Payment.IntentID, + &i.Payment.AmountMsat, + &i.Payment.CreatedAt, + &i.Payment.PaymentIdentifier, + &i.Payment.FailReason, + &i.IntentType, + &i.IntentPayload, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchRouteLevelFirstHopCustomRecords = `-- name: FetchRouteLevelFirstHopCustomRecords :many +SELECT + l.id, + l.htlc_attempt_index, + l.key, + l.value +FROM payment_attempt_first_hop_custom_records l +WHERE l.htlc_attempt_index IN (/*SLICE:htlc_attempt_indices*/?) +` + +func (q *Queries) FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]PaymentAttemptFirstHopCustomRecord, error) { + query := fetchRouteLevelFirstHopCustomRecords + var queryParams []interface{} + if len(htlcAttemptIndices) > 0 { + for _, v := range htlcAttemptIndices { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", makeQueryParams(len(queryParams), len(htlcAttemptIndices)), 1) + } else { + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentAttemptFirstHopCustomRecord + for rows.Next() { + var i PaymentAttemptFirstHopCustomRecord + if err := rows.Scan( + &i.ID, + &i.HtlcAttemptIndex, + &i.Key, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const filterPayments = `-- name: FilterPayments :many +/* ───────────────────────────────────────────── + fetch queries + ───────────────────────────────────────────── +*/ + +SELECT + p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE ( + p.id > $1 OR + $1 IS NULL +) AND ( + p.id < $2 OR + $2 IS NULL +) AND ( + p.created_at >= $3 OR + $3 IS NULL +) AND ( + p.created_at <= $4 OR + $4 IS NULL +) AND ( + i.intent_type = $5 OR + $5 IS NULL OR i.intent_type IS NULL +) +ORDER BY + CASE WHEN $6 = false OR $6 IS NULL THEN p.id END ASC, + CASE WHEN $6 = true THEN p.id END DESC +LIMIT $7 +` + +type FilterPaymentsParams struct { + IndexOffsetGet sql.NullInt64 + IndexOffsetLet sql.NullInt64 + CreatedAfter sql.NullTime + CreatedBefore sql.NullTime + IntentType sql.NullInt16 + Reverse interface{} + NumLimit int32 +} + +type FilterPaymentsRow struct { + Payment Payment + IntentType sql.NullInt16 + IntentPayload []byte +} + +func (q *Queries) FilterPayments(ctx context.Context, arg FilterPaymentsParams) ([]FilterPaymentsRow, error) { + rows, err := q.db.QueryContext(ctx, filterPayments, + arg.IndexOffsetGet, + arg.IndexOffsetLet, + arg.CreatedAfter, + arg.CreatedBefore, + arg.IntentType, + arg.Reverse, + arg.NumLimit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FilterPaymentsRow + for rows.Next() { + var i FilterPaymentsRow + if err := rows.Scan( + &i.Payment.ID, + &i.Payment.IntentID, + &i.Payment.AmountMsat, + &i.Payment.CreatedAt, + &i.Payment.PaymentIdentifier, + &i.Payment.FailReason, + &i.IntentType, + &i.IntentPayload, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertHtlcAttempt = `-- name: InsertHtlcAttempt :one +INSERT INTO payment_htlc_attempts ( + payment_id, + attempt_index, + session_key, + attempt_time, + payment_hash, + first_hop_amount_msat, + route_total_time_lock, + route_total_amount, + route_source_key) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9) +RETURNING id +` + +type InsertHtlcAttemptParams struct { + PaymentID int64 + AttemptIndex int64 + SessionKey []byte + AttemptTime time.Time + PaymentHash []byte + FirstHopAmountMsat int64 + RouteTotalTimeLock int32 + RouteTotalAmount int64 + RouteSourceKey []byte +} + +func (q *Queries) InsertHtlcAttempt(ctx context.Context, arg InsertHtlcAttemptParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertHtlcAttempt, + arg.PaymentID, + arg.AttemptIndex, + arg.SessionKey, + arg.AttemptTime, + arg.PaymentHash, + arg.FirstHopAmountMsat, + arg.RouteTotalTimeLock, + arg.RouteTotalAmount, + arg.RouteSourceKey, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertPayment = `-- name: InsertPayment :one +INSERT INTO payments ( + intent_id, + amount_msat, + created_at, + payment_identifier, + fail_reason) +VALUES ( + $1, + $2, + $3, + $4, + NULL +) +RETURNING id +` + +type InsertPaymentParams struct { + IntentID sql.NullInt64 + AmountMsat int64 + CreatedAt time.Time + PaymentIdentifier []byte +} + +// Insert a new payment with the given intent ID and return its ID. +func (q *Queries) InsertPayment(ctx context.Context, arg InsertPaymentParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertPayment, + arg.IntentID, + arg.AmountMsat, + arg.CreatedAt, + arg.PaymentIdentifier, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertPaymentAttemptFirstHopCustomRecord = `-- name: InsertPaymentAttemptFirstHopCustomRecord :exec +INSERT INTO payment_attempt_first_hop_custom_records ( + htlc_attempt_index, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentAttemptFirstHopCustomRecordParams struct { + HtlcAttemptIndex int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg InsertPaymentAttemptFirstHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentAttemptFirstHopCustomRecord, arg.HtlcAttemptIndex, arg.Key, arg.Value) + return err +} + +const insertPaymentFirstHopCustomRecord = `-- name: InsertPaymentFirstHopCustomRecord :exec +INSERT INTO payment_first_hop_custom_records ( + payment_id, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentFirstHopCustomRecordParams struct { + PaymentID int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentFirstHopCustomRecord(ctx context.Context, arg InsertPaymentFirstHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentFirstHopCustomRecord, arg.PaymentID, arg.Key, arg.Value) + return err +} + +const insertPaymentHopCustomRecord = `-- name: InsertPaymentHopCustomRecord :exec +INSERT INTO payment_hop_custom_records ( + hop_id, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentHopCustomRecordParams struct { + HopID int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentHopCustomRecord(ctx context.Context, arg InsertPaymentHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentHopCustomRecord, arg.HopID, arg.Key, arg.Value) + return err +} + +const insertPaymentIntent = `-- name: InsertPaymentIntent :one +INSERT INTO payment_intents ( + intent_type, + intent_payload) +VALUES ( + $1, + $2 +) +ON CONFLICT (intent_type, intent_payload) DO UPDATE SET + intent_type = EXCLUDED.intent_type, + intent_payload = EXCLUDED.intent_payload +RETURNING id +` + +type InsertPaymentIntentParams struct { + IntentType int16 + IntentPayload []byte +} + +// Insert a new payment intent and return its ID. +func (q *Queries) InsertPaymentIntent(ctx context.Context, arg InsertPaymentIntentParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertPaymentIntent, arg.IntentType, arg.IntentPayload) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertRouteHop = `-- name: InsertRouteHop :one +INSERT INTO payment_route_hops ( + htlc_attempt_index, + hop_index, + pub_key, + scid, + outgoing_time_lock, + amt_to_forward, + meta_data +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING id +` + +type InsertRouteHopParams struct { + HtlcAttemptIndex int64 + HopIndex int32 + PubKey []byte + Scid string + OutgoingTimeLock int32 + AmtToForward int64 + MetaData []byte +} + +func (q *Queries) InsertRouteHop(ctx context.Context, arg InsertRouteHopParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertRouteHop, + arg.HtlcAttemptIndex, + arg.HopIndex, + arg.PubKey, + arg.Scid, + arg.OutgoingTimeLock, + arg.AmtToForward, + arg.MetaData, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertRouteHopAmp = `-- name: InsertRouteHopAmp :exec +INSERT INTO payment_route_hop_amp ( + hop_id, + root_share, + set_id, + child_index +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type InsertRouteHopAmpParams struct { + HopID int64 + RootShare []byte + SetID []byte + ChildIndex int32 +} + +func (q *Queries) InsertRouteHopAmp(ctx context.Context, arg InsertRouteHopAmpParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopAmp, + arg.HopID, + arg.RootShare, + arg.SetID, + arg.ChildIndex, + ) + return err +} + +const insertRouteHopBlinded = `-- name: InsertRouteHopBlinded :exec +INSERT INTO payment_route_hop_blinded ( + hop_id, + encrypted_data, + blinding_point, + blinded_path_total_amt +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type InsertRouteHopBlindedParams struct { + HopID int64 + EncryptedData []byte + BlindingPoint []byte + BlindedPathTotalAmt sql.NullInt64 +} + +func (q *Queries) InsertRouteHopBlinded(ctx context.Context, arg InsertRouteHopBlindedParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopBlinded, + arg.HopID, + arg.EncryptedData, + arg.BlindingPoint, + arg.BlindedPathTotalAmt, + ) + return err +} + +const insertRouteHopMpp = `-- name: InsertRouteHopMpp :exec +INSERT INTO payment_route_hop_mpp ( + hop_id, + payment_addr, + total_msat +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertRouteHopMppParams struct { + HopID int64 + PaymentAddr []byte + TotalMsat int64 +} + +func (q *Queries) InsertRouteHopMpp(ctx context.Context, arg InsertRouteHopMppParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopMpp, arg.HopID, arg.PaymentAddr, arg.TotalMsat) + return err +} + +const settleAttempt = `-- name: SettleAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + settle_preimage +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type SettleAttemptParams struct { + AttemptIndex int64 + ResolutionTime time.Time + ResolutionType int32 + SettlePreimage []byte +} + +func (q *Queries) SettleAttempt(ctx context.Context, arg SettleAttemptParams) error { + _, err := q.db.ExecContext(ctx, settleAttempt, + arg.AttemptIndex, + arg.ResolutionTime, + arg.ResolutionType, + arg.SettlePreimage, + ) + return err +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 0087559be8f..3319e1e3026 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -14,24 +14,41 @@ type Querier interface { AddSourceNode(ctx context.Context, nodeID int64) error AddV1ChannelProof(ctx context.Context, arg AddV1ChannelProofParams) (sql.Result, error) ClearKVInvoiceHashIndex(ctx context.Context) error + CountPayments(ctx context.Context) (int64, error) CountZombieChannels(ctx context.Context, version int16) (int64, error) CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error DeleteChannels(ctx context.Context, ids []int64) error DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error + DeleteFailedAttempts(ctx context.Context, paymentID int64) error DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) DeleteNode(ctx context.Context, id int64) error DeleteNodeAddresses(ctx context.Context, nodeID int64) error DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error + DeletePayment(ctx context.Context, id int64) error DeletePruneLogEntriesInRange(ctx context.Context, arg DeletePruneLogEntriesInRangeParams) error DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error) DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) + FailAttempt(ctx context.Context, arg FailAttemptParams) error + FailPayment(ctx context.Context, arg FailPaymentParams) (sql.Result, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) + // Fetch all inflight attempts across all payments + FetchAllInflightAttempts(ctx context.Context) ([]PaymentHtlcAttempt, error) + FetchHopLevelCustomRecords(ctx context.Context, hopIds []int64) ([]PaymentHopCustomRecord, error) + FetchHopsForAttempt(ctx context.Context, htlcAttemptIndex int64) ([]FetchHopsForAttemptRow, error) + FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]FetchHopsForAttemptsRow, error) + // This fetches all htlc attempts for a payment. + FetchHtlcAttemptsForPayment(ctx context.Context, paymentID int64) ([]FetchHtlcAttemptsForPaymentRow, error) + FetchPayment(ctx context.Context, paymentIdentifier []byte) (FetchPaymentRow, error) + FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentID int64) ([]PaymentFirstHopCustomRecord, error) + FetchPaymentsByIDs(ctx context.Context, paymentIds []int64) ([]FetchPaymentsByIDsRow, error) + FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]PaymentAttemptFirstHopCustomRecord, error) FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) + FilterPayments(ctx context.Context, arg FilterPaymentsParams) ([]FilterPaymentsRow, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error) @@ -101,6 +118,7 @@ type Querier interface { // UpsertEdgePolicy query is used because of the constraint in that query that // requires a policy update to have a newer last_update than the existing one). InsertEdgePolicyMig(ctx context.Context, arg InsertEdgePolicyMigParams) (int64, error) + InsertHtlcAttempt(ctx context.Context, arg InsertHtlcAttemptParams) (int64, error) InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error) @@ -114,6 +132,17 @@ type Querier interface { // is used because of the constraint in that query that requires a node update // to have a newer last_update than the existing node). InsertNodeMig(ctx context.Context, arg InsertNodeMigParams) (int64, error) + // Insert a new payment with the given intent ID and return its ID. + InsertPayment(ctx context.Context, arg InsertPaymentParams) (int64, error) + InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg InsertPaymentAttemptFirstHopCustomRecordParams) error + InsertPaymentFirstHopCustomRecord(ctx context.Context, arg InsertPaymentFirstHopCustomRecordParams) error + InsertPaymentHopCustomRecord(ctx context.Context, arg InsertPaymentHopCustomRecordParams) error + // Insert a new payment intent and return its ID. + InsertPaymentIntent(ctx context.Context, arg InsertPaymentIntentParams) (int64, error) + InsertRouteHop(ctx context.Context, arg InsertRouteHopParams) (int64, error) + InsertRouteHopAmp(ctx context.Context, arg InsertRouteHopAmpParams) error + InsertRouteHopBlinded(ctx context.Context, arg InsertRouteHopBlindedParams) error + InsertRouteHopMpp(ctx context.Context, arg InsertRouteHopMppParams) error IsClosedChannel(ctx context.Context, scid []byte) (bool, error) IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error) IsZombieChannel(ctx context.Context, arg IsZombieChannelParams) (bool, error) @@ -133,6 +162,7 @@ type Querier interface { OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error SetMigration(ctx context.Context, arg SetMigrationParams) error + SettleAttempt(ctx context.Context, arg SettleAttemptParams) error UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error) UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error) diff --git a/sqldb/sqlc/queries/payments.sql b/sqldb/sqlc/queries/payments.sql new file mode 100644 index 00000000000..db21a9be4b5 --- /dev/null +++ b/sqldb/sqlc/queries/payments.sql @@ -0,0 +1,368 @@ +/* ───────────────────────────────────────────── + fetch queries + ───────────────────────────────────────────── +*/ + +-- name: FilterPayments :many +SELECT + sqlc.embed(p), + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE ( + p.id > sqlc.narg('index_offset_get') OR + sqlc.narg('index_offset_get') IS NULL +) AND ( + p.id < sqlc.narg('index_offset_let') OR + sqlc.narg('index_offset_let') IS NULL +) AND ( + p.created_at >= sqlc.narg('created_after') OR + sqlc.narg('created_after') IS NULL +) AND ( + p.created_at <= sqlc.narg('created_before') OR + sqlc.narg('created_before') IS NULL +) AND ( + i.intent_type = sqlc.narg('intent_type') OR + sqlc.narg('intent_type') IS NULL OR i.intent_type IS NULL +) +ORDER BY + CASE WHEN sqlc.narg('reverse') = false OR sqlc.narg('reverse') IS NULL THEN p.id END ASC, + CASE WHEN sqlc.narg('reverse') = true THEN p.id END DESC +LIMIT @num_limit; + +-- name: FetchPayment :one +SELECT + sqlc.embed(p), + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.payment_identifier = $1; + +-- name: FetchPaymentsByIDs :many +SELECT + sqlc.embed(p), + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.id IN (sqlc.slice('payment_ids')/*SLICE:payment_ids*/); + +-- name: CountPayments :one +SELECT COUNT(*) FROM payments; + + +-- This fetches all htlc attempts for a payment. +-- name: FetchHtlcAttemptsForPayment :many +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key, + hr.resolution_type, + hr.resolution_time, + hr.failure_source_index, + hr.htlc_fail_reason, + hr.failure_msg, + hr.settle_preimage +FROM payment_htlc_attempts ha +LEFT JOIN payment_htlc_attempt_resolutions hr ON hr.attempt_index = ha.attempt_index +WHERE ha.payment_id = $1 +ORDER BY ha.attempt_time ASC; + +-- name: FetchAllInflightAttempts :many +-- Fetch all inflight attempts across all payments +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key +FROM payment_htlc_attempts ha +WHERE NOT EXISTS ( + SELECT 1 FROM payment_htlc_attempt_resolutions hr + WHERE hr.attempt_index = ha.attempt_index +); + +-- name: FetchHopsForAttempt :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index = $1 +ORDER BY h.hop_index ASC; + +-- name: FetchHopsForAttempts :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index IN (sqlc.slice('htlc_attempt_indices')/*SLICE:htlc_attempt_indices*/); + +-- name: FetchPaymentLevelFirstHopCustomRecords :many +SELECT + l.id, + l.payment_id, + l.key, + l.value +FROM payment_first_hop_custom_records l +WHERE l.payment_id = $1; + +-- name: FetchRouteLevelFirstHopCustomRecords :many +SELECT + l.id, + l.htlc_attempt_index, + l.key, + l.value +FROM payment_attempt_first_hop_custom_records l +WHERE l.htlc_attempt_index IN (sqlc.slice('htlc_attempt_indices')/*SLICE:htlc_attempt_indices*/); + +-- name: FetchHopLevelCustomRecords :many +SELECT + l.id, + l.hop_id, + l.key, + l.value +FROM payment_hop_custom_records l +WHERE l.hop_id IN (sqlc.slice('hop_ids')/*SLICE:hop_ids*/); + + +-- name: DeletePayment :exec +DELETE FROM payments WHERE id = $1; + +-- name: DeleteFailedAttempts :exec +DELETE FROM payment_htlc_attempts WHERE payment_id = $1 AND attempt_index IN ( + SELECT attempt_index FROM payment_htlc_attempt_resolutions WHERE resolution_type = 2 +); + +-- name: InsertPaymentIntent :one +-- Insert a new payment intent and return its ID. +INSERT INTO payment_intents ( + intent_type, + intent_payload) +VALUES ( + @intent_type, + @intent_payload +) +ON CONFLICT (intent_type, intent_payload) DO UPDATE SET + intent_type = EXCLUDED.intent_type, + intent_payload = EXCLUDED.intent_payload +RETURNING id; + +-- name: InsertPayment :one +-- Insert a new payment with the given intent ID and return its ID. +INSERT INTO payments ( + intent_id, + amount_msat, + created_at, + payment_identifier, + fail_reason) +VALUES ( + @intent_id, + @amount_msat, + @created_at, + @payment_identifier, + NULL +) +RETURNING id; + +-- name: InsertPaymentFirstHopCustomRecord :exec +INSERT INTO payment_first_hop_custom_records ( + payment_id, + key, + value +) +VALUES ( + @payment_id, + @key, + @value +); + +-- name: InsertHtlcAttempt :one +INSERT INTO payment_htlc_attempts ( + payment_id, + attempt_index, + session_key, + attempt_time, + payment_hash, + first_hop_amount_msat, + route_total_time_lock, + route_total_amount, + route_source_key) +VALUES ( + @payment_id, + @attempt_index, + @session_key, + @attempt_time, + @payment_hash, + @first_hop_amount_msat, + @route_total_time_lock, + @route_total_amount, + @route_source_key) +RETURNING id; + +-- name: InsertPaymentAttemptFirstHopCustomRecord :exec +INSERT INTO payment_attempt_first_hop_custom_records ( + htlc_attempt_index, + key, + value +) +VALUES ( + @htlc_attempt_index, + @key, + @value +); + +-- name: InsertRouteHop :one +INSERT INTO payment_route_hops ( + htlc_attempt_index, + hop_index, + pub_key, + scid, + outgoing_time_lock, + amt_to_forward, + meta_data +) +VALUES ( + @htlc_attempt_index, + @hop_index, + @pub_key, + @scid, + @outgoing_time_lock, + @amt_to_forward, + @meta_data +) +RETURNING id; + +-- name: InsertRouteHopMpp :exec +INSERT INTO payment_route_hop_mpp ( + hop_id, + payment_addr, + total_msat +) +VALUES ( + @hop_id, + @payment_addr, + @total_msat +); + +-- name: InsertRouteHopAmp :exec +INSERT INTO payment_route_hop_amp ( + hop_id, + root_share, + set_id, + child_index +) +VALUES ( + @hop_id, + @root_share, + @set_id, + @child_index +); + +-- name: InsertRouteHopBlinded :exec +INSERT INTO payment_route_hop_blinded ( + hop_id, + encrypted_data, + blinding_point, + blinded_path_total_amt +) +VALUES ( + @hop_id, + @encrypted_data, + @blinding_point, + @blinded_path_total_amt +); + +-- name: InsertPaymentHopCustomRecord :exec +INSERT INTO payment_hop_custom_records ( + hop_id, + key, + value +) +VALUES ( + @hop_id, + @key, + @value +); + +-- name: SettleAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + settle_preimage +) +VALUES ( + @attempt_index, + @resolution_time, + @resolution_type, + @settle_preimage +); + +-- name: FailAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + failure_source_index, + htlc_fail_reason, + failure_msg +) +VALUES ( + @attempt_index, + @resolution_time, + @resolution_type, + @failure_source_index, + @htlc_fail_reason, + @failure_msg +); + +-- name: FailPayment :execresult +UPDATE payments SET fail_reason = $1 WHERE payment_identifier = $2;