Skip to content

Commit b3dc3ed

Browse files
bhandrasRoasbeef
authored andcommitted
channeldb: filter AMP state to relevant set IDs
When fetching an AMP invoice we sometimes filter HTLCs to selected set IDs, however we always kept the full AMP state which is irrelevant as it contains state for all AMP payments. This was a side effect of UpdateInvoice needing to serialize the whole invoice when storing after an update but it is an unwanted "feature" as users will need to filter to relevant set when listing an AMP payment or subsribing to an update.
1 parent cadce23 commit b3dc3ed

File tree

2 files changed

+54
-13
lines changed

2 files changed

+54
-13
lines changed

channeldb/invoices.go

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
269269

270270
// For each key found, we'll look up the actual
271271
// invoice, then accumulate it into our return value.
272-
invoice, err := fetchInvoice(invoiceKey, invoices)
272+
invoice, err := fetchInvoice(
273+
invoiceKey, invoices, nil, false,
274+
)
273275
if err != nil {
274276
return err
275277
}
@@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
341343

342344
// An invoice was found, retrieve the remainder of the invoice
343345
// body.
344-
i, err := fetchInvoice(invoiceNum, invoices, setID)
346+
i, err := fetchInvoice(
347+
invoiceNum, invoices, []*invpkg.SetID{setID}, true,
348+
)
345349
if err != nil {
346350
return err
347351
}
@@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
468472
return nil
469473
}
470474

471-
invoice, err := fetchInvoice(v, invoices)
475+
invoice, err := fetchInvoice(v, invoices, nil, false)
472476
if err != nil {
473477
return err
474478
}
@@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
526530
// characteristics for our query and returns the number of items
527531
// we have added to our set of invoices.
528532
accumulateInvoices := func(_, indexValue []byte) (bool, error) {
529-
invoice, err := fetchInvoice(indexValue, invoices)
533+
invoice, err := fetchInvoice(
534+
indexValue, invoices, nil, false,
535+
)
530536
if err != nil {
531537
return false, err
532538
}
@@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
654660
if setIDHint != nil {
655661
invSetID = *setIDHint
656662
}
657-
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID)
663+
invoice, err := fetchInvoice(
664+
invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
665+
)
658666
if err != nil {
659667
return err
660668
}
@@ -676,15 +684,43 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
676684
updatedInvoice, err = invpkg.UpdateInvoice(
677685
payHash, updater.invoice, now, callback, updater,
678686
)
687+
if err != nil {
688+
return err
689+
}
679690

680-
return err
691+
// If this is an AMP update, then limit the returned AMP state
692+
// to only the requested set ID.
693+
if setIDHint != nil {
694+
filterInvoiceAMPState(updatedInvoice, &invSetID)
695+
}
696+
697+
return nil
681698
}, func() {
682699
updatedInvoice = nil
683700
})
684701

685702
return updatedInvoice, err
686703
}
687704

705+
// filterInvoiceAMPState filters the AMP state of the invoice to only include
706+
// state for the specified set IDs.
707+
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
708+
filteredAMPState := make(invpkg.AMPInvoiceState)
709+
710+
for _, setID := range setIDs {
711+
if setID == nil {
712+
return
713+
}
714+
715+
ampState, ok := invoice.AMPState[*setID]
716+
if ok {
717+
filteredAMPState[*setID] = ampState
718+
}
719+
}
720+
721+
invoice.AMPState = filteredAMPState
722+
}
723+
688724
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
689725
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
690726

@@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
10561092
// For each key found, we'll look up the actual
10571093
// invoice, then accumulate it into our return value.
10581094
invoice, err := fetchInvoice(
1059-
invoiceKey[:], invoices, setID,
1095+
invoiceKey[:], invoices, []*invpkg.SetID{setID},
1096+
true,
10601097
)
10611098
if err != nil {
10621099
return err
@@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
14851522
// specified by the invoice number. If the setID fields are set, then only the
14861523
// HTLC information pertaining to those set IDs is returned.
14871524
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
1488-
setIDs ...*invpkg.SetID) (invpkg.Invoice, error) {
1525+
setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
14891526

14901527
invoiceBytes := invoices.Get(invoiceNum)
14911528
if invoiceBytes == nil {
@@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
15181555
log.Errorf("unable to fetch amp htlcs for inv "+
15191556
"%v and setIDs %v: %w", invoiceNum, setIDs, err)
15201557
}
1558+
1559+
if filterAMPState {
1560+
filterInvoiceAMPState(&invoice, setIDs...)
1561+
}
15211562
}
15221563

15231564
return invoice, nil
@@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
21632204
return nil
21642205
}
21652206

2166-
invoice, err := fetchInvoice(v, invoices)
2207+
invoice, err := fetchInvoice(v, invoices, nil, false)
21672208
if err != nil {
21682209
return err
21692210
}

itest/lnd_amp_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
303303
// return the "projected" sub-invoice for a given setID.
304304
require.Equal(ht, 1, len(invoiceNtfn.Htlcs))
305305

306-
// However the AMP state index should show that there've been two
307-
// repeated payments to this invoice so far.
308-
require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState))
306+
// The AMP state should also be restricted to a single entry for the
307+
// "projected" sub-invoice.
308+
require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState))
309309

310310
// Now we'll look up the invoice using the new LookupInvoice2 RPC call
311311
// by the set ID of each of the invoices.
@@ -364,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
364364
// through.
365365
backlogInv := ht.ReceiveInvoiceUpdate(invSub2)
366366
require.Equal(ht, 1, len(backlogInv.Htlcs))
367-
require.Equal(ht, 2, len(backlogInv.AmpInvoiceState))
367+
require.Equal(ht, 1, len(backlogInv.AmpInvoiceState))
368368
require.True(ht, backlogInv.Settled)
369369
require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat))
370370
}

0 commit comments

Comments
 (0)