diff --git a/sweepbatcher/store.go b/sweepbatcher/store.go index 3af26bf0a..4cf899c85 100644 --- a/sweepbatcher/store.go +++ b/sweepbatcher/store.go @@ -3,6 +3,7 @@ package sweepbatcher import ( "context" "database/sql" + "fmt" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" @@ -121,6 +122,30 @@ func (s *SQLStore) UpdateSweepBatch(ctx context.Context, batch *dbBatch) error { return s.baseDb.UpdateBatch(ctx, batchToUpdateArgs(*batch)) } +// ConfirmBatchWithSweeps atomically confirms the batch and updates its sweeps. +func (s *SQLStore) ConfirmBatchWithSweeps(ctx context.Context, batch *dbBatch, + sweeps []*dbSweep) error { + + writeOpts := loopdb.NewSqlWriteOpts() + + return s.baseDb.ExecTx(ctx, writeOpts, func(tx Querier) error { + err := tx.UpdateBatch(ctx, batchToUpdateArgs(*batch)) + if err != nil { + return fmt.Errorf("update batch %d: %w", batch.ID, err) + } + + for _, sweep := range sweeps { + err := tx.UpsertSweep(ctx, sweepToUpsertArgs(*sweep)) + if err != nil { + return fmt.Errorf("upsert sweep %v: %w", + sweep.Outpoint, err) + } + } + + return nil + }) +} + // FetchBatchSweeps fetches all the sweeps that are part a batch. func (s *SQLStore) FetchBatchSweeps(ctx context.Context, id int32) ( []*dbSweep, error) { diff --git a/sweepbatcher/store_mock.go b/sweepbatcher/store_mock.go index 2e28a603b..fac6ffec6 100644 --- a/sweepbatcher/store_mock.go +++ b/sweepbatcher/store_mock.go @@ -3,6 +3,7 @@ package sweepbatcher import ( "context" "errors" + "fmt" "sort" "sync" @@ -77,6 +78,32 @@ func (s *StoreMock) UpdateSweepBatch(ctx context.Context, return nil } +// ConfirmBatchWithSweeps updates the batch and the provided sweeps atomically. +func (s *StoreMock) ConfirmBatchWithSweeps(ctx context.Context, + batch *dbBatch, sweeps []*dbSweep) error { + + s.mu.Lock() + defer s.mu.Unlock() + + s.batches[batch.ID] = *batch + + for _, sweep := range sweeps { + sweepCopy := *sweep + + old, exists := s.sweeps[sweep.Outpoint] + if !exists { + return fmt.Errorf("confirming unknown sweep %v", + sweep.Outpoint) + } + + sweepCopy.ID = old.ID + + s.sweeps[sweep.Outpoint] = sweepCopy + } + + return nil +} + // FetchBatchSweeps fetches all the sweeps that belong to a batch. func (s *StoreMock) FetchBatchSweeps(ctx context.Context, id int32) ([]*dbSweep, error) { diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index 58cdf19dc..fe91fd5ad 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -2224,10 +2224,6 @@ func (b *batch) handleConf(ctx context.Context, b.Infof("confirmed in txid %s", b.batchTxid) b.state = Confirmed - if err := b.persist(ctx); err != nil { - return fmt.Errorf("saving batch failed: %w", err) - } - // If the batch is in presigned mode, cleanup presignedHelper. presigned, err := b.isPresigned() if err != nil { @@ -2261,18 +2257,16 @@ func (b *batch) handleConf(ctx context.Context, confirmedSweeps = []wire.OutPoint{} purgeList = make([]SweepRequest, 0, len(b.sweeps)) totalSweptAmt btcutil.Amount + dbConfirmed = make([]*dbSweep, 0, len(allSweeps)) ) for _, sweep := range allSweeps { _, found := confirmedSet[sweep.outpoint] if found { - // Save the sweep as completed. Note that sweeps are - // marked completed after the batch is marked confirmed - // because the check in handleSweeps checks sweep's - // status first and then checks the batch status. - err := b.persistSweep(ctx, sweep, true) - if err != nil { - return err - } + // Save the sweep as completed; the batch row and all + // sweeps are persisted atomically below. + dbConfirmed = append( + dbConfirmed, b.dbSweepFrom(sweep, true), + ) confirmedSweeps = append( confirmedSweeps, sweep.outpoint, @@ -2328,8 +2322,15 @@ func (b *batch) handleConf(ctx context.Context, } } - b.Infof("fully confirmed sweeps: %v, purged sweeps: %v, "+ - "purged swaps: %v", confirmedSweeps, purgedSweeps, purgedSwaps) + b.Infof("Fully confirmed sweeps: %v, purged sweeps: %v, "+ + "purged swaps: %v. Saving the batch and sweeps to DB", + confirmedSweeps, purgedSweeps, purgedSwaps) + + if err := b.persistConfirmedBatch(ctx, dbConfirmed); err != nil { + return fmt.Errorf("saving confirmed batch failed: %w", err) + } + + b.Infof("Successfully saved the batch and confirmed sweeps to DB") // Proceed with purging the sweeps. This will feed the sweeps that // didn't make it to the confirmed batch transaction back to the batcher @@ -2445,6 +2446,11 @@ func (b *batch) isComplete() bool { // persist updates the batch in the database. func (b *batch) persist(ctx context.Context) error { + return b.store.UpdateSweepBatch(ctx, b.dbBatch()) +} + +// dbBatch builds the dbBatch representation for the current in-memory state. +func (b *batch) dbBatch() *dbBatch { bch := &dbBatch{} bch.ID = b.id @@ -2459,7 +2465,7 @@ func (b *batch) persist(ctx context.Context) error { bch.LastRbfSatPerKw = int32(b.rbfCache.FeeRate) bch.MaxTimeoutDistance = b.cfg.maxTimeoutDistance - return b.store.UpdateSweepBatch(ctx, bch) + return bch } // getBatchDestAddr returns the batch's destination address. If the batch @@ -2612,16 +2618,31 @@ func (b *batch) writeToConfErrChan(ctx context.Context, confErr error) { } } +// persistSweep upserts the given sweep into the backing store and optionally +// marks it as completed. func (b *batch) persistSweep(ctx context.Context, sweep sweep, completed bool) error { - return b.store.UpsertSweep(ctx, &dbSweep{ + return b.store.UpsertSweep(ctx, b.dbSweepFrom(sweep, completed)) +} + +// dbSweepFrom builds the dbSweep representation for a batch sweep. +func (b *batch) dbSweepFrom(sweep sweep, completed bool) *dbSweep { + return &dbSweep{ BatchID: b.id, SwapHash: sweep.swapHash, Outpoint: sweep.outpoint, Amount: sweep.value, Completed: completed, - }) + } +} + +// persistConfirmedBatch atomically records the batch confirmation metadata +// along with all sweeps that confirmed in the same transaction. +func (b *batch) persistConfirmedBatch(ctx context.Context, + sweeps []*dbSweep) error { + + return b.store.ConfirmBatchWithSweeps(ctx, b.dbBatch(), sweeps) } // clampBatchFee takes the fee amount and total amount of the sweeps in the diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index dd5d01751..51b707bac 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -59,6 +59,11 @@ type BatcherStore interface { // UpdateSweepBatch updates a batch in the database. UpdateSweepBatch(ctx context.Context, batch *dbBatch) error + // ConfirmBatchWithSweeps atomically marks the batch as confirmed and + // updates the provided sweeps in the database. + ConfirmBatchWithSweeps(ctx context.Context, batch *dbBatch, + sweeps []*dbSweep) error + // FetchBatchSweeps fetches all the sweeps that belong to a batch. FetchBatchSweeps(ctx context.Context, id int32) ([]*dbSweep, error) @@ -975,9 +980,8 @@ func (b *Batcher) handleSweeps(ctx context.Context, sweeps []*sweep, "sweeps with primarySweep %x: confirmed=%v", len(sweeps), sweep.swapHash[:6], parentBatch.Confirmed) - // Note that sweeps are marked completed after the batch is - // marked confirmed because here we check the sweep status - // first and then check the batch status. + // Batch + sweeps are persisted atomically, so if the sweep + // shows as completed its parent batch must be confirmed. if parentBatch.Confirmed { debugf("Sweep group of %d sweeps with primarySweep %x "+ "is fully confirmed, switching directly to "+ diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index 2d35541b2..6a9c57c10 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -19,6 +19,7 @@ import ( "github.com/btcsuite/btclog/v2" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/loopdb/sqlc" "github.com/lightninglabs/loop/test" "github.com/lightninglabs/loop/utils" "github.com/lightningnetwork/lnd/build" @@ -4393,6 +4394,255 @@ func testSweepBatcherHandleBatchShutdown(t *testing.T, store testStore, require.NoError(t, err) } +// failingBaseDB wraps a BaseDB and injects a failure after the batch row is +// marked confirmed but before the sweeps are persisted, emulating a crash. +type failingBaseDB struct { + // BaseDB is the actual database implementation we delegate to. + BaseDB + + // mu synchronizes access to the failure state. + mu sync.Mutex + + // armed is set once we observe the batch row being marked confirmed. + armed bool + + // failed ensures we only inject the failure once. + failed bool + + // failErr is the error returned to callers when the injection triggers. + failErr error +} + +// newFailingBaseDB creates a new failure-injecting wrapper around the provided +// BaseDB implementation. +func newFailingBaseDB(inner BaseDB) *failingBaseDB { + return &failingBaseDB{ + BaseDB: inner, + failErr: errors.New("forced failure after confirming batch"), + } +} + +// markArmed remembers that the batch row was updated to confirmed so the next +// sweep update will be forced to fail. +func (f *failingBaseDB) markArmed() { + f.mu.Lock() + defer f.mu.Unlock() + + if !f.failed { + f.armed = true + } +} + +// shouldFail returns true exactly once after the wrapper has been armed. +func (f *failingBaseDB) shouldFail() bool { + f.mu.Lock() + defer f.mu.Unlock() + + if f.armed && !f.failed { + f.failed = true + f.armed = false + return true + } + + return false +} + +// UpdateBatch proxies the batch update and arms the failure if the batch was +// marked confirmed. +func (f *failingBaseDB) UpdateBatch(ctx context.Context, + arg sqlc.UpdateBatchParams) error { + + if arg.Confirmed { + f.markArmed() + } + + return f.BaseDB.UpdateBatch(ctx, arg) +} + +// UpsertSweep forwards the sweep update unless a failure injection is pending. +func (f *failingBaseDB) UpsertSweep(ctx context.Context, + arg sqlc.UpsertSweepParams) error { + + if f.shouldFail() { + return f.failErr + } + + return f.BaseDB.UpsertSweep(ctx, arg) +} + +// ExecTx wraps the transactional Querier with failingQuerier so the failure +// state is respected inside transactions. +func (f *failingBaseDB) ExecTx(ctx context.Context, opts loopdb.TxOptions, + txBody func(Querier) error) error { + + return f.BaseDB.ExecTx(ctx, opts, func(q Querier) error { + return txBody(&failingQuerier{ + Querier: q, + parent: f, + }) + }) +} + +// failingQuerier proxies the ExecTx-scoped Querier to propagate the failure +// injection logic into transactional code paths. +type failingQuerier struct { + // Querier is the underlying transactional view. + Querier + + // parent references the owning failingBaseDB so we share the failure + // state across transactional calls. + parent *failingBaseDB +} + +// UpdateBatch mirrors failingBaseDB.UpdateBatch within a transaction scope. +func (f *failingQuerier) UpdateBatch(ctx context.Context, + arg sqlc.UpdateBatchParams) error { + + if arg.Confirmed { + f.parent.markArmed() + } + + return f.Querier.UpdateBatch(ctx, arg) +} + +// UpsertSweep mirrors failingBaseDB.UpsertSweep for transactional calls. +func (f *failingQuerier) UpsertSweep(ctx context.Context, + arg sqlc.UpsertSweepParams) error { + + if f.parent.shouldFail() { + return f.parent.failErr + } + + return f.Querier.UpsertSweep(ctx, arg) +} + +// TestSweepBatcherConfirmedBatchIncompleteSweeps documents the current crash +// window where a batch can be marked confirmed while its sweeps remain +// incomplete in the DB. This test runs only against the loopdb backend and +// injects failures at the BaseDB layer to simulate a crash. +func TestSweepBatcherConfirmedBatchIncompleteSweeps(t *testing.T) { + logger := btclog.NewSLogger(btclog.NewDefaultHandler(os.Stdout)) + logger.SetLevel(btclog.LevelTrace) + UseLogger(logger.SubSystem("SWEEP")) + + // Set up a fresh loopdb instance so we exercise the real SQL backend. + sqlDB := loopdb.NewTestDB(t) + typedSqlDB := loopdb.NewTypedStore[Querier](sqlDB) + faultyDB := newFailingBaseDB(typedSqlDB) + lnd := test.NewMockLnd() + batcherStore := NewSQLStore(faultyDB, lnd.ChainParams) + swapStore := newLoopdbStore(t, sqlDB) + + const ( + sweepValue btcutil.Amount = 1_000_000 + confHeight = 777 + ) + + ctx := context.Background() + + sweepOutpoint := wire.OutPoint{ + Hash: chainhash.Hash{0, 0, 0, 3}, + Index: 7, + } + swapHash := lntypes.Hash{3, 3, 3} + + notifier := &SpendNotifier{ + SpendChan: make(chan *SpendDetail, ntfnBufferSize), + ConfChan: make(chan *ConfDetail, ntfnBufferSize), + QuitChan: make(chan bool, ntfnBufferSize), + } + + sweepReq := SweepRequest{ + SwapHash: swapHash, + Inputs: []Input{{ + Value: sweepValue, + Outpoint: sweepOutpoint, + }}, + Notifier: notifier, + } + + swap := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 144, + AmountRequested: sweepValue, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + Preimage: lntypes.Preimage{3}, + }, + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: confTarget, + } + + // Seed the DB with an initiated Loop Out swap so AddSweep can load it. + require.NoError(t, swapStore.CreateLoopOut(ctx, swapHash, swap)) + swapStore.AssertLoopOutStored() + + sweepStore, err := NewSweepFetcherFromSwapStore( + swapStore, lnd.ChainParams, + ) + require.NoError(t, err) + + ctx1, cancel1 := context.WithCancel(ctx) + defer cancel1() + + // The failing DB wrapper will arm itself when the batch row is updated, + // then abort the first sweep update performed in the same transaction, + // mimicking a crash between those two steps. + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, sweepStore, + ) + + var wg sync.WaitGroup + wg.Add(1) + var runErr error + go func() { + defer wg.Done() + runErr = batcher.Run(ctx1) + }() + + <-batcher.initDone + + // Add the sweep once so the batcher spins up a batch. + require.NoError(t, batcher.AddSweep(ctx1, &sweepReq)) + + <-lnd.RegisterSpendChannel + publishedTx := <-lnd.TxPublishChannel + + spendDetail := &chainntnfs.SpendDetail{ + SpentOutPoint: &sweepOutpoint, + SpendingTx: publishedTx, + SpenderTxHash: new(chainhash.Hash), + SpenderInputIndex: 0, + } + *spendDetail.SpenderTxHash = publishedTx.TxHash() + lnd.SpendChannel <- spendDetail + + <-lnd.RegisterConfChannel + require.NoError(t, lnd.NotifyHeight(confHeight)) + lnd.ConfChannel <- &chainntnfs.TxConfirmation{ + BlockHeight: confHeight, + Tx: publishedTx, + } + + // The failing BaseDB injects its error while handleConf stores the + // confirmed batch/sweeps. Observe that error, then verify the DB was + // left consistent (both the batch and sweeps remain unconfirmed). + wg.Wait() + require.ErrorIs(t, runErr, faultyDB.failErr) + + completed, err := batcherStore.GetSweepStatus(ctx, sweepOutpoint) + require.NoError(t, err) + + parentBatch, err := batcherStore.GetParentBatch(ctx, sweepOutpoint) + require.NoError(t, err) + + require.Equal(t, parentBatch.Confirmed, completed, + "inconsistent DB: confirmed batch vs sweep completion") +} + // testCustomSignMuSig2 tests the operation with custom musig2 signer. func testCustomSignMuSig2(t *testing.T, store testStore, batcherStore testBatcherStore) {