@@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
269269
270270 // For each key found, we'll look up the actual
271271 // invoice, then accumulate it into our return value.
272- invoice , err := fetchInvoice (invoiceKey , invoices )
272+ invoice , err := fetchInvoice (
273+ invoiceKey , invoices , nil , false ,
274+ )
273275 if err != nil {
274276 return err
275277 }
@@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
341343
342344 // An invoice was found, retrieve the remainder of the invoice
343345 // body.
344- i , err := fetchInvoice (invoiceNum , invoices , setID )
346+ i , err := fetchInvoice (
347+ invoiceNum , invoices , []* invpkg.SetID {setID }, true ,
348+ )
345349 if err != nil {
346350 return err
347351 }
@@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
468472 return nil
469473 }
470474
471- invoice , err := fetchInvoice (v , invoices )
475+ invoice , err := fetchInvoice (v , invoices , nil , false )
472476 if err != nil {
473477 return err
474478 }
@@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
526530 // characteristics for our query and returns the number of items
527531 // we have added to our set of invoices.
528532 accumulateInvoices := func (_ , indexValue []byte ) (bool , error ) {
529- invoice , err := fetchInvoice (indexValue , invoices )
533+ invoice , err := fetchInvoice (
534+ indexValue , invoices , nil , false ,
535+ )
530536 if err != nil {
531537 return false , err
532538 }
@@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
654660 if setIDHint != nil {
655661 invSetID = * setIDHint
656662 }
657- invoice , err := fetchInvoice (invoiceNum , invoices , & invSetID )
663+ invoice , err := fetchInvoice (
664+ invoiceNum , invoices , []* invpkg.SetID {& invSetID }, false ,
665+ )
658666 if err != nil {
659667 return err
660668 }
@@ -676,15 +684,43 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
676684 updatedInvoice , err = invpkg .UpdateInvoice (
677685 payHash , updater .invoice , now , callback , updater ,
678686 )
687+ if err != nil {
688+ return err
689+ }
679690
680- return err
691+ // If this is an AMP update, then limit the returned AMP state
692+ // to only the requested set ID.
693+ if setIDHint != nil {
694+ filterInvoiceAMPState (updatedInvoice , & invSetID )
695+ }
696+
697+ return nil
681698 }, func () {
682699 updatedInvoice = nil
683700 })
684701
685702 return updatedInvoice , err
686703}
687704
705+ // filterInvoiceAMPState filters the AMP state of the invoice to only include
706+ // state for the specified set IDs.
707+ func filterInvoiceAMPState (invoice * invpkg.Invoice , setIDs ... * invpkg.SetID ) {
708+ filteredAMPState := make (invpkg.AMPInvoiceState )
709+
710+ for _ , setID := range setIDs {
711+ if setID == nil {
712+ return
713+ }
714+
715+ ampState , ok := invoice .AMPState [* setID ]
716+ if ok {
717+ filteredAMPState [* setID ] = ampState
718+ }
719+ }
720+
721+ invoice .AMPState = filteredAMPState
722+ }
723+
688724// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
689725type ampHTLCsMap map [invpkg.SetID ]map [models.CircuitKey ]* invpkg.InvoiceHTLC
690726
@@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
10561092 // For each key found, we'll look up the actual
10571093 // invoice, then accumulate it into our return value.
10581094 invoice , err := fetchInvoice (
1059- invoiceKey [:], invoices , setID ,
1095+ invoiceKey [:], invoices , []* invpkg.SetID {setID },
1096+ true ,
10601097 )
10611098 if err != nil {
10621099 return err
@@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
14851522// specified by the invoice number. If the setID fields are set, then only the
14861523// HTLC information pertaining to those set IDs is returned.
14871524func fetchInvoice (invoiceNum []byte , invoices kvdb.RBucket ,
1488- setIDs ... * invpkg.SetID ) (invpkg.Invoice , error ) {
1525+ setIDs [] * invpkg.SetID , filterAMPState bool ) (invpkg.Invoice , error ) {
14891526
14901527 invoiceBytes := invoices .Get (invoiceNum )
14911528 if invoiceBytes == nil {
@@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
15181555 log .Errorf ("unable to fetch amp htlcs for inv " +
15191556 "%v and setIDs %v: %w" , invoiceNum , setIDs , err )
15201557 }
1558+
1559+ if filterAMPState {
1560+ filterInvoiceAMPState (& invoice , setIDs ... )
1561+ }
15211562 }
15221563
15231564 return invoice , nil
@@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
21632204 return nil
21642205 }
21652206
2166- invoice , err := fetchInvoice (v , invoices )
2207+ invoice , err := fetchInvoice (v , invoices , nil , false )
21672208 if err != nil {
21682209 return err
21692210 }
0 commit comments