@@ -10,6 +10,7 @@ import (
1010 "strconv"
1111 "time"
1212
13+ "github.com/davecgh/go-spew/spew"
1314 "github.com/lightningnetwork/lnd/channeldb/models"
1415 "github.com/lightningnetwork/lnd/clock"
1516 "github.com/lightningnetwork/lnd/lntypes"
@@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
4647 GetInvoice (ctx context.Context ,
4748 arg sqlc.GetInvoiceParams ) ([]sqlc.Invoice , error )
4849
50+ GetInvoiceBySetID (ctx context.Context , setID []byte ) ([]sqlc.Invoice ,
51+ error )
52+
4953 GetInvoiceFeatures (ctx context.Context ,
5054 invoiceID int64 ) ([]sqlc.InvoiceFeature , error )
5155
@@ -343,16 +347,31 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
343347 params .SetID = ref .SetID ()[:]
344348 }
345349
346- rows , err := db .GetInvoice (ctx , params )
350+ var (
351+ rows []sqlc.Invoice
352+ err error
353+ )
354+
355+ // We need to split the query based on how we intend to look up the
356+ // invoice. If only the set ID is given then we want to have an exact
357+ // match on the set ID. If other fields are given, we want to match on
358+ // those fields and the set ID but with a less strict join condition.
359+ if params .Hash == nil && params .PaymentAddr == nil &&
360+ params .SetID != nil {
361+
362+ rows , err = db .GetInvoiceBySetID (ctx , params .SetID )
363+ } else {
364+ rows , err = db .GetInvoice (ctx , params )
365+ }
347366 switch {
348367 case len (rows ) == 0 :
349368 return nil , ErrInvoiceNotFound
350369
351370 case len (rows ) > 1 :
352371 // In case the reference is ambiguous, meaning it matches more
353372 // than one invoice, we'll return an error.
354- return nil , fmt .Errorf ("ambiguous invoice ref: %s" ,
355- ref .String ())
373+ return nil , fmt .Errorf ("ambiguous invoice ref: %s: %s " ,
374+ ref .String (), spew . Sdump ( rows ) )
356375
357376 case err != nil :
358377 return nil , fmt .Errorf ("unable to fetch invoice: %w" , err )
@@ -1308,13 +1327,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
13081327// invoice and is therefore atomic. The fields to update are controlled by the
13091328// supplied callback.
13101329func (i * SQLStore ) UpdateInvoice (ctx context.Context , ref InvoiceRef ,
1311- _ * SetID , callback InvoiceUpdateCallback ) (
1330+ setID * SetID , callback InvoiceUpdateCallback ) (
13121331 * Invoice , error ) {
13131332
13141333 var updatedInvoice * Invoice
13151334
13161335 txOpt := SQLInvoiceQueriesTxOptions {readOnly : false }
13171336 txErr := i .db .ExecTx (ctx , & txOpt , func (db SQLInvoiceQueries ) error {
1337+ if setID != nil {
1338+ // Make sure to use the set ID if this is an AMP update.
1339+ var setIDBytes [32 ]byte
1340+ copy (setIDBytes [:], setID [:])
1341+ ref .setID = & setIDBytes
1342+
1343+ // If we're updating an AMP invoice, we'll also only
1344+ // need to fetch the HTLCs for the given set ID.
1345+ ref .refModifier = HtlcSetOnlyModifier
1346+ }
1347+
13181348 invoice , err := i .fetchInvoice (ctx , db , ref )
13191349 if err != nil {
13201350 return err
0 commit comments