diff --git a/docs/release-notes/release-notes-0.21.0.md b/docs/release-notes/release-notes-0.21.0.md index bb02b9c276..d7f416f5be 100644 --- a/docs/release-notes/release-notes-0.21.0.md +++ b/docs/release-notes/release-notes-0.21.0.md @@ -54,6 +54,12 @@ refacotor the payment related LND code to make it more modular. * Implement the SQL backend for the [payments database](https://github.com/lightningnetwork/lnd/pull/9147) + * Implement query methods for the [payments db + SQL Backend](https://github.com/lightningnetwork/lnd/pull/10287) + * Implement insert methods for the [payments db + SQL Backend](https://github.com/lightningnetwork/lnd/pull/10291) + * Finalize SQL payments implementation [enabling unit and itests + for SQL backend](https://github.com/lightningnetwork/lnd/pull/10292) ## Code Health diff --git a/itest/lnd_payment_test.go b/itest/lnd_payment_test.go index 37aff05226..f683cd44a8 100644 --- a/itest/lnd_payment_test.go +++ b/itest/lnd_payment_test.go @@ -504,61 +504,86 @@ func testListPayments(ht *lntest.HarnessTest) { expected bool } - // Create test cases to check the timestamp filters. - createCases := func(createTimeSeconds uint64) []testCase { + // Create test cases with proper rounding for start and end dates. + createCases := func(startTimeSeconds, + endTimeSeconds uint64) []testCase { + return []testCase{ { // Use a start date same as the creation date - // should return us the item. + // (truncated) should return us the item. name: "exact start date", - startDate: createTimeSeconds, + startDate: startTimeSeconds, expected: true, }, { // Use an earlier start date should return us // the item. name: "earlier start date", - startDate: createTimeSeconds - 1, + startDate: startTimeSeconds - 1, expected: true, }, { // Use a future start date should return us // nothing. name: "future start date", - startDate: createTimeSeconds + 1, + startDate: startTimeSeconds + 1, expected: false, }, { // Use an end date same as the creation date - // should return us the item. + // (ceiling) should return us the item. name: "exact end date", - endDate: createTimeSeconds, + endDate: endTimeSeconds, expected: true, }, { // Use an end date in the future should return // us the item. name: "future end date", - endDate: createTimeSeconds + 1, + endDate: endTimeSeconds + 1, expected: true, }, { // Use an earlier end date should return us // nothing. - name: "earlier end date", - endDate: createTimeSeconds - 1, + name: "earlier end date", + // The native sql backend has a higher + // precision than the kv backend, the native sql + // backend uses microseconds, the kv backend + // when filtering uses seconds so we need to + // subtract 2 seconds to ensure the payment is + // not included. + // We could also truncate before inserting + // into the sql db but I rather relax this test + // here. + endDate: endTimeSeconds - 2, expected: false, }, } } - // Get the payment creation time in seconds. - paymentCreateSeconds := uint64( - p.CreationTimeNs / time.Second.Nanoseconds(), + // Get the payment creation time in seconds, using different approaches + // for start and end date comparisons to avoid rounding issues. + creationTime := time.Unix(0, p.CreationTimeNs) + + // For start date comparisons: use truncation (floor) to include + // payments from the beginning of that second. + paymentCreateSecondsStart := uint64( + creationTime.Truncate(time.Second).Unix(), + ) + + // For end date comparisons: use ceiling to include payments up to the + // end of that second. + paymentCreateSecondsEnd := uint64( + (p.CreationTimeNs + time.Second.Nanoseconds() - 1) / + time.Second.Nanoseconds(), ) // Create test cases from the payment creation time. - testCases := createCases(paymentCreateSeconds) + testCases := createCases( + paymentCreateSecondsStart, paymentCreateSecondsEnd, + ) // We now check the timestamp filters in `ListPayments`. for _, tc := range testCases { @@ -578,7 +603,9 @@ func testListPayments(ht *lntest.HarnessTest) { } // Create test cases from the invoice creation time. - testCases = createCases(uint64(invoice.CreationDate)) + testCases = createCases( + uint64(invoice.CreationDate), uint64(invoice.CreationDate), + ) // We now do the same check for `ListInvoices`. for _, tc := range testCases { diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 377d0be3ea..04faa7840b 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1745,6 +1745,8 @@ func (r *RouterBackend) MarshallPayment(payment *paymentsdb.MPPayment) ( // If any of the htlcs have settled, extract a valid // preimage. if htlc.Settle != nil { + // For AMP payments all hashes will be different so we + // will depict the last htlc preimage. preimage = htlc.Settle.Preimage fee += htlc.Route.TotalFees() } diff --git a/payments/db/errors.go b/payments/db/errors.go index 40e37d95c6..424b63c7bd 100644 --- a/payments/db/errors.go +++ b/payments/db/errors.go @@ -130,4 +130,8 @@ var ( // NOTE: Only used for the kv backend. ErrNoSequenceNrIndex = errors.New("payment sequence number index " + "does not exist") + + // errMaxPaymentsReached is used internally to signal that the maximum + // number of payments has been reached during a paginated query. + errMaxPaymentsReached = errors.New("max payments reached") ) diff --git a/payments/db/kv_store_test.go b/payments/db/kv_store_test.go index 73df8ead4e..1d3652409f 100644 --- a/payments/db/kv_store_test.go +++ b/payments/db/kv_store_test.go @@ -2,6 +2,7 @@ package paymentsdb import ( "bytes" + "crypto/sha256" "math" "reflect" "testing" @@ -61,10 +62,15 @@ func TestKVStoreDeleteNonInFlight(t *testing.T) { var numSuccess, numInflight int for _, p := range payments { - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash( + t, 0, genSessionKey(t), rhash, + ) + require.NoError(t, err) // Sends base htlc message which initiate StatusInFlight. err = paymentDB.InitPayment(info.PaymentIdentifier, info) @@ -474,7 +480,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { paymentDB := NewKVTestDB(t) // Generate a test payment which does not have duplicates. - noDuplicates, _, _, err := genInfo(t) + noDuplicates, _, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -490,7 +496,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { require.NoError(t, err) // Generate a test payment which we will add duplicates to. - hasDuplicates, _, preimg, err := genInfo(t) + hasDuplicates, preimg, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -648,7 +654,7 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) // Generate fake information for the duplicate payment. - info, _, _, err := genInfo(t) + info, _, err := genInfo(t) require.NoError(t, err) // Write the payment info to disk under the creation info key. This code @@ -680,17 +686,19 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) } -// TestQueryPayments tests retrieval of payments with forwards and reversed -// queries. -// -// TODO(ziggie): Make this test db agnostic. -func TestQueryPayments(t *testing.T) { - // Define table driven test for QueryPayments. +// TestKVStoreQueryPaymentsDuplicates tests the KV store's legacy duplicate +// payment handling. This tests the specific case where duplicate payments +// are stored in a nested bucket within the parent payment bucket. +func TestKVStoreQueryPaymentsDuplicates(t *testing.T) { + t.Parallel() + // Test payments have sequence indices [1, 3, 4, 5, 6, 7]. // Note that the payment with index 7 has the same payment hash as 6, // and is stored in a nested bucket within payment 6 rather than being - // its own entry in the payments bucket. We do this to test retrieval - // of legacy payments. + // its own entry in the payments bucket. This tests retrieval of legacy + // duplicate payments which is KV-store specific. + // These test cases focus on validating that duplicate payments (seq 7, + // nested under payment 6) are correctly returned in queries. tests := []struct { name string query Query @@ -702,31 +710,20 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs []uint64 }{ { - name: "IndexOffset at the end of the payments range", + name: "query includes duplicate payment in forward " + + "order", query: Query{ - IndexOffset: 7, - MaxPayments: 7, + IndexOffset: 5, + MaxPayments: 3, Reversed: false, IncludeIncomplete: true, }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "query in forwards order, start at beginning", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, + firstIndex: 6, + lastIndex: 7, + expectedSeqNrs: []uint64{6, 7}, }, { - name: "query in forwards order, start at end, overflow", + name: "query duplicate payment at end", query: Query{ IndexOffset: 6, MaxPayments: 2, @@ -738,44 +735,7 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{7}, }, { - name: "start at offset index outside of payments", - query: Query{ - IndexOffset: 20, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 0, - lastIndex: 0, - expectedSeqNrs: nil, - }, - { - name: "overflow in forwards order", - query: Query{ - IndexOffset: 4, - MaxPayments: math.MaxUint64, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 5, - lastIndex: 7, - expectedSeqNrs: []uint64{5, 6, 7}, - }, - { - name: "start at offset index outside of payments, " + - "reversed order", - query: Query{ - IndexOffset: 9, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 6, - lastIndex: 7, - expectedSeqNrs: []uint64{6, 7}, - }, - { - name: "query in reverse order, start at end", + name: "query includes duplicate in reverse order", query: Query{ IndexOffset: 0, MaxPayments: 2, @@ -787,36 +747,11 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{6, 7}, }, { - name: "query in reverse order, starting in middle", - query: Query{ - IndexOffset: 4, - MaxPayments: 2, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "query in reverse order, starting in middle, " + - "with underflow", - query: Query{ - IndexOffset: 4, - MaxPayments: 5, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 3, - expectedSeqNrs: []uint64{1, 3}, - }, - { - name: "all payments in reverse, order maintained", + name: "query all payments includes duplicate", query: Query{ IndexOffset: 0, - MaxPayments: 7, - Reversed: true, + MaxPayments: math.MaxUint64, + Reversed: false, IncludeIncomplete: true, }, firstIndex: 1, @@ -824,7 +759,7 @@ func TestQueryPayments(t *testing.T) { expectedSeqNrs: []uint64{1, 3, 4, 5, 6, 7}, }, { - name: "exclude incomplete payments", + name: "exclude incomplete includes duplicate", query: Query{ IndexOffset: 0, MaxPayments: 7, @@ -835,96 +770,6 @@ func TestQueryPayments(t *testing.T) { lastIndex: 7, expectedSeqNrs: []uint64{7}, }, - { - name: "query payments at index gap", - query: Query{ - IndexOffset: 1, - MaxPayments: 7, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 7, - expectedSeqNrs: []uint64{3, 4, 5, 6, 7}, - }, - { - name: "query payments reverse before index gap", - query: Query{ - IndexOffset: 3, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments reverse on index gap", - query: Query{ - IndexOffset: 2, - MaxPayments: 7, - Reversed: true, - IncludeIncomplete: true, - }, - firstIndex: 1, - lastIndex: 1, - expectedSeqNrs: []uint64{1}, - }, - { - name: "query payments forward on index gap", - query: Query{ - IndexOffset: 2, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - }, - firstIndex: 3, - lastIndex: 4, - expectedSeqNrs: []uint64{3, 4}, - }, - { - name: "query in forwards order, with start creation " + - "time", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 5, - }, - firstIndex: 5, - lastIndex: 6, - expectedSeqNrs: []uint64{5, 6}, - }, - { - name: "query in forwards order, with start creation " + - "time at end, overflow", - query: Query{ - IndexOffset: 0, - MaxPayments: 2, - Reversed: false, - IncludeIncomplete: true, - CreationDateStart: 7, - }, - firstIndex: 7, - lastIndex: 7, - expectedSeqNrs: []uint64{7}, - }, - { - name: "query with start and end creation time", - query: Query{ - IndexOffset: 9, - MaxPayments: math.MaxUint64, - Reversed: true, - IncludeIncomplete: true, - CreationDateStart: 3, - CreationDateEnd: 5, - }, - firstIndex: 3, - lastIndex: 5, - expectedSeqNrs: []uint64{3, 4, 5}, - }, } for _, tt := range tests { @@ -956,7 +801,7 @@ func TestQueryPayments(t *testing.T) { for i := 0; i < nonDuplicatePayments; i++ { // Generate a test payment. - info, _, preimg, err := genInfo(t) + info, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to create test "+ "payment: %v", err) diff --git a/payments/db/payment_test.go b/payments/db/payment_test.go index 534a1b1e5e..f2d772c800 100644 --- a/payments/db/payment_test.go +++ b/payments/db/payment_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "math" "reflect" "testing" "time" @@ -58,7 +59,10 @@ var ( ChannelID: 12345, OutgoingTimeLock: 111, AmtToForward: 555, - LegacyPayload: true, + + // Only tlv payloads are now supported in LND therefore we set + // LegacyPayload to false. + LegacyPayload: false, } testRoute = route.Route{ @@ -116,13 +120,20 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) { attemptID := uint64(0) for i := 0; i < len(payments); i++ { - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Set the payment id accordingly in the payments slice. payments[i].id = info.PaymentIdentifier - attempt.AttemptID = attemptID + attempt, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + attemptID++ // Init the payment. @@ -148,7 +159,10 @@ func createTestPayments(t *testing.T, p DB, payments []*payment) { // Depending on the test case, fail or succeed the next // attempt. - attempt.AttemptID = attemptID + attempt, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) attemptID++ _, err = p.RegisterAttempt(info.PaymentIdentifier, attempt) @@ -334,7 +348,7 @@ func assertDBPayments(t *testing.T, paymentDB DB, payments []*payment) { } // genPreimage generates a random preimage. -func genPreimage(t *testing.T) ([32]byte, error) { +func genPreimage(t *testing.T) (lntypes.Preimage, error) { t.Helper() var preimage [32]byte @@ -345,31 +359,85 @@ func genPreimage(t *testing.T) ([32]byte, error) { return preimage, nil } -// genInfo generates a payment creation info, an attempt info and a preimage. -func genInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo, - lntypes.Preimage, error) { +// genSessionKey generates a new random private key for use as a session key. +func genSessionKey(t *testing.T) *btcec.PrivateKey { + t.Helper() - preimage, err := genPreimage(t) - if err != nil { - return nil, nil, preimage, fmt.Errorf("unable to "+ - "generate preimage: %v", err) + key, err := btcec.NewPrivateKey() + require.NoError(t, err) + + return key +} + +// genPaymentCreationInfo generates a payment creation info. +func genPaymentCreationInfo(t *testing.T, + paymentHash lntypes.Hash) *PaymentCreationInfo { + + t.Helper() + + // Add constant first hop custom records for testing for testing + // purposes. + firstHopCustomRecords := lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 1: []byte("test_record_1"), + lnwire.MinCustomRecordsTlvType + 2: []byte("test_record_2"), + lnwire.MinCustomRecordsTlvType + 3: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, + }, } + return &PaymentCreationInfo{ + PaymentIdentifier: paymentHash, + Value: testRoute.ReceiverAmt(), + CreationTime: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + FirstHopCustomRecords: firstHopCustomRecords, + } +} + +// genPreimageAndHash generates a random preimage and its corresponding hash. +func genPreimageAndHash(t *testing.T) (lntypes.Preimage, lntypes.Hash, error) { + t.Helper() + + preimage, err := genPreimage(t) + require.NoError(t, err) + rhash := sha256.Sum256(preimage[:]) var hash lntypes.Hash copy(hash[:], rhash[:]) + return preimage, hash, nil +} + +// genAttemptWithPreimage generates an HTLC attempt and returns both the +// attempt and preimage. +func genAttemptWithHash(t *testing.T, attemptID uint64, + sessionKey *btcec.PrivateKey, hash lntypes.Hash) (*HTLCAttemptInfo, + error) { + + t.Helper() + attempt, err := NewHtlcAttempt( - 0, priv, *testRoute.Copy(), time.Time{}, &hash, + attemptID, sessionKey, *testRoute.Copy(), time.Time{}, + &hash, ) - require.NoError(t, err) + if err != nil { + return nil, err + } - return &PaymentCreationInfo{ - PaymentIdentifier: rhash, - Value: testRoute.ReceiverAmt(), - CreationTime: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte("hola"), - }, &attempt.HTLCAttemptInfo, preimage, nil + return &attempt.HTLCAttemptInfo, nil +} + +// genInfo generates a payment creation info and the corresponding preimage. +func genInfo(t *testing.T) (*PaymentCreationInfo, lntypes.Preimage, error) { + preimage, _, err := genPreimageAndHash(t) + if err != nil { + return nil, preimage, err + } + + rhash := sha256.Sum256(preimage[:]) + creationInfo := genPaymentCreationInfo(t, rhash) + + return creationInfo, preimage, nil } // TestDeleteFailedAttempts checks that DeleteFailedAttempts properly removes @@ -481,7 +549,17 @@ func TestMPPRecordValidation(t *testing.T) { paymentDB := NewTestDB(t) - info, attempt, _, err := genInfo(t) + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + + attemptID := uint64(0) + + attempt, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) require.NoError(t, err, "unable to generate htlc message") // Init the payment. @@ -502,29 +580,45 @@ func TestMPPRecordValidation(t *testing.T) { require.NoError(t, err, "unable to send htlc message") // Now try to register a non-MPP attempt, which should fail. - b := *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = nil - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + attemptID++ + attempt2, err := genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + attempt2.Route.FinalHop().MPP = nil + + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrMPPayment) // Try to register attempt one with a different payment address. - b.Route.FinalHop().MPP = record.NewMPP( + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value, [32]byte{2}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrMPPPaymentAddrMismatch) // Try registering one with a different total amount. - b.Route.FinalHop().MPP = record.NewMPP( + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value/2, [32]byte{1}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrMPPTotalAmountMismatch) // Create and init a new payment. This time we'll check that we cannot // register an MPP attempt if we already registered a non-MPP one. - info, attempt, _, err = genInfo(t) + preimg, err = genPreimage(t) + require.NoError(t, err) + + rhash = sha256.Sum256(preimg[:]) + info = genPaymentCreationInfo(t, rhash) + + attemptID++ + attempt, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + require.NoError(t, err, "unable to generate htlc message") err = paymentDB.InitPayment(info.PaymentIdentifier, info) @@ -535,13 +629,17 @@ func TestMPPRecordValidation(t *testing.T) { require.NoError(t, err, "unable to send htlc message") // Attempt to register an MPP attempt, which should fail. - b = *attempt - b.AttemptID = 1 - b.Route.FinalHop().MPP = record.NewMPP( + attemptID++ + attempt2, err = genAttemptWithHash( + t, attemptID, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + attempt2.Route.FinalHop().MPP = record.NewMPP( info.Value, [32]byte{1}, ) - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt2) require.ErrorIs(t, err, ErrNonMPPayment) } @@ -1197,8 +1295,11 @@ func TestSuccessesWithoutInFlight(t *testing.T) { paymentDB := NewTestDB(t) - info, _, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Attempt to complete the payment should fail. _, err = paymentDB.SettleAttempt( @@ -1217,8 +1318,11 @@ func TestFailsWithoutInFlight(t *testing.T) { paymentDB := NewTestDB(t) - info, _, _, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Calling Fail should return an error. _, err = paymentDB.Fail( @@ -1292,8 +1396,13 @@ func TestSwitchDoubleSend(t *testing.T) { paymentDB := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash(t, 0, genSessionKey(t), rhash) + require.NoError(t, err) // Sends base htlc message which initiate base status and move it to // StatusInFlight and verifies that it was changed. @@ -1365,8 +1474,13 @@ func TestSwitchFail(t *testing.T) { paymentDB := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - require.NoError(t, err, "unable to generate htlc message") + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) + attempt, err := genAttemptWithHash(t, 0, genSessionKey(t), rhash) + require.NoError(t, err) // Sends base htlc message which initiate StatusInFlight. err = paymentDB.InitPayment(info.PaymentIdentifier, info) @@ -1444,7 +1558,11 @@ func TestSwitchFail(t *testing.T) { assertPaymentInfo(t, paymentDB, info.PaymentIdentifier, info, nil, htlc) // Record another attempt. - attempt.AttemptID = 1 + attempt, err = genAttemptWithHash( + t, 1, genSessionKey(t), rhash, + ) + require.NoError(t, err) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, attempt) require.NoError(t, err, "unable to send htlc message") assertDBPaymentstatus( @@ -1522,16 +1640,15 @@ func TestMultiShard(t *testing.T) { runSubTest := func(t *testing.T, test testCase) { paymentDB := NewTestDB(t) - info, attempt, preimg, err := genInfo(t) - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } + preimg, err := genPreimage(t) + require.NoError(t, err) + + rhash := sha256.Sum256(preimg[:]) + info := genPaymentCreationInfo(t, rhash) // Init the payment, moving it to the StatusInFlight state. err = paymentDB.InitPayment(info.PaymentIdentifier, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } + require.NoError(t, err) assertPaymentIndex(t, paymentDB, info.PaymentIdentifier) assertDBPaymentstatus( @@ -1546,19 +1663,23 @@ func TestMultiShard(t *testing.T) { // attempts's value to one third of the payment amount, and // populate the MPP options. shardAmt := info.Value / 3 - attempt.Route.FinalHop().AmtToForward = shardAmt - attempt.Route.FinalHop().MPP = record.NewMPP( - info.Value, [32]byte{1}, - ) var attempts []*HTLCAttemptInfo for i := uint64(0); i < 3; i++ { - a := *attempt - a.AttemptID = i - attempts = append(attempts, &a) + a, err := genAttemptWithHash( + t, i, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + a.Route.FinalHop().AmtToForward = shardAmt + a.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + attempts = append(attempts, a) _, err = paymentDB.RegisterAttempt( - info.PaymentIdentifier, &a, + info.PaymentIdentifier, a, ) if err != nil { t.Fatalf("unable to send htlc message: %v", err) @@ -1569,7 +1690,7 @@ func TestMultiShard(t *testing.T) { ) htlc := &htlcStatus{ - HTLCAttemptInfo: &a, + HTLCAttemptInfo: a, } assertPaymentInfo( t, paymentDB, info.PaymentIdentifier, info, nil, @@ -1580,9 +1701,17 @@ func TestMultiShard(t *testing.T) { // For a fourth attempt, check that attempting to // register it will fail since the total sent amount // will be too large. - b := *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + b, err := genAttemptWithHash( + t, 3, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + b.Route.FinalHop().AmtToForward = shardAmt + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, b) require.ErrorIs(t, err, ErrValueExceedsAmt) // Fail the second attempt. @@ -1679,9 +1808,17 @@ func TestMultiShard(t *testing.T) { // Try to register yet another attempt. This should fail now // that the payment has reached a terminal condition. - b = *attempt - b.AttemptID = 3 - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) + b, err = genAttemptWithHash( + t, 3, genSessionKey(t), rhash, + ) + require.NoError(t, err) + + b.Route.FinalHop().AmtToForward = shardAmt + b.Route.FinalHop().MPP = record.NewMPP( + info.Value, [32]byte{1}, + ) + + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, b) if test.settleFirst { require.ErrorIs( t, err, ErrPaymentPendingSettled, @@ -1780,8 +1917,8 @@ func TestMultiShard(t *testing.T) { ) // Finally assert we cannot register more attempts. - _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, &b) - require.Equal(t, registerErr, err) + _, err = paymentDB.RegisterAttempt(info.PaymentIdentifier, b) + require.ErrorIs(t, err, registerErr) } for _, test := range tests { @@ -1793,3 +1930,477 @@ func TestMultiShard(t *testing.T) { }) } } + +// TestQueryPayments tests retrieval of payments with forwards and reversed +// queries. +func TestQueryPayments(t *testing.T) { + // Define table driven test for QueryPayments. + // Test payments have sequence indices [1, 3, 4, 5, 6]. + // Note that payment with index 2 is deleted to create a gap in the + // sequence numbers. + tests := []struct { + name string + query Query + firstIndex uint64 + lastIndex uint64 + + // expectedSeqNrs contains the set of sequence numbers we expect + // our query to return. + expectedSeqNrs []uint64 + }{ + { + name: "IndexOffset at the end of the payments range", + query: Query{ + IndexOffset: 6, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "query in forwards order, start at beginning", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in forwards order, start at end, overflow", + query: Query{ + IndexOffset: 5, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "start at offset index outside of payments", + query: Query{ + IndexOffset: 20, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 0, + lastIndex: 0, + expectedSeqNrs: nil, + }, + { + name: "overflow in forwards order", + query: Query{ + IndexOffset: 4, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "start at offset index outside of payments, " + + "reversed order", + query: Query{ + IndexOffset: 9, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in reverse order, start at end", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in reverse order, starting in middle", + query: Query{ + IndexOffset: 4, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "query in reverse order, starting in middle, " + + "with underflow", + query: Query{ + IndexOffset: 4, + MaxPayments: 5, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "all payments in reverse, order maintained", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 6, + expectedSeqNrs: []uint64{1, 3, 4, 5, 6}, + }, + { + name: "exclude incomplete payments", + query: Query{ + IndexOffset: 0, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: false, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "query payments at index gap", + query: Query{ + IndexOffset: 1, + MaxPayments: 7, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 6, + expectedSeqNrs: []uint64{3, 4, 5, 6}, + }, + { + name: "query payments reverse before index gap", + query: Query{ + IndexOffset: 3, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments reverse on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 7, + Reversed: true, + IncludeIncomplete: true, + }, + firstIndex: 1, + lastIndex: 1, + expectedSeqNrs: []uint64{1}, + }, + { + name: "query payments forward on index gap", + query: Query{ + IndexOffset: 2, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + }, + firstIndex: 3, + lastIndex: 4, + expectedSeqNrs: []uint64{3, 4}, + }, + { + name: "query in forwards order, with start creation " + + "time", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 5, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "query in forwards order, with start creation " + + "time at end, overflow", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CreationDateStart: 6, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + { + name: "query with start and end creation time", + query: Query{ + IndexOffset: 9, + MaxPayments: math.MaxUint64, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: 3, + CreationDateEnd: 5, + }, + firstIndex: 3, + lastIndex: 5, + expectedSeqNrs: []uint64{3, 4, 5}, + }, + { + name: "query with only end creation time", + query: Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: true, + CreationDateEnd: 4, + }, + firstIndex: 1, + lastIndex: 4, + expectedSeqNrs: []uint64{1, 3, 4}, + }, + { + name: "query reversed with creation date start", + query: Query{ + IndexOffset: 0, + MaxPayments: 3, + Reversed: true, + IncludeIncomplete: true, + CreationDateStart: 3, + }, + firstIndex: 4, + lastIndex: 6, + expectedSeqNrs: []uint64{4, 5, 6}, + }, + { + name: "count total with forward pagination", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: false, + IncludeIncomplete: true, + CountTotal: true, + }, + firstIndex: 1, + lastIndex: 3, + expectedSeqNrs: []uint64{1, 3}, + }, + { + name: "count total with reverse pagination", + query: Query{ + IndexOffset: 0, + MaxPayments: 2, + Reversed: true, + IncludeIncomplete: true, + CountTotal: true, + }, + firstIndex: 5, + lastIndex: 6, + expectedSeqNrs: []uint64{5, 6}, + }, + { + name: "count total with filters", + query: Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + Reversed: false, + IncludeIncomplete: false, + CountTotal: true, + }, + firstIndex: 6, + lastIndex: 6, + expectedSeqNrs: []uint64{6}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + paymentDB := NewTestDB(t) + + // Make a preliminary query to make sure it's ok to + // query when we have no payments. + resp, err := paymentDB.QueryPayments(ctx, tt.query) + require.NoError(t, err) + require.Len(t, resp.Payments, 0) + + // Populate the database with a set of test payments. + // We create 6 payments, deleting the payment at index + // 2 so that we cover the case where sequence numbers + // are missing. + numberOfPayments := 6 + + // Store payment info for all payments so we can delete + // one after all are created. + var paymentInfos []*PaymentCreationInfo + + // First, create all payments. + for i := range numberOfPayments { + // Generate a test payment. + info, _, err := genInfo(t) + require.NoError(t, err) + + // Override creation time to allow for testing + // of CreationDateStart and CreationDateEnd. + info.CreationTime = time.Unix(int64(i+1), 0) + + paymentInfos = append(paymentInfos, info) + + // Create a new payment entry in the database. + err = paymentDB.InitPayment( + info.PaymentIdentifier, info, + ) + require.NoError(t, err) + } + + // Now delete the payment at index 1 (the second + // payment). + pmt, err := paymentDB.FetchPayment( + paymentInfos[1].PaymentIdentifier, + ) + require.NoError(t, err) + + // We delete the whole payment. + err = paymentDB.DeletePayment( + paymentInfos[1].PaymentIdentifier, false, + ) + require.NoError(t, err) + + // Verify the payment is deleted. + _, err = paymentDB.FetchPayment( + paymentInfos[1].PaymentIdentifier, + ) + require.ErrorIs( + t, err, ErrPaymentNotInitiated, + ) + + // Verify the index is removed (KV store only). + assertNoIndex( + t, paymentDB, pmt.SequenceNum, + ) + + // For the last payment, settle it so we have at least + // one completed payment for the "exclude incomplete" + // test case. + lastPaymentInfo := paymentInfos[numberOfPayments-1] + attempt, err := NewHtlcAttempt( + 1, priv, testRoute, + time.Unix(100, 0), + &lastPaymentInfo.PaymentIdentifier, + ) + require.NoError(t, err) + + _, err = paymentDB.RegisterAttempt( + lastPaymentInfo.PaymentIdentifier, + &attempt.HTLCAttemptInfo, + ) + require.NoError(t, err) + + var preimg lntypes.Preimage + copy(preimg[:], rev[:]) + + _, err = paymentDB.SettleAttempt( + lastPaymentInfo.PaymentIdentifier, + attempt.AttemptID, + &HTLCSettleInfo{ + Preimage: preimg, + }, + ) + require.NoError(t, err) + + // Fetch all payments in the database. + resp, err = paymentDB.QueryPayments( + ctx, Query{ + IndexOffset: 0, + MaxPayments: math.MaxUint64, + IncludeIncomplete: true, + }, + ) + require.NoError(t, err) + + allPayments := resp.Payments + + if len(allPayments) != 5 { + t.Fatalf("Number of payments received does "+ + "not match expected one. Got %v, "+ + "want %v.", len(allPayments), 5) + } + + querySlice, err := paymentDB.QueryPayments( + ctx, tt.query, + ) + require.NoError(t, err) + + if tt.firstIndex != querySlice.FirstIndexOffset || + tt.lastIndex != querySlice.LastIndexOffset { + + t.Errorf("First or last index does not match "+ + "expected index. Want (%d, %d), "+ + "got (%d, %d).", + tt.firstIndex, tt.lastIndex, + querySlice.FirstIndexOffset, + querySlice.LastIndexOffset) + } + + if len(querySlice.Payments) != len(tt.expectedSeqNrs) { + t.Errorf("expected: %v payments, got: %v", + len(tt.expectedSeqNrs), + len(querySlice.Payments)) + } + + for i, seqNr := range tt.expectedSeqNrs { + q := querySlice.Payments[i] + if seqNr != q.SequenceNum { + t.Errorf("sequence numbers do not "+ + "match, got %v, want %v", + q.SequenceNum, seqNr) + } + } + + // Verify CountTotal is set correctly when requested. + if tt.query.CountTotal { + // We should have 5 total payments + // (6 created - 1 deleted). + expectedTotal := uint64(5) + if querySlice.TotalCount != expectedTotal { + t.Errorf("expected total count %v, "+ + "got %v", expectedTotal, + querySlice.TotalCount) + } + } else if querySlice.TotalCount != 0 { + t.Errorf("expected total count 0 when "+ + "CountTotal=false, got %v", + querySlice.TotalCount) + } + }) + } +} diff --git a/payments/db/sql_converters.go b/payments/db/sql_converters.go new file mode 100644 index 0000000000..7f6d222def --- /dev/null +++ b/payments/db/sql_converters.go @@ -0,0 +1,277 @@ +package paymentsdb + +import ( + "bytes" + "fmt" + "sort" + "strconv" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/sqldb/sqlc" + "github.com/lightningnetwork/lnd/tlv" +) + +// dbPaymentToCreationInfo converts database payment data to +// PaymentCreationInfo. +func dbPaymentToCreationInfo(paymentIdentifier []byte, amountMsat int64, + createdAt time.Time, intentPayload []byte, + firstHopCustomRecords lnwire.CustomRecords) *PaymentCreationInfo { + + // This is the payment hash for non-AMP payments and the SetID for AMP + // payments. + var identifier lntypes.Hash + copy(identifier[:], paymentIdentifier) + + return &PaymentCreationInfo{ + PaymentIdentifier: identifier, + Value: lnwire.MilliSatoshi(amountMsat), + // The creation time is stored in the database as UTC but here + // we convert it to local time. + CreationTime: createdAt.Local(), + PaymentRequest: intentPayload, + FirstHopCustomRecords: firstHopCustomRecords, + } +} + +// dbAttemptToHTLCAttempt converts a database HTLC attempt to an HTLCAttempt. +func dbAttemptToHTLCAttempt( + dbAttempt sqlc.FetchHtlcAttemptsForPaymentRow, + hops []sqlc.FetchHopsForAttemptsRow, + hopCustomRecords map[int64][]sqlc.PaymentHopCustomRecord, + routeCustomRecords []sqlc.PaymentAttemptFirstHopCustomRecord) ( + *HTLCAttempt, error) { + + // Convert route-level first hop custom records to CustomRecords map. + var firstHopWireCustomRecords lnwire.CustomRecords + if len(routeCustomRecords) > 0 { + firstHopWireCustomRecords = make(lnwire.CustomRecords) + for _, record := range routeCustomRecords { + firstHopWireCustomRecords[uint64(record.Key)] = + record.Value + } + } + + // Build the route from the database data. + route, err := dbDataToRoute( + hops, hopCustomRecords, dbAttempt.FirstHopAmountMsat, + dbAttempt.RouteTotalTimeLock, dbAttempt.RouteTotalAmount, + dbAttempt.RouteSourceKey, firstHopWireCustomRecords, + ) + if err != nil { + return nil, fmt.Errorf("failed to convert to route: %w", + err) + } + + hash, err := lntypes.MakeHash(dbAttempt.PaymentHash) + if err != nil { + return nil, fmt.Errorf("failed to parse payment "+ + "hash: %w", err) + } + + // Create the attempt info. + var sessionKey [32]byte + copy(sessionKey[:], dbAttempt.SessionKey) + + info := HTLCAttemptInfo{ + AttemptID: uint64(dbAttempt.AttemptIndex), + sessionKey: sessionKey, + Route: *route, + AttemptTime: dbAttempt.AttemptTime, + Hash: &hash, + } + + attempt := &HTLCAttempt{ + HTLCAttemptInfo: info, + } + + // Add settlement info if present. + if dbAttempt.ResolutionType.Valid && + HTLCAttemptResolutionType(dbAttempt.ResolutionType.Int32) == + HTLCAttemptResolutionSettled { + + var preimage lntypes.Preimage + copy(preimage[:], dbAttempt.SettlePreimage) + + attempt.Settle = &HTLCSettleInfo{ + Preimage: preimage, + SettleTime: dbAttempt.ResolutionTime.Time, + } + } + + // Add failure info if present. + if dbAttempt.ResolutionType.Valid && + HTLCAttemptResolutionType(dbAttempt.ResolutionType.Int32) == + HTLCAttemptResolutionFailed { + + failure := &HTLCFailInfo{ + FailTime: dbAttempt.ResolutionTime.Time, + } + + if dbAttempt.HtlcFailReason.Valid { + failure.Reason = HTLCFailReason( + dbAttempt.HtlcFailReason.Int32, + ) + } + + if dbAttempt.FailureSourceIndex.Valid { + failure.FailureSourceIndex = uint32( + dbAttempt.FailureSourceIndex.Int32, + ) + } + + // Decode the failure message if present. + if len(dbAttempt.FailureMsg) > 0 { + msg, err := lnwire.DecodeFailureMessage( + bytes.NewReader(dbAttempt.FailureMsg), 0, + ) + if err != nil { + return nil, fmt.Errorf("failed to decode "+ + "failure message: %w", err) + } + failure.Message = msg + } + + attempt.Failure = failure + } + + return attempt, nil +} + +// dbDataToRoute converts database route data to a route.Route. +func dbDataToRoute(hops []sqlc.FetchHopsForAttemptsRow, + hopCustomRecords map[int64][]sqlc.PaymentHopCustomRecord, + firstHopAmountMsat int64, totalTimeLock int32, totalAmount int64, + sourceKey []byte, firstHopWireCustomRecords lnwire.CustomRecords) ( + *route.Route, error) { + + if len(hops) == 0 { + return nil, fmt.Errorf("no hops provided") + } + + // Sort hops by hop index. + sort.Slice(hops, func(i, j int) bool { + return hops[i].HopIndex < hops[j].HopIndex + }) + + routeHops := make([]*route.Hop, len(hops)) + + for i, hop := range hops { + pubKey, err := route.NewVertexFromBytes(hop.PubKey) + if err != nil { + return nil, fmt.Errorf("failed to parse pub key: %w", + err) + } + + var channelID uint64 + if hop.Scid != "" { + // The SCID is stored as a string representation + // of the uint64. + var err error + channelID, err = strconv.ParseUint(hop.Scid, 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse "+ + "scid: %w", err) + } + } + + routeHop := &route.Hop{ + PubKeyBytes: pubKey, + ChannelID: channelID, + OutgoingTimeLock: uint32(hop.OutgoingTimeLock), + AmtToForward: lnwire.MilliSatoshi(hop.AmtToForward), + } + + // Add MPP record if present. + if len(hop.MppPaymentAddr) > 0 { + var paymentAddr [32]byte + copy(paymentAddr[:], hop.MppPaymentAddr) + routeHop.MPP = record.NewMPP( + lnwire.MilliSatoshi(hop.MppTotalMsat.Int64), + paymentAddr, + ) + } + + // Add AMP record if present. + if len(hop.AmpRootShare) > 0 { + var rootShare [32]byte + copy(rootShare[:], hop.AmpRootShare) + var setID [32]byte + copy(setID[:], hop.AmpSetID) + + routeHop.AMP = record.NewAMP( + rootShare, setID, + uint32(hop.AmpChildIndex.Int32), + ) + } + + // Add blinding point if present (only for introduction node + // in blinded route). + if len(hop.BlindingPoint) > 0 { + pubKey, err := btcec.ParsePubKey(hop.BlindingPoint) + if err != nil { + return nil, fmt.Errorf("failed to parse "+ + "blinding point: %w", err) + } + routeHop.BlindingPoint = pubKey + } + + // Add encrypted data if present (for all blinded hops). + if len(hop.EncryptedData) > 0 { + routeHop.EncryptedData = hop.EncryptedData + } + + // Add total amount if present (only for final hop in blinded + // route). + if hop.BlindedPathTotalAmt.Valid { + routeHop.TotalAmtMsat = lnwire.MilliSatoshi( + hop.BlindedPathTotalAmt.Int64, + ) + } + + // Add hop-level custom records. + if records, ok := hopCustomRecords[hop.ID]; ok { + routeHop.CustomRecords = make( + record.CustomSet, + ) + for _, rec := range records { + routeHop.CustomRecords[uint64(rec.Key)] = + rec.Value + } + } + + // Add metadata if present. + if len(hop.MetaData) > 0 { + routeHop.Metadata = hop.MetaData + } + + routeHops[i] = routeHop + } + + // Parse the source node public key. + var sourceNode route.Vertex + copy(sourceNode[:], sourceKey) + + route := &route.Route{ + TotalTimeLock: uint32(totalTimeLock), + TotalAmount: lnwire.MilliSatoshi(totalAmount), + SourcePubKey: sourceNode, + Hops: routeHops, + FirstHopWireCustomRecords: firstHopWireCustomRecords, + } + + // Set the first hop amount if it is set. + if firstHopAmountMsat != 0 { + route.FirstHopAmount = tlv.NewRecordT[tlv.TlvType0]( + tlv.NewBigSizeT(lnwire.MilliSatoshi( + firstHopAmountMsat, + )), + ) + } + + return route, nil +} diff --git a/payments/db/sql_store.go b/payments/db/sql_store.go index 12585caf64..63b5c968b2 100644 --- a/payments/db/sql_store.go +++ b/payments/db/sql_store.go @@ -1,14 +1,91 @@ package paymentsdb import ( + "bytes" + "context" + "database/sql" + "errors" "fmt" + "math" + "strconv" + "time" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/sqlc" +) + +// PaymentIntentType represents the type of payment intent. +type PaymentIntentType int16 + +const ( + // PaymentIntentTypeBolt11 indicates a BOLT11 invoice payment. + PaymentIntentTypeBolt11 PaymentIntentType = 0 +) + +// HTLCAttemptResolutionType represents the type of HTLC attempt resolution. +type HTLCAttemptResolutionType int32 + +const ( + // HTLCAttemptResolutionSettled indicates the HTLC attempt was settled + // successfully with a preimage. + HTLCAttemptResolutionSettled HTLCAttemptResolutionType = 1 + + // HTLCAttemptResolutionFailed indicates the HTLC attempt failed. + HTLCAttemptResolutionFailed HTLCAttemptResolutionType = 2 ) // SQLQueries is a subset of the sqlc.Querier interface that can be used to // execute queries against the SQL payments tables. +// +//nolint:ll,interfacebloat type SQLQueries interface { + /* + Payment DB read operations. + */ + FilterPayments(ctx context.Context, query sqlc.FilterPaymentsParams) ([]sqlc.FilterPaymentsRow, error) + FetchPayment(ctx context.Context, paymentIdentifier []byte) (sqlc.FetchPaymentRow, error) + FetchPaymentsByIDs(ctx context.Context, paymentIDs []int64) ([]sqlc.FetchPaymentsByIDsRow, error) + + CountPayments(ctx context.Context) (int64, error) + + FetchHtlcAttemptsForPayment(ctx context.Context, paymentID int64) ([]sqlc.FetchHtlcAttemptsForPaymentRow, error) + FetchAllInflightAttempts(ctx context.Context) ([]sqlc.PaymentHtlcAttempt, error) + FetchHopsForAttempt(ctx context.Context, htlcAttemptIndex int64) ([]sqlc.FetchHopsForAttemptRow, error) + FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]sqlc.FetchHopsForAttemptsRow, error) + + FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentID int64) ([]sqlc.PaymentFirstHopCustomRecord, error) + FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]sqlc.PaymentAttemptFirstHopCustomRecord, error) + FetchHopLevelCustomRecords(ctx context.Context, hopIDs []int64) ([]sqlc.PaymentHopCustomRecord, error) + + /* + Payment DB write operations. + */ + InsertPaymentIntent(ctx context.Context, arg sqlc.InsertPaymentIntentParams) (int64, error) + InsertPayment(ctx context.Context, arg sqlc.InsertPaymentParams) (int64, error) + InsertPaymentFirstHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentFirstHopCustomRecordParams) error + + InsertHtlcAttempt(ctx context.Context, arg sqlc.InsertHtlcAttemptParams) (int64, error) + InsertRouteHop(ctx context.Context, arg sqlc.InsertRouteHopParams) (int64, error) + InsertRouteHopMpp(ctx context.Context, arg sqlc.InsertRouteHopMppParams) error + InsertRouteHopAmp(ctx context.Context, arg sqlc.InsertRouteHopAmpParams) error + InsertRouteHopBlinded(ctx context.Context, arg sqlc.InsertRouteHopBlindedParams) error + + InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentAttemptFirstHopCustomRecordParams) error + InsertPaymentHopCustomRecord(ctx context.Context, arg sqlc.InsertPaymentHopCustomRecordParams) error + + SettleAttempt(ctx context.Context, arg sqlc.SettleAttemptParams) error + FailAttempt(ctx context.Context, arg sqlc.FailAttemptParams) error + + FailPayment(ctx context.Context, arg sqlc.FailPaymentParams) (sql.Result, error) + + DeletePayment(ctx context.Context, paymentID int64) error + + // DeleteFailedAttempts removes all failed HTLCs from the db for a + // given payment. + DeleteFailedAttempts(ctx context.Context, paymentID int64) error } // BatchedSQLQueries is a version of the SQLQueries that's capable @@ -20,10 +97,6 @@ type BatchedSQLQueries interface { // SQLStore represents a storage backend. type SQLStore struct { - // TODO(ziggie): Remove the KVStore once all the interface functions are - // implemented. - KVStore - cfg *SQLStoreConfig db BatchedSQLQueries @@ -65,3 +138,1118 @@ func NewSQLStore(cfg *SQLStoreConfig, db BatchedSQLQueries, // A compile-time constraint to ensure SQLStore implements DB. var _ DB = (*SQLStore)(nil) + +// fetchPaymentWithCompleteData fetches a payment with all its related data +// including attempts, hops, and custom records from the database. +func (s *SQLStore) fetchPaymentWithCompleteData(ctx context.Context, + db SQLQueries, dbPayment sqlc.PaymentAndIntent) (*MPPayment, error) { + + // The query will only return BOLT 11 payment intents or intents with + // no intent type set. + paymentIntent := dbPayment.GetPaymentIntent() + paymentRequest := paymentIntent.IntentPayload + + // Fetch payment-level first hop custom records. + payment := dbPayment.GetPayment() + customRecords, err := db.FetchPaymentLevelFirstHopCustomRecords( + ctx, payment.ID, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch payment level custom "+ + "records: %w", err) + } + + // Convert to the FirstHopCustomRecords map. + var firstHopCustomRecords lnwire.CustomRecords + if len(customRecords) > 0 { + firstHopCustomRecords = make(lnwire.CustomRecords) + for _, record := range customRecords { + firstHopCustomRecords[uint64(record.Key)] = record.Value + } + } + + // Convert the basic payment info. + info := dbPaymentToCreationInfo( + payment.PaymentIdentifier, payment.AmountMsat, + payment.CreatedAt, paymentRequest, firstHopCustomRecords, + ) + + // Fetch all HTLC attempts for this payment. + attempts, err := s.fetchHTLCAttemptsForPayment( + ctx, db, payment.ID, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch HTLC attempts: %w", + err) + } + + // Set the failure reason if present. + var failureReason *FailureReason + if payment.FailReason.Valid { + reason := FailureReason(payment.FailReason.Int32) + failureReason = &reason + } + + mpPayment := &MPPayment{ + SequenceNum: uint64(payment.ID), + Info: info, + HTLCs: attempts, + FailureReason: failureReason, + } + + // The status and state will be determined by calling + // SetState after construction. + if err := mpPayment.SetState(); err != nil { + return nil, fmt.Errorf("failed to set payment state: %w", err) + } + + return mpPayment, nil +} + +// fetchHTLCAttemptsForPayment fetches all HTLC attempts for a payment and +// uses ExecuteBatchQuery to efficiently fetch hops and custom records. +func (s *SQLStore) fetchHTLCAttemptsForPayment(ctx context.Context, + db SQLQueries, paymentID int64) ([]HTLCAttempt, error) { + + // Fetch all HTLC attempts for this payment. + dbAttempts, err := db.FetchHtlcAttemptsForPayment( + ctx, paymentID, + ) + + if err != nil { + return nil, fmt.Errorf("failed to fetch HTLC attempts: %w", + err) + } + + if len(dbAttempts) == 0 { + return nil, nil + } + + // Collect all attempt indices for batch fetching. + attemptIndices := make([]int64, len(dbAttempts)) + for i, attempt := range dbAttempts { + attemptIndices[i] = attempt.AttemptIndex + } + + // Fetch all hops for all attempts using ExecuteBatchQuery. + hopsByAttempt := make(map[int64][]sqlc.FetchHopsForAttemptsRow) + err = sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, attemptIndices, + func(idx int64) int64 { return idx }, + func(ctx context.Context, indices []int64) ( + []sqlc.FetchHopsForAttemptsRow, error) { + + return db.FetchHopsForAttempts(ctx, indices) + }, + func(ctx context.Context, + hop sqlc.FetchHopsForAttemptsRow) error { + + hopsByAttempt[hop.HtlcAttemptIndex] = append( + hopsByAttempt[hop.HtlcAttemptIndex], hop, + ) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch hops for attempts: %w", + err) + } + + // Collect all hop IDs for fetching hop-level custom records. + var hopIDs []int64 + for _, hops := range hopsByAttempt { + for _, hop := range hops { + hopIDs = append(hopIDs, hop.ID) + } + } + + // Fetch all hop-level custom records using ExecuteBatchQuery. + hopCustomRecords := make(map[int64][]sqlc.PaymentHopCustomRecord) + if len(hopIDs) > 0 { + err = sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, hopIDs, + func(id int64) int64 { return id }, + func(ctx context.Context, ids []int64) ( + []sqlc.PaymentHopCustomRecord, error) { + + return db.FetchHopLevelCustomRecords(ctx, ids) + }, + func(ctx context.Context, + record sqlc.PaymentHopCustomRecord) error { + + // TODO(ziggie): Can we get rid of this? + // This has to be in place otherwise the + // comparison will not match. + if record.Value == nil { + record.Value = []byte{} + } + + hopCustomRecords[record.HopID] = append( + hopCustomRecords[record.HopID], record, + ) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch hop custom "+ + "records: %w", err) + } + } + + // Fetch route-level first hop custom records using ExecuteBatchQuery. + routeCustomRecords := make( + map[int64][]sqlc.PaymentAttemptFirstHopCustomRecord, + ) + err = sqldb.ExecuteBatchQuery( + ctx, s.cfg.QueryCfg, attemptIndices, + func(idx int64) int64 { return idx }, + func(ctx context.Context, indices []int64) ( + []sqlc.PaymentAttemptFirstHopCustomRecord, error) { + + return db.FetchRouteLevelFirstHopCustomRecords( + ctx, indices, + ) + }, + func(ctx context.Context, + record sqlc.PaymentAttemptFirstHopCustomRecord) error { + + routeCustomRecords[record.HtlcAttemptIndex] = append( + routeCustomRecords[record.HtlcAttemptIndex], + record, + ) + + return nil + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to fetch route custom "+ + "records: %w", err) + } + + // Now convert all attempts to HTLCAttempt structs. + attempts := make([]HTLCAttempt, 0, len(dbAttempts)) + for _, dbAttempt := range dbAttempts { + attemptIndex := dbAttempt.AttemptIndex + attempt, err := dbAttemptToHTLCAttempt( + dbAttempt, hopsByAttempt[attemptIndex], + hopCustomRecords, + routeCustomRecords[attemptIndex], + ) + if err != nil { + return nil, fmt.Errorf("failed to convert attempt "+ + "%d: %w", attemptIndex, err) + } + attempts = append(attempts, *attempt) + } + + return attempts, nil +} + +// QueryPayments queries the payments from the database. +// +// This is part of the DB interface. +func (s *SQLStore) QueryPayments(ctx context.Context, + query Query) (Response, error) { + + if query.MaxPayments == 0 { + return Response{}, fmt.Errorf("max payments must be non-zero") + } + + var ( + allPayments []*MPPayment + totalCount int64 + initialCursor int64 + ) + + extractCursor := func( + row sqlc.FilterPaymentsRow) int64 { + + return row.Payment.ID + } + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + // We first count all payments to determine the total count + // if requested. + if query.CountTotal { + totalPayments, err := db.CountPayments(ctx) + if err != nil { + return fmt.Errorf("failed to count "+ + "payments: %w", err) + } + totalCount = totalPayments + } + + processPayment := func(ctx context.Context, + dbPayment sqlc.FilterPaymentsRow) error { + + // Fetch all the additional data for the payment. + mpPayment, err := s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + + // To keep compatibility with the old API, we only + // return non-succeeded payments if requested. + if mpPayment.Status != StatusSucceeded && + !query.IncludeIncomplete { + + return nil + } + + if uint64(len(allPayments)) >= query.MaxPayments { + return errMaxPaymentsReached + } + + allPayments = append(allPayments, mpPayment) + + return nil + } + + queryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.FilterPaymentsRow, error) { + + filterParams := sqlc.FilterPaymentsParams{ + NumLimit: limit, + Reverse: query.Reversed, + // For now there only BOLT 11 payment intents + // exist. + IntentType: sqldb.SQLInt16( + PaymentIntentTypeBolt11, + ), + } + + if query.Reversed { + filterParams.IndexOffsetLet = sqldb.SQLInt64( + lastID, + ) + } else { + filterParams.IndexOffsetGet = sqldb.SQLInt64( + lastID, + ) + } + + // Add potential date filters if specified. + if query.CreationDateStart != 0 { + filterParams.CreatedAfter = sqldb.SQLTime( + time.Unix(query.CreationDateStart, 0). + UTC(), + ) + } + if query.CreationDateEnd != 0 { + filterParams.CreatedBefore = sqldb.SQLTime( + time.Unix(query.CreationDateEnd, 0). + UTC(), + ) + } + + return db.FilterPayments(ctx, filterParams) + } + + if query.Reversed { + if query.IndexOffset == 0 { + initialCursor = int64(math.MaxInt64) + } else { + initialCursor = int64(query.IndexOffset) + } + } else { + initialCursor = int64(query.IndexOffset) + } + + return sqldb.ExecutePaginatedQuery( + ctx, s.cfg.QueryCfg, initialCursor, queryFunc, + extractCursor, processPayment, + ) + }, func() { + allPayments = nil + }) + + // We make sure we don't return an error if we reached the maximum + // number of payments. Which is the pagination limit for the query + // itself. + if err != nil && !errors.Is(err, errMaxPaymentsReached) { + return Response{}, fmt.Errorf("failed to query payments: %w", + err) + } + + // Handle case where no payments were found + if len(allPayments) == 0 { + return Response{ + Payments: allPayments, + FirstIndexOffset: 0, + LastIndexOffset: 0, + TotalCount: uint64(totalCount), + }, nil + } + + // If the query was reversed, we need to reverse the payment list + // to match the kvstore behavior and return payments in forward order. + if query.Reversed { + for i, j := 0, len(allPayments)-1; i < j; i, j = i+1, j-1 { + allPayments[i], allPayments[j] = allPayments[j], + allPayments[i] + } + } + + return Response{ + Payments: allPayments, + FirstIndexOffset: allPayments[0].SequenceNum, + LastIndexOffset: allPayments[len(allPayments)-1].SequenceNum, + TotalCount: uint64(totalCount), + }, nil +} + +// FetchPayment fetches the payment corresponding to the given payment +// hash. +// +// This is part of the DB interface. +func (s *SQLStore) FetchPayment(paymentHash lntypes.Hash) (*MPPayment, error) { + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + dbPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to fetch payment: %w", err) + } + + if errors.Is(err, sql.ErrNoRows) { + return ErrPaymentNotInitiated + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch payment: %w", err) + } + + return mpPayment, nil +} + +// FetchInFlightPayments fetches all payments with status InFlight. +// +// TODO(ziggie): Add pagination (LIMIT)) to this function? +// +// This is part of the DB interface. +func (s *SQLStore) FetchInFlightPayments() ([]*MPPayment, error) { + ctx := context.TODO() + + var mpPayments []*MPPayment + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + inflightDBAttempts, err := db.FetchAllInflightAttempts(ctx) + if err != nil { + return fmt.Errorf("failed to fetch inflight "+ + "attempts: %w", err) + } + + paymentIDs := make([]int64, len(inflightDBAttempts)) + for i, attempt := range inflightDBAttempts { + paymentIDs[i] = attempt.PaymentID + } + + dbPayments, err := db.FetchPaymentsByIDs(ctx, paymentIDs) + if err != nil { + return fmt.Errorf("failed to fetch payments by IDs: %w", + err) + } + + mpPayments = make([]*MPPayment, len(dbPayments)) + for i, dbPayment := range dbPayments { + mpPayment, err := s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + mpPayments[i] = mpPayment + } + + return nil + }, func() { + mpPayments = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fetch inflight "+ + "attempts: %w", err) + } + + return mpPayments, nil +} + +// DeletePayment deletes a payment from the DB given its payment hash. If +// failedHtlcsOnly is set, only failed HTLC attempts of the payment will be +// deleted. +func (s *SQLStore) DeletePayment(paymentHash lntypes.Hash, + failedHtlcsOnly bool) error { + + ctx := context.TODO() + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + fetchPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + completePayment, err := s.fetchPaymentWithCompleteData( + ctx, db, fetchPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + if err := completePayment.Status.removable(); err != nil { + return fmt.Errorf("payment %v cannot be deleted: %w", + paymentHash, err) + } + + // If we are only deleting failed HTLCs, we delete them. + if failedHtlcsOnly { + return db.DeleteFailedAttempts( + ctx, fetchPayment.Payment.ID, + ) + } + + // Be careful to not use s.db here, because we are in a + // transaction, is there a way to make this more secure? + return db.DeletePayment(ctx, fetchPayment.Payment.ID) + }, func() { + }) + if err != nil { + return fmt.Errorf("failed to delete payment "+ + "(failedHtlcsOnly: %v, paymentHash: %v): %w", + failedHtlcsOnly, paymentHash, err) + } + + return nil +} + +// DeleteFailedAttempts removes all failed HTLCs from the db. It should +// be called for a given payment whenever all inflight htlcs are +// completed, and the payment has reached a final terminal state. +func (s *SQLStore) DeleteFailedAttempts(paymentHash lntypes.Hash) error { + // In case we are configured to keep failed payment attempts, we exit + // early. + if s.keepFailedPaymentAttempts { + return nil + } + + ctx := context.TODO() + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // We first fetch the payment to get the payment ID. + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + + completePayment, err := s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + if err := completePayment.Status.removable(); err != nil { + return fmt.Errorf("payment %v cannot be deleted: %w", + paymentHash, err) + } + + // Then we delete the failed attempts for this payment. + return db.DeleteFailedAttempts(ctx, payment.Payment.ID) + }, func() { + }) + if err != nil { + return fmt.Errorf("failed to delete failed attempts for "+ + "payment %v: %w", paymentHash, err) + } + + return nil +} + +// InitPayment initializes a payment. +// +// This is part of the DB interface. +func (s *SQLStore) InitPayment(paymentHash lntypes.Hash, + paymentCreationInfo *PaymentCreationInfo) error { + + ctx := context.TODO() + + // Create the payment in the database. + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + existingPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err == nil { + completePayment, err := s.fetchPaymentWithCompleteData( + ctx, db, existingPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + + // Check if the payment is initializable otherwise + // we'll return early. + err = completePayment.Status.initializable() + if err != nil { + return err + } + } else if !errors.Is(err, sql.ErrNoRows) { + // Some other error occurred + return fmt.Errorf("failed to check existing "+ + "payment: %w", err) + } + + // If payment exists and is failed, delete it first. + if existingPayment.Payment.ID != 0 { + err := db.DeletePayment(ctx, existingPayment.Payment.ID) + if err != nil { + return fmt.Errorf("failed to delete "+ + "payment: %w", err) + } + } + + var intentID *int64 + if len(paymentCreationInfo.PaymentRequest) > 0 { + intentIDValue, err := db.InsertPaymentIntent(ctx, + sqlc.InsertPaymentIntentParams{ + IntentType: int16( + PaymentIntentTypeBolt11, + ), + IntentPayload: paymentCreationInfo. + PaymentRequest, + }) + if err != nil { + return fmt.Errorf("failed to initialize "+ + "payment intent: %w", err) + } + intentID = &intentIDValue + } + + // Only set the intent ID if it's not nil. + var intentIDParam sql.NullInt64 + if intentID != nil { + intentIDParam = sqldb.SQLInt64(*intentID) + } + + paymentID, err := db.InsertPayment(ctx, + sqlc.InsertPaymentParams{ + IntentID: intentIDParam, + AmountMsat: int64( + paymentCreationInfo.Value, + ), + CreatedAt: paymentCreationInfo. + CreationTime.UTC(), + PaymentIdentifier: paymentHash[:], + }) + if err != nil { + return fmt.Errorf("failed to insert payment: %w", err) + } + + firstHopCustomRecords := paymentCreationInfo. + FirstHopCustomRecords + + for key, value := range firstHopCustomRecords { + err = db.InsertPaymentFirstHopCustomRecord(ctx, + sqlc.InsertPaymentFirstHopCustomRecordParams{ + PaymentID: paymentID, + Key: int64(key), + Value: value, + }) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment first hop custom "+ + "record: %w", err) + } + } + + return nil + }, func() { + }) + if err != nil { + return fmt.Errorf("failed to initialize payment: %w", err) + } + + return nil +} + +// insertRouteHops inserts all route hop data for a given set of hops. +func (s *SQLStore) insertRouteHops(ctx context.Context, db SQLQueries, + hops []*route.Hop, attemptID uint64) error { + + for i, hop := range hops { + // Insert the basic route hop data and get the generated ID + hopID, err := db.InsertRouteHop(ctx, sqlc.InsertRouteHopParams{ + HtlcAttemptIndex: int64(attemptID), + HopIndex: int32(i), + PubKey: hop.PubKeyBytes[:], + Scid: strconv.FormatUint( + hop.ChannelID, 10, + ), + OutgoingTimeLock: int32(hop.OutgoingTimeLock), + AmtToForward: int64(hop.AmtToForward), + MetaData: hop.Metadata, + }) + if err != nil { + return fmt.Errorf("failed to insert route hop: %w", err) + } + + // Insert the per-hop custom records + if len(hop.CustomRecords) > 0 { + for key, value := range hop.CustomRecords { + err = db.InsertPaymentHopCustomRecord(ctx, + sqlc.InsertPaymentHopCustomRecordParams{ + HopID: hopID, + Key: int64(key), + Value: value, + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment hop custom "+ + "records: %w", err) + } + } + } + + // Insert MPP data if present + if hop.MPP != nil { + paymentAddr := hop.MPP.PaymentAddr() + err = db.InsertRouteHopMpp(ctx, + sqlc.InsertRouteHopMppParams{ + HopID: hopID, + PaymentAddr: paymentAddr[:], + TotalMsat: int64(hop.MPP.TotalMsat()), + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop MPP: %w", err) + } + } + + // Insert AMP data if present + if hop.AMP != nil { + rootShare := hop.AMP.RootShare() + setID := hop.AMP.SetID() + err = db.InsertRouteHopAmp(ctx, + sqlc.InsertRouteHopAmpParams{ + HopID: hopID, + RootShare: rootShare[:], + SetID: setID[:], + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop AMP: %w", err) + } + } + + // Insert blinded route data if present + if hop.EncryptedData != nil || hop.BlindingPoint != nil { + var blindingPointBytes []byte + if hop.BlindingPoint != nil { + blindingPointBytes = hop.BlindingPoint. + SerializeCompressed() + } + + err = db.InsertRouteHopBlinded(ctx, + sqlc.InsertRouteHopBlindedParams{ + HopID: hopID, + EncryptedData: hop.EncryptedData, + BlindingPoint: blindingPointBytes, + BlindedPathTotalAmt: sqldb.SQLInt64( + hop.TotalAmtMsat, + ), + }, + ) + if err != nil { + return fmt.Errorf("failed to insert "+ + "route hop blinded: %w", err) + } + } + } + + return nil +} + +// RegisterAttempt registers an attempt for a payment. +// +// This is part of the DB interface. +func (s *SQLStore) RegisterAttempt(paymentHash lntypes.Hash, + attempt *HTLCAttemptInfo) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // 1. First Fetch the payment and check if it is registrable. + existingPayment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, existingPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + if err := mpPayment.Registrable(); err != nil { + return fmt.Errorf("htlc attempt not registrable: %w", + err) + } + + // Verify the attempt is compatible with the existing payment. + if err := verifyAttempt(mpPayment, attempt); err != nil { + return fmt.Errorf("failed to verify attempt: %w", err) + } + + // Fist register the plain HTLC attempt. + // Prepare the session key. + sessionKey := attempt.SessionKey() + sessionKeyBytes := sessionKey.Serialize() + + _, err = db.InsertHtlcAttempt(ctx, sqlc.InsertHtlcAttemptParams{ + PaymentID: existingPayment.Payment.ID, + AttemptIndex: int64(attempt.AttemptID), + SessionKey: sessionKeyBytes, + AttemptTime: attempt.AttemptTime, + PaymentHash: paymentHash[:], + FirstHopAmountMsat: int64( + attempt.Route.FirstHopAmount.Val.Int(), + ), + RouteTotalTimeLock: int32(attempt.Route.TotalTimeLock), + RouteTotalAmount: int64(attempt.Route.TotalAmount), + RouteSourceKey: attempt.Route.SourcePubKey[:], + }) + if err != nil { + return fmt.Errorf("failed to insert HTLC "+ + "attempt: %w", err) + } + + // Insert the route level first hop custom records. + attemptFirstHopCustomRecords := attempt.Route. + FirstHopWireCustomRecords + + for key, value := range attemptFirstHopCustomRecords { + err = db.InsertPaymentAttemptFirstHopCustomRecord(ctx, + //nolint:ll + sqlc.InsertPaymentAttemptFirstHopCustomRecordParams{ + HtlcAttemptIndex: int64(attempt.AttemptID), + Key: int64(key), + Value: value, + }) + if err != nil { + return fmt.Errorf("failed to insert "+ + "payment attempt first hop custom "+ + "record: %w", err) + } + } + + // Insert the route hops. + err = s.insertRouteHops( + ctx, db, attempt.Route.Hops, attempt.AttemptID, + ) + if err != nil { + return fmt.Errorf("failed to insert route hops: %w", + err) + } + + // Add the attempt to the payment without fetching it from the + // DB again. + mpPayment.HTLCs = append(mpPayment.HTLCs, HTLCAttempt{ + HTLCAttemptInfo: *attempt, + }) + + if err := mpPayment.SetState(); err != nil { + return fmt.Errorf("failed to set payment state: %w", + err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to register attempt: %w", err) + } + + return mpPayment, nil +} + +// SettleAttempt marks the given attempt settled with the preimage. +func (s *SQLStore) SettleAttempt(paymentHash lntypes.Hash, + attemptID uint64, settleInfo *HTLCSettleInfo) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // Before updating the attempt, we fetch the payment to get the + // payment ID. + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to fetch payment: %w", err) + } + if errors.Is(err, sql.ErrNoRows) { + return ErrPaymentNotInitiated + } + + err = db.SettleAttempt(ctx, sqlc.SettleAttemptParams{ + AttemptIndex: int64(attemptID), + ResolutionTime: time.Now(), + ResolutionType: int32(HTLCAttemptResolutionSettled), + SettlePreimage: settleInfo.Preimage[:], + }) + if err != nil { + return fmt.Errorf("failed to settle attempt: %w", err) + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to settle attempt: %w", err) + } + + return mpPayment, nil +} + +// FailAttempt marks the given attempt failed. +func (s *SQLStore) FailAttempt(paymentHash lntypes.Hash, + attemptID uint64, failInfo *HTLCFailInfo) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + // Before updating the attempt, we fetch the payment to get the + // payment ID. + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to fetch payment: %w", err) + } + if errors.Is(err, sql.ErrNoRows) { + return ErrPaymentNotInitiated + } + + var failureMsg bytes.Buffer + if failInfo.Message != nil { + err := lnwire.EncodeFailureMessage( + &failureMsg, failInfo.Message, 0, + ) + if err != nil { + return fmt.Errorf("failed to encode "+ + "failure message: %w", err) + } + } + + err = db.FailAttempt(ctx, sqlc.FailAttemptParams{ + AttemptIndex: int64(attemptID), + ResolutionTime: time.Now(), + ResolutionType: int32(HTLCAttemptResolutionFailed), + FailureSourceIndex: sqldb.SQLInt32( + failInfo.FailureSourceIndex, + ), + HtlcFailReason: sqldb.SQLInt32(failInfo.Reason), + FailureMsg: failureMsg.Bytes(), + }) + if err != nil { + return fmt.Errorf("failed to fail attempt: %w", err) + } + + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with"+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fail attempt: %w", err) + } + + return mpPayment, nil +} + +// Fail transitions a payment into the Failed state, and records the ultimate +// reason the payment failed. Note that this should only be called when all +// active attempts are already failed. After invoking this method, InitPayment +// should return nil on its next call for this payment hash, allowing the user +// to make a subsequent payments for the same payment hash. +func (s *SQLStore) Fail(paymentHash lntypes.Hash, + reason FailureReason) (*MPPayment, error) { + + ctx := context.TODO() + + var mpPayment *MPPayment + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + result, err := db.FailPayment(ctx, sqlc.FailPaymentParams{ + PaymentIdentifier: paymentHash[:], + FailReason: sqldb.SQLInt32(reason), + }) + if err != nil { + return fmt.Errorf("failed to fail payment: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", + err) + } + if rowsAffected == 0 { + return ErrPaymentNotInitiated + } + + payment, err := db.FetchPayment(ctx, paymentHash[:]) + if err != nil { + return fmt.Errorf("failed to fetch payment: %w", err) + } + mpPayment, err = s.fetchPaymentWithCompleteData( + ctx, db, payment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment with "+ + "complete data: %w", err) + } + + return nil + }, func() { + mpPayment = nil + }) + if err != nil { + return nil, fmt.Errorf("failed to fail payment: %w", err) + } + + return mpPayment, nil +} + +// DeletePayments deletes all payments from the DB given the specified flags. +// +// TODO(ziggie): batch and use iterator instead. +func (s *SQLStore) DeletePayments(failedOnly, failedHtlcsOnly bool) (int, + error) { + + var numPayments int + ctx := context.TODO() + + extractCursor := func( + row sqlc.FilterPaymentsRow) int64 { + + return row.Payment.ID + } + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + processPayment := func(ctx context.Context, + dbPayment sqlc.FilterPaymentsRow) error { + + // Fetch all the additional data for the payment. + mpPayment, err := s.fetchPaymentWithCompleteData( + ctx, db, dbPayment, + ) + if err != nil { + return fmt.Errorf("failed to fetch payment "+ + "with complete data: %w", err) + } + + // Payments which are not final yet cannot be deleted. + // we skip them. + if err := mpPayment.Status.removable(); err != nil { + return nil + } + + // If we are only deleting failed payments, we skip + // if the payment is not failed. + if failedOnly && mpPayment.Status != StatusFailed { + return nil + } + + // If we are only deleting failed HTLCs, we delete them + // and return early. + if failedHtlcsOnly { + return db.DeleteFailedAttempts( + ctx, dbPayment.Payment.ID, + ) + } + + // Otherwise we delete the payment. + err = db.DeletePayment(ctx, dbPayment.Payment.ID) + if err != nil { + return fmt.Errorf("failed to delete "+ + "payment: %w", err) + } + + numPayments++ + + return nil + } + + queryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.FilterPaymentsRow, error) { + + filterParams := sqlc.FilterPaymentsParams{ + NumLimit: limit, + // For now there are only BOLT 11 payment + // intents. + IntentType: sqldb.SQLInt16( + PaymentIntentTypeBolt11, + ), + IndexOffsetGet: sqldb.SQLInt64( + lastID, + ), + } + + return db.FilterPayments(ctx, filterParams) + } + + return sqldb.ExecutePaginatedQuery( + ctx, s.cfg.QueryCfg, int64(-1), queryFunc, + extractCursor, processPayment, + ) + }, func() { + numPayments = 0 + }) + if err != nil { + return 0, fmt.Errorf("failed to delete payments "+ + "(failedOnly: %v, failedHtlcsOnly: %v): %w", + failedOnly, failedHtlcsOnly, err) + } + + return numPayments, nil +} diff --git a/payments/db/test_kvdb.go b/payments/db/test_kvdb.go index e0ee1738d7..a4bbfccbd9 100644 --- a/payments/db/test_kvdb.go +++ b/payments/db/test_kvdb.go @@ -1,3 +1,5 @@ +//go:build !test_db_sqlite && !test_db_postgres + package paymentsdb import ( diff --git a/payments/db/test_postgres.go b/payments/db/test_postgres.go new file mode 100644 index 0000000000..1a20c7c0df --- /dev/null +++ b/payments/db/test_postgres.go @@ -0,0 +1,95 @@ +//go:build test_db_postgres && !test_db_sqlite + +package paymentsdb + +import ( + "database/sql" + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB, opts ...OptionModifier) DB { + return NewTestDBWithFixture(t, nil, opts...) +} + +// NewTestDBFixture creates a new sqldb.TestPgFixture for testing purposes. +func NewTestDBFixture(t *testing.T) *sqldb.TestPgFixture { + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + return pgFixture +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture, opts ...OptionModifier) DB { + + var querier BatchedSQLQueries + if pgFixture == nil { + querier = newBatchQuerier(t) + } else { + querier = newBatchQuerierWithFixture(t, pgFixture) + } + + store, err := NewSQLStore( + &SQLStoreConfig{ + QueryCfg: sqldb.DefaultPostgresConfig(), + }, querier, opts..., + ) + require.NoError(t, err) + + return store +} + +// newBatchQuerier creates a new BatchedSQLQueries instance for testing +// using a PostgreSQL database fixture. +func newBatchQuerier(t testing.TB) BatchedSQLQueries { + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + return newBatchQuerierWithFixture(t, pgFixture) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a PostgreSQL database fixture. +func newBatchQuerierWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture) BatchedSQLQueries { + + db := sqldb.NewTestPostgresDB(t, pgFixture).BaseDB + + return sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLQueries { + return db.WithTx(tx) + }, + ) +} + +// NewKVTestDB is a helper function that creates an BBolt database for testing +// and there is no need to convert the interface to the KVStore because for +// some unit tests we still need access to the kvdb interface. +func NewKVTestDB(t *testing.T, opts ...OptionModifier) *KVStore { + backend, backendCleanup, err := kvdb.GetTestBackend( + t.TempDir(), "kvPaymentDB", + ) + require.NoError(t, err) + + t.Cleanup(backendCleanup) + + paymentDB, err := NewKVStore(backend, opts...) + require.NoError(t, err) + + return paymentDB +} diff --git a/payments/db/test_sqlite.go b/payments/db/test_sqlite.go new file mode 100644 index 0000000000..d751bb9977 --- /dev/null +++ b/payments/db/test_sqlite.go @@ -0,0 +1,74 @@ +//go:build !test_db_postgres && test_db_sqlite + +package paymentsdb + +import ( + "database/sql" + "testing" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" +) + +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB, opts ...OptionModifier) DB { + return NewTestDBWithFixture(t, nil, opts...) +} + +// NewTestDBFixture is a no-op for the sqlite build. +func NewTestDBFixture(_ *testing.T) *sqldb.TestPgFixture { + return nil +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, _ *sqldb.TestPgFixture, + opts ...OptionModifier) DB { + + store, err := NewSQLStore( + &SQLStoreConfig{ + QueryCfg: sqldb.DefaultSQLiteConfig(), + }, newBatchQuerier(t), opts..., + ) + require.NoError(t, err) + return store +} + +// newBatchQuerier creates a new BatchedSQLQueries instance for testing +// using a SQLite database. +func newBatchQuerier(t testing.TB) BatchedSQLQueries { + return newBatchQuerierWithFixture(t, nil) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a SQLite database. +func newBatchQuerierWithFixture(t testing.TB, + _ *sqldb.TestPgFixture) BatchedSQLQueries { + + db := sqldb.NewTestSqliteDB(t).BaseDB + + return sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLQueries { + return db.WithTx(tx) + }, + ) +} + +// NewKVTestDB is a helper function that creates an BBolt database for testing +// and there is no need to convert the interface to the KVStore because for +// some unit tests we still need access to the kvdb interface. +func NewKVTestDB(t *testing.T, opts ...OptionModifier) *KVStore { + backend, backendCleanup, err := kvdb.GetTestBackend( + t.TempDir(), "kvPaymentDB", + ) + require.NoError(t, err) + + t.Cleanup(backendCleanup) + + paymentDB, err := NewKVStore(backend, opts...) + require.NoError(t, err) + + return paymentDB +} diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go index f7bc499185..db0a03a421 100644 --- a/sqldb/sqlc/db_custom.go +++ b/sqldb/sqlc/db_custom.go @@ -161,3 +161,79 @@ func (r GetChannelsBySCIDRangeRow) Node1Pub() []byte { func (r GetChannelsBySCIDRangeRow) Node2Pub() []byte { return r.Node2PubKey } + +// PaymentAndIntent is an interface that provides access to a payment and its +// associated payment intent. +type PaymentAndIntent interface { + // GetPayment returns the Payment associated with this interface. + GetPayment() Payment + + // GetPaymentIntent returns the PaymentIntent associated with this payment. + GetPaymentIntent() PaymentIntent +} + +// GetPayment returns the Payment associated with this interface. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FilterPaymentsRow) GetPayment() Payment { + return r.Payment +} + +// GetPaymentIntent returns the PaymentIntent associated with this payment. +// If the payment has no intent (IntentType is NULL), this returns a zero-value +// PaymentIntent. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FilterPaymentsRow) GetPaymentIntent() PaymentIntent { + if !r.IntentType.Valid { + return PaymentIntent{} + } + return PaymentIntent{ + IntentType: r.IntentType.Int16, + IntentPayload: r.IntentPayload, + } +} + +// GetPayment returns the Payment associated with this interface. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentRow) GetPayment() Payment { + return r.Payment +} + +// GetPaymentIntent returns the PaymentIntent associated with this payment. +// If the payment has no intent (IntentType is NULL), this returns a zero-value +// PaymentIntent. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentRow) GetPaymentIntent() PaymentIntent { + if !r.IntentType.Valid { + return PaymentIntent{} + } + return PaymentIntent{ + IntentType: r.IntentType.Int16, + IntentPayload: r.IntentPayload, + } +} + +// GetPayment returns the Payment associated with this interface. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentsByIDsRow) GetPayment() Payment { + return r.Payment +} + +// GetPaymentIntent returns the PaymentIntent associated with this payment. +// If the payment has no intent (IntentType is NULL), this returns a zero-value +// PaymentIntent. +// +// NOTE: This method is part of the PaymentAndIntent interface. +func (r FetchPaymentsByIDsRow) GetPaymentIntent() PaymentIntent { + if !r.IntentType.Valid { + return PaymentIntent{} + } + return PaymentIntent{ + IntentType: r.IntentType.Int16, + IntentPayload: r.IntentPayload, + } +} diff --git a/sqldb/sqlc/migrations/000009_payments.up.sql b/sqldb/sqlc/migrations/000009_payments.up.sql index c856db8f44..0d85b497b0 100644 --- a/sqldb/sqlc/migrations/000009_payments.up.sql +++ b/sqldb/sqlc/migrations/000009_payments.up.sql @@ -32,9 +32,13 @@ CREATE TABLE IF NOT EXISTS payment_intents ( ); -- Index for efficient querying by intent type -CREATE INDEX IF NOT EXISTS idx_payment_intents_type +CREATE INDEX IF NOT EXISTS idx_payment_intents_type ON payment_intents(intent_type); +-- Unique constraint for deduplication of payment intents +CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_intents_unique +ON payment_intents(intent_type, intent_payload); + -- ───────────────────────────────────────────── -- Payments Table -- ───────────────────────────────────────────── @@ -187,7 +191,8 @@ CREATE TABLE IF NOT EXISTS payment_htlc_attempt_resolutions ( -- HTLC failure reason code htlc_fail_reason INTEGER, - -- Failure message from the failing node + -- Failure message from the failing node, this message is binary encoded + -- using the lightning wire protocol, see also lnwire/onion_error.go failure_msg BLOB, -- Ensure data integrity: settled attempts must have preimage, diff --git a/sqldb/sqlc/payments.sql.go b/sqldb/sqlc/payments.sql.go new file mode 100644 index 0000000000..babab68247 --- /dev/null +++ b/sqldb/sqlc/payments.sql.go @@ -0,0 +1,1084 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: payments.sql + +package sqlc + +import ( + "context" + "database/sql" + "strings" + "time" +) + +const countPayments = `-- name: CountPayments :one +SELECT COUNT(*) FROM payments +` + +func (q *Queries) CountPayments(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, countPayments) + var count int64 + err := row.Scan(&count) + return count, err +} + +const deleteFailedAttempts = `-- name: DeleteFailedAttempts :exec +DELETE FROM payment_htlc_attempts WHERE payment_id = $1 AND attempt_index IN ( + SELECT attempt_index FROM payment_htlc_attempt_resolutions WHERE resolution_type = 2 +) +` + +func (q *Queries) DeleteFailedAttempts(ctx context.Context, paymentID int64) error { + _, err := q.db.ExecContext(ctx, deleteFailedAttempts, paymentID) + return err +} + +const deletePayment = `-- name: DeletePayment :exec +DELETE FROM payments WHERE id = $1 +` + +func (q *Queries) DeletePayment(ctx context.Context, id int64) error { + _, err := q.db.ExecContext(ctx, deletePayment, id) + return err +} + +const failAttempt = `-- name: FailAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + failure_source_index, + htlc_fail_reason, + failure_msg +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6 +) +` + +type FailAttemptParams struct { + AttemptIndex int64 + ResolutionTime time.Time + ResolutionType int32 + FailureSourceIndex sql.NullInt32 + HtlcFailReason sql.NullInt32 + FailureMsg []byte +} + +func (q *Queries) FailAttempt(ctx context.Context, arg FailAttemptParams) error { + _, err := q.db.ExecContext(ctx, failAttempt, + arg.AttemptIndex, + arg.ResolutionTime, + arg.ResolutionType, + arg.FailureSourceIndex, + arg.HtlcFailReason, + arg.FailureMsg, + ) + return err +} + +const failPayment = `-- name: FailPayment :execresult +UPDATE payments SET fail_reason = $1 WHERE payment_identifier = $2 +` + +type FailPaymentParams struct { + FailReason sql.NullInt32 + PaymentIdentifier []byte +} + +func (q *Queries) FailPayment(ctx context.Context, arg FailPaymentParams) (sql.Result, error) { + return q.db.ExecContext(ctx, failPayment, arg.FailReason, arg.PaymentIdentifier) +} + +const fetchAllInflightAttempts = `-- name: FetchAllInflightAttempts :many +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key +FROM payment_htlc_attempts ha +WHERE NOT EXISTS ( + SELECT 1 FROM payment_htlc_attempt_resolutions hr + WHERE hr.attempt_index = ha.attempt_index +) +` + +// Fetch all inflight attempts across all payments +func (q *Queries) FetchAllInflightAttempts(ctx context.Context) ([]PaymentHtlcAttempt, error) { + rows, err := q.db.QueryContext(ctx, fetchAllInflightAttempts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentHtlcAttempt + for rows.Next() { + var i PaymentHtlcAttempt + if err := rows.Scan( + &i.ID, + &i.AttemptIndex, + &i.PaymentID, + &i.SessionKey, + &i.AttemptTime, + &i.PaymentHash, + &i.FirstHopAmountMsat, + &i.RouteTotalTimeLock, + &i.RouteTotalAmount, + &i.RouteSourceKey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHopLevelCustomRecords = `-- name: FetchHopLevelCustomRecords :many +SELECT + l.id, + l.hop_id, + l.key, + l.value +FROM payment_hop_custom_records l +WHERE l.hop_id IN (/*SLICE:hop_ids*/?) +` + +func (q *Queries) FetchHopLevelCustomRecords(ctx context.Context, hopIds []int64) ([]PaymentHopCustomRecord, error) { + query := fetchHopLevelCustomRecords + var queryParams []interface{} + if len(hopIds) > 0 { + for _, v := range hopIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:hop_ids*/?", makeQueryParams(len(queryParams), len(hopIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:hop_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentHopCustomRecord + for rows.Next() { + var i PaymentHopCustomRecord + if err := rows.Scan( + &i.ID, + &i.HopID, + &i.Key, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHopsForAttempt = `-- name: FetchHopsForAttempt :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index = $1 +ORDER BY h.hop_index ASC +` + +type FetchHopsForAttemptRow struct { + ID int64 + HtlcAttemptIndex int64 + HopIndex int32 + PubKey []byte + Scid string + OutgoingTimeLock int32 + AmtToForward int64 + MetaData []byte + MppPaymentAddr []byte + MppTotalMsat sql.NullInt64 + AmpRootShare []byte + AmpSetID []byte + AmpChildIndex sql.NullInt32 + EncryptedData []byte + BlindingPoint []byte + BlindedPathTotalAmt sql.NullInt64 +} + +func (q *Queries) FetchHopsForAttempt(ctx context.Context, htlcAttemptIndex int64) ([]FetchHopsForAttemptRow, error) { + rows, err := q.db.QueryContext(ctx, fetchHopsForAttempt, htlcAttemptIndex) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchHopsForAttemptRow + for rows.Next() { + var i FetchHopsForAttemptRow + if err := rows.Scan( + &i.ID, + &i.HtlcAttemptIndex, + &i.HopIndex, + &i.PubKey, + &i.Scid, + &i.OutgoingTimeLock, + &i.AmtToForward, + &i.MetaData, + &i.MppPaymentAddr, + &i.MppTotalMsat, + &i.AmpRootShare, + &i.AmpSetID, + &i.AmpChildIndex, + &i.EncryptedData, + &i.BlindingPoint, + &i.BlindedPathTotalAmt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHopsForAttempts = `-- name: FetchHopsForAttempts :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index IN (/*SLICE:htlc_attempt_indices*/?) +` + +type FetchHopsForAttemptsRow struct { + ID int64 + HtlcAttemptIndex int64 + HopIndex int32 + PubKey []byte + Scid string + OutgoingTimeLock int32 + AmtToForward int64 + MetaData []byte + MppPaymentAddr []byte + MppTotalMsat sql.NullInt64 + AmpRootShare []byte + AmpSetID []byte + AmpChildIndex sql.NullInt32 + EncryptedData []byte + BlindingPoint []byte + BlindedPathTotalAmt sql.NullInt64 +} + +func (q *Queries) FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]FetchHopsForAttemptsRow, error) { + query := fetchHopsForAttempts + var queryParams []interface{} + if len(htlcAttemptIndices) > 0 { + for _, v := range htlcAttemptIndices { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", makeQueryParams(len(queryParams), len(htlcAttemptIndices)), 1) + } else { + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchHopsForAttemptsRow + for rows.Next() { + var i FetchHopsForAttemptsRow + if err := rows.Scan( + &i.ID, + &i.HtlcAttemptIndex, + &i.HopIndex, + &i.PubKey, + &i.Scid, + &i.OutgoingTimeLock, + &i.AmtToForward, + &i.MetaData, + &i.MppPaymentAddr, + &i.MppTotalMsat, + &i.AmpRootShare, + &i.AmpSetID, + &i.AmpChildIndex, + &i.EncryptedData, + &i.BlindingPoint, + &i.BlindedPathTotalAmt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchHtlcAttemptsForPayment = `-- name: FetchHtlcAttemptsForPayment :many +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key, + hr.resolution_type, + hr.resolution_time, + hr.failure_source_index, + hr.htlc_fail_reason, + hr.failure_msg, + hr.settle_preimage +FROM payment_htlc_attempts ha +LEFT JOIN payment_htlc_attempt_resolutions hr ON hr.attempt_index = ha.attempt_index +WHERE ha.payment_id = $1 +ORDER BY ha.attempt_time ASC +` + +type FetchHtlcAttemptsForPaymentRow struct { + ID int64 + AttemptIndex int64 + PaymentID int64 + SessionKey []byte + AttemptTime time.Time + PaymentHash []byte + FirstHopAmountMsat int64 + RouteTotalTimeLock int32 + RouteTotalAmount int64 + RouteSourceKey []byte + ResolutionType sql.NullInt32 + ResolutionTime sql.NullTime + FailureSourceIndex sql.NullInt32 + HtlcFailReason sql.NullInt32 + FailureMsg []byte + SettlePreimage []byte +} + +// This fetches all htlc attempts for a payment. +func (q *Queries) FetchHtlcAttemptsForPayment(ctx context.Context, paymentID int64) ([]FetchHtlcAttemptsForPaymentRow, error) { + rows, err := q.db.QueryContext(ctx, fetchHtlcAttemptsForPayment, paymentID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchHtlcAttemptsForPaymentRow + for rows.Next() { + var i FetchHtlcAttemptsForPaymentRow + if err := rows.Scan( + &i.ID, + &i.AttemptIndex, + &i.PaymentID, + &i.SessionKey, + &i.AttemptTime, + &i.PaymentHash, + &i.FirstHopAmountMsat, + &i.RouteTotalTimeLock, + &i.RouteTotalAmount, + &i.RouteSourceKey, + &i.ResolutionType, + &i.ResolutionTime, + &i.FailureSourceIndex, + &i.HtlcFailReason, + &i.FailureMsg, + &i.SettlePreimage, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchPayment = `-- name: FetchPayment :one +SELECT + p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.payment_identifier = $1 +` + +type FetchPaymentRow struct { + Payment Payment + IntentType sql.NullInt16 + IntentPayload []byte +} + +func (q *Queries) FetchPayment(ctx context.Context, paymentIdentifier []byte) (FetchPaymentRow, error) { + row := q.db.QueryRowContext(ctx, fetchPayment, paymentIdentifier) + var i FetchPaymentRow + err := row.Scan( + &i.Payment.ID, + &i.Payment.IntentID, + &i.Payment.AmountMsat, + &i.Payment.CreatedAt, + &i.Payment.PaymentIdentifier, + &i.Payment.FailReason, + &i.IntentType, + &i.IntentPayload, + ) + return i, err +} + +const fetchPaymentLevelFirstHopCustomRecords = `-- name: FetchPaymentLevelFirstHopCustomRecords :many +SELECT + l.id, + l.payment_id, + l.key, + l.value +FROM payment_first_hop_custom_records l +WHERE l.payment_id = $1 +` + +func (q *Queries) FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentID int64) ([]PaymentFirstHopCustomRecord, error) { + rows, err := q.db.QueryContext(ctx, fetchPaymentLevelFirstHopCustomRecords, paymentID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentFirstHopCustomRecord + for rows.Next() { + var i PaymentFirstHopCustomRecord + if err := rows.Scan( + &i.ID, + &i.PaymentID, + &i.Key, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchPaymentsByIDs = `-- name: FetchPaymentsByIDs :many +SELECT + p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.id IN (/*SLICE:payment_ids*/?) +` + +type FetchPaymentsByIDsRow struct { + Payment Payment + IntentType sql.NullInt16 + IntentPayload []byte +} + +func (q *Queries) FetchPaymentsByIDs(ctx context.Context, paymentIds []int64) ([]FetchPaymentsByIDsRow, error) { + query := fetchPaymentsByIDs + var queryParams []interface{} + if len(paymentIds) > 0 { + for _, v := range paymentIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:payment_ids*/?", makeQueryParams(len(queryParams), len(paymentIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:payment_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchPaymentsByIDsRow + for rows.Next() { + var i FetchPaymentsByIDsRow + if err := rows.Scan( + &i.Payment.ID, + &i.Payment.IntentID, + &i.Payment.AmountMsat, + &i.Payment.CreatedAt, + &i.Payment.PaymentIdentifier, + &i.Payment.FailReason, + &i.IntentType, + &i.IntentPayload, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const fetchRouteLevelFirstHopCustomRecords = `-- name: FetchRouteLevelFirstHopCustomRecords :many +SELECT + l.id, + l.htlc_attempt_index, + l.key, + l.value +FROM payment_attempt_first_hop_custom_records l +WHERE l.htlc_attempt_index IN (/*SLICE:htlc_attempt_indices*/?) +` + +func (q *Queries) FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]PaymentAttemptFirstHopCustomRecord, error) { + query := fetchRouteLevelFirstHopCustomRecords + var queryParams []interface{} + if len(htlcAttemptIndices) > 0 { + for _, v := range htlcAttemptIndices { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", makeQueryParams(len(queryParams), len(htlcAttemptIndices)), 1) + } else { + query = strings.Replace(query, "/*SLICE:htlc_attempt_indices*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PaymentAttemptFirstHopCustomRecord + for rows.Next() { + var i PaymentAttemptFirstHopCustomRecord + if err := rows.Scan( + &i.ID, + &i.HtlcAttemptIndex, + &i.Key, + &i.Value, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const filterPayments = `-- name: FilterPayments :many +/* ───────────────────────────────────────────── + fetch queries + ───────────────────────────────────────────── +*/ + +SELECT + p.id, p.intent_id, p.amount_msat, p.created_at, p.payment_identifier, p.fail_reason, + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE ( + p.id > $1 OR + $1 IS NULL +) AND ( + p.id < $2 OR + $2 IS NULL +) AND ( + p.created_at >= $3 OR + $3 IS NULL +) AND ( + p.created_at <= $4 OR + $4 IS NULL +) AND ( + i.intent_type = $5 OR + $5 IS NULL OR i.intent_type IS NULL +) +ORDER BY + CASE WHEN $6 = false OR $6 IS NULL THEN p.id END ASC, + CASE WHEN $6 = true THEN p.id END DESC +LIMIT $7 +` + +type FilterPaymentsParams struct { + IndexOffsetGet sql.NullInt64 + IndexOffsetLet sql.NullInt64 + CreatedAfter sql.NullTime + CreatedBefore sql.NullTime + IntentType sql.NullInt16 + Reverse interface{} + NumLimit int32 +} + +type FilterPaymentsRow struct { + Payment Payment + IntentType sql.NullInt16 + IntentPayload []byte +} + +func (q *Queries) FilterPayments(ctx context.Context, arg FilterPaymentsParams) ([]FilterPaymentsRow, error) { + rows, err := q.db.QueryContext(ctx, filterPayments, + arg.IndexOffsetGet, + arg.IndexOffsetLet, + arg.CreatedAfter, + arg.CreatedBefore, + arg.IntentType, + arg.Reverse, + arg.NumLimit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FilterPaymentsRow + for rows.Next() { + var i FilterPaymentsRow + if err := rows.Scan( + &i.Payment.ID, + &i.Payment.IntentID, + &i.Payment.AmountMsat, + &i.Payment.CreatedAt, + &i.Payment.PaymentIdentifier, + &i.Payment.FailReason, + &i.IntentType, + &i.IntentPayload, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertHtlcAttempt = `-- name: InsertHtlcAttempt :one +INSERT INTO payment_htlc_attempts ( + payment_id, + attempt_index, + session_key, + attempt_time, + payment_hash, + first_hop_amount_msat, + route_total_time_lock, + route_total_amount, + route_source_key) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9) +RETURNING id +` + +type InsertHtlcAttemptParams struct { + PaymentID int64 + AttemptIndex int64 + SessionKey []byte + AttemptTime time.Time + PaymentHash []byte + FirstHopAmountMsat int64 + RouteTotalTimeLock int32 + RouteTotalAmount int64 + RouteSourceKey []byte +} + +func (q *Queries) InsertHtlcAttempt(ctx context.Context, arg InsertHtlcAttemptParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertHtlcAttempt, + arg.PaymentID, + arg.AttemptIndex, + arg.SessionKey, + arg.AttemptTime, + arg.PaymentHash, + arg.FirstHopAmountMsat, + arg.RouteTotalTimeLock, + arg.RouteTotalAmount, + arg.RouteSourceKey, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertPayment = `-- name: InsertPayment :one +INSERT INTO payments ( + intent_id, + amount_msat, + created_at, + payment_identifier, + fail_reason) +VALUES ( + $1, + $2, + $3, + $4, + NULL +) +RETURNING id +` + +type InsertPaymentParams struct { + IntentID sql.NullInt64 + AmountMsat int64 + CreatedAt time.Time + PaymentIdentifier []byte +} + +// Insert a new payment with the given intent ID and return its ID. +func (q *Queries) InsertPayment(ctx context.Context, arg InsertPaymentParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertPayment, + arg.IntentID, + arg.AmountMsat, + arg.CreatedAt, + arg.PaymentIdentifier, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertPaymentAttemptFirstHopCustomRecord = `-- name: InsertPaymentAttemptFirstHopCustomRecord :exec +INSERT INTO payment_attempt_first_hop_custom_records ( + htlc_attempt_index, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentAttemptFirstHopCustomRecordParams struct { + HtlcAttemptIndex int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg InsertPaymentAttemptFirstHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentAttemptFirstHopCustomRecord, arg.HtlcAttemptIndex, arg.Key, arg.Value) + return err +} + +const insertPaymentFirstHopCustomRecord = `-- name: InsertPaymentFirstHopCustomRecord :exec +INSERT INTO payment_first_hop_custom_records ( + payment_id, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentFirstHopCustomRecordParams struct { + PaymentID int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentFirstHopCustomRecord(ctx context.Context, arg InsertPaymentFirstHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentFirstHopCustomRecord, arg.PaymentID, arg.Key, arg.Value) + return err +} + +const insertPaymentHopCustomRecord = `-- name: InsertPaymentHopCustomRecord :exec +INSERT INTO payment_hop_custom_records ( + hop_id, + key, + value +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertPaymentHopCustomRecordParams struct { + HopID int64 + Key int64 + Value []byte +} + +func (q *Queries) InsertPaymentHopCustomRecord(ctx context.Context, arg InsertPaymentHopCustomRecordParams) error { + _, err := q.db.ExecContext(ctx, insertPaymentHopCustomRecord, arg.HopID, arg.Key, arg.Value) + return err +} + +const insertPaymentIntent = `-- name: InsertPaymentIntent :one +INSERT INTO payment_intents ( + intent_type, + intent_payload) +VALUES ( + $1, + $2 +) +ON CONFLICT (intent_type, intent_payload) DO UPDATE SET + intent_type = EXCLUDED.intent_type, + intent_payload = EXCLUDED.intent_payload +RETURNING id +` + +type InsertPaymentIntentParams struct { + IntentType int16 + IntentPayload []byte +} + +// Insert a new payment intent and return its ID. +func (q *Queries) InsertPaymentIntent(ctx context.Context, arg InsertPaymentIntentParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertPaymentIntent, arg.IntentType, arg.IntentPayload) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertRouteHop = `-- name: InsertRouteHop :one +INSERT INTO payment_route_hops ( + htlc_attempt_index, + hop_index, + pub_key, + scid, + outgoing_time_lock, + amt_to_forward, + meta_data +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING id +` + +type InsertRouteHopParams struct { + HtlcAttemptIndex int64 + HopIndex int32 + PubKey []byte + Scid string + OutgoingTimeLock int32 + AmtToForward int64 + MetaData []byte +} + +func (q *Queries) InsertRouteHop(ctx context.Context, arg InsertRouteHopParams) (int64, error) { + row := q.db.QueryRowContext(ctx, insertRouteHop, + arg.HtlcAttemptIndex, + arg.HopIndex, + arg.PubKey, + arg.Scid, + arg.OutgoingTimeLock, + arg.AmtToForward, + arg.MetaData, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const insertRouteHopAmp = `-- name: InsertRouteHopAmp :exec +INSERT INTO payment_route_hop_amp ( + hop_id, + root_share, + set_id, + child_index +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type InsertRouteHopAmpParams struct { + HopID int64 + RootShare []byte + SetID []byte + ChildIndex int32 +} + +func (q *Queries) InsertRouteHopAmp(ctx context.Context, arg InsertRouteHopAmpParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopAmp, + arg.HopID, + arg.RootShare, + arg.SetID, + arg.ChildIndex, + ) + return err +} + +const insertRouteHopBlinded = `-- name: InsertRouteHopBlinded :exec +INSERT INTO payment_route_hop_blinded ( + hop_id, + encrypted_data, + blinding_point, + blinded_path_total_amt +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type InsertRouteHopBlindedParams struct { + HopID int64 + EncryptedData []byte + BlindingPoint []byte + BlindedPathTotalAmt sql.NullInt64 +} + +func (q *Queries) InsertRouteHopBlinded(ctx context.Context, arg InsertRouteHopBlindedParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopBlinded, + arg.HopID, + arg.EncryptedData, + arg.BlindingPoint, + arg.BlindedPathTotalAmt, + ) + return err +} + +const insertRouteHopMpp = `-- name: InsertRouteHopMpp :exec +INSERT INTO payment_route_hop_mpp ( + hop_id, + payment_addr, + total_msat +) +VALUES ( + $1, + $2, + $3 +) +` + +type InsertRouteHopMppParams struct { + HopID int64 + PaymentAddr []byte + TotalMsat int64 +} + +func (q *Queries) InsertRouteHopMpp(ctx context.Context, arg InsertRouteHopMppParams) error { + _, err := q.db.ExecContext(ctx, insertRouteHopMpp, arg.HopID, arg.PaymentAddr, arg.TotalMsat) + return err +} + +const settleAttempt = `-- name: SettleAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + settle_preimage +) +VALUES ( + $1, + $2, + $3, + $4 +) +` + +type SettleAttemptParams struct { + AttemptIndex int64 + ResolutionTime time.Time + ResolutionType int32 + SettlePreimage []byte +} + +func (q *Queries) SettleAttempt(ctx context.Context, arg SettleAttemptParams) error { + _, err := q.db.ExecContext(ctx, settleAttempt, + arg.AttemptIndex, + arg.ResolutionTime, + arg.ResolutionType, + arg.SettlePreimage, + ) + return err +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 0087559be8..3319e1e302 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -14,24 +14,41 @@ type Querier interface { AddSourceNode(ctx context.Context, nodeID int64) error AddV1ChannelProof(ctx context.Context, arg AddV1ChannelProofParams) (sql.Result, error) ClearKVInvoiceHashIndex(ctx context.Context) error + CountPayments(ctx context.Context) (int64, error) CountZombieChannels(ctx context.Context, version int16) (int64, error) CreateChannel(ctx context.Context, arg CreateChannelParams) (int64, error) DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error DeleteChannels(ctx context.Context, ids []int64) error DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error + DeleteFailedAttempts(ctx context.Context, paymentID int64) error DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) DeleteNode(ctx context.Context, id int64) error DeleteNodeAddresses(ctx context.Context, nodeID int64) error DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error + DeletePayment(ctx context.Context, id int64) error DeletePruneLogEntriesInRange(ctx context.Context, arg DeletePruneLogEntriesInRangeParams) error DeleteUnconnectedNodes(ctx context.Context) ([][]byte, error) DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) + FailAttempt(ctx context.Context, arg FailAttemptParams) error + FailPayment(ctx context.Context, arg FailPaymentParams) (sql.Result, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) + // Fetch all inflight attempts across all payments + FetchAllInflightAttempts(ctx context.Context) ([]PaymentHtlcAttempt, error) + FetchHopLevelCustomRecords(ctx context.Context, hopIds []int64) ([]PaymentHopCustomRecord, error) + FetchHopsForAttempt(ctx context.Context, htlcAttemptIndex int64) ([]FetchHopsForAttemptRow, error) + FetchHopsForAttempts(ctx context.Context, htlcAttemptIndices []int64) ([]FetchHopsForAttemptsRow, error) + // This fetches all htlc attempts for a payment. + FetchHtlcAttemptsForPayment(ctx context.Context, paymentID int64) ([]FetchHtlcAttemptsForPaymentRow, error) + FetchPayment(ctx context.Context, paymentIdentifier []byte) (FetchPaymentRow, error) + FetchPaymentLevelFirstHopCustomRecords(ctx context.Context, paymentID int64) ([]PaymentFirstHopCustomRecord, error) + FetchPaymentsByIDs(ctx context.Context, paymentIds []int64) ([]FetchPaymentsByIDsRow, error) + FetchRouteLevelFirstHopCustomRecords(ctx context.Context, htlcAttemptIndices []int64) ([]PaymentAttemptFirstHopCustomRecord, error) FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) + FilterPayments(ctx context.Context, arg FilterPaymentsParams) ([]FilterPaymentsRow, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error) @@ -101,6 +118,7 @@ type Querier interface { // UpsertEdgePolicy query is used because of the constraint in that query that // requires a policy update to have a newer last_update than the existing one). InsertEdgePolicyMig(ctx context.Context, arg InsertEdgePolicyMigParams) (int64, error) + InsertHtlcAttempt(ctx context.Context, arg InsertHtlcAttemptParams) (int64, error) InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error) @@ -114,6 +132,17 @@ type Querier interface { // is used because of the constraint in that query that requires a node update // to have a newer last_update than the existing node). InsertNodeMig(ctx context.Context, arg InsertNodeMigParams) (int64, error) + // Insert a new payment with the given intent ID and return its ID. + InsertPayment(ctx context.Context, arg InsertPaymentParams) (int64, error) + InsertPaymentAttemptFirstHopCustomRecord(ctx context.Context, arg InsertPaymentAttemptFirstHopCustomRecordParams) error + InsertPaymentFirstHopCustomRecord(ctx context.Context, arg InsertPaymentFirstHopCustomRecordParams) error + InsertPaymentHopCustomRecord(ctx context.Context, arg InsertPaymentHopCustomRecordParams) error + // Insert a new payment intent and return its ID. + InsertPaymentIntent(ctx context.Context, arg InsertPaymentIntentParams) (int64, error) + InsertRouteHop(ctx context.Context, arg InsertRouteHopParams) (int64, error) + InsertRouteHopAmp(ctx context.Context, arg InsertRouteHopAmpParams) error + InsertRouteHopBlinded(ctx context.Context, arg InsertRouteHopBlindedParams) error + InsertRouteHopMpp(ctx context.Context, arg InsertRouteHopMppParams) error IsClosedChannel(ctx context.Context, scid []byte) (bool, error) IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error) IsZombieChannel(ctx context.Context, arg IsZombieChannelParams) (bool, error) @@ -133,6 +162,7 @@ type Querier interface { OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error SetMigration(ctx context.Context, arg SetMigrationParams) error + SettleAttempt(ctx context.Context, arg SettleAttemptParams) error UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error) UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error) diff --git a/sqldb/sqlc/queries/payments.sql b/sqldb/sqlc/queries/payments.sql new file mode 100644 index 0000000000..db21a9be4b --- /dev/null +++ b/sqldb/sqlc/queries/payments.sql @@ -0,0 +1,368 @@ +/* ───────────────────────────────────────────── + fetch queries + ───────────────────────────────────────────── +*/ + +-- name: FilterPayments :many +SELECT + sqlc.embed(p), + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE ( + p.id > sqlc.narg('index_offset_get') OR + sqlc.narg('index_offset_get') IS NULL +) AND ( + p.id < sqlc.narg('index_offset_let') OR + sqlc.narg('index_offset_let') IS NULL +) AND ( + p.created_at >= sqlc.narg('created_after') OR + sqlc.narg('created_after') IS NULL +) AND ( + p.created_at <= sqlc.narg('created_before') OR + sqlc.narg('created_before') IS NULL +) AND ( + i.intent_type = sqlc.narg('intent_type') OR + sqlc.narg('intent_type') IS NULL OR i.intent_type IS NULL +) +ORDER BY + CASE WHEN sqlc.narg('reverse') = false OR sqlc.narg('reverse') IS NULL THEN p.id END ASC, + CASE WHEN sqlc.narg('reverse') = true THEN p.id END DESC +LIMIT @num_limit; + +-- name: FetchPayment :one +SELECT + sqlc.embed(p), + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.payment_identifier = $1; + +-- name: FetchPaymentsByIDs :many +SELECT + sqlc.embed(p), + i.intent_type AS "intent_type", + i.intent_payload AS "intent_payload" +FROM payments p +LEFT JOIN payment_intents i ON i.id = p.intent_id +WHERE p.id IN (sqlc.slice('payment_ids')/*SLICE:payment_ids*/); + +-- name: CountPayments :one +SELECT COUNT(*) FROM payments; + + +-- This fetches all htlc attempts for a payment. +-- name: FetchHtlcAttemptsForPayment :many +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key, + hr.resolution_type, + hr.resolution_time, + hr.failure_source_index, + hr.htlc_fail_reason, + hr.failure_msg, + hr.settle_preimage +FROM payment_htlc_attempts ha +LEFT JOIN payment_htlc_attempt_resolutions hr ON hr.attempt_index = ha.attempt_index +WHERE ha.payment_id = $1 +ORDER BY ha.attempt_time ASC; + +-- name: FetchAllInflightAttempts :many +-- Fetch all inflight attempts across all payments +SELECT + ha.id, + ha.attempt_index, + ha.payment_id, + ha.session_key, + ha.attempt_time, + ha.payment_hash, + ha.first_hop_amount_msat, + ha.route_total_time_lock, + ha.route_total_amount, + ha.route_source_key +FROM payment_htlc_attempts ha +WHERE NOT EXISTS ( + SELECT 1 FROM payment_htlc_attempt_resolutions hr + WHERE hr.attempt_index = ha.attempt_index +); + +-- name: FetchHopsForAttempt :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index = $1 +ORDER BY h.hop_index ASC; + +-- name: FetchHopsForAttempts :many +SELECT + h.id, + h.htlc_attempt_index, + h.hop_index, + h.pub_key, + h.scid, + h.outgoing_time_lock, + h.amt_to_forward, + h.meta_data, + m.payment_addr AS mpp_payment_addr, + m.total_msat AS mpp_total_msat, + a.root_share AS amp_root_share, + a.set_id AS amp_set_id, + a.child_index AS amp_child_index, + b.encrypted_data, + b.blinding_point, + b.blinded_path_total_amt +FROM payment_route_hops h +LEFT JOIN payment_route_hop_mpp m ON m.hop_id = h.id +LEFT JOIN payment_route_hop_amp a ON a.hop_id = h.id +LEFT JOIN payment_route_hop_blinded b ON b.hop_id = h.id +WHERE h.htlc_attempt_index IN (sqlc.slice('htlc_attempt_indices')/*SLICE:htlc_attempt_indices*/); + +-- name: FetchPaymentLevelFirstHopCustomRecords :many +SELECT + l.id, + l.payment_id, + l.key, + l.value +FROM payment_first_hop_custom_records l +WHERE l.payment_id = $1; + +-- name: FetchRouteLevelFirstHopCustomRecords :many +SELECT + l.id, + l.htlc_attempt_index, + l.key, + l.value +FROM payment_attempt_first_hop_custom_records l +WHERE l.htlc_attempt_index IN (sqlc.slice('htlc_attempt_indices')/*SLICE:htlc_attempt_indices*/); + +-- name: FetchHopLevelCustomRecords :many +SELECT + l.id, + l.hop_id, + l.key, + l.value +FROM payment_hop_custom_records l +WHERE l.hop_id IN (sqlc.slice('hop_ids')/*SLICE:hop_ids*/); + + +-- name: DeletePayment :exec +DELETE FROM payments WHERE id = $1; + +-- name: DeleteFailedAttempts :exec +DELETE FROM payment_htlc_attempts WHERE payment_id = $1 AND attempt_index IN ( + SELECT attempt_index FROM payment_htlc_attempt_resolutions WHERE resolution_type = 2 +); + +-- name: InsertPaymentIntent :one +-- Insert a new payment intent and return its ID. +INSERT INTO payment_intents ( + intent_type, + intent_payload) +VALUES ( + @intent_type, + @intent_payload +) +ON CONFLICT (intent_type, intent_payload) DO UPDATE SET + intent_type = EXCLUDED.intent_type, + intent_payload = EXCLUDED.intent_payload +RETURNING id; + +-- name: InsertPayment :one +-- Insert a new payment with the given intent ID and return its ID. +INSERT INTO payments ( + intent_id, + amount_msat, + created_at, + payment_identifier, + fail_reason) +VALUES ( + @intent_id, + @amount_msat, + @created_at, + @payment_identifier, + NULL +) +RETURNING id; + +-- name: InsertPaymentFirstHopCustomRecord :exec +INSERT INTO payment_first_hop_custom_records ( + payment_id, + key, + value +) +VALUES ( + @payment_id, + @key, + @value +); + +-- name: InsertHtlcAttempt :one +INSERT INTO payment_htlc_attempts ( + payment_id, + attempt_index, + session_key, + attempt_time, + payment_hash, + first_hop_amount_msat, + route_total_time_lock, + route_total_amount, + route_source_key) +VALUES ( + @payment_id, + @attempt_index, + @session_key, + @attempt_time, + @payment_hash, + @first_hop_amount_msat, + @route_total_time_lock, + @route_total_amount, + @route_source_key) +RETURNING id; + +-- name: InsertPaymentAttemptFirstHopCustomRecord :exec +INSERT INTO payment_attempt_first_hop_custom_records ( + htlc_attempt_index, + key, + value +) +VALUES ( + @htlc_attempt_index, + @key, + @value +); + +-- name: InsertRouteHop :one +INSERT INTO payment_route_hops ( + htlc_attempt_index, + hop_index, + pub_key, + scid, + outgoing_time_lock, + amt_to_forward, + meta_data +) +VALUES ( + @htlc_attempt_index, + @hop_index, + @pub_key, + @scid, + @outgoing_time_lock, + @amt_to_forward, + @meta_data +) +RETURNING id; + +-- name: InsertRouteHopMpp :exec +INSERT INTO payment_route_hop_mpp ( + hop_id, + payment_addr, + total_msat +) +VALUES ( + @hop_id, + @payment_addr, + @total_msat +); + +-- name: InsertRouteHopAmp :exec +INSERT INTO payment_route_hop_amp ( + hop_id, + root_share, + set_id, + child_index +) +VALUES ( + @hop_id, + @root_share, + @set_id, + @child_index +); + +-- name: InsertRouteHopBlinded :exec +INSERT INTO payment_route_hop_blinded ( + hop_id, + encrypted_data, + blinding_point, + blinded_path_total_amt +) +VALUES ( + @hop_id, + @encrypted_data, + @blinding_point, + @blinded_path_total_amt +); + +-- name: InsertPaymentHopCustomRecord :exec +INSERT INTO payment_hop_custom_records ( + hop_id, + key, + value +) +VALUES ( + @hop_id, + @key, + @value +); + +-- name: SettleAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + settle_preimage +) +VALUES ( + @attempt_index, + @resolution_time, + @resolution_type, + @settle_preimage +); + +-- name: FailAttempt :exec +INSERT INTO payment_htlc_attempt_resolutions ( + attempt_index, + resolution_time, + resolution_type, + failure_source_index, + htlc_fail_reason, + failure_msg +) +VALUES ( + @attempt_index, + @resolution_time, + @resolution_type, + @failure_source_index, + @htlc_fail_reason, + @failure_msg +); + +-- name: FailPayment :execresult +UPDATE payments SET fail_reason = $1 WHERE payment_identifier = $2;