Skip to content

Commit 6c9e0f3

Browse files
authored
Merge pull request #10196 from ziggie1984/refactor-payment-part-6
refactor payments part 6
2 parents 221c3e8 + f540e9c commit 6c9e0f3

File tree

7 files changed

+147
-93
lines changed

7 files changed

+147
-93
lines changed

payments/db/kv_store.go

Lines changed: 6 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,14 @@ type KVStore struct {
133133
keepFailedPaymentAttempts bool
134134
}
135135

136-
// defaultKVStoreOptions returns the default options for the KV store.
137-
func defaultKVStoreOptions() *StoreOptions {
138-
return &StoreOptions{
139-
KeepFailedPaymentAttempts: false,
140-
}
141-
}
136+
// A compile-time constraint to ensure KVStore implements DB.
137+
var _ DB = (*KVStore)(nil)
142138

143139
// NewKVStore creates a new KVStore for payments.
144140
func NewKVStore(db kvdb.Backend,
145141
options ...OptionModifier) (*KVStore, error) {
146142

147-
opts := defaultKVStoreOptions()
143+
opts := DefaultOptions()
148144
for _, applyOption := range options {
149145
applyOption(opts)
150146
}
@@ -393,88 +389,9 @@ func (p *KVStore) RegisterAttempt(paymentHash lntypes.Hash,
393389
return err
394390
}
395391

396-
// If the final hop has encrypted data, then we know this is a
397-
// blinded payment. In blinded payments, MPP records are not set
398-
// for split payments and the recipient is responsible for using
399-
// a consistent PathID across the various encrypted data
400-
// payloads that we received from them for this payment. All we
401-
// need to check is that the total amount field for each HTLC
402-
// in the split payment is correct.
403-
isBlinded := len(attempt.Route.FinalHop().EncryptedData) != 0
404-
405-
// Make sure any existing shards match the new one with regards
406-
// to MPP options.
407-
mpp := attempt.Route.FinalHop().MPP
408-
409-
// MPP records should not be set for attempts to blinded paths.
410-
if isBlinded && mpp != nil {
411-
return ErrMPPRecordInBlindedPayment
412-
}
413-
414-
for _, h := range payment.InFlightHTLCs() {
415-
hMpp := h.Route.FinalHop().MPP
416-
417-
// If this is a blinded payment, then no existing HTLCs
418-
// should have MPP records.
419-
if isBlinded && hMpp != nil {
420-
return ErrMPPRecordInBlindedPayment
421-
}
422-
423-
// If this is a blinded payment, then we just need to
424-
// check that the TotalAmtMsat field for this shard
425-
// is equal to that of any other shard in the same
426-
// payment.
427-
if isBlinded {
428-
if attempt.Route.FinalHop().TotalAmtMsat !=
429-
h.Route.FinalHop().TotalAmtMsat {
430-
431-
//nolint:ll
432-
return ErrBlindedPaymentTotalAmountMismatch
433-
}
434-
435-
continue
436-
}
437-
438-
switch {
439-
// We tried to register a non-MPP attempt for a MPP
440-
// payment.
441-
case mpp == nil && hMpp != nil:
442-
return ErrMPPayment
443-
444-
// We tried to register a MPP shard for a non-MPP
445-
// payment.
446-
case mpp != nil && hMpp == nil:
447-
return ErrNonMPPayment
448-
449-
// Non-MPP payment, nothing more to validate.
450-
case mpp == nil:
451-
continue
452-
}
453-
454-
// Check that MPP options match.
455-
if mpp.PaymentAddr() != hMpp.PaymentAddr() {
456-
return ErrMPPPaymentAddrMismatch
457-
}
458-
459-
if mpp.TotalMsat() != hMpp.TotalMsat() {
460-
return ErrMPPTotalAmountMismatch
461-
}
462-
}
463-
464-
// If this is a non-MPP attempt, it must match the total amount
465-
// exactly. Note that a blinded payment is considered an MPP
466-
// attempt.
467-
amt := attempt.Route.ReceiverAmt()
468-
if !isBlinded && mpp == nil && amt != payment.Info.Value {
469-
return ErrValueMismatch
470-
}
471-
472-
// Ensure we aren't sending more than the total payment amount.
473-
sentAmt, _ := payment.SentAmt()
474-
if sentAmt+amt > payment.Info.Value {
475-
return fmt.Errorf("%w: attempted=%v, payment amount="+
476-
"%v", ErrValueExceedsAmt,
477-
sentAmt+amt, payment.Info.Value)
392+
// Verify the attempt is compatible with the existing payment.
393+
if err := verifyAttempt(payment, attempt); err != nil {
394+
return err
478395
}
479396

480397
htlcsBucket, err := bucket.CreateBucketIfNotExists(

payments/db/options.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ type StoreOptions struct {
1010
KeepFailedPaymentAttempts bool
1111
}
1212

13+
// DefaultOptions returns a StoreOptions populated with default values.
14+
func DefaultOptions() *StoreOptions {
15+
return &StoreOptions{
16+
KeepFailedPaymentAttempts: false,
17+
NoMigration: false,
18+
}
19+
}
20+
1321
// OptionModifier is a function signature for modifying the default
1422
// StoreOptions.
1523
type OptionModifier func(*StoreOptions)

payments/db/payment.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,20 @@ func (h *HTLCAttemptInfo) SessionKey() *btcec.PrivateKey {
175175
return h.cachedSessionKey
176176
}
177177

178+
// setSessionKey sets the session key for the htlc attempt.
179+
//
180+
// NOTE: Only used for testing.
181+
//
182+
//nolint:unused
183+
func (h *HTLCAttemptInfo) setSessionKey(sessionKey *btcec.PrivateKey) {
184+
h.cachedSessionKey = sessionKey
185+
186+
// Also set the session key as a raw bytes.
187+
var scratch [btcec.PrivKeyBytesLen]byte
188+
copy(scratch[:], sessionKey.Serialize())
189+
h.sessionKey = scratch
190+
}
191+
178192
// OnionBlob returns the onion blob created from the sphinx construction.
179193
func (h *HTLCAttemptInfo) OnionBlob() ([lnwire.OnionPacketSize]byte, error) {
180194
var zeroBytes [lnwire.OnionPacketSize]byte
@@ -712,3 +726,95 @@ func generateSphinxPacket(rt *route.Route, paymentHash []byte,
712726
PaymentPath: sphinxPath.NodeKeys(),
713727
}, nil
714728
}
729+
730+
// verifyAttempt validates that a new HTLC attempt is compatible with the
731+
// existing payment and its in-flight HTLCs. This function checks:
732+
// 1. MPP (Multi-Path Payment) compatibility between attempts
733+
// 2. Blinded payment consistency
734+
// 3. Amount validation
735+
// 4. Total payment amount limits
736+
func verifyAttempt(payment *MPPayment, attempt *HTLCAttemptInfo) error {
737+
// If the final hop has encrypted data, then we know this is a
738+
// blinded payment. In blinded payments, MPP records are not set
739+
// for split payments and the recipient is responsible for using
740+
// a consistent PathID across the various encrypted data
741+
// payloads that we received from them for this payment. All we
742+
// need to check is that the total amount field for each HTLC
743+
// in the split payment is correct.
744+
isBlinded := len(attempt.Route.FinalHop().EncryptedData) != 0
745+
746+
// Make sure any existing shards match the new one with regards
747+
// to MPP options.
748+
mpp := attempt.Route.FinalHop().MPP
749+
750+
// MPP records should not be set for attempts to blinded paths.
751+
if isBlinded && mpp != nil {
752+
return ErrMPPRecordInBlindedPayment
753+
}
754+
755+
for _, h := range payment.InFlightHTLCs() {
756+
hMpp := h.Route.FinalHop().MPP
757+
758+
// If this is a blinded payment, then no existing HTLCs
759+
// should have MPP records.
760+
if isBlinded && hMpp != nil {
761+
return ErrMPPRecordInBlindedPayment
762+
}
763+
764+
// If this is a blinded payment, then we just need to
765+
// check that the TotalAmtMsat field for this shard
766+
// is equal to that of any other shard in the same
767+
// payment.
768+
if isBlinded {
769+
if attempt.Route.FinalHop().TotalAmtMsat !=
770+
h.Route.FinalHop().TotalAmtMsat {
771+
772+
return ErrBlindedPaymentTotalAmountMismatch
773+
}
774+
775+
continue
776+
}
777+
778+
switch {
779+
// We tried to register a non-MPP attempt for a MPP
780+
// payment.
781+
case mpp == nil && hMpp != nil:
782+
return ErrMPPayment
783+
784+
// We tried to register a MPP shard for a non-MPP
785+
// payment.
786+
case mpp != nil && hMpp == nil:
787+
return ErrNonMPPayment
788+
789+
// Non-MPP payment, nothing more to validate.
790+
case mpp == nil:
791+
continue
792+
}
793+
794+
// Check that MPP options match.
795+
if mpp.PaymentAddr() != hMpp.PaymentAddr() {
796+
return ErrMPPPaymentAddrMismatch
797+
}
798+
799+
if mpp.TotalMsat() != hMpp.TotalMsat() {
800+
return ErrMPPTotalAmountMismatch
801+
}
802+
}
803+
804+
// If this is a non-MPP attempt, it must match the total amount
805+
// exactly. Note that a blinded payment is considered an MPP
806+
// attempt.
807+
amt := attempt.Route.ReceiverAmt()
808+
if !isBlinded && mpp == nil && amt != payment.Info.Value {
809+
return ErrValueMismatch
810+
}
811+
812+
// Ensure we aren't sending more than the total payment amount.
813+
sentAmt, _ := payment.SentAmt()
814+
if sentAmt+amt > payment.Info.Value {
815+
return fmt.Errorf("%w: attempted=%v, payment amount=%v",
816+
ErrValueExceedsAmt, sentAmt+amt, payment.Info.Value)
817+
}
818+
819+
return nil
820+
}

payments/db/query.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
package paymentsdb
22

3+
const (
4+
// DefaultMaxPayments is the default maximum number of payments returned
5+
// in the payments query pagination.
6+
DefaultMaxPayments = 100
7+
)
8+
39
// Query represents a query to the payments database starting or ending
410
// at a certain offset index. The number of retrieved records can be limited.
511
type Query struct {

rpcserver.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7523,10 +7523,15 @@ func (r *rpcServer) ListPayments(ctx context.Context,
75237523
CreationDateEnd: int64(req.CreationDateEnd),
75247524
}
75257525

7526-
// If the maximum number of payments wasn't specified, then we'll
7527-
// default to return the maximal number of payments representable.
7526+
// If the maximum number of payments wasn't specified, we default to
7527+
// a reasonable number to prevent resource exhaustion. All of the
7528+
// payments are fetched into memory. Moreover we don't want our daemon
7529+
// to remain stable and do other stuff rather than serving payments.
7530+
//
7531+
// TODO(ziggie): Choose a more specific default value when results of
7532+
// performance testing are available.
75287533
if req.MaxPayments == 0 {
7529-
query.MaxPayments = math.MaxUint64
7534+
query.MaxPayments = paymentsdb.DefaultMaxPayments
75307535
}
75317536

75327537
paymentsQuerySlice, err := r.server.paymentsDB.QueryPayments(

sqldb/postgres.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ var (
3030
postgresSchemaReplacements = map[string]string{
3131
"BLOB": "BYTEA",
3232
"INTEGER PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
33-
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
33+
// We need this space in front of the TIMESTAMP keyword to
34+
// avoid replacing words which just have the word "TIMESTAMP" in
35+
// them.
36+
" TIMESTAMP": " TIMESTAMP WITHOUT TIME ZONE",
3437
}
3538

3639
// Make sure PostgresStore implements the MigrationExecutor interface.

sqldb/sqlutils.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ func SQLStrValid(s string) sql.NullString {
7575
}
7676
}
7777

78+
// SQLBool turns a boolean into the NullBool that sql/sqlc uses when a boolean
79+
// can be permitted to be NULL.
80+
func SQLBool(b bool) sql.NullBool {
81+
return sql.NullBool{
82+
Bool: b,
83+
Valid: true,
84+
}
85+
}
86+
7887
// SQLTime turns a time.Time into the NullTime that sql/sqlc uses when a time
7988
// can be permitted to be NULL.
8089
func SQLTime(t time.Time) sql.NullTime {

0 commit comments

Comments
 (0)