From c04e3eed3e227f5e5f0efc186077e77c73b2d820 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Mon, 10 Feb 2025 15:53:16 -0300 Subject: [PATCH 01/12] sweepbatcher: make func constructUnsignedTx pure Also added a unit test for it. --- sweepbatcher/sweep_batch.go | 14 +- sweepbatcher/sweep_batch_test.go | 319 +++++++++++++++++++++++++++++++ 2 files changed, 326 insertions(+), 7 deletions(-) create mode 100644 sweepbatcher/sweep_batch_test.go diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index 6a4989cd9..b087a8994 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -882,9 +882,9 @@ func (b *batch) createPsbt(unsignedTx *wire.MsgTx, sweeps []sweep) ([]byte, // constructUnsignedTx creates unsigned tx from the sweeps, paying to the addr. // It also returns absolute fee (from weight and clamped). -func (b *batch) constructUnsignedTx(sweeps []sweep, - address btcutil.Address) (*wire.MsgTx, lntypes.WeightUnit, - btcutil.Amount, btcutil.Amount, error) { +func constructUnsignedTx(sweeps []sweep, address btcutil.Address, + currentHeight int32, feeRate chainfee.SatPerKWeight) (*wire.MsgTx, + lntypes.WeightUnit, btcutil.Amount, btcutil.Amount, error) { // Sanity check, there should be at least 1 sweep in this batch. if len(sweeps) == 0 { @@ -894,7 +894,7 @@ func (b *batch) constructUnsignedTx(sweeps []sweep, // Create the batch transaction. batchTx := &wire.MsgTx{ Version: 2, - LockTime: uint32(b.currentHeight), + LockTime: uint32(currentHeight), } // Add transaction inputs and estimate its weight. @@ -946,7 +946,7 @@ func (b *batch) constructUnsignedTx(sweeps []sweep, // Find weight and fee. weight := weightEstimate.Weight() - feeForWeight := b.rbfCache.FeeRate.FeeForWeight(weight) + feeForWeight := feeRate.FeeForWeight(weight) // Clamp the calculated fee to the max allowed fee amount for the batch. fee := clampBatchFee(feeForWeight, batchAmt) @@ -1031,8 +1031,8 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, // Construct unsigned batch transaction. var err error - tx, weight, feeForWeight, fee, err = b.constructUnsignedTx( - sweeps, address, + tx, weight, feeForWeight, fee, err = constructUnsignedTx( + sweeps, address, b.currentHeight, b.rbfCache.FeeRate, ) if err != nil { return 0, fmt.Errorf("failed to construct tx: %w", err), diff --git a/sweepbatcher/sweep_batch_test.go b/sweepbatcher/sweep_batch_test.go new file mode 100644 index 000000000..dd873b300 --- /dev/null +++ b/sweepbatcher/sweep_batch_test.go @@ -0,0 +1,319 @@ +package sweepbatcher + +import ( + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/utils" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +// TestConstructUnsignedTx verifies that the function constructUnsignedTx +// correctly creates unsigned transactions. +func TestConstructUnsignedTx(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + p2trAddr := "bcrt1pa38tp2hgjevqv3jcsxeu7v72n0s5a3ck8q2u8r" + + "k6mm67dv7uk26qq8je7e" + p2trAddress, err := btcutil.DecodeAddress(p2trAddr, nil) + require.NoError(t, err) + p2trPkScript, err := txscript.PayToAddrScript(p2trAddress) + require.NoError(t, err) + + serializedPubKey := []byte{ + 0x02, 0x19, 0x2d, 0x74, 0xd0, 0xcb, 0x94, 0x34, 0x4c, 0x95, + 0x69, 0xc2, 0xe7, 0x79, 0x01, 0x57, 0x3d, 0x8d, 0x79, 0x03, + 0xc3, 0xeb, 0xec, 0x3a, 0x95, 0x77, 0x24, 0x89, 0x5d, 0xca, + 0x52, 0xc6, 0xb4} + p2pkAddress, err := btcutil.NewAddressPubKey( + serializedPubKey, &chaincfg.RegressionNetParams, + ) + require.NoError(t, err) + + swapHash := lntypes.Hash{1, 1, 1} + + swapContract := &loopdb.SwapContract{ + CltvExpiry: 222, + AmountRequested: 2_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + } + + htlc, err := utils.GetHtlc( + swapHash, swapContract, &chaincfg.RegressionNetParams, + ) + require.NoError(t, err) + estimator := htlc.AddSuccessToEstimator + + brokenEstimator := func(*input.TxWeightEstimator) error { + return fmt.Errorf("weight estimator test failure") + } + + cases := []struct { + name string + sweeps []sweep + address btcutil.Address + currentHeight int32 + feeRate chainfee.SatPerKWeight + wantErr string + wantTx *wire.MsgTx + wantWeight lntypes.WeightUnit + wantFeeForWeight btcutil.Amount + wantFee btcutil.Amount + }{ + { + name: "no sweeps error", + wantErr: "no sweeps in batch", + }, + + { + name: "two coop sweeps", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + wantWeight: 626, + wantFeeForWeight: 626, + wantFee: 626, + }, + + { + name: "p2tr destination address", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: p2trAddress, + currentHeight: 800_000, + feeRate: 1000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999326, + PkScript: p2trPkScript, + }, + }, + }, + wantWeight: 674, + wantFeeForWeight: 674, + wantFee: 674, + }, + + { + name: "unknown kind of address", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: nil, + wantErr: "unsupported address type", + }, + + { + name: "pay-to-pubkey address", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: p2pkAddress, + wantErr: "unknown address type", + }, + + { + name: "fee more than 20% clamped", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1_000_000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2400000, + PkScript: batchPkScript, + }, + }, + }, + wantWeight: 626, + wantFeeForWeight: 626_000, + wantFee: 600_000, + }, + + { + name: "coop and noncoop", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + nonCoopHint: true, + htlc: *htlc, + htlcSuccessEstimator: estimator, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1000, + wantTx: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + }, + { + PreviousOutPoint: op2, + Sequence: 1, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999211, + PkScript: batchPkScript, + }, + }, + }, + wantWeight: 789, + wantFeeForWeight: 789, + wantFee: 789, + }, + + { + name: "weight estimator fails", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + nonCoopHint: true, + htlc: *htlc, + htlcSuccessEstimator: brokenEstimator, + }, + }, + address: destAddr, + currentHeight: 800_000, + feeRate: 1000, + wantErr: "sweep.htlcSuccessEstimator failed: " + + "weight estimator test failure", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + tx, weight, feeForW, fee, err := constructUnsignedTx( + tc.sweeps, tc.address, tc.currentHeight, + tc.feeRate, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.wantTx, tx) + require.Equal(t, tc.wantWeight, weight) + require.Equal(t, tc.wantFeeForWeight, feeForW) + require.Equal(t, tc.wantFee, fee) + } + }) + } +} From ba492be36393254a43d3145b875926a5830dec3b Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Thu, 20 Feb 2025 01:51:07 -0300 Subject: [PATCH 02/12] test: implement MinRelayFee RPC --- test/lnd_services_mock.go | 5 +++++ test/walletkit_mock.go | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/test/lnd_services_mock.go b/test/lnd_services_mock.go index db4447448..07e25170b 100644 --- a/test/lnd_services_mock.go +++ b/test/lnd_services_mock.go @@ -29,6 +29,7 @@ func NewMockLnd() *LndMockServices { lightningClient := &mockLightningClient{} walletKit := &mockWalletKit{ feeEstimates: make(map[int32]chainfee.SatPerKWeight), + minRelayFee: chainfee.FeePerKwFloor, } chainNotifier := &mockChainNotifier{} signer := &mockSigner{} @@ -278,3 +279,7 @@ func (s *LndMockServices) SetFeeEstimate(confTarget int32, confTarget, feeEstimate, ) } + +func (s *LndMockServices) SetMinRelayFee(feeEstimate chainfee.SatPerKWeight) { + s.LndServices.WalletKit.(*mockWalletKit).setMinRelayFee(feeEstimate) +} diff --git a/test/walletkit_mock.go b/test/walletkit_mock.go index 637686c68..12a8d52ec 100644 --- a/test/walletkit_mock.go +++ b/test/walletkit_mock.go @@ -34,6 +34,7 @@ type mockWalletKit struct { feeEstimateLock sync.Mutex feeEstimates map[int32]chainfee.SatPerKWeight + minRelayFee chainfee.SatPerKWeight } var _ lndclient.WalletKitClient = (*mockWalletKit)(nil) @@ -169,6 +170,24 @@ func (m *mockWalletKit) EstimateFeeRate(ctx context.Context, return feeEstimate, nil } +func (m *mockWalletKit) setMinRelayFee(fee chainfee.SatPerKWeight) { + m.feeEstimateLock.Lock() + defer m.feeEstimateLock.Unlock() + + m.minRelayFee = fee +} + +// MinRelayFee returns the current minimum relay fee based on our chain backend +// in sat/kw. It can be set with setMinRelayFee. +func (m *mockWalletKit) MinRelayFee( + ctx context.Context) (chainfee.SatPerKWeight, error) { + + m.feeEstimateLock.Lock() + defer m.feeEstimateLock.Unlock() + + return m.minRelayFee, nil +} + // ListSweeps returns a list of the sweep transaction ids known to our node. func (m *mockWalletKit) ListSweeps(_ context.Context, _ int32) ( []string, error) { From ddfc925558e35f0608dee09b5af9384f270b8857 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 21 Feb 2025 14:44:08 -0300 Subject: [PATCH 03/12] test: implement SignPsbt RPC --- test/walletkit_mock.go | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/test/walletkit_mock.go b/test/walletkit_mock.go index 12a8d52ec..55d81644d 100644 --- a/test/walletkit_mock.go +++ b/test/walletkit_mock.go @@ -1,6 +1,7 @@ package test import ( + "bytes" "context" "errors" "fmt" @@ -246,6 +247,25 @@ func (m *mockWalletKit) FundPsbt(_ context.Context, return nil, 0, nil, nil } +// finalScriptWitness is a sample signature suitable to put into PSBT. +var finalScriptWitness = func() []byte { + const pver = 0 + var buf bytes.Buffer + + // Write the number of witness elements. + if err := wire.WriteVarInt(&buf, pver, 1); err != nil { + panic(err) + } + + // Write a single witness element with a signature. + signature := make([]byte, 64) + if err := wire.WriteVarBytes(&buf, pver, signature); err != nil { + panic(err) + } + + return buf.Bytes() +}() + // SignPsbt expects a partial transaction with all inputs and outputs // fully declared and tries to sign all unsigned inputs that have all // required fields (UTXO information, BIP32 derivation information, @@ -258,9 +278,19 @@ func (m *mockWalletKit) FundPsbt(_ context.Context, // locking or input/output/fee value validation, PSBT finalization). Any // input that is incomplete will be skipped. func (m *mockWalletKit) SignPsbt(_ context.Context, - _ *psbt.Packet) (*psbt.Packet, error) { + packet *psbt.Packet) (*psbt.Packet, error) { - return nil, nil + inputs := make([]psbt.PInput, len(packet.Inputs)) + copy(inputs, packet.Inputs) + + for i := range inputs { + inputs[i].FinalScriptWitness = finalScriptWitness + } + + signedPacket := *packet + signedPacket.Inputs = inputs + + return &signedPacket, nil } // FinalizePsbt expects a partial transaction with all inputs and From aad7008aa557bb9c3ffc6faa7448aa284211c4af Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 21 Feb 2025 20:47:18 -0300 Subject: [PATCH 04/12] test: allow intercepting PublishTransaction --- test/lnd_services_mock.go | 7 +++++++ test/walletkit_mock.go | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/test/lnd_services_mock.go b/test/lnd_services_mock.go index 07e25170b..aaf5c1106 100644 --- a/test/lnd_services_mock.go +++ b/test/lnd_services_mock.go @@ -129,6 +129,11 @@ type SignOutputRawRequest struct { SignDescriptors []*lndclient.SignDescriptor } +// PublishHandler is optional transaction handler function called upon calling +// the method PublishTransaction. +type PublishHandler func(ctx context.Context, tx *wire.MsgTx, + label string) error + // LndMockServices provides a full set of mocked lnd services. type LndMockServices struct { lndclient.LndServices @@ -174,6 +179,8 @@ type LndMockServices struct { WaitForFinished func() + PublishHandler PublishHandler + lock sync.Mutex } diff --git a/test/walletkit_mock.go b/test/walletkit_mock.go index 55d81644d..332d78866 100644 --- a/test/walletkit_mock.go +++ b/test/walletkit_mock.go @@ -113,7 +113,13 @@ func (m *mockWalletKit) NextAddr(context.Context, string, walletrpc.AddressType, } func (m *mockWalletKit) PublishTransaction(ctx context.Context, tx *wire.MsgTx, - _ string) error { + label string) error { + + if m.lnd.PublishHandler != nil { + if err := m.lnd.PublishHandler(ctx, tx, label); err != nil { + return err + } + } m.lnd.AddTx(tx) m.lnd.TxPublishChannel <- tx From 45ed71672f92ba853b10ca125e8e784b60ed99d5 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Tue, 25 Feb 2025 00:20:59 -0300 Subject: [PATCH 05/12] sweepbatcher: fix race in store_mock --- sweepbatcher/store_mock.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sweepbatcher/store_mock.go b/sweepbatcher/store_mock.go index 57cdd34b7..815b19917 100644 --- a/sweepbatcher/store_mock.go +++ b/sweepbatcher/store_mock.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sort" + "sync" "github.com/btcsuite/btcd/btcutil" "github.com/lightningnetwork/lnd/lntypes" @@ -13,6 +14,7 @@ import ( type StoreMock struct { batches map[int32]dbBatch sweeps map[lntypes.Hash]dbSweep + mu sync.Mutex } // NewStoreMock instantiates a new mock store. @@ -28,6 +30,9 @@ func NewStoreMock() *StoreMock { func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) ( []*dbBatch, error) { + s.mu.Lock() + defer s.mu.Unlock() + result := []*dbBatch{} for _, batch := range s.batches { batch := batch @@ -44,6 +49,9 @@ func (s *StoreMock) FetchUnconfirmedSweepBatches(ctx context.Context) ( func (s *StoreMock) InsertSweepBatch(ctx context.Context, batch *dbBatch) (int32, error) { + s.mu.Lock() + defer s.mu.Unlock() + var id int32 if len(s.batches) == 0 { @@ -66,12 +74,18 @@ func (s *StoreMock) DropBatch(ctx context.Context, id int32) error { func (s *StoreMock) UpdateSweepBatch(ctx context.Context, batch *dbBatch) error { + s.mu.Lock() + defer s.mu.Unlock() + s.batches[batch.ID] = *batch return nil } // ConfirmBatch confirms a batch. func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error { + s.mu.Lock() + defer s.mu.Unlock() + batch, ok := s.batches[id] if !ok { return errors.New("batch not found") @@ -87,6 +101,9 @@ func (s *StoreMock) ConfirmBatch(ctx context.Context, id int32) error { func (s *StoreMock) FetchBatchSweeps(ctx context.Context, id int32) ([]*dbSweep, error) { + s.mu.Lock() + defer s.mu.Unlock() + result := []*dbSweep{} for _, sweep := range s.sweeps { sweep := sweep @@ -104,7 +121,11 @@ func (s *StoreMock) FetchBatchSweeps(ctx context.Context, // UpsertSweep inserts a sweep into the database, or updates an existing sweep. func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error { + s.mu.Lock() + defer s.mu.Unlock() + s.sweeps[sweep.SwapHash] = *sweep + return nil } @@ -112,6 +133,9 @@ func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error { func (s *StoreMock) GetSweepStatus(ctx context.Context, swapHash lntypes.Hash) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + sweep, ok := s.sweeps[swapHash] if !ok { return false, nil @@ -127,6 +151,9 @@ func (s *StoreMock) Close() error { // AssertSweepStored asserts that a sweep is stored. func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.sweeps[id] return ok } @@ -135,6 +162,9 @@ func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( *dbBatch, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, sweep := range s.sweeps { if sweep.SwapHash == swapHash { batch, ok := s.batches[sweep.BatchID] @@ -153,6 +183,9 @@ func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( func (s *StoreMock) TotalSweptAmount(ctx context.Context, batchID int32) ( btcutil.Amount, error) { + s.mu.Lock() + defer s.mu.Unlock() + batch, ok := s.batches[batchID] if !ok { return 0, errors.New("batch not found") From 94818deedcd749ee3b93d77808117325f37e4a2c Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Fri, 21 Feb 2025 20:48:56 -0300 Subject: [PATCH 06/12] sweepbatcher: fix usage of EventuallyWithT It should use the c variable passed into the lambda, not the parent t. It should use assert, not require package. --- sweepbatcher/sweep_batcher_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index fbfb0d418..d861a5b29 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -950,7 +950,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for batch publishing to be skipped, because initialDelay has not // ended. require.EventuallyWithT(t, func(c *assert.CollectT) { - require.Contains(t, testLogger.debugMessages, stillWaitingMsg) + assert.Contains(c, testLogger.debugMessages, stillWaitingMsg) }, test.Timeout, eventuallyCheckFrequency) // Advance the clock to the end of initialDelay. @@ -1274,7 +1274,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for sweep to be added to the batch. require.EventuallyWithT(t, func(c *assert.CollectT) { - require.Contains(t, testLogger2.infoMessages, "adding sweep %x") + assert.Contains(c, testLogger2.infoMessages, "adding sweep %x") }, test.Timeout, eventuallyCheckFrequency) // Advance the clock by publishDelay. Don't wait largeInitialDelay. From edd013d5eb20d6a820f943f9149ab7e1ff313842 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Tue, 25 Feb 2025 23:14:36 -0300 Subject: [PATCH 07/12] sweepbatcher: remove all completed batches Previously, if a completed batch was visited after a batch to which the sweep was added, it was not deleted because the function returned early. This has been separated into two loops: the first one removes completed batches, and the second one adds the sweep to a batch. --- sweepbatcher/sweep_batcher.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 3fe9fe9c4..5b40c5fcf 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -590,16 +590,18 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, sweep.notifier = notifier - // Check if the sweep is already in a batch. If that is the case, we - // provide the sweep to that batch and return. + // This is a check to see if a batch is completed. In that case we just + // lazily delete it. for _, batch := range b.batches { - // This is a check to see if a batch is completed. In that case - // we just lazily delete it and continue our scan. if batch.isComplete() { delete(b.batches, batch.id) continue } + } + // Check if the sweep is already in a batch. If that is the case, we + // provide the sweep to that batch and return. + for _, batch := range b.batches { if batch.sweepExists(sweep.swapHash) { accepted, err := batch.addSweep(ctx, sweep) if err != nil && !errors.Is(err, ErrBatchShuttingDown) { From d0fbdb052ebd3488449febd1bb9c365ef75619fb Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Tue, 25 Feb 2025 23:29:25 -0300 Subject: [PATCH 08/12] sweepbatcher: replace batch logger atomically This is needed to fix crashes in unit tests under -race. --- sweepbatcher/sweep_batch.go | 102 ++++++++++++++++------------- sweepbatcher/sweep_batcher_test.go | 4 +- 2 files changed, 60 insertions(+), 46 deletions(-) diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index b087a8994..9537c17cc 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -9,6 +9,7 @@ import ( "math" "strings" "sync" + "sync/atomic" "time" "github.com/btcsuite/btcd/blockchain" @@ -284,8 +285,8 @@ type batch struct { // cfg is the configuration for this batch. cfg *batchConfig - // log is the logger for this batch. - log btclog.Logger + // log_ is the logger for this batch. + log_ atomic.Pointer[btclog.Logger] wg sync.WaitGroup } @@ -387,7 +388,7 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { } } - return &batch{ + b := &batch{ id: bk.id, state: bk.state, primarySweepID: bk.primaryID, @@ -412,9 +413,22 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { publishErrorHandler: bk.publishErrorHandler, purger: bk.purger, store: bk.store, - log: bk.log, cfg: &cfg, - }, nil + } + + b.setLog(bk.log) + + return b, nil +} + +// log returns current logger. +func (b *batch) log() btclog.Logger { + return *b.log_.Load() +} + +// setLog atomically replaces the logger. +func (b *batch) setLog(logger btclog.Logger) { + b.log_.Store(&logger) } // addSweep tries to add a sweep to the batch. If this is the first sweep being @@ -430,7 +444,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // If the provided sweep is nil, we can't proceed with any checks, so // we just return early. if sweep == nil { - b.log.Infof("the sweep is nil") + b.log().Infof("the sweep is nil") return false, nil } @@ -473,7 +487,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // the batch, do not add another sweep to prevent the tx from becoming // non-standard. if len(b.sweeps) >= MaxSweepsPerBatch { - b.log.Infof("the batch has already too many sweeps (%d >= %d)", + b.log().Infof("the batch has already too many sweeps %d >= %d", len(b.sweeps), MaxSweepsPerBatch) return false, nil @@ -483,7 +497,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // arrive here after the batch got closed because of a spend. In this // case we cannot add the sweep to this batch. if b.state != Open { - b.log.Infof("the batch state (%v) is not open", b.state) + b.log().Infof("the batch state (%v) is not open", b.state) return false, nil } @@ -493,14 +507,14 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // we cannot add this sweep to the batch. for _, s := range b.sweeps { if s.isExternalAddr { - b.log.Infof("the batch already has a sweep (%x) with "+ + b.log().Infof("the batch already has a sweep %x with "+ "an external address", s.swapHash[:6]) return false, nil } if sweep.isExternalAddr { - b.log.Infof("the batch is not empty and new sweep (%x)"+ + b.log().Infof("the batch is not empty and new sweep %x"+ " has an external address", sweep.swapHash[:6]) return false, nil @@ -515,7 +529,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { int32(math.Abs(float64(sweep.timeout - s.timeout))) if timeoutDistance > b.cfg.maxTimeoutDistance { - b.log.Infof("too long timeout distance between the "+ + b.log().Infof("too long timeout distance between the "+ "batch and sweep %x: %d > %d", sweep.swapHash[:6], timeoutDistance, b.cfg.maxTimeoutDistance) @@ -544,7 +558,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { } // Add the sweep to the batch's sweeps. - b.log.Infof("adding sweep %x", sweep.swapHash[:6]) + b.log().Infof("adding sweep %x", sweep.swapHash[:6]) b.sweeps[sweep.swapHash] = *sweep // Update FeeRate. Max(sweep.minFeeRate) for all the sweeps of @@ -572,7 +586,7 @@ func (b *batch) sweepExists(hash lntypes.Hash) bool { // Wait waits for the batch to gracefully stop. func (b *batch) Wait() { - b.log.Infof("Stopping") + b.log().Infof("Stopping") <-b.finished } @@ -613,7 +627,7 @@ func (b *batch) Run(ctx context.Context) error { // Set currentHeight here, because it may be needed in monitorSpend. select { case b.currentHeight = <-blockChan: - b.log.Debugf("initial height for the batch is %v", + b.log().Debugf("initial height for the batch is %v", b.currentHeight) case <-runCtx.Done(): @@ -652,7 +666,7 @@ func (b *batch) Run(ctx context.Context) error { // completes. timerChan := clock.TickAfter(b.cfg.batchPublishDelay) - b.log.Infof("started, primary %x, total sweeps %v", + b.log().Infof("started, primary %x, total sweeps %v", b.primarySweepID[0:6], len(b.sweeps)) for { @@ -662,7 +676,7 @@ func (b *batch) Run(ctx context.Context) error { // blockChan provides immediately the current tip. case height := <-blockChan: - b.log.Debugf("received block %v", height) + b.log().Debugf("received block %v", height) // Set the timer to publish the batch transaction after // the configured delay. @@ -670,7 +684,7 @@ func (b *batch) Run(ctx context.Context) error { b.currentHeight = height case <-initialDelayChan: - b.log.Debugf("initial delay of duration %v has ended", + b.log().Debugf("initial delay of duration %v has ended", b.cfg.initialDelay) // Set the timer to publish the batch transaction after @@ -680,8 +694,8 @@ func (b *batch) Run(ctx context.Context) error { case <-timerChan: // Check that batch is still open. if b.state != Open { - b.log.Debugf("Skipping publishing, because the"+ - " batch is not open (%v).", b.state) + b.log().Debugf("Skipping publishing, because "+ + "the batch is not open (%v).", b.state) continue } @@ -695,7 +709,7 @@ func (b *batch) Run(ctx context.Context) error { // initialDelayChan has just fired, this check passes. now := clock.Now() if skipBefore.After(now) { - b.log.Debugf(stillWaitingMsg, skipBefore, now) + b.log().Debugf(stillWaitingMsg, skipBefore, now) continue } @@ -715,8 +729,8 @@ func (b *batch) Run(ctx context.Context) error { case <-b.reorgChan: b.state = Open - b.log.Warnf("reorg detected, batch is able to accept " + - "new sweeps") + b.log().Warnf("reorg detected, batch is able to " + + "accept new sweeps") err := b.monitorSpend(ctx, b.sweeps[b.primarySweepID]) if err != nil { @@ -755,7 +769,7 @@ func (b *batch) timeout() int32 { func (b *batch) isUrgent(skipBefore time.Time) bool { timeout := b.timeout() if timeout <= 0 { - b.log.Warnf("Method timeout() returned %v. Number of"+ + b.log().Warnf("Method timeout() returned %v. Number of"+ " sweeps: %d. It may be an empty batch.", timeout, len(b.sweeps)) return false @@ -779,7 +793,7 @@ func (b *batch) isUrgent(skipBefore time.Time) bool { return false } - b.log.Debugf("cancelling waiting for urgent sweep (timeBank is %v, "+ + b.log().Debugf("cancelling waiting for urgent sweep (timeBank is %v, "+ "remainingWaiting is %v)", timeBank, remainingWaiting) // Signal to the caller to cancel initialDelay. @@ -795,7 +809,7 @@ func (b *batch) publish(ctx context.Context) error { ) if len(b.sweeps) == 0 { - b.log.Debugf("skipping publish: no sweeps in the batch") + b.log().Debugf("skipping publish: no sweeps in the batch") return nil } @@ -808,7 +822,7 @@ func (b *batch) publish(ctx context.Context) error { // logPublishError is a function which logs publish errors. logPublishError := func(errMsg string, err error) { - b.publishErrorHandler(err, errMsg, b.log) + b.publishErrorHandler(err, errMsg, b.log()) } fee, err, signSuccess = b.publishMixedBatch(ctx) @@ -830,9 +844,9 @@ func (b *batch) publish(ctx context.Context) error { } } - b.log.Infof("published, total sweeps: %v, fees: %v", len(b.sweeps), fee) + b.log().Infof("published, total sweeps: %v, fees: %v", len(b.sweeps), fee) for _, sweep := range b.sweeps { - b.log.Infof("published sweep %x, value: %v", + b.log().Infof("published sweep %x, value: %v", sweep.swapHash[:6], sweep.value) } @@ -1026,8 +1040,8 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, coopInputs int ) for attempt := 1; ; attempt++ { - b.log.Infof("Attempt %d of collecting cooperative signatures.", - attempt) + b.log().Infof("Attempt %d of collecting cooperative "+ + "signatures.", attempt) // Construct unsigned batch transaction. var err error @@ -1062,7 +1076,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, ctx, i, sweep, tx, prevOutsMap, psbtBytes, ) if err != nil { - b.log.Infof("cooperative signing failed for "+ + b.log().Infof("cooperative signing failed for "+ "sweep %x: %v", sweep.swapHash[:6], err) // Set coopFailed flag for this sweep in all the @@ -1201,7 +1215,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, } } txHash := tx.TxHash() - b.log.Infof("attempting to publish batch tx=%v with feerate=%v, "+ + b.log().Infof("attempting to publish batch tx=%v with feerate=%v, "+ "weight=%v, feeForWeight=%v, fee=%v, sweeps=%d, "+ "%d cooperative: (%s) and %d non-cooperative (%s), destAddr=%s", txHash, b.rbfCache.FeeRate, weight, feeForWeight, fee, @@ -1215,7 +1229,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, blockchain.GetTransactionWeight(btcutil.NewTx(tx)), ) if realWeight != weight { - b.log.Warnf("actual weight of tx %v is %v, estimated as %d", + b.log().Warnf("actual weight of tx %v is %v, estimated as %d", txHash, realWeight, weight) } @@ -1239,11 +1253,11 @@ func (b *batch) debugLogTx(msg string, tx *wire.MsgTx) { // Serialize the transaction and convert to hex string. buf := bytes.NewBuffer(make([]byte, 0, tx.SerializeSize())) if err := tx.Serialize(buf); err != nil { - b.log.Errorf("failed to serialize tx for debug log: %v", err) + b.log().Errorf("failed to serialize tx for debug log: %v", err) return } - b.log.Debugf("%s: %s", msg, hex.EncodeToString(buf.Bytes())) + b.log().Debugf("%s: %s", msg, hex.EncodeToString(buf.Bytes())) } // musig2sign signs one sweep using musig2. @@ -1405,14 +1419,14 @@ func (b *batch) updateRbfRate(ctx context.Context) error { if b.rbfCache.FeeRate == 0 { // We set minFeeRate in each sweep, so fee rate is expected to // be initiated here. - b.log.Warnf("rbfCache.FeeRate is 0, which must not happen.") + b.log().Warnf("rbfCache.FeeRate is 0, which must not happen.") if b.cfg.batchConfTarget == 0 { - b.log.Warnf("updateRbfRate called with zero " + + b.log().Warnf("updateRbfRate called with zero " + "batchConfTarget") } - b.log.Infof("initializing rbf fee rate for conf target=%v", + b.log().Infof("initializing rbf fee rate for conf target=%v", b.cfg.batchConfTarget) rate, err := b.wallet.EstimateFeeRate( ctx, b.cfg.batchConfTarget, @@ -1461,7 +1475,7 @@ func (b *batch) monitorSpend(ctx context.Context, primarySweep sweep) error { defer cancel() defer b.wg.Done() - b.log.Infof("monitoring spend for outpoint %s", + b.log().Infof("monitoring spend for outpoint %s", primarySweep.outpoint.String()) for { @@ -1584,7 +1598,7 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { if len(spendTx.TxOut) > 0 { b.batchPkScript = spendTx.TxOut[0].PkScript } else { - b.log.Warnf("transaction %v has no outputs", txHash) + b.log().Warnf("transaction %v has no outputs", txHash) } // As a previous version of the batch transaction may get confirmed, @@ -1666,13 +1680,13 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { err := b.purger(&sweep) if err != nil { - b.log.Errorf("unable to purge sweep %x: %v", + b.log().Errorf("unable to purge sweep %x: %v", sweep.SwapHash[:6], err) } } }() - b.log.Infof("spent, total sweeps: %v, purged sweeps: %v", + b.log().Infof("spent, total sweeps: %v, purged sweeps: %v", len(notifyList), len(purgeList)) err := b.monitorConfirmations(ctx) @@ -1690,7 +1704,7 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { // handleConf handles a confirmation notification. This is the final step of the // batch. Here we signal to the batcher that this batch was completed. func (b *batch) handleConf(ctx context.Context) error { - b.log.Infof("confirmed in txid %s", b.batchTxid) + b.log().Infof("confirmed in txid %s", b.batchTxid) b.state = Confirmed return b.store.ConfirmBatch(ctx, b.id) @@ -1769,7 +1783,7 @@ func (b *batch) insertAndAcquireID(ctx context.Context) (int32, error) { } b.id = id - b.log = batchPrefixLogger(fmt.Sprintf("%d", b.id)) + b.setLog(batchPrefixLogger(fmt.Sprintf("%d", b.id))) return id, nil } diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index d861a5b29..e198d9e09 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -940,7 +940,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { } require.NotNil(t, batch1) testLogger := &wrappedLogger{Logger: batch1.log} - batch1.log = testLogger + batch1.setLog(testLogger) // Advance the clock to publishDelay. It will trigger the publishDelay // timer, but won't result in publishing, because of initialDelay. @@ -1234,7 +1234,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { } require.NotNil(t, batch2) testLogger2 := &wrappedLogger{Logger: batch2.log} - batch2.log = testLogger2 + batch2.setLog(testLogger2) // Add another sweep which is urgent. It will go to the same batch // to make sure minimum timeout is calculated properly. From 34d70f93549b5f1eab9a9222cfd09526dc2bf59e Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Tue, 25 Feb 2025 23:38:08 -0300 Subject: [PATCH 09/12] sweepbatcher/test: protect mock data with mutex Several structures were accessed without protection causing crashes under -race. --- sweepbatcher/sweep_batcher_test.go | 40 +++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index e198d9e09..7fee7efa1 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -811,18 +811,26 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, type wrappedLogger struct { btclog.Logger + mu sync.Mutex + debugMessages []string infoMessages []string } // Debugf logs debug message. func (l *wrappedLogger) Debugf(format string, params ...interface{}) { + l.mu.Lock() + defer l.mu.Unlock() + l.debugMessages = append(l.debugMessages, format) l.Logger.Debugf(format, params...) } // Infof logs info message. func (l *wrappedLogger) Infof(format string, params ...interface{}) { + l.mu.Lock() + defer l.mu.Unlock() + l.infoMessages = append(l.infoMessages, format) l.Logger.Infof(format, params...) } @@ -950,6 +958,9 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for batch publishing to be skipped, because initialDelay has not // ended. require.EventuallyWithT(t, func(c *assert.CollectT) { + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + assert.Contains(c, testLogger.debugMessages, stillWaitingMsg) }, test.Timeout, eventuallyCheckFrequency) @@ -1274,6 +1285,9 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for sweep to be added to the batch. require.EventuallyWithT(t, func(c *assert.CollectT) { + testLogger2.mu.Lock() + defer testLogger2.mu.Unlock() + assert.Contains(c, testLogger2.infoMessages, "adding sweep %x") }, test.Timeout, eventuallyCheckFrequency) @@ -2810,11 +2824,22 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, type sweepFetcherMock struct { store map[lntypes.Hash]*SweepInfo + mu sync.Mutex +} + +func (f *sweepFetcherMock) setSweep(hash lntypes.Hash, info *SweepInfo) { + f.mu.Lock() + defer f.mu.Unlock() + + f.store[hash] = info } func (f *sweepFetcherMock) FetchSweep(ctx context.Context, hash lntypes.Hash) ( *SweepInfo, error) { + f.mu.Lock() + defer f.mu.Unlock() + return f.store[hash], nil } @@ -3279,7 +3304,7 @@ func testWithMixedBatch(t *testing.T, store testStore, if i == 0 { sweepInfo.NonCoopHint = true } - sweepFetcher.store[swapHash] = sweepInfo + sweepFetcher.setSweep(swapHash, sweepInfo) // Create sweep request. sweepReq := SweepRequest{ @@ -3433,7 +3458,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, ) require.NoError(t, err) - sweepFetcher.store[swapHash] = &SweepInfo{ + sweepFetcher.setSweep(swapHash, &SweepInfo{ Preimage: preimages[i], NonCoopHint: nonCoopHints[i], @@ -3445,7 +3470,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, HTLC: *htlc, HTLCSuccessEstimator: htlc.AddSuccessToEstimator, DestAddr: destAddr, - } + }) // Create sweep request. sweepReq := SweepRequest{ @@ -4035,6 +4060,8 @@ type loopdbBatcherStore struct { BatcherStore sweepsSet map[lntypes.Hash]struct{} + + mu sync.Mutex } // UpsertSweep inserts a sweep into the database, or updates an existing sweep @@ -4042,6 +4069,9 @@ type loopdbBatcherStore struct { func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, sweep *dbSweep) error { + s.mu.Lock() + defer s.mu.Unlock() + err := s.BatcherStore.UpsertSweep(ctx, sweep) if err == nil { s.sweepsSet[sweep.SwapHash] = struct{}{} @@ -4051,7 +4081,11 @@ func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, // AssertSweepStored asserts that a sweep is stored. func (s *loopdbBatcherStore) AssertSweepStored(id lntypes.Hash) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, has := s.sweepsSet[id] + return has } From cd087d48365a69b40d93d5239ff5c1f6a27a5181 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Tue, 25 Feb 2025 23:42:27 -0300 Subject: [PATCH 10/12] sweepbatcher/test: fix races in require.Eventually The code inside require.Eventually runs in parallel with the event loops of the batcher and its batches. Accessing fields of the batcher and batches must be done within an event loop. To address this, testRunInEventLoop methods were added to the Batcher and batch types. Unit tests were then rewritten to use this approach when accessing batcher and batch fields. Additionally, in many cases, receive operations from RegisterSpendChannel were moved before require.Eventually. This prevents testRunInEventLoop from getting stuck in an event loop while blocked on a RegisterSpendChannel send operation. --- sweepbatcher/sweep_batch.go | 42 ++ sweepbatcher/sweep_batcher.go | 51 +++ sweepbatcher/sweep_batcher_test.go | 620 ++++++++++++++++++----------- 3 files changed, 485 insertions(+), 228 deletions(-) diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index 9537c17cc..bd97dd4ac 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -215,6 +215,12 @@ type batch struct { // reorgChan is the channel over which reorg notifications are received. reorgChan chan struct{} + // testReqs is a channel where test requests are received. + // This is used only in unit tests! The reason to have this is to + // avoid data races in require.Eventually calls running in parallel + // to the event loop. See method testRunInEventLoop(). + testReqs chan *testRequest + // errChan is the channel over which errors are received. errChan chan error @@ -352,6 +358,7 @@ func NewBatch(cfg batchConfig, bk batchKit) *batch { spendChan: make(chan *chainntnfs.SpendDetail), confChan: make(chan *chainntnfs.TxConfirmation, 1), reorgChan: make(chan struct{}, 1), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), callEnter: make(chan struct{}), callLeave: make(chan struct{}), @@ -396,6 +403,7 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { spendChan: make(chan *chainntnfs.SpendDetail), confChan: make(chan *chainntnfs.TxConfirmation, 1), reorgChan: make(chan struct{}, 1), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), callEnter: make(chan struct{}), callLeave: make(chan struct{}), @@ -737,6 +745,10 @@ func (b *batch) Run(ctx context.Context) error { return err } + case testReq := <-b.testReqs: + testReq.handler() + close(testReq.quit) + case err := <-blockErrChan: return err @@ -749,6 +761,36 @@ func (b *batch) Run(ctx context.Context) error { } } +// testRunInEventLoop runs a function in the event loop blocking until +// the function returns. For unit tests only! +func (b *batch) testRunInEventLoop(ctx context.Context, handler func()) { + // If the event loop is finished, run the function. + select { + case <-b.stopping: + handler() + + return + default: + } + + quit := make(chan struct{}) + req := &testRequest{ + handler: handler, + quit: quit, + } + + select { + case b.testReqs <- req: + case <-ctx.Done(): + return + } + + select { + case <-quit: + case <-ctx.Done(): + } +} + // timeout returns minimum timeout as block height among sweeps of the batch. // If the batch is empty, return -1. func (b *batch) timeout() int32 { diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 5b40c5fcf..fae60274b 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -225,6 +225,16 @@ var ( ErrBatcherShuttingDown = errors.New("batcher shutting down") ) +// testRequest is a function passed to an event loop and a channel used to +// wait until the function is executed. This is used in unit tests only! +type testRequest struct { + // handler is the function to an event loop. + handler func() + + // quit is closed when the handler completes. + quit chan struct{} +} + // Batcher is a system that is responsible for accepting sweep requests and // placing them in appropriate batches. It will spin up new batches as needed. type Batcher struct { @@ -234,6 +244,12 @@ type Batcher struct { // sweepReqs is a channel where sweep requests are received. sweepReqs chan SweepRequest + // testReqs is a channel where test requests are received. + // This is used only in unit tests! The reason to have this is to + // avoid data races in require.Eventually calls running in parallel + // to the event loop. See method testRunInEventLoop(). + testReqs chan *testRequest + // errChan is a channel where errors are received. errChan chan error @@ -461,6 +477,7 @@ func NewBatcher(wallet lndclient.WalletKitClient, return &Batcher{ batches: make(map[int32]*batch), sweepReqs: make(chan SweepRequest), + testReqs: make(chan *testRequest), errChan: make(chan error, 1), quit: make(chan struct{}), initDone: make(chan struct{}), @@ -528,6 +545,10 @@ func (b *Batcher) Run(ctx context.Context) error { return err } + case testReq := <-b.testReqs: + testReq.handler() + close(testReq.quit) + case err := <-b.errChan: log.Warnf("Batcher received an error: %v.", err) return err @@ -551,6 +572,36 @@ func (b *Batcher) AddSweep(sweepReq *SweepRequest) error { } } +// testRunInEventLoop runs a function in the event loop blocking until +// the function returns. For unit tests only! +func (b *Batcher) testRunInEventLoop(ctx context.Context, handler func()) { + // If the event loop is finished, run the function. + select { + case <-b.quit: + handler() + + return + default: + } + + quit := make(chan struct{}) + req := &testRequest{ + handler: handler, + quit: quit, + } + + select { + case b.testReqs <- req: + case <-ctx.Done(): + return + } + + select { + case <-quit: + case <-ctx.Done(): + } +} + // handleSweep handles a sweep request by either placing it in an existing // batch, or by spinning up a new batch for it. func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index 7fee7efa1..22551619a 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -109,18 +109,35 @@ func checkBatcherError(t *testing.T, err error) { } } -// getOnlyBatch makes sure the batcher has exactly one batch and returns it. -func getOnlyBatch(batcher *Batcher) *batch { - if len(batcher.batches) != 1 { - panic(fmt.Sprintf("getOnlyBatch called on a batcher having "+ - "%d batches", len(batcher.batches))) - } +// getBatches returns batches in thread-safe way. +func getBatches(ctx context.Context, batcher *Batcher) []*batch { + var batches []*batch + batcher.testRunInEventLoop(ctx, func() { + for _, batch := range batcher.batches { + batches = append(batches, batch) + } + }) + + return batches +} + +// tryGetOnlyBatch returns a single batch if there is exactly one batch, or nil. +func tryGetOnlyBatch(ctx context.Context, batcher *Batcher) *batch { + batches := getBatches(ctx, batcher) - for _, batch := range batcher.batches { - return batch + if len(batches) == 1 { + return batches[0] + } else { + return nil } +} - panic("unreachable") +// getOnlyBatch makes sure the batcher has exactly one batch and returns it. +func getOnlyBatch(t *testing.T, ctx context.Context, batcher *Batcher) *batch { + batches := getBatches(ctx, batcher) + require.Len(t, batches, 1) + + return batches[0] } // testSweepBatcherBatchCreation tests that sweep requests enter the expected @@ -186,7 +203,7 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Wait for tx to be published. @@ -236,7 +253,7 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, // Batcher should not create a second batch as timeout distance is small // enough. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Create a third sweep request that has more timeout distance than @@ -273,33 +290,43 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since the second batch got created we check that it registered its + // primary sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as timeout distance is greater // than the threshold require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since the second batch got created we check that it registered its - // primary sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps // in it. - for _, batch := range batcher.batches { - switch batch.primarySweepID { - case sweepReq1.SwapHash: - if len(batch.sweeps) != 2 { - return false - } + batches := getBatches(ctx, batcher) - case sweepReq3.SwapHash: - if len(batch.sweeps) != 1 { - return false + for _, batch := range batches { + var bad bool + + batch.testRunInEventLoop(ctx, func() { + switch batch.primarySweepID { + case sweepReq1.SwapHash: + if len(batch.sweeps) != 2 { + bad = true + } + + case sweepReq3.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } } + }) + + if bad { + return false } } @@ -480,24 +507,26 @@ func testTxLabeler(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return len(getBatches(ctx, batcher)) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Wait for tx to be published. <-lnd.TxPublishChannel // Find the batch and assign it to a local variable for easier access. var theBatch *batch - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - theBatch = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + theBatch = btch + } + }) } // Now test the label. @@ -632,15 +661,15 @@ func testPublishErrorHandler(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return len(getBatches(ctx, batcher)) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // The first attempt to publish the batch tx is expected to fail. require.ErrorIs(t, <-publishErrorChan, testPublishError) @@ -710,26 +739,33 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) - // Eventually request will be consumed and a new batch will spin up. - require.Eventually(t, func() bool { - return len(batcher.batches) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // When batch is successfully created it will execute it's first step, // which leads to a spend monitor of the primary sweep. <-lnd.RegisterSpendChannel + // Eventually request will be consumed and a new batch will spin up. + require.Eventually(t, func() bool { + return len(getBatches(ctx, batcher)) == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Find the batch and assign it to a local variable for easier access. batch := &batch{} - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - batch = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + batch = btch + } + }) } require.Eventually(t, func() bool { // Batch should have the sweep stored. - return len(batch.sweeps) == 1 + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + + return numSweeps == 1 }, test.Timeout, eventuallyCheckFrequency) // The primary sweep id should be that of the first inserted sweep. @@ -744,7 +780,12 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // After receiving a height notification the batch will step again, // leading to a new spend monitoring. require.Eventually(t, func() bool { - return batch.currentHeight == 601 + var currentHeight int32 + batch.testRunInEventLoop(ctx, func() { + currentHeight = batch.currentHeight + }) + + return currentHeight == 601 }, test.Timeout, eventuallyCheckFrequency) // Wait for tx to be published. @@ -788,7 +829,12 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // The batch should eventually read the spend notification and progress // its state to closed. require.Eventually(t, func() bool { - return batch.state == Closed + var state batchState + batch.testRunInEventLoop(ctx, func() { + state = batch.state + }) + + return state == Closed }, test.Timeout, eventuallyCheckFrequency) err = lnd.NotifyHeight(604) @@ -802,7 +848,12 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, // Eventually the batch receives the confirmation notification and // confirms itself. require.Eventually(t, func() bool { - return batch.isComplete() + var complete bool + batch.testRunInEventLoop(ctx, func() { + complete = batch.isComplete() + }) + + return complete }, test.Timeout, eventuallyCheckFrequency) } @@ -938,17 +989,18 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Eventually the batch is launched. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Replace the logger in the batch with wrappedLogger to watch messages. - var batch1 *batch - for _, batch := range batcher.batches { - batch1 = batch - } - require.NotNil(t, batch1) - testLogger := &wrappedLogger{Logger: batch1.log} - batch1.setLog(testLogger) + batch1 := getOnlyBatch(t, ctx, batcher) + var testLogger *wrappedLogger + batch1.testRunInEventLoop(ctx, func() { + testLogger = &wrappedLogger{ + Logger: batch1.log(), + } + batch1.setLog(testLogger) + }) // Advance the clock to publishDelay. It will trigger the publishDelay // timer, but won't result in publishing, because of initialDelay. @@ -986,16 +1038,19 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) + // Make sure the batch has one sweep. + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 + return numSweeps == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -1031,25 +1086,6 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for the batcher to be initialized. <-batcher.initDone - // Wait for batch to load. - require.Eventually(t, func() bool { - // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { - return false - } - - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { - return false - } - - // Get the batch. - batch := getOnlyBatch(batcher) - - // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 - }, test.Timeout, eventuallyCheckFrequency) - // Expect a timer to be set: 0 (instead of publishDelay), and // RegisterSpend to be called. The order is not determined, so catch // these actions from two separate goroutines. @@ -1062,6 +1098,9 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Since a batch was created we check that it registered for its // primary sweep's spend. <-lnd.RegisterSpendChannel + + // Wait for tx to be published. + <-lnd.TxPublishChannel }() wg3.Add(1) @@ -1076,6 +1115,28 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for RegisterSpend and for timer registration. wg3.Wait() + // Wait for batch to load. + require.Eventually(t, func() bool { + // Make sure that the sweep was stored + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { + return false + } + + // Make sure the batch has one sweep. + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + + // Make sure the batch has one sweep. + return numSweeps == 1 + }, test.Timeout, eventuallyCheckFrequency) + // Expect one timer: publishDelay (0). wantDelays = []time.Duration{0} require.Equal(t, wantDelays, delays) @@ -1084,9 +1145,6 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { now = now.Add(time.Millisecond) testClock.SetTime(now) - // Wait for tx to be published. - <-lnd.TxPublishChannel - // Tick tock next block. err = lnd.NotifyHeight(601) require.NoError(t, err) @@ -1237,15 +1295,18 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { require.Equal(t, wantDelays, delays) // Replace the logger in the batch with wrappedLogger to watch messages. - var batch2 *batch - for _, batch := range batcher.batches { - if batch.id != batch1.id { - batch2 = batch - } + var testLogger2 *wrappedLogger + for _, batch := range getBatches(ctx, batcher) { + batch.testRunInEventLoop(ctx, func() { + if batch.id != batch1.id { + testLogger2 = &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger2) + } + }) } - require.NotNil(t, batch2) - testLogger2 := &wrappedLogger{Logger: batch2.log} - batch2.setLog(testLogger2) + require.NotNil(t, testLogger2) // Add another sweep which is urgent. It will go to the same batch // to make sure minimum timeout is calculated properly. @@ -1413,15 +1474,19 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, // Eventually the batches are launched and all the sweeps are added. require.Eventually(t, func() bool { // Make sure all the batches have started. - if len(batcher.batches) != expectedBatches { + batches := getBatches(ctx, batcher) + if len(batches) != expectedBatches { return false } // Make sure all the sweeps were added. sweepsNum := 0 - for _, batch := range batcher.batches { - sweepsNum += len(batch.sweeps) + for _, batch := range batches { + batch.testRunInEventLoop(ctx, func() { + sweepsNum += len(batch.sweeps) + }) } + return sweepsNum == swapsNum }, test.Timeout, eventuallyCheckFrequency) @@ -1602,20 +1667,27 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // Batcher should create a batch for the sweeps. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Find the batch and store it in a local variable for easier access. b := &batch{} - for _, btch := range batcher.batches { - if btch.primarySweepID == sweepReq1.SwapHash { - b = btch - } + for _, btch := range getBatches(ctx, batcher) { + btch.testRunInEventLoop(ctx, func() { + if btch.primarySweepID == sweepReq1.SwapHash { + b = btch + } + }) } // Batcher should contain all sweeps. require.Eventually(t, func() bool { - return len(b.sweeps) == 3 + var numSweeps int + b.testRunInEventLoop(ctx, func() { + numSweeps = len(b.sweeps) + }) + + return numSweeps == 3 }, test.Timeout, eventuallyCheckFrequency) // Verify that the batch has a primary sweep id that matches the first @@ -1664,20 +1736,25 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // Eventually the batch reads the notification and proceeds to a closed // state. require.Eventually(t, func() bool { - return b.state == Closed + var state batchState + b.testRunInEventLoop(ctx, func() { + state = b.state + }) + + return state == Closed }, test.Timeout, eventuallyCheckFrequency) + // Since second batch was created we check that it registered for its + // primary sweep's spend. + <-lnd.RegisterSpendChannel + // While handling the spend notification the batch should detect that // some sweeps did not appear in the spending tx, therefore it redirects // them back to the batcher and the batcher inserts them in a new batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since second batch was created we check that it registered for its - // primary sweep's spend. - <-lnd.RegisterSpendChannel - // We mock the confirmation notification. lnd.ConfChannel <- &chainntnfs.TxConfirmation{ Tx: spendingTx, @@ -1692,26 +1769,35 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // confirmation forever. <-lnd.TxPublishChannel + // Re-add one of remaining sweeps to trigger removing the completed + // batch from the batcher. + require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Eventually the batch receives the confirmation notification, // gracefully exits and the batcher deletes it. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Find the other batch, which includes the sweeps that did not appear // in the spending tx. - b = &batch{} - for _, btch := range batcher.batches { - b = btch - } + b = getOnlyBatch(t, ctx, batcher) // After all the sweeps enter, it should contain 2 sweeps. require.Eventually(t, func() bool { - return len(b.sweeps) == 2 + var numSweeps int + b.testRunInEventLoop(ctx, func() { + numSweeps = len(b.sweeps) + }) + return numSweeps == 2 }, test.Timeout, eventuallyCheckFrequency) // The batch should be in an open state. - require.Equal(t, b.state, Open) + var state batchState + b.testRunInEventLoop(ctx, func() { + state = b.state + }) + require.Equal(t, state, Open) } // testSweepBatcherNonWalletAddr tests that sweep requests that sweep to a non @@ -1767,16 +1853,16 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel @@ -1817,16 +1903,16 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq2)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as first batch is a non wallet // addr batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for second batch to be published. <-lnd.TxPublishChannel @@ -1864,38 +1950,47 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a new batch as timeout distance is greater than // the threshold require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return len(getBatches(ctx, batcher)) == 3 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published for 3rd batch. <-lnd.TxPublishChannel require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps // in it. - for _, batch := range batcher.batches { - switch batch.primarySweepID { - case sweepReq1.SwapHash: - if len(batch.sweeps) != 1 { - return false - } - - case sweepReq2.SwapHash: - if len(batch.sweeps) != 1 { - return false + batches := getBatches(ctx, batcher) + for _, batch := range batches { + var bad bool + + batch.testRunInEventLoop(ctx, func() { + switch batch.primarySweepID { + case sweepReq1.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } + + case sweepReq2.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } + + case sweepReq3.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } } + }) - case sweepReq3.SwapHash: - if len(batch.sweeps) != 1 { - return false - } + if bad { + return false } } @@ -2117,16 +2212,16 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Deliver sweep request to batcher. require.NoError(t, batcher.AddSweep(&sweepReq1)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Once batcher receives sweep request it will eventually spin up a // batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx to be published. <-lnd.TxPublishChannel @@ -2138,7 +2233,7 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Batcher should not create a second batch as timeout distance is small // enough. require.Eventually(t, func() bool { - return len(batcher.batches) == 1 + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Publish a block to trigger batch 1 republishing. @@ -2151,32 +2246,32 @@ func testSweepBatcherComposite(t *testing.T, store testStore, require.NoError(t, batcher.AddSweep(&sweepReq3)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a second batch as this sweep pays to a non // wallet address. require.Eventually(t, func() bool { - return len(batcher.batches) == 2 + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the second batch to be published (1 sweep). tx = <-lnd.TxPublishChannel require.Equal(t, 1, len(tx.TxIn)) require.NoError(t, batcher.AddSweep(&sweepReq4)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a third batch as timeout distance is greater // than the threshold. require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return len(getBatches(ctx, batcher)) == 3 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the third batch to be published (1 sweep). tx = <-lnd.TxPublishChannel require.Equal(t, 1, len(tx.TxIn)) @@ -2195,21 +2290,21 @@ func testSweepBatcherComposite(t *testing.T, store testStore, // Batcher should not create a fourth batch as timeout distance is small // enough for it to join the last batch. require.Eventually(t, func() bool { - return len(batcher.batches) == 3 + return len(getBatches(ctx, batcher)) == 3 }, test.Timeout, eventuallyCheckFrequency) require.NoError(t, batcher.AddSweep(&sweepReq6)) + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + // Batcher should create a fourth batch as this sweep pays to a non // wallet address. require.Eventually(t, func() bool { - return len(batcher.batches) == 4 + return len(getBatches(ctx, batcher)) == 4 }, test.Timeout, eventuallyCheckFrequency) - // Since a batch was created we check that it registered for its primary - // sweep's spend. - <-lnd.RegisterSpendChannel - // Wait for tx for the 4th batch to be published (1 sweep). tx = <-lnd.TxPublishChannel require.Equal(t, 1, len(tx.TxIn)) @@ -2217,27 +2312,35 @@ func testSweepBatcherComposite(t *testing.T, store testStore, require.Eventually(t, func() bool { // Verify that each batch has the correct number of sweeps in // it. - for _, batch := range batcher.batches { - switch batch.primarySweepID { - case sweepReq1.SwapHash: - if len(batch.sweeps) != 2 { - return false - } - - case sweepReq3.SwapHash: - if len(batch.sweeps) != 1 { - return false - } - - case sweepReq4.SwapHash: - if len(batch.sweeps) != 2 { - return false + batches := getBatches(ctx, batcher) + for _, batch := range batches { + var bad bool + batch.testRunInEventLoop(ctx, func() { + switch batch.primarySweepID { + case sweepReq1.SwapHash: + if len(batch.sweeps) != 2 { + bad = true + } + + case sweepReq3.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } + + case sweepReq4.SwapHash: + if len(batch.sweeps) != 2 { + bad = true + } + + case sweepReq6.SwapHash: + if len(batch.sweeps) != 1 { + bad = true + } } + }) - case sweepReq6.SwapHash: - if len(batch.sweeps) != 1 { - return false - } + if bad { + return false } } @@ -2374,8 +2477,11 @@ func testRestoringEmptyBatch(t *testing.T, store testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - return batcherStore.AssertSweepStored(sweepReq.SwapHash) && - len(batcher.batches) == 1 + if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + return false + } + + return len(getBatches(ctx, batcher)) == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have only one batch stored (as we dropped the dormant @@ -2591,9 +2697,14 @@ func testHandleSweepTwice(t *testing.T, backend testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - return batcherStore.AssertSweepStored(sweepReq1.SwapHash) && - batcherStore.AssertSweepStored(sweepReq2.SwapHash) && - len(batcher.batches) == 2 + if !batcherStore.AssertSweepStored(sweepReq1.SwapHash) { + return false + } + if !batcherStore.AssertSweepStored(sweepReq2.SwapHash) { + return false + } + + return len(getBatches(ctx, batcher)) == 2 }, test.Timeout, eventuallyCheckFrequency) // Change the second sweep so that it can be added to the first batch. @@ -2622,7 +2733,8 @@ func testHandleSweepTwice(t *testing.T, backend testStore, require.Eventually(t, func() bool { // Make sure there are two batches. - batches := batcher.batches + batches := getBatches(ctx, batcher) + if len(batches) != 2 { return false } @@ -2638,7 +2750,13 @@ func testHandleSweepTwice(t *testing.T, backend testStore, } // Make sure the second batch has the second sweep. - sweep2, has := secondBatch.sweeps[sweepReq2.SwapHash] + var ( + sweep2 sweep + has bool + ) + secondBatch.testRunInEventLoop(ctx, func() { + sweep2, has = secondBatch.sweeps[sweepReq2.SwapHash] + }) if !has { return false } @@ -2649,8 +2767,15 @@ func testHandleSweepTwice(t *testing.T, backend testStore, // Make sure each batch has one sweep. If the second sweep was added to // both batches, the following check won't pass. - for _, batch := range batcher.batches { - require.Equal(t, 1, len(batch.sweeps)) + batches := getBatches(ctx, batcher) + for _, batch := range batches { + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + + // Make sure the batch has one sweep. + require.Equal(t, 1, numSweeps) } // Publish a block to trigger batch 2 republishing. @@ -2743,21 +2868,28 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) + // Make sure the batch has one sweep. + var ( + numSweeps int + confTarget int32 + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + confTarget = batch.cfg.batchConfTarget + }) // Make sure the batch has one sweep. - if len(batch.sweeps) != 1 { + if numSweeps != 1 { return false } // Make sure the batch has proper batchConfTarget. - return batch.cfg.batchConfTarget == 123 + return confTarget == 123 }, test.Timeout, eventuallyCheckFrequency) // Make sure we have stored the batch. @@ -2786,6 +2918,12 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, // Wait for the batcher to be initialized. <-batcher.initDone + // Expect registration for spend notification. + <-lnd.RegisterSpendChannel + + // Wait for tx to be published. + <-lnd.TxPublishChannel + // Wait for batch to load. require.Eventually(t, func() bool { // Make sure that the sweep was stored @@ -2793,26 +2931,28 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) + // Make sure the batch has one sweep. + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) // Make sure the batch has one sweep. - return len(batch.sweeps) == 1 + return numSweeps == 1 }, test.Timeout, eventuallyCheckFrequency) // Make sure batchConfTarget was preserved. - require.Equal(t, 123, int(getOnlyBatch(batcher).cfg.batchConfTarget)) - - // Expect registration for spend notification. - <-lnd.RegisterSpendChannel - - // Wait for tx to be published. - <-lnd.TxPublishChannel + batch := getOnlyBatch(t, ctx, batcher) + var confTarget int32 + batch.testRunInEventLoop(ctx, func() { + confTarget = batch.cfg.batchConfTarget + }) + require.Equal(t, int32(123), confTarget) // Now make the batcher quit by canceling the context. cancel() @@ -2957,21 +3097,27 @@ func testSweepFetcher(t *testing.T, store testStore, return false } - // Make sure there is exactly one active batch. - if len(batcher.batches) != 1 { + // Try to get the batch. + batch := tryGetOnlyBatch(ctx, batcher) + if batch == nil { return false } - // Get the batch. - batch := getOnlyBatch(batcher) - // Make sure the batch has one sweep. - if len(batch.sweeps) != 1 { + var ( + numSweeps int + confTarget int32 + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + confTarget = batch.cfg.batchConfTarget + }) + if numSweeps != 1 { return false } // Make sure the batch has proper batchConfTarget. - return batch.cfg.batchConfTarget == 123 + return confTarget == 123 }, test.Timeout, eventuallyCheckFrequency) // Get the published transaction and check the fee rate. @@ -3819,9 +3965,17 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is feeRateMedium. - batch := getOnlyBatch(batcher) - require.Len(t, batch.sweeps, 1) - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + batch := getOnlyBatch(t, ctx, batcher) + var ( + numSweeps int + cachedFeeRate chainfee.SatPerKWeight + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, 1, numSweeps) + require.Equal(t, feeRateMedium, cachedFeeRate) // Now decrease the fee of sweep1. setFeeRate(swapHash1, feeRateLow) @@ -3835,7 +3989,10 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is still feeRateMedium. - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + batch.testRunInEventLoop(ctx, func() { + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, feeRateMedium, cachedFeeRate) // Add sweep2, with feeRateMedium. swapHash2 := lntypes.Hash{2, 2, 2} @@ -3881,8 +4038,12 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate is still feeRateMedium. - require.Len(t, batch.sweeps, 2) - require.Equal(t, feeRateMedium, batch.rbfCache.FeeRate) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, 2, numSweeps) + require.Equal(t, feeRateMedium, cachedFeeRate) // Now update fee rate of second sweep (which is not primary) to // feeRateHigh. Fee rate of sweep 1 is still feeRateLow. @@ -3898,7 +4059,10 @@ func testFeeRateGrows(t *testing.T, store testStore, <-lnd.TxPublishChannel // Make sure the fee rate increased to feeRateHigh. - require.Equal(t, feeRateHigh, batch.rbfCache.FeeRate) + batch.testRunInEventLoop(ctx, func() { + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, feeRateHigh, cachedFeeRate) } // TestSweepBatcherBatchCreation tests that sweep requests enter the expected From d079fe9db4ba51b273b776e622dca9a5bdce127c Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Wed, 26 Feb 2025 00:13:22 -0300 Subject: [PATCH 11/12] sweepbatcher: fix race conditions in UseLogger --- sweepbatcher/greedy_batch_selection.go | 4 ++-- sweepbatcher/log.go | 14 +++++++---- sweepbatcher/sweep_batcher.go | 32 +++++++++++++++----------- sweepbatcher/sweep_batcher_test.go | 2 +- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/sweepbatcher/greedy_batch_selection.go b/sweepbatcher/greedy_batch_selection.go index 30f1cb33a..88036630c 100644 --- a/sweepbatcher/greedy_batch_selection.go +++ b/sweepbatcher/greedy_batch_selection.go @@ -92,8 +92,8 @@ func (b *Batcher) greedyAddSweep(ctx context.Context, sweep *sweep) error { return nil } - log.Debugf("Batch selection algorithm returned batch id %d for"+ - " sweep %x, but acceptance failed.", batchId, + log().Debugf("Batch selection algorithm returned batch id %d "+ + "for sweep %x, but acceptance failed.", batchId, sweep.swapHash[:6]) } diff --git a/sweepbatcher/log.go b/sweepbatcher/log.go index 24d6cc297..7f33dc76a 100644 --- a/sweepbatcher/log.go +++ b/sweepbatcher/log.go @@ -2,15 +2,21 @@ package sweepbatcher import ( "fmt" + "sync/atomic" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" ) -// log is a logger that is initialized with no output filters. This +// log_ is a logger that is initialized with no output filters. This // means the package will not perform any logging by default until the // caller requests it. -var log btclog.Logger +var log_ atomic.Pointer[btclog.Logger] + +// log returns active logger. +func log() btclog.Logger { + return *log_.Load() +} // The default amount of logging is none. func init() { @@ -20,12 +26,12 @@ func init() { // batchPrefixLogger returns a logger that prefixes all log messages with // the ID. func batchPrefixLogger(batchID string) btclog.Logger { - return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log) + return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log()) } // UseLogger uses a specified Logger to output package logging info. // This should be used in preference to SetLogWriter if the caller is also // using btclog. func UseLogger(logger btclog.Logger) { - log = logger + log_.Store(&logger) } diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index fae60274b..939604c30 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -535,13 +535,15 @@ func (b *Batcher) Run(ctx context.Context) error { case sweepReq := <-b.sweepReqs: sweep, err := b.fetchSweep(runCtx, sweepReq) if err != nil { - log.Warnf("fetchSweep failed: %v.", err) + log().Warnf("fetchSweep failed: %v.", err) + return err } err = b.handleSweep(runCtx, sweep, sweepReq.Notifier) if err != nil { - log.Warnf("handleSweep failed: %v.", err) + log().Warnf("handleSweep failed: %v.", err) + return err } @@ -550,11 +552,13 @@ func (b *Batcher) Run(ctx context.Context) error { close(testReq.quit) case err := <-b.errChan: - log.Warnf("Batcher received an error: %v.", err) + log().Warnf("Batcher received an error: %v.", err) + return err case <-runCtx.Done(): - log.Infof("Stopping Batcher: run context cancelled.") + log().Infof("Stopping Batcher: run context cancelled.") + return runCtx.Err() } } @@ -612,8 +616,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return err } - log.Infof("Batcher handling sweep %x, completed=%v", sweep.swapHash[:6], - completed) + log().Infof("Batcher handling sweep %x, completed=%v", + sweep.swapHash[:6], completed) // If the sweep has already been completed in a confirmed batch then we // can't attach its notifier to the batch as that is no longer running. @@ -624,8 +628,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, // on-chain confirmations to prevent issues caused by reorgs. parentBatch, err := b.store.GetParentBatch(ctx, sweep.swapHash) if err != nil { - log.Errorf("unable to get parent batch for sweep %x: "+ - "%v", sweep.swapHash[:6], err) + log().Errorf("unable to get parent batch for sweep %x:"+ + " %v", sweep.swapHash[:6], err) return err } @@ -677,8 +681,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return nil } - log.Warnf("Greedy batch selection algorithm failed for sweep %x: %v. "+ - "Falling back to old approach.", sweep.swapHash[:6], err) + log().Warnf("Greedy batch selection algorithm failed for sweep %x: %v."+ + " Falling back to old approach.", sweep.swapHash[:6], err) // If one of the batches accepts the sweep, we provide it to that batch. for _, batch := range b.batches { @@ -783,13 +787,13 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error { } if len(dbSweeps) == 0 { - log.Infof("skipping restored batch %d as it has no sweeps", + log().Infof("skipping restored batch %d as it has no sweeps", batch.id) // It is safe to drop this empty batch as it has no sweeps. err := b.store.DropBatch(ctx, batch.id) if err != nil { - log.Warnf("unable to drop empty batch %d: %v", + log().Warnf("unable to drop empty batch %d: %v", batch.id, err) } @@ -931,7 +935,7 @@ func (b *Batcher) monitorSpendAndNotify(ctx context.Context, sweep *sweep, b.wg.Add(1) go func() { defer b.wg.Done() - log.Infof("Batcher monitoring spend for swap %x", + log().Infof("Batcher monitoring spend for swap %x", sweep.swapHash[:6]) for { @@ -1110,7 +1114,7 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, } } else { if s.ConfTarget == 0 { - log.Warnf("Fee estimation was requested for zero "+ + log().Warnf("Fee estimation was requested for zero "+ "confTarget for sweep %x.", swapHash[:6]) } minFeeRate, err = b.wallet.EstimateFeeRate(ctx, s.ConfTarget) diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index 22551619a..392da7698 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -1373,7 +1373,7 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, batcherStore testBatcherStore) { // Disable logging, because this test is very noisy. - oldLogger := log + oldLogger := log() UseLogger(build.NewSubLogger("SWEEP", nil)) defer UseLogger(oldLogger) From ddd3fec5b3d8f89ebea291e286ac1f6e392dc777 Mon Sep 17 00:00:00 2001 From: Boris Nagaev Date: Mon, 10 Feb 2025 18:39:17 -0300 Subject: [PATCH 12/12] sweepbatcher: add mode with presigned txs and CPFP In this mode sweepbatcher uses transactions provided by the CPFP helper, which may be pre-signed and not replace-by-fee (RBF) compatible. In such cases, CPFP may be necessary. A single Batcher instance can handle both CPFP and regular batches. Currently CPFP and non-CPFP sweeps never appear in the same batch. --- sweepbatcher/cpfp.go | 651 +++++++++++++ sweepbatcher/cpfp_test.go | 1293 ++++++++++++++++++++++++++ sweepbatcher/sweep_batch.go | 129 ++- sweepbatcher/sweep_batcher.go | 110 ++- sweepbatcher/sweep_batcher_test.go | 1382 ++++++++++++++++++++++++++++ 5 files changed, 3559 insertions(+), 6 deletions(-) create mode 100644 sweepbatcher/cpfp.go create mode 100644 sweepbatcher/cpfp_test.go diff --git a/sweepbatcher/cpfp.go b/sweepbatcher/cpfp.go new file mode 100644 index 000000000..9a695e127 --- /dev/null +++ b/sweepbatcher/cpfp.go @@ -0,0 +1,651 @@ +package sweepbatcher + +import ( + "bytes" + "context" + "fmt" + + "github.com/btcsuite/btcd/blockchain" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/btcutil/psbt" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +// ensurePresigned checks that we can sign a 1:1 transaction sweeping the input. +func (b *batch) ensurePresigned(ctx context.Context, newSweep *sweep) error { + if b.cfg.cpfpHelper == nil { + return fmt.Errorf("cpfpHelper is not installed") + } + if len(b.sweeps) != 0 { + return fmt.Errorf("ensurePresigned is done when adding to an " + + "empty batch") + } + + sweeps := []sweep{ + { + outpoint: newSweep.outpoint, + value: newSweep.value, + cpfp: newSweep.cpfp, + }, + } + + // Cache the destination address. + destAddr, err := b.getSweepsDestAddr(ctx, sweeps) + if err != nil { + return fmt.Errorf("failed to find destination address: %w", err) + } + + // Set LockTime to 0. It is not critical. + const currentHeight = 0 + + // Check if we can sign with minimum fee rate. + const feeRate = chainfee.FeePerKwFloor + + tx, _, _, _, err := constructUnsignedTx( + sweeps, destAddr, currentHeight, feeRate, + ) + if err != nil { + return fmt.Errorf("failed to construct unsigned tx "+ + "for feeRate %v: %w", feeRate, err) + } + + // Try to presign this transaction. + batchAmt := newSweep.value + signedTx, err := b.cfg.cpfpHelper.SignTx(ctx, tx, batchAmt, feeRate) + if err != nil { + return fmt.Errorf("failed to sign unsigned tx %v "+ + "for feeRate %v: %w", tx.TxHash(), feeRate, err) + } + + // Check the SignTx worked correctly. + err = CheckSignedTx(tx, signedTx, batchAmt, feeRate) + if err != nil { + return fmt.Errorf("signed tx doesn't correspond the "+ + "unsigned tx: %w", err) + } + + return nil +} + +// presign tries to presign batch sweep transactions composed of this batch and +// the sweep. It signs multiple versions of the transaction to make sure there +// is a transaction to be published if minRelayFee grows. +func (b *batch) presign(ctx context.Context, newSweep *sweep) error { + if b.cfg.cpfpHelper == nil { + return fmt.Errorf("cpfpHelper is not installed") + } + if len(b.sweeps) == 0 { + return fmt.Errorf("presigning is done when adding to a " + + "non-empty batch") + } + + // Create the list of sweeps of the future batch. + sweeps := make([]sweep, 0, len(b.sweeps)+1) + for _, sweep := range b.sweeps { + sweeps = append(sweeps, sweep) + } + existingSweeps := sweeps + sweeps = append(sweeps, *newSweep) + + // Cache the destination address. + destAddr, err := b.getSweepsDestAddr(ctx, existingSweeps) + if err != nil { + return fmt.Errorf("failed to find destination address: %w", err) + } + + return presign(ctx, b.cfg.cpfpHelper, sweeps, destAddr) +} + +// presigner tries to presign a batch transaction. +type presigner interface { + // Presign tries to presign a batch transaction. If the method returns + // nil, it is guaranteed that future calls to SignTx on this set of + // sweeps return valid signed transactions. + Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error +} + +// presign tries to presign batch sweep transactions of the sweeps. It signs +// multiple versions of the transaction to make sure there is a transaction to +// be published if minRelayFee grows. +func presign(ctx context.Context, presigner presigner, sweeps []sweep, + destAddr btcutil.Address) error { + + if presigner == nil { + return fmt.Errorf("presigner is not installed") + } + + if len(sweeps) == 0 { + return fmt.Errorf("there are no sweeps") + } + + // Keep track of the total amount this batch is sweeping back. + batchAmt := btcutil.Amount(0) + for _, sweep := range sweeps { + batchAmt += sweep.value + } + + // Go from the floor (1.01 sat/vbyte) to 2k sat/vbyte with step of 1.5x. + const ( + start = chainfee.FeePerKwFloor + stop = chainfee.AbsoluteFeePerKwFloor * 2_000 + ) + + // Set LockTime to 0. It is not critical. + const currentHeight = 0 + + for feeRate := start; feeRate <= stop; feeRate = (feeRate * 3) / 2 { + // Construct an unsigned transaction for this fee rate. + tx, _, feeForWeight, fee, err := constructUnsignedTx( + sweeps, destAddr, currentHeight, feeRate, + ) + if err != nil { + return fmt.Errorf("failed to construct unsigned tx "+ + "for feeRate %v: %w", feeRate, err) + } + + // Try to presign this transaction. + err = presigner.Presign(ctx, tx, batchAmt) + if err != nil { + return fmt.Errorf("failed to presign unsigned tx %v "+ + "for feeRate %v: %w", tx.TxHash(), feeRate, err) + } + + // If fee was clamped, stop here, because fee rate won't grow. + if fee < feeForWeight { + break + } + } + + return nil +} + +// cpfpLabelPrefix is a prefix added to the label of the batch to form a label +// for CPFP transaction. +const cpfpLabelPrefix = "cpfp-for-" + +// publishWithCPFP creates sweep transaction using a custom transaction signer +// and publishes it. It may use CPFP if the custom signer returned a pre-signed +// transaction with insufficient fee. It returns fee of the first transaction, +// not including CPFP's fee, an error (if signing and/or publishing failed) and +// a boolean flag indicating signing success. This mode is incompatible with +// an external address, because it may use CPFP and is designed for batches. +func (b *batch) publishWithCPFP(ctx context.Context) (btcutil.Amount, error, + bool) { + + // Sanity check, there should be at least 1 sweep in this batch. + if len(b.sweeps) == 0 { + return 0, fmt.Errorf("no sweeps in batch"), false + } + + // Make sure that no external address is used. + for _, sweep := range b.sweeps { + if sweep.isExternalAddr { + return 0, fmt.Errorf("external address was used with " + + "a custom transaction signer"), false + } + } + + // Cache current height and desired feerate of the batch. + currentHeight := b.currentHeight + feeRate := b.rbfCache.FeeRate + + // Append this sweep to an array of sweeps. This is needed to keep the + // order of sweeps stored, as iterating the sweeps map does not + // guarantee same order. + sweeps := make([]sweep, 0, len(b.sweeps)) + for _, sweep := range b.sweeps { + sweeps = append(sweeps, sweep) + } + + // Cache the destination address. + address, err := b.getSweepsDestAddr(ctx, sweeps) + if err != nil { + return 0, fmt.Errorf("failed to find destination address: %w", + err), false + } + + // Construct unsigned batch transaction. + tx, weight, _, fee, err := constructUnsignedTx( + sweeps, address, currentHeight, feeRate, + ) + if err != nil { + return 0, fmt.Errorf("failed to construct tx: %w", err), + false + } + + // Adjust feeRate, because it may have been clamped. + feeRate = chainfee.NewSatPerKWeight(fee, weight) + + // Calculate total input amount. + batchAmt := btcutil.Amount(0) + for _, sweep := range sweeps { + batchAmt += sweep.value + } + + // Determine the current minimum relay fee based on our chain backend. + minRelayFee, err := b.wallet.MinRelayFee(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get minRelayFee: %w", err), + false + } + + // Get a signed transaction. It may be either new transaction or a + // pre-signed one. + signedTx, err := b.cfg.cpfpHelper.SignTx(ctx, tx, batchAmt, minRelayFee) + if err != nil { + return 0, fmt.Errorf("failed to sign tx: %w", err), + false + } + + // Run sanity checks to make sure cpfpHelper.SignTx complied with all + // the invariants. + err = CheckSignedTx(tx, signedTx, batchAmt, minRelayFee) + if err != nil { + return 0, fmt.Errorf("signed tx doesn't correspond the "+ + "unsigned tx: %w", err), false + } + tx = signedTx + txHash := tx.TxHash() + + // Make sure tx weight matches the expected value. + realWeight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(tx)), + ) + if realWeight != weight { + b.log().Warnf("actual weight of tx %v is %v, estimated as %d", + txHash, realWeight, weight) + } + + // Find actual fee rate of the signed transaction. It may differ from + // the desired fee rate, because SignTx may return a presigned tx. + output := btcutil.Amount(tx.TxOut[0].Value) + fee = batchAmt - output + signedFeeRate := chainfee.NewSatPerKWeight(fee, realWeight) + + b.log().Infof("attempting to publish custom signed tx=%v, "+ + "desiredFeerate=%v, signedFeeRate=%v, weight=%v, fee=%v, "+ + "sweeps=%d, destAddr=%s", txHash, feeRate, signedFeeRate, + weight, fee, len(tx.TxIn), address) + b.debugLogTx("serialized batch", tx) + + // Publish the transaction. If it fails, we don't return immediately, + // because we may still need a CPFP and it can be done against a + // previously published transaction. + publishErr1 := b.wallet.PublishTransaction( + ctx, tx, b.cfg.txLabeler(b.id), + ) + if publishErr1 == nil { + // Store the batch transaction's txid and pkScript, to use in + // CPFP and for monitoring purposes. + b.batchTxid = &txHash + b.batchPkScript = tx.TxOut[0].PkScript + + if err := b.persist(ctx); err != nil { + return 0, fmt.Errorf("failed to persist: %w", err), true + } + } else { + b.log().Infof("failed to publish custom signed tx=%v, "+ + "desiredFeerate=%v, signedFeeRate=%v, weight=%v, "+ + "fee=%v, sweeps=%d, destAddr=%s", txHash, feeRate, + signedFeeRate, weight, fee, len(tx.TxIn), address) + } + + // Load previously published tx if it exists. + var parentTx *wire.MsgTx + if b.batchTxid != nil { + parentTx, err = b.cfg.cpfpHelper.LoadTx(ctx, *b.batchTxid) + if err != nil { + return 0, fmt.Errorf("failed to load batch tx %v: %w", + *b.batchTxid, err), true + } + } else { + b.log().Warnf("need a CPFP, but there is no published tx known") + } + + // Print this log here, to keep isCPFPNeeded a pure function. + if parentTx != nil && len(parentTx.TxIn) < len(tx.TxIn) { + b.log().Infof("Skip publishing CPFP, because batch tx in mempool"+ + "has %d inputs, while the batch has now %d inputs", + len(parentTx.TxIn), len(tx.TxIn)) + } + + // Determine if CPFP is needed and its feerate. + needsCPFP, err := isCPFPNeeded( + parentTx, batchAmt, len(tx.TxIn), feeRate, signedFeeRate, + ) + if err != nil { + return 0, fmt.Errorf("failed to determine if CPFP is "+ + "needed: %w", err), false + } + + // If CPFP is not needed, we are done now. + if !needsCPFP { + b.log().Infof("CPFP is not needed") + + return fee, publishErr1, true + } + + b.log().Infof("CPFP is needed, parent is %v", parentTx.TxHash()) + + // Create and sign CPFP. + parentWeight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(parentTx)), + ) + parentOutput := btcutil.Amount(parentTx.TxOut[0].Value) + parentFee := batchAmt - parentOutput + childTx, childFeeRate, err := makeUnsignedCPFP( + *b.batchTxid, parentOutput, parentWeight, parentFee, + minRelayFee, feeRate, address, currentHeight, + ) + if err != nil { + return 0, fmt.Errorf("failed to make CPFP tx: %w", err), + true + } + + childTx, err = b.signChildTx(ctx, childTx) + if err != nil { + return 0, fmt.Errorf("failed to sign CPFP tx: %w", err), + true + } + + childTxHash := childTx.TxHash() + parentFeeRate := chainfee.NewSatPerKWeight(parentFee, parentWeight) + b.log().Infof("attempting to publish child tx %v to CPFP parent tx %v,"+ + " effectiveFeeRate=%v, parentFeeRate=%v, childFeeRate=%v", + childTxHash, *b.batchTxid, feeRate, parentFeeRate, + childFeeRate) + b.debugLogTx("serialized child tx", childTx) + + // Publish child transaction. + publishErr2 := b.wallet.PublishTransaction( + ctx, childTx, cpfpLabelPrefix+b.cfg.txLabeler(b.id), + ) + if publishErr2 != nil { + b.log().Infof("failed to publish child tx %v to CPFP parent "+ + "tx %v, effectiveFeeRate=%v, parentFeeRate=%v, "+ + "childFeeRate=%v", childTxHash, *b.batchTxid, feeRate, + parentFeeRate, childFeeRate) + + return fee, publishErr2, true + } + + return fee, publishErr1, true +} + +// getSweepsDestAddr returns the destination address used by a group of sweeps. +// The method must be used in CPFP mode only. +func (b *batch) getSweepsDestAddr(ctx context.Context, + sweeps []sweep) (btcutil.Address, error) { + + if b.cfg.cpfpHelper == nil { + return nil, fmt.Errorf("getSweepsDestAddr used without CPFP") + } + + inputs := make([]wire.OutPoint, len(sweeps)) + for i, s := range sweeps { + if !s.cpfp { + return nil, fmt.Errorf("getSweepsDestAddr used on a " + + "non-CPFP input") + } + + inputs[i] = s.outpoint + } + + // Load pkScript from the CPFP helper. + pkScriptBytes, err := b.cfg.cpfpHelper.DestPkScript(ctx, inputs) + if err != nil { + return nil, fmt.Errorf("cpfpHelper.DestPkScript failed for "+ + "inputs %v: %w", inputs, err) + } + + // Convert pkScript to btcutil.Address. + pkScript, err := txscript.ParsePkScript(pkScriptBytes) + if err != nil { + return nil, fmt.Errorf("txscript.ParsePkScript failed for "+ + "pkScript %x returned for inputs %v: %w", pkScriptBytes, + inputs, err) + } + + address, err := pkScript.Address(b.cfg.chainParams) + if err != nil { + return nil, fmt.Errorf("pkScript.Address failed for "+ + "pkScript %x returned for inputs %v: %w", pkScriptBytes, + inputs, err) + } + + return address, nil +} + +// CheckSignedTx makes sure that signedTx matches the unsignedTx. It checks +// according to criteria specified in the description of CpfpHelper.SignTx. +func CheckSignedTx(unsignedTx, signedTx *wire.MsgTx, inputAmt btcutil.Amount, + minRelayFee chainfee.SatPerKWeight) error { + + // Make sure the set of inputs is the same. + unsignedMap := make(map[wire.OutPoint]uint32, len(unsignedTx.TxIn)) + for _, txIn := range unsignedTx.TxIn { + unsignedMap[txIn.PreviousOutPoint] = txIn.Sequence + } + for _, txIn := range signedTx.TxIn { + seq, has := unsignedMap[txIn.PreviousOutPoint] + if !has { + return fmt.Errorf("input %s is new in signed tx", + txIn.PreviousOutPoint) + } + if seq != txIn.Sequence { + return fmt.Errorf("sequence mismatch in input %s: "+ + "%d in unsigned, %d in signed", + txIn.PreviousOutPoint, seq, txIn.Sequence) + } + delete(unsignedMap, txIn.PreviousOutPoint) + } + for outpoint := range unsignedMap { + return fmt.Errorf("input %s is missing in signed tx", outpoint) + } + + // Compare outputs. + if len(unsignedTx.TxOut) != 1 { + return fmt.Errorf("unsigned tx has %d outputs, want 1", + len(unsignedTx.TxOut)) + } + if len(signedTx.TxOut) != 1 { + return fmt.Errorf("the signed tx has %d outputs, want 1", + len(signedTx.TxOut)) + } + unsignedOut := unsignedTx.TxOut[0] + signedOut := signedTx.TxOut[0] + if !bytes.Equal(unsignedOut.PkScript, signedOut.PkScript) { + return fmt.Errorf("mismatch of output pkScript: %v, %v", + unsignedOut.PkScript, signedOut.PkScript) + } + + // Find the feerate of signedTx. + fee := inputAmt - btcutil.Amount(signedOut.Value) + weight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(signedTx)), + ) + feeRate := chainfee.NewSatPerKWeight(fee, weight) + if feeRate < minRelayFee { + return fmt.Errorf("feerate (%v) of signed tx is lower than "+ + "minRelayFee (%v)", feeRate, minRelayFee) + } + + // Check LockTime. + if signedTx.LockTime > unsignedTx.LockTime { + return fmt.Errorf("locktime (%d) of signed tx is higher than "+ + "locktime of unsigned tx (%d)", signedTx.LockTime, + unsignedTx.LockTime) + } + + // Check Version. + if signedTx.Version != unsignedTx.Version { + return fmt.Errorf("version (%d) of signed tx is not equal to "+ + "version of unsigned tx (%d)", signedTx.Version, + unsignedTx.Version) + } + + return nil +} + +// feeRateThresholdPPM is the ratio of accepted underpayment of fee for which +// no CPFP is used to adjust the effective fee rate. If the underpayment is +// higher, then CPFP is enabled. It is measured in PPM, current level is 2%. +const feeRateThresholdPPM = 2_0000 + +// isCPFPNeeded returns if CPFP is needed to make the effective fee rate close +// to the desired feeRate. The threshold is feeRateThresholdPPM. +func isCPFPNeeded(parentTx *wire.MsgTx, inputAmt btcutil.Amount, numSweeps int, + desiredFeeRate, signedFeeRate chainfee.SatPerKWeight) (bool, error) { + + // First, if feerate of the signed tx matches exactly the desired + // feerate, this means, that we didn't use a presigned transaction, + // which means that all the input are likely to be online, so we don't + // use CPFP. + if desiredFeeRate == signedFeeRate { + return false, nil + } + + // If no transaction was ever published, we can't do CPFP anyway. A + // warning is produced by the calling function in this case. + if parentTx == nil { + return false, nil + } + + // Sanity checks. + if len(parentTx.TxOut) != 1 { + return false, fmt.Errorf("batch tx must have one output, "+ + "but it has %d", len(parentTx.TxOut)) + } + + // Make sure that the parent transaction is signed. + for _, txIn := range parentTx.TxIn { + if len(txIn.Witness) == 0 { + return false, fmt.Errorf("the tx must be signed") + } + } + + // If previously published tx has fewer inputs than the current state + // of the batch, skip CPFP, since it would bump an outdated state. + if len(parentTx.TxIn) < numSweeps { + return false, nil + } + + // Previously published transaction must not have more inputs than the + // current batch state, because inputs are only added. + if len(parentTx.TxIn) > numSweeps { + return false, fmt.Errorf("parent tx has more inputs (%d) than "+ + "exist in the batch currently (%d)", len(parentTx.TxIn), + numSweeps) + } + + // Calculate fee rate of the transaction. + weight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(parentTx)), + ) + fee := inputAmt - btcutil.Amount(parentTx.TxOut[0].Value) + if fee < 0 { + return false, fmt.Errorf("the tx has negative fee %v", fee) + } + parentFeeRate := chainfee.NewSatPerKWeight(fee, weight) + + // Check of the observed_feerate < desired_feerate - threshold. + threshold := desiredFeeRate * feeRateThresholdPPM / 1_000_000 + cpfpNeeded := parentFeeRate < desiredFeeRate-threshold + + return cpfpNeeded, nil +} + +// maxChildFeeSharePPM specifies max share (in ppm) of total funds that can be +// burn in the child transaction in CPFP. Currently it is set to 20%. +const maxChildFeeSharePPM = 20_0000 + +// makeUnsignedCPFP constructs unsigned child tx for CPFP to achieve desired +// effective fee rate. It also returns fee rate of the constructed child tx. +// The transaction spends the UTXO to the same address. Supports P2WKH, P2TR. +func makeUnsignedCPFP(parentTxid chainhash.Hash, parentOutput btcutil.Amount, + parentWeight lntypes.WeightUnit, parentFee btcutil.Amount, minRelayFee, + effectiveFeeRate chainfee.SatPerKWeight, address btcutil.Address, + currentHeight int32) (*wire.MsgTx, chainfee.SatPerKWeight, error) { + + // Estimate the weight of the child tx. + var estimator input.TxWeightEstimator + switch address.(type) { + case *btcutil.AddressWitnessPubKeyHash: + estimator.AddP2WKHInput() + estimator.AddP2WKHOutput() + + case *btcutil.AddressTaproot: + estimator.AddTaprootKeySpendInput(txscript.SigHashDefault) + estimator.AddP2TROutput() + + default: + return nil, 0, fmt.Errorf("unknown address type %T", address) + } + childWeight := estimator.Weight() + + // Estimate the fee of the child tx. + totalWeight := parentWeight + childWeight + totalFee := effectiveFeeRate.FeeForWeight(totalWeight) + childFee := totalFee - parentFee + childFeeRate := chainfee.NewSatPerKWeight(childFee, childWeight) + if childFeeRate < minRelayFee { + childFeeRate = minRelayFee + childFee = childFeeRate.FeeForWeight(childWeight) + } + if childFeeRate < effectiveFeeRate { + return nil, 0, fmt.Errorf("got child fee rate %v lower than "+ + "effective fee rate %v", childFeeRate, effectiveFeeRate) + } + if childFee > parentOutput*maxChildFeeSharePPM/1_000_000 { + return nil, 0, fmt.Errorf("child fee %v is higher than %d%% "+ + "of total funds %v", childFee, + maxChildFeeSharePPM*100/1_000_000, parentOutput) + } + + // Construct child tx. + childTx := &wire.MsgTx{ + Version: 2, + LockTime: uint32(currentHeight), + } + childTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: wire.OutPoint{ + Hash: parentTxid, + Index: 0, + }, + }) + pkScript, err := txscript.PayToAddrScript(address) + if err != nil { + return nil, 0, fmt.Errorf("txscript.PayToAddrScript "+ + "failed: %w", err) + } + childTx.AddTxOut(&wire.TxOut{ + PkScript: pkScript, + Value: int64(parentOutput - childFee), + }) + + return childTx, childFeeRate, nil +} + +// signChildTx signs child CPFP transaction using LND client. +func (b *batch) signChildTx(ctx context.Context, + unsignedTx *wire.MsgTx) (*wire.MsgTx, error) { + + // Create PSBT packet object. + packet, err := psbt.NewFromUnsignedTx(unsignedTx) + if err != nil { + return nil, fmt.Errorf("failed to create PSBT: %w", err) + } + + packet, err = b.wallet.SignPsbt(ctx, packet) + if err != nil { + return nil, fmt.Errorf("signing PSBT failed: %w", err) + } + + return psbt.Extract(packet) +} diff --git a/sweepbatcher/cpfp_test.go b/sweepbatcher/cpfp_test.go new file mode 100644 index 000000000..399e37184 --- /dev/null +++ b/sweepbatcher/cpfp_test.go @@ -0,0 +1,1293 @@ +package sweepbatcher + +import ( + "context" + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/require" +) + +// mockPresigner is an implementation of Presigner used in TestPresign. +type mockPresigner struct { + // outputs collects outputs of presigned transactions. + outputs []btcutil.Amount + + // failAt is optional index of a call at which it fails, 1 based. + failAt int +} + +// Presign memorizes the value of the output and fails if the number of +// calls previously made is failAt. +func (p *mockPresigner) Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error { + + if len(p.outputs)+1 == p.failAt { + return fmt.Errorf("test error in Presign") + } + + p.outputs = append(p.outputs, btcutil.Amount(tx.TxOut[0].Value)) + + return nil +} + +// TestPresign checks that function presign presigns correct set of transactions +// and handles edge cases properly. +func TestPresign(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + ctx := context.Background() + + cases := []struct { + name string + presigner presigner + sweeps []sweep + destAddr btcutil.Address + wantErr string + wantOutputs []btcutil.Amount + }{ + { + name: "error: no presigner", + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + }, + destAddr: destAddr, + wantErr: "presigner is not installed", + }, + + { + name: "error: no sweeps", + presigner: &mockPresigner{}, + destAddr: destAddr, + wantErr: "there are no sweeps", + }, + + { + name: "error: no destAddr", + presigner: &mockPresigner{}, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + }, + wantErr: "unsupported address type ", + }, + + { + name: "two coop sweeps", + presigner: &mockPresigner{}, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000_000, + }, + { + outpoint: op2, + value: 2_000_000, + }, + }, + destAddr: destAddr, + wantOutputs: []btcutil.Amount{ + 2999842, 2999763, 2999645, 2999467, 2999200, + 2998800, 2998201, 2997301, 2995952, 2993927, + 2990890, 2986336, 2979503, 2969255, 2953882, + 2930824, 2896235, 2844353, 2766529, + }, + }, + + { + name: "small amount => fewer steps until clamped", + presigner: &mockPresigner{}, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000, + }, + { + outpoint: op2, + value: 2_000, + }, + }, + destAddr: destAddr, + wantOutputs: []btcutil.Amount{ + 2842, 2763, 2645, 2467, 2400, + }, + }, + + { + name: "third signing fails", + presigner: &mockPresigner{ + failAt: 3, + }, + sweeps: []sweep{ + { + outpoint: op1, + value: 1_000, + }, + { + outpoint: op2, + value: 2_000, + }, + }, + destAddr: destAddr, + wantErr: "for feeRate 568 sat/kw", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := presign(ctx, tc.presigner, tc.sweeps, tc.destAddr) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + outputs := tc.presigner.(*mockPresigner).outputs + require.Equal(t, tc.wantOutputs, outputs) + } + }) + } +} + +// TestCheckSignedTx tests that function CheckSignedTx checks all the criteria +// of CpfpHelper.SignTx correctly. +func TestCheckSignedTx(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + cases := []struct { + name string + unsignedTx *wire.MsgTx + signedTx *wire.MsgTx + inputAmt btcutil.Amount + minRelayFee chainfee.SatPerKWeight + wantErr string + }{ + { + name: "success", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "", + }, + + { + name: "bad locktime", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_001, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "locktime", + }, + + { + name: "bad version", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 3, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "version", + }, + + { + name: "missing input", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "is missing in signed tx", + }, + + { + name: "extra input", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "is new in signed tx", + }, + + { + name: "mismatch of sequence numbers", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 3, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "sequence mismatch", + }, + + { + name: "extra output in unsignedTx", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "unsigned tx has 2 outputs, want 1", + }, + + { + name: "extra output in signedTx", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "the signed tx has 2 outputs, want 1", + }, + + { + name: "mismatch of output pk_script", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript[1:], + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 253, + wantErr: "mismatch of output pkScript", + }, + + { + name: "too low feerate in signedTx", + unsignedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 800_000, + }, + signedTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: wire.TxWitness{ + []byte("test"), + }, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + LockTime: 799_999, + }, + inputAmt: 3_000_000, + minRelayFee: 250_000, + wantErr: "is lower than minRelayFee", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := CheckSignedTx( + tc.unsignedTx, tc.signedTx, tc.inputAmt, + tc.minRelayFee, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +// TestIsCPFPNeeded tests that function isCPFPNeeded works correctly, satisfying +// feeRateThresholdPPM. +func TestIsCPFPNeeded(t *testing.T) { + // Prepare the necessary data for test cases. + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + witness := wire.TxWitness{ + make([]byte, 64), + } + + cases := []struct { + name string + parentTx *wire.MsgTx + inputAmt btcutil.Amount + numSweeps int + desiredFeeRate chainfee.SatPerKWeight + signedFeeRate chainfee.SatPerKWeight + wantErr string + wantNeedsCPFP bool + }{ + { + name: "fee rate matches exacly", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "fee rate higher than needed", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 900, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "fee rate slightly lower than needed", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1020, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "fee rate significantly lower than needed", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1100, + wantErr: "", + wantNeedsCPFP: true, + }, + { + name: "fewer inputs in parent transaction", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 3, + desiredFeeRate: 1100, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "more inputs in parent transaction", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 1, + desiredFeeRate: 1100, + wantErr: "parent tx has more inputs", + }, + { + name: "signed fee rate equal to desired fee rate", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1100, + signedFeeRate: 1100, + wantErr: "", + wantNeedsCPFP: false, + }, + { + name: "error: tx has negative fee", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 3_001_000, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "negative fee", + }, + { + name: "error: tx has multiple outputs", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + Witness: witness, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + Witness: witness, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 1_000_000, + PkScript: batchPkScript, + }, + { + Value: 2_000_000, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "must have one output", + }, + { + name: "error: unsigned tx", + parentTx: &wire.MsgTx{ + Version: 2, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op1, + Sequence: 1, + }, + { + PreviousOutPoint: op2, + Sequence: 2, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2999374, + PkScript: batchPkScript, + }, + }, + }, + inputAmt: 3_000_000, + numSweeps: 2, + desiredFeeRate: 1000, + wantErr: "the tx must be signed", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + needsCPFP, err := isCPFPNeeded( + tc.parentTx, tc.inputAmt, tc.numSweeps, + tc.desiredFeeRate, tc.signedFeeRate, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.wantNeedsCPFP, needsCPFP) + } + }) + } +} + +// TestMakeUnsignedCPFP tests that function makeUnsignedCPFP works correctly, +// satisfying maxChildFeeSharePPM and making sure that child fee rate is higher +// than effective fee rate and of minRelayFee. +func TestMakeUnsignedCPFP(t *testing.T) { + // Prepare the necessary data for test cases. + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + p2trAddr := "bcrt1pa38tp2hgjevqv3jcsxeu7v72n0s5a3ck8q2u8r" + + "k6mm67dv7uk26qq8je7e" + p2trAddress, err := btcutil.DecodeAddress(p2trAddr, nil) + require.NoError(t, err) + p2trPkScript, err := txscript.PayToAddrScript(p2trAddress) + require.NoError(t, err) + + serializedPubKey := []byte{ + 0x02, 0x19, 0x2d, 0x74, 0xd0, 0xcb, 0x94, 0x34, 0x4c, 0x95, + 0x69, 0xc2, 0xe7, 0x79, 0x01, 0x57, 0x3d, 0x8d, 0x79, 0x03, + 0xc3, 0xeb, 0xec, 0x3a, 0x95, 0x77, 0x24, 0x89, 0x5d, 0xca, + 0x52, 0xc6, 0xb4} + p2pkAddress, err := btcutil.NewAddressPubKey( + serializedPubKey, &chaincfg.RegressionNetParams, + ) + require.NoError(t, err) + + batchTxid := chainhash.Hash{5, 5, 5} + + op := wire.OutPoint{ + Hash: batchTxid, + Index: 0, + } + + cases := []struct { + name string + parentTxid chainhash.Hash + parentOutput btcutil.Amount + parentWeight lntypes.WeightUnit + parentFee btcutil.Amount + minRelayFee chainfee.SatPerKWeight + effectiveFeeRate chainfee.SatPerKWeight + address btcutil.Address + currentHeight int32 + wantErr string + wantUnsignedChild *wire.MsgTx + wantChildFeeRate chainfee.SatPerKWeight + }{ + { + name: "normal child creation", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 2000, + address: p2trAddress, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2997860, + PkScript: p2trPkScript, + }, + }, + }, + wantChildFeeRate: 3410, + }, + { + name: "p2wpkh address", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 2000, + address: destAddr, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2997870, + PkScript: batchPkScript, + }, + }, + }, + wantChildFeeRate: 3426, + }, + { + name: "error: p2pk address", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 2000, + address: p2pkAddress, + currentHeight: 800_000, + wantErr: "unknown address type", + }, + { + name: "effective feerate as in parent", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 1000, + address: p2trAddress, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2998930, + PkScript: p2trPkScript, + }, + }, + }, + wantChildFeeRate: 1000, + }, + { + name: "effective feerate below parent", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 500, + address: p2trAddress, + currentHeight: 800_000, + wantErr: "lower than effective fee rate", + }, + { + name: "high minRelayFee", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 10_000, + effectiveFeeRate: 2000, + address: p2trAddress, + currentHeight: 800_000, + wantUnsignedChild: &wire.MsgTx{ + Version: 2, + LockTime: 800_000, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: op, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 2994934, + PkScript: p2trPkScript, + }, + }, + }, + wantChildFeeRate: 10_000, + }, + { + name: "child fee too high", + parentTxid: batchTxid, + parentOutput: 2999374, + parentWeight: 626, + parentFee: 626, + minRelayFee: 253, + effectiveFeeRate: 750_000, + address: p2trAddress, + currentHeight: 800_000, + wantErr: "is higher than 20% of total funds", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + childTx, childFeeRate, err := makeUnsignedCPFP( + tc.parentTxid, tc.parentOutput, tc.parentWeight, + tc.parentFee, tc.minRelayFee, + tc.effectiveFeeRate, tc.address, + tc.currentHeight, + ) + if tc.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, tc.wantErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.wantUnsignedChild, childTx) + require.Equal( + t, tc.wantChildFeeRate, childFeeRate, + ) + } + }) + } +} diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index bd97dd4ac..a0ffb1b8a 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -17,6 +17,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil/psbt" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" @@ -120,6 +121,9 @@ type sweep struct { // but it failed. We try to spend a sweep cooperatively only once. This // status is not persisted in the DB. coopFailed bool + + // cpfp is set, if the sweep should be handled in CPFP mode. + cpfp bool } // batchState is the state of the batch. @@ -173,6 +177,14 @@ type batchConfig struct { // Note that musig2SignSweep must be nil in this case, however signer // client must still be provided, as it is used for non-coop spendings. customMuSig2Signer SignMuSig2 + + // cpfpHelper provides methods used when a custom tx signer and CPFP + // are enabled. + cpfpHelper CpfpHelper + + // chainParams are the chain parameters of the chain that is used by + // batches. + chainParams *chaincfg.Params } // rbfCache stores data related to our last fee bump. @@ -440,7 +452,9 @@ func (b *batch) setLog(logger btclog.Logger) { } // addSweep tries to add a sweep to the batch. If this is the first sweep being -// added to the batch then it also sets the primary sweep ID. +// added to the batch then it also sets the primary sweep ID. If CPFP mode is +// enabled, the result depends on the outcome of cpfpHelper.Presign for a +// non-empty batch. For an empty batch, the input needs to pass PresignSweep. func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { done, err := b.scheduleNextCall() defer done() @@ -546,6 +560,54 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { } } + // If CPFP mode is enabled, we should first presign the new version of + // batch transaction. Also ensure that all the sweeps in the batch use + // the same mode (CPFP or regular). + if sweep.cpfp { + // Ensure that all the sweeps in the batch use also CPFP mode. + for _, s := range b.sweeps { + if !s.cpfp { + b.log().Infof("failed to add sweep %x to the "+ + "batch, because the batch has "+ + "non-CPFP sweep %x", sweep.swapHash[:6], + s.swapHash[:6]) + + return false, nil + } + } + + if len(b.sweeps) != 0 { + if err := b.presign(ctx, sweep); err != nil { + b.log().Infof("failed to add sweep %x to the "+ + "batch, because failed to presign new "+ + "version of batch tx: %v", + sweep.swapHash[:6], err) + + return false, nil + } + } else { + if err := b.ensurePresigned(ctx, sweep); err != nil { + return false, fmt.Errorf("failed to check "+ + "signing of input %x, this means that "+ + "batcher.PresignSweep was not called "+ + "prior to AddSweep for this input: %w", + sweep.swapHash[:6], err) + } + } + } else { + // Ensure that all the sweeps in the batch don't use CPFP. + for _, s := range b.sweeps { + if s.cpfp { + b.log().Infof("failed to add sweep %x to the "+ + "batch, because the batch has "+ + "CPFP sweep %x", sweep.swapHash[:6], + s.swapHash[:6]) + + return false, nil + } + } + } + // Past this point we know that a new incoming sweep passes the // acceptance criteria and is now ready to be added to this batch. @@ -842,6 +904,22 @@ func (b *batch) isUrgent(skipBefore time.Time) bool { return true } +// isCPFP returns if the batch uses CPFP mode. Currently CPFP and non-CPFP +// sweeps never appear in the same batch. Fails if the batch is empty. +func (b *batch) isCPFP() (bool, error) { + if len(b.sweeps) == 0 { + return false, fmt.Errorf("the batch is empty") + } + + for _, sweep := range b.sweeps { + if sweep.cpfp { + return true, nil + } + } + + return false, nil +} + // publish creates and publishes the latest batch transaction to the network. func (b *batch) publish(ctx context.Context) error { var ( @@ -867,7 +945,19 @@ func (b *batch) publish(ctx context.Context) error { b.publishErrorHandler(err, errMsg, b.log()) } - fee, err, signSuccess = b.publishMixedBatch(ctx) + // Determine if we should use CPFP mode for the batch. + cpfp, err := b.isCPFP() + if err != nil { + return fmt.Errorf("failed to determine if the batch %d uses "+ + "CPFP mode: %w", b.id, err) + } + + if cpfp { + fee, err, signSuccess = b.publishWithCPFP(ctx) + } else { + fee, err, signSuccess = b.publishMixedBatch(ctx) + } + if err != nil { if signSuccess { logPublishError("publish error", err) @@ -1746,6 +1836,28 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { // handleConf handles a confirmation notification. This is the final step of the // batch. Here we signal to the batcher that this batch was completed. func (b *batch) handleConf(ctx context.Context) error { + // If the batch is in CPFP mode, cleanup cpfpHelper. + cpfp, err := b.isCPFP() + if err != nil { + return fmt.Errorf("failed to determine if the batch %d uses "+ + "CPFP mode: %w", b.id, err) + } + + if cpfp { + b.log().Infof("Cleaning up CPFP store") + + inputs := make([]wire.OutPoint, 0, len(b.sweeps)) + for _, sweep := range b.sweeps { + inputs = append(inputs, sweep.outpoint) + } + + err := b.cfg.cpfpHelper.CleanupTransactions(ctx, inputs) + if err != nil { + return fmt.Errorf("failed to clean up store for "+ + "batch %d, inputs %v: %w", b.id, inputs, err) + } + } + b.log().Infof("confirmed in txid %s", b.batchTxid) b.state = Confirmed @@ -1788,7 +1900,20 @@ func (b *batch) persist(ctx context.Context) error { // getBatchDestAddr returns the batch's destination address. If the batch // has already generated an address then the same one will be returned. +// The method must not be used in CPFP mode. Use getSweepsDestAddr instead. func (b *batch) getBatchDestAddr(ctx context.Context) (btcutil.Address, error) { + // Determine if we should use CPFP mode for the batch. + cpfp, err := b.isCPFP() + if err != nil { + return nil, fmt.Errorf("failed to determine if the batch %d "+ + "uses CPFP mode: %w", b.id, err) + } + + // Make sure that the method is not used for CPFP batches. + if cpfp { + return nil, fmt.Errorf("getBatchDestAddr used in CPFP mode") + } + var address btcutil.Address // If a batch address is set, use that. Otherwise, generate a diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 939604c30..65b230303 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -153,6 +153,50 @@ type SignMuSig2 func(ctx context.Context, muSig2Version input.MuSig2Version, swapHash lntypes.Hash, rootHash chainhash.Hash, sigHash [32]byte, ) ([]byte, error) +// CpfpHelper provides methods used when a custom tx signer and CPFP are used. +// In this mode sweepbatcher uses transactions provided by CPFP helper, which +// may be pre-signed and non-RBF'able, in which case CPFP may be needed. CPFP +// helper also provides transactions it previously produced by txid and affects +// batch selection - it has method Presign called upon new batch creation and +// adding to existing batch. +type CpfpHelper interface { + // IsCpfpApplied returns if CPFP mode is enabled for a particular sweep. + // This method should always return the same value for the same sweep. + // Currently CPFP and non-CPFP sweeps never appear in the same batch. + IsCpfpApplied(ctx context.Context, input wire.OutPoint) (bool, error) + + // Presign tries to presign a batch transaction. If the method returns + // nil, it is guaranteed that future calls to SignTx on this set of + // sweeps return valid signed transactions. + Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error + + // DestPkScript returns destination pkScript used in a presigned + // transaction sweeping the inputs. Returns an error, if such tx + // doesn't exist. If there are many such transactions, returns any + // of pkScript's. + DestPkScript(ctx context.Context, + inputs []wire.OutPoint) ([]byte, error) + + // SignTx signs an unsigned transaction or returns a pre-signed tx. + // It must satisfy the following invariants: + // - the set of inputs is the same, though the order may change; + // - the output is the same, but its amount may be different; + // - feerate is higher or equal to minRelayFee; + // - LockTime may be decreased; + // - transaction version must be the same; + // - Sequence numbers in the inputs must be preserved. + SignTx(ctx context.Context, tx *wire.MsgTx, inputAmt btcutil.Amount, + minRelayFee chainfee.SatPerKWeight) (*wire.MsgTx, error) + + // LoadTx returns any tx previously returned by SignTx. + LoadTx(ctx context.Context, txid chainhash.Hash) (*wire.MsgTx, error) + + // CleanupTransactions removes all transactions related to any of the + // outpoints. Should be called after sweep batch tx is fully confirmed. + CleanupTransactions(ctx context.Context, inputs []wire.OutPoint) error +} + // VerifySchnorrSig is a function that can be used to verify a schnorr // signature. type VerifySchnorrSig func(pubKey *btcec.PublicKey, hash, sig []byte) error @@ -329,6 +373,10 @@ type Batcher struct { // error. By default, it logs all errors as warnings, but "insufficient // fee" as Info. publishErrorHandler PublishErrorHandler + + // cpfpHelper provides methods used when a custom tx signer and CPFP + // are enabled. + cpfpHelper CpfpHelper } // BatcherConfig holds batcher configuration. @@ -369,6 +417,10 @@ type BatcherConfig struct { // error. By default, it logs all errors as warnings, but "insufficient // fee" as Info. publishErrorHandler PublishErrorHandler + + // cpfpHelper provides methods used when a custom tx signer and CPFP + // are enabled. + cpfpHelper CpfpHelper } // BatcherOption configures batcher behaviour. @@ -442,6 +494,20 @@ func WithPublishErrorHandler(handler PublishErrorHandler) BatcherOption { } } +// WithCpfpHelper instructs sweepbatcher to switch to mode in which it may use +// CPFP for fee bumping. In this mode it uses transactions provided by CPFP +// helper, which may be pre-signed and non-RBF'able, in which case CPFP may be +// needed. CPFP helper also provides transactions it previously produced by txid +// and affects batch selection - it has method Presign called upon new batch +// creation and adding to existing batch. In CPFP mode method PresignSweep must +// be called prior to AddSweep. If PresignSweep fails, AddSweep must not be +// called. +func WithCpfpHelper(cpfpHelper CpfpHelper) BatcherOption { + return func(cfg *BatcherConfig) { + cfg.cpfpHelper = cpfpHelper + } +} + // NewBatcher creates a new Batcher instance. func NewBatcher(wallet lndclient.WalletKitClient, chainNotifier lndclient.ChainNotifierClient, @@ -496,6 +562,7 @@ func NewBatcher(wallet lndclient.WalletKitClient, txLabeler: cfg.txLabeler, customMuSig2Signer: cfg.customMuSig2Signer, publishErrorHandler: cfg.publishErrorHandler, + cpfpHelper: cfg.cpfpHelper, } } @@ -564,8 +631,29 @@ func (b *Batcher) Run(ctx context.Context) error { } } +// PresignSweep creates and stores presigned 1:1 transactions for the sweep. +// This method must be called prior to AddSweep if CPFP mode is enabled. +func (b *Batcher) PresignSweep(ctx context.Context, sweepOutpoint wire.OutPoint, + sweepValue btcutil.Amount, destAddress btcutil.Address) error { + + if b.cpfpHelper == nil { + return fmt.Errorf("cpfpHelper is not installed") + } + + sweeps := []sweep{ + { + outpoint: sweepOutpoint, + value: sweepValue, + }, + } + + return presign(ctx, b.cpfpHelper, sweeps, destAddress) +} + // AddSweep adds a sweep request to the batcher for handling. This will either -// place the sweep in an existing batch or create a new one. +// place the sweep in an existing batch or create a new one. In CPFP mode call +// PresignSweep prior to AddSweep. If PresignSweep fails, AddSweep must not be +// called. func (b *Batcher) AddSweep(sweepReq *SweepRequest) error { select { case b.sweepReqs <- *sweepReq: @@ -616,8 +704,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return err } - log().Infof("Batcher handling sweep %x, completed=%v", - sweep.swapHash[:6], completed) + log().Infof("Batcher handling sweep %x, cpfp=%v, completed=%v", + sweep.swapHash[:6], sweep.cpfp, completed) // If the sweep has already been completed in a confirmed batch then we // can't attach its notifier to the batch as that is no longer running. @@ -703,7 +791,8 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, return b.spinUpNewBatch(ctx, sweep) } -// spinUpNewBatch creates new batch, starts it and adds the sweep to it. +// spinUpNewBatch creates new batch, starts it and adds the sweep to it. If CPFP +// mode is enabled, the result also depends on outcome of cpfpHelper.Presign. func (b *Batcher) spinUpNewBatch(ctx context.Context, sweep *sweep) error { // Spin up a fresh batch. newBatch, err := b.spinUpBatch(ctx) @@ -1099,6 +1188,16 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, swapHash[:6], err) } + // Determine if CPFP mode is used for this sweep. + var cpfp bool + if b.cpfpHelper != nil { + cpfp, err = b.cpfpHelper.IsCpfpApplied(ctx, outpoint) + if err != nil { + return nil, fmt.Errorf("failed to determine CPFP "+ + "status for sweep %x: %w", swapHash[:6], err) + } + } + // Find minimum fee rate for the sweep. Use customFeeRate if it is // provided, otherwise use wallet's EstimateFeeRate. var minFeeRate chainfee.SatPerKWeight @@ -1142,6 +1241,7 @@ func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, destAddr: s.DestAddr, minFeeRate: minFeeRate, nonCoopHint: s.NonCoopHint, + cpfp: cpfp, }, nil } @@ -1152,7 +1252,9 @@ func (b *Batcher) newBatchConfig(maxTimeoutDistance int32) batchConfig { noBumping: b.customFeeRate != nil, txLabeler: b.txLabeler, customMuSig2Signer: b.customMuSig2Signer, + cpfpHelper: b.cpfpHelper, clock: b.clock, + chainParams: b.chainParams, } } diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index 392da7698..d89456dba 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "strings" "sync" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog" "github.com/lightninglabs/lndclient" @@ -4065,6 +4067,1351 @@ func testFeeRateGrows(t *testing.T, store testStore, require.Equal(t, feeRateHigh, cachedFeeRate) } +// mockCpfpHelper implements CpfpHelper interface and stores arguments passed +// in its methods to validate correctness of function publishWithCPFP. +type mockCpfpHelper struct { + // onlineOutpoints specifies which outpoints are capable of + // participating in presigning. + onlineOutpoints map[wire.OutPoint]bool + + // presignedBatches is the collection of presigned batches. + presignedBatches []*wire.MsgTx + + // mu should be hold by all the public methods of this type. + mu sync.Mutex + + // cleanupCalled is a channel where an element is sent every time + // CleanupTransactions is called. + cleanupCalled chan struct{} +} + +// newMockCpfpHelper returns new instance of mockCpfpHelper. +func newMockCpfpHelper() *mockCpfpHelper { + return &mockCpfpHelper{ + onlineOutpoints: make(map[wire.OutPoint]bool), + cleanupCalled: make(chan struct{}), + } +} + +// SetOutpointOnline changes the online status of an outpoint. +func (h *mockCpfpHelper) SetOutpointOnline(op wire.OutPoint, online bool) { + h.mu.Lock() + defer h.mu.Unlock() + + h.onlineOutpoints[op] = online +} + +// findOfflineInputs returns inputs of a tx which are offline. +func (h *mockCpfpHelper) findOfflineInputs(tx *wire.MsgTx) []wire.OutPoint { + offline := make([]wire.OutPoint, 0, len(tx.TxIn)) + for _, txIn := range tx.TxIn { + if !h.onlineOutpoints[txIn.PreviousOutPoint] { + offline = append(offline, txIn.PreviousOutPoint) + } + } + + return offline +} + +// sign signs the transaction. +func (h *mockCpfpHelper) sign(tx *wire.MsgTx) { + // Sign all the inputs. + for i := range tx.TxIn { + tx.TxIn[i].Witness = wire.TxWitness{ + make([]byte, 64), + } + } +} + +// getTxFeerate returns fee rate of a transaction. +func (h *mockCpfpHelper) getTxFeerate(tx *wire.MsgTx, + inputAmt btcutil.Amount) chainfee.SatPerKWeight { + + // "Sign" tx's copy to assess the weight. + tx2 := tx.Copy() + h.sign(tx2) + weight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(tx2)), + ) + fee := btcutil.Amount(tx.TxOut[0].Value) - inputAmt + + return chainfee.NewSatPerKWeight(fee, weight) +} + +// IsCpfpApplied returns if the input was previously used in any call to the +// SetOutpointOnline method. +func (h *mockCpfpHelper) IsCpfpApplied(ctx context.Context, + input wire.OutPoint) (bool, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + _, has := h.onlineOutpoints[input] + + return has, nil +} + +// Presign tries to presign the transaction. It succeeds if all the inputs +// are online. In case of success it adds the transaction to presignedBatches. +func (h *mockCpfpHelper) Presign(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount) error { + + h.mu.Lock() + defer h.mu.Unlock() + + if offline := h.findOfflineInputs(tx); len(offline) != 0 { + return fmt.Errorf("some inputs of tx are offline: %v", offline) + } + + tx = tx.Copy() + h.sign(tx) + h.presignedBatches = append(h.presignedBatches, tx) + + return nil +} + +// DestPkScript returns destination pkScript used in 1:1 presigned tx. +func (h *mockCpfpHelper) DestPkScript(ctx context.Context, + inputs []wire.OutPoint) ([]byte, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + inputsSet := make(map[wire.OutPoint]struct{}, len(inputs)) + for _, input := range inputs { + inputsSet[input] = struct{}{} + } + if len(inputsSet) != len(inputs) { + return nil, fmt.Errorf("duplicate inputs") + } + + inputsMatch := func(tx *wire.MsgTx) bool { + if len(tx.TxIn) != len(inputsSet) { + return false + } + + for _, txIn := range tx.TxIn { + if _, has := inputsSet[txIn.PreviousOutPoint]; !has { + return false + } + } + + return true + } + + for _, tx := range h.presignedBatches { + if inputsMatch(tx) { + return tx.TxOut[0].PkScript, nil + } + } + + return nil, fmt.Errorf("tx sweeping inputs %v not found", inputs) +} + +// SignTx tries to sign the transaction. If all the inputs are online, it signs +// the exact transaction passed and adds it to presignedBatches. Otherwise it +// looks for a transaction in presignedBatches satisfying the criteria. +func (h *mockCpfpHelper) SignTx(ctx context.Context, tx *wire.MsgTx, + inputAmt btcutil.Amount, + minRelayFee chainfee.SatPerKWeight) (*wire.MsgTx, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + // If all the inputs are online, sign this exact transaction. + if offline := h.findOfflineInputs(tx); len(offline) == 0 { + tx = tx.Copy() + h.sign(tx) + + // Add to the collection. + h.presignedBatches = append(h.presignedBatches, tx) + + return tx, nil + } + + // Find feerate of input tx. + inputFeeRate := h.getTxFeerate(tx, inputAmt) + + // Try to find a transaction in the collection satisfying all the + // criteria of CpfpHelper.SignTx. If there are many such transactions, + // select a transaction with feerate which is the closest to the feerate + // of the input tx. + var ( + bestTx *wire.MsgTx + bestFeerateDistance chainfee.SatPerKWeight + ) + for _, candidate := range h.presignedBatches { + err := CheckSignedTx(tx, candidate, inputAmt, minRelayFee) + if err != nil { + continue + } + + feeRate := h.getTxFeerate(candidate, inputAmt) + feeRateDistance := feeRate - inputFeeRate + if feeRateDistance < 0 { + feeRateDistance = -feeRateDistance + } + + if bestTx == nil || feeRateDistance < bestFeerateDistance { + bestTx = candidate + bestFeerateDistance = feeRateDistance + } + } + + if bestTx == nil { + return nil, fmt.Errorf("no such presigned tx found") + } + + return bestTx.Copy(), nil +} + +// LoadTx tries to load the transaction by txid. It scans presignedBatches. +func (h *mockCpfpHelper) LoadTx(ctx context.Context, + txid chainhash.Hash) (*wire.MsgTx, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + for _, tx := range h.presignedBatches { + if tx.TxHash() == txid { + return tx.Copy(), nil + } + } + + return nil, fmt.Errorf("tx with ID %v not found", txid) +} + +// CleanupTransactions removes all transactions related to any of the outpoints. +func (h *mockCpfpHelper) CleanupTransactions(ctx context.Context, + inputs []wire.OutPoint) error { + + h.mu.Lock() + defer h.mu.Unlock() + + inputsSet := make(map[wire.OutPoint]struct{}, len(inputs)) + for _, input := range inputs { + inputsSet[input] = struct{}{} + } + if len(inputsSet) != len(inputs) { + return fmt.Errorf("duplicate inputs") + } + + var presignedBatches []*wire.MsgTx + + // Filter out transactions spending any of the inputs passed. + for _, tx := range h.presignedBatches { + var match bool + for _, txIn := range tx.TxIn { + if _, has := inputsSet[txIn.PreviousOutPoint]; has { + match = true + break + } + } + + if !match { + presignedBatches = append(presignedBatches, tx) + } + } + + h.presignedBatches = presignedBatches + + h.cleanupCalled <- struct{}{} + + return nil +} + +// dummySweepFetcherMock implements SweepFetcher by returning blank SweepInfo. +// It is used in TestCPFP, because it doesn't use any fields from SweepInfo. +type dummySweepFetcherMock struct { +} + +// FetchSweep returns blank SweepInfo. +func (f *dummySweepFetcherMock) FetchSweep(ctx context.Context, + hash lntypes.Hash) (*SweepInfo, error) { + + return &SweepInfo{ + // Set Timeout to prevent warning messages about timeout=0. + Timeout: 1000, + }, nil +} + +// testCPFP_input1_offline_then_input2 tests CPFP mode for the following +// scenario: first input is added, then goes offline, then feerate grows, one of +// presigned transactions is published, feerate grows further, then CPFP is used +// and then another online input is added and is assigned to another batch. +func testCPFP_input1_offline_then_input2(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper)) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + + // This should fail, because the input is offline. + cpfpHelper.SetOutpointOnline(op1, false) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.Error(t, err) + require.ErrorContains(t, err, "offline") + + // Enable the input and try again. + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + + // Increase fee rate and turn off the input, so it can't sign updated + // tx. The feerate is close to the feerate of one of presigned txs, so + // there is no CPFP. + setFeeRate(feeRateMedium) + cpfpHelper.SetOutpointOnline(op1, false) + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(987034), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Make sure the fee rate is feeRateMedium. + batch := getOnlyBatch(t, ctx, batcher) + var ( + numSweeps int + cachedFeeRate chainfee.SatPerKWeight + ) + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + cachedFeeRate = batch.rbfCache.FeeRate + }) + require.Equal(t, 1, numSweeps) + require.Equal(t, feeRateMedium, cachedFeeRate) + + // Raise feerate and trigger new publishing. The parent tx should be the + // same plus a CPFP tx. + setFeeRate(feeRateHigh) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + parent2 := <-lnd.TxPublishChannel + child := <-lnd.TxPublishChannel + require.Equal(t, parent.TxHash(), parent2.TxHash()) + require.Len(t, child.TxIn, 1) + require.Len(t, child.TxOut, 1) + parentOp := wire.OutPoint{ + Hash: parent2.TxHash(), + Index: 0, + } + require.Equal(t, parentOp, child.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(966600), child.TxOut[0].Value) + require.Equal(t, batchPkScript, child.TxOut[0].PkScript) + + // Now add another input. It is online, but the first input is still + // offline, so another input should go to another batch. + swapHash2 := lntypes.Hash{2, 2, 2} + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2}, + Index: 2, + } + sweepReq2 := SweepRequest{ + SwapHash: swapHash2, + Value: 2_000_000, + Outpoint: op2, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op2, true) + err = batcher.PresignSweep(ctx, op2, 2_000_000, destAddr) + require.NoError(t, err) + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq2)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + batch2 := <-lnd.TxPublishChannel + require.Len(t, batch2.TxIn, 1) + require.Len(t, batch2.TxOut, 1) + require.Equal(t, op2, batch2.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(1984160), batch2.TxOut[0].Value) + require.Equal(t, batchPkScript, batch2.TxOut[0].PkScript) + + // Now confirm the first batch. Make sure its presigned transactions + // were removed, but not the transactions of the second batch. + presignedSize1 := len(cpfpHelper.presignedBatches) + + parent2hash := parent2.TxHash() + spendDetail := &chainntnfs.SpendDetail{ + SpentOutPoint: &sweepReq1.Outpoint, + SpendingTx: parent2, + SpenderTxHash: &parent2hash, + SpenderInputIndex: 0, + SpendingHeight: 601, + } + lnd.SpendChannel <- spendDetail + <-lnd.RegisterConfChannel + require.NoError(t, lnd.NotifyHeight(604)) + lnd.ConfChannel <- &chainntnfs.TxConfirmation{ + Tx: parent2, + } + + <-cpfpHelper.cleanupCalled + + presignedSize2 := len(cpfpHelper.presignedBatches) + require.Greater(t, presignedSize2, 0) + require.Greater(t, presignedSize1, presignedSize2) + + // Make sure we still have presigned transactions for the second batch. + cpfpHelper.SetOutpointOnline(op2, false) + _, err = cpfpHelper.SignTx( + ctx, batch2, 2_000_000, chainfee.FeePerKwFloor, + ) + require.NoError(t, err) +} + +// testCPFP_two_inputs_one_goes_offline tests CPFP mode for the following +// scenario: two online inputs are added, then one of them goes offline, then +// feerate grows and a presigned transaction is used. +func testCPFP_two_inputs_one_goes_offline(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Add second sweep. + swapHash2 := lntypes.Hash{2, 2, 2} + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2}, + Index: 2, + } + sweepReq2 := SweepRequest{ + SwapHash: swapHash2, + Value: 2_000_000, + Outpoint: op2, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op2, true) + err = batcher.PresignSweep(ctx, op2, 2_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq2)) + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 2) + require.Len(t, parent.TxOut, 1) + require.ElementsMatch( + t, []wire.OutPoint{op1, op2}, + []wire.OutPoint{ + parent.TxIn[0].PreviousOutPoint, + parent.TxIn[1].PreviousOutPoint, + }, + ) + require.Equal(t, int64(2993740), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Now turn off the second input, raise feerate and trigger new + // publishing. The feerate is close to one of the presigned feerates, + // so this should result in RBF. + cpfpHelper.SetOutpointOnline(op2, false) + setFeeRate(feeRateMedium) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, batcher.AddSweep(&sweepReq2)) + require.NoError(t, lnd.NotifyHeight(601)) + + parent2 := <-lnd.TxPublishChannel + require.NotEqual(t, parent.TxHash(), parent2.TxHash()) + require.Len(t, parent2.TxIn, 2) + require.Len(t, parent2.TxOut, 1) + require.ElementsMatch( + t, []wire.OutPoint{op1, op2}, + []wire.OutPoint{ + parent.TxIn[0].PreviousOutPoint, + parent.TxIn[1].PreviousOutPoint, + }, + ) + require.Equal(t, int64(2979503), parent2.TxOut[0].Value) + require.Equal(t, batchPkScript, parent2.TxOut[0].PkScript) +} + +// testCPFP_cpfp_previous_version tests CPFP mode for the following scenario: +// one input is added, a transaction is published, then the input goes offline +// and feerate grows, RBF is attempted, but broadcast fails, so the batcher +// CPFPs previously published version. +func testCPFP_cpfp_previous_version(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(996040), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Now turn off the first input, raise feerate and trigger new + // publishing, which will fail. + var failedToPublishTx *wire.MsgTx + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + // We should fail the first publishing, which is a sweep, + // but we shouldn't fail CPFP publishing. + if strings.HasPrefix(label, cpfpLabelPrefix) { + return nil + } + + failedToPublishTx = tx + + return fmt.Errorf("test error") + } + cpfpHelper.SetOutpointOnline(op1, false) + setFeeRate(feeRateMedium) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + child := <-lnd.TxPublishChannel + require.NotEqual(t, parent.TxHash(), child.TxHash()) + require.Len(t, child.TxIn, 1) + require.Len(t, child.TxOut, 1) + require.Equal(t, wire.OutPoint{ + Hash: parent.TxHash(), + Index: 0, + }, child.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(974950), child.TxOut[0].Value) + + // Make sure the failed attempt used higher feerate than parent. + require.Equal(t, int64(987034), failedToPublishTx.TxOut[0].Value) +} + +// testCPFP_no_cpfp_if_all_online tests CPFP mode for the following scenario: +// one input is added, a transaction is published, then feerate grows, RBF is +// attempted, but broadcast fails, but CPFP is not used, because all the inputs +// are online (which is deduced by SignTx signing a tx with the same feerate as +// requested). +func testCPFP_no_cpfp_if_all_online(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(996040), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Replace the logger in the batch with wrappedLogger to watch messages. + batch := getOnlyBatch(t, ctx, batcher) + testLogger := &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger) + + // Now turn off the first input, raise feerate and trigger new + // publishing, which will fail. + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + return fmt.Errorf("test error") + } + setFeeRate(feeRateMedium) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + // Wait for batcher to log that CPFP is not needed. + require.EventuallyWithT(t, func(c *assert.CollectT) { + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + assert.Contains( + c, testLogger.infoMessages, "CPFP is not needed", + ) + }, test.Timeout, eventuallyCheckFrequency) +} + +// testCPFP_first_publish_fails tests CPFP mode for the following scenario: +// one input is added and goes offline, feerate grows a transaction is attempted +// to be published, but fails, no CPFP is attempted. Then the input goes online +// and is published being signed online. +func testCPFP_first_publish_fails(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + cpfpHelper.SetOutpointOnline(op1, false) + + // Make sure that publish attempt fails. + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + return fmt.Errorf("test error") + } + + // Add the sweep, triggering the publish attempt. + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Replace the logger in the batch with wrappedLogger to watch messages. + batch := getOnlyBatch(t, ctx, batcher) + testLogger := &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger) + + // Trigger another publish attempt in case "CPFP is not needed" was + // logged before we installed the logger watcher. + require.NoError(t, lnd.NotifyHeight(601)) + + // Wait for batcher to log that CPFP is not needed. + require.EventuallyWithT(t, func(c *assert.CollectT) { + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + assert.Contains( + c, testLogger.infoMessages, "CPFP is not needed", + ) + }, test.Timeout, eventuallyCheckFrequency) + + // Now turn on the first input, raise feerate and trigger new + // publishing, which should succeed. + lnd.PublishHandler = nil + setFeeRate(feeRateMedium) + cpfpHelper.SetOutpointOnline(op1, true) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(602)) + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(988120), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) +} + +// testCPFP_cpfp_publishing_fails tests CPFP mode for the following scenario: +// one input is added, a transaction is published, then the input goes offline +// and feerate grows, RBF is published and then CPFP is attempted to achieve +// the exact desired fee rate, but fails to be published. After then another +// block comes in and both the parent and the child are published and this +// succeeds. +func testCPFP_cpfp_publishing_fails(t *testing.T, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + batchPkScript, err := txscript.PayToAddrScript(destAddr) + require.NoError(t, err) + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, &dummySweepFetcherMock{}, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + // Create the first sweep. + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op1, true) + err = batcher.PresignSweep(ctx, op1, 1_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + parent := <-lnd.TxPublishChannel + require.Len(t, parent.TxIn, 1) + require.Len(t, parent.TxOut, 1) + require.Equal(t, op1, parent.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(996040), parent.TxOut[0].Value) + require.Equal(t, batchPkScript, parent.TxOut[0].PkScript) + + // Replace the logger in the batch with wrappedLogger to watch messages. + batch := getOnlyBatch(t, ctx, batcher) + testLogger := &wrappedLogger{ + Logger: batch.log(), + } + batch.setLog(testLogger) + + // Now turn off the first input, raise feerate and trigger new + // publishing, which will succeed. But then the CPFP will fail. + var failedToPublishTx *wire.MsgTx + lnd.PublishHandler = func(ctx context.Context, tx *wire.MsgTx, + label string) error { + + // We should fail the CPFP only. + if strings.HasPrefix(label, cpfpLabelPrefix) { + failedToPublishTx = tx + + return fmt.Errorf("test error") + } + + return nil + } + cpfpHelper.SetOutpointOnline(op1, false) + setFeeRate(feeRateHigh) + require.NoError(t, batcher.AddSweep(&sweepReq1)) + require.NoError(t, lnd.NotifyHeight(601)) + + // Expect new version of the batch to be published. This is one + // of the presigned transactions. + parent2 := <-lnd.TxPublishChannel + require.NotEqual(t, parent.TxHash(), parent2.TxHash()) + require.Len(t, parent2.TxIn, 1) + require.Len(t, parent2.TxOut, 1) + require.Equal(t, op1, parent2.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(987034), parent2.TxOut[0].Value) + require.Equal(t, batchPkScript, parent2.TxOut[0].PkScript) + + // Wait for batcher to log that CPFP has failed. + require.Eventually(t, func() bool { + testLogger.mu.Lock() + defer testLogger.mu.Unlock() + + for _, msg := range testLogger.infoMessages { + match := strings.Contains( + msg, "failed to publish child tx", + ) + if match { + return true + } + } + + return false + }, test.Timeout, eventuallyCheckFrequency) + + // Make sure that the failed to publish tx is our expected CPFP + // spending parent2. + require.Len(t, failedToPublishTx.TxIn, 1) + require.Len(t, failedToPublishTx.TxOut, 1) + require.Equal(t, wire.OutPoint{ + Hash: parent2.TxHash(), + Index: 0, + }, failedToPublishTx.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(966600), failedToPublishTx.TxOut[0].Value) + require.Equal(t, batchPkScript, failedToPublishTx.TxOut[0].PkScript) + + // Great, now les all published transactions pass and trigger another + // publishing. + lnd.PublishHandler = nil + require.NoError(t, lnd.NotifyHeight(602)) + + // Expect a parent and a child to be published. + parent2a := <-lnd.TxPublishChannel + require.Equal(t, parent2.TxHash(), parent2a.TxHash()) + + child := <-lnd.TxPublishChannel + require.Len(t, child.TxIn, 1) + require.Len(t, child.TxOut, 1) + require.Equal(t, wire.OutPoint{ + Hash: parent2a.TxHash(), + Index: 0, + }, child.TxIn[0].PreviousOutPoint) + require.Equal(t, int64(966600), child.TxOut[0].Value) + require.Equal(t, batchPkScript, child.TxOut[0].PkScript) +} + +// testCPFP_cpfp_and_regular_sweeps tests a combination of CPFP mode and regular +// mode for the following scenario: one regular input is added, then a CPFP +// input is added and it goes to another batch, because they shouldn't appear +// in the same batch. Then another regular and another CPFP inputs are added and +// go to the existing batches of their types. +func testCPFP_cpfp_and_regular_sweeps(t *testing.T, store testStore, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + lnd := test.NewMockLnd() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + const ( + feeRateLow = chainfee.SatPerKWeight(10_000) + feeRateMedium = chainfee.SatPerKWeight(30_000) + feeRateHigh = chainfee.SatPerKWeight(40_000) + ) + + currentFeeRate := feeRateLow + setFeeRate := func(feeRate chainfee.SatPerKWeight) { + currentFeeRate = feeRate + } + customFeeRate := func(_ context.Context, + _ lntypes.Hash) (chainfee.SatPerKWeight, error) { + + return currentFeeRate, nil + } + + cpfpHelper := newMockCpfpHelper() + + sweepStore, err := NewSweepFetcherFromSwapStore(store, lnd.ChainParams) + require.NoError(t, err) + + batcher := NewBatcher( + lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + testMuSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, sweepStore, + WithCustomFeeRate(customFeeRate), WithCpfpHelper(cpfpHelper), + ) + + go func() { + err := batcher.Run(ctx) + checkBatcherError(t, err) + }() + + setFeeRate(feeRateLow) + + ///////////////////////////////////// + // Create the first regular sweep. // + ///////////////////////////////////// + swapHash1 := lntypes.Hash{1, 1, 1} + op1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + } + sweepReq1 := SweepRequest{ + SwapHash: swapHash1, + Value: 1_000_000, + Outpoint: op1, + Notifier: &dummyNotifier, + } + + swap1 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 1_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{1}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash1, swap1) + require.NoError(t, err) + store.AssertLoopOutStored() + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq1)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + tx1 := <-lnd.TxPublishChannel + require.Len(t, tx1.TxIn, 1) + require.Len(t, tx1.TxOut, 1) + + ////////////////////////////////// + // Create the first CPFP sweep. // + ////////////////////////////////// + swapHash2 := lntypes.Hash{2, 2, 2} + op2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2}, + Index: 2, + } + + swap2 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 2_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{2}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash2, swap2) + require.NoError(t, err) + store.AssertLoopOutStored() + + sweepReq2 := SweepRequest{ + SwapHash: swapHash2, + Value: 2_000_000, + Outpoint: op2, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op2, true) + err = batcher.PresignSweep(ctx, op2, 2_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq2)) + + // Since a batch was created we check that it registered for its primary + // sweep's spend. + <-lnd.RegisterSpendChannel + + // Wait for a transactions to be published. + tx2 := <-lnd.TxPublishChannel + require.Len(t, tx2.TxIn, 1) + require.Len(t, tx2.TxOut, 1) + require.Equal(t, op2, tx2.TxIn[0].PreviousOutPoint) + + ////////////////////////////////////// + // Create the second regular sweep. // + ////////////////////////////////////// + swapHash3 := lntypes.Hash{3, 3, 3} + op3 := wire.OutPoint{ + Hash: chainhash.Hash{3, 3}, + Index: 3, + } + sweepReq3 := SweepRequest{ + SwapHash: swapHash3, + Value: 4_000_000, + Outpoint: op3, + Notifier: &dummyNotifier, + } + + swap3 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 4_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{3}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash3, swap3) + require.NoError(t, err) + store.AssertLoopOutStored() + + // Deliver sweep request to batcher. + require.NoError(t, batcher.AddSweep(&sweepReq3)) + + /////////////////////////////////// + // Create the second CPFP sweep. // + /////////////////////////////////// + swapHash4 := lntypes.Hash{4, 4, 4} + op4 := wire.OutPoint{ + Hash: chainhash.Hash{4, 4}, + Index: 4, + } + + swap4 := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 3_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: lntypes.Preimage{4}, + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + err = store.CreateLoopOut(ctx, swapHash4, swap4) + require.NoError(t, err) + store.AssertLoopOutStored() + + sweepReq4 := SweepRequest{ + SwapHash: swapHash4, + Value: 3_000_000, + Outpoint: op4, + Notifier: &dummyNotifier, + } + cpfpHelper.SetOutpointOnline(op4, true) + err = batcher.PresignSweep(ctx, op4, 4_000_000, destAddr) + require.NoError(t, err) + require.NoError(t, batcher.AddSweep(&sweepReq4)) + + // Wait for the both batches to have two sweeps. + require.Eventually(t, func() bool { + // Make sure there are two batches. + batches := getBatches(ctx, batcher) + if len(batches) != 2 { + return false + } + + // Make sure each batch has two sweeps. + for _, batch := range batches { + var numSweeps int + batch.testRunInEventLoop(ctx, func() { + numSweeps = len(batch.sweeps) + }) + if numSweeps != 2 { + return false + } + } + + return true + }, test.Timeout, eventuallyCheckFrequency) + + // Mine a block to trigger both batches publishing. + require.NoError(t, lnd.NotifyHeight(601)) + + // Wait for a transactions to be published. + tx3 := <-lnd.TxPublishChannel + require.Len(t, tx3.TxIn, 2) + require.Len(t, tx3.TxOut, 1) + require.Equal(t, int64(4993740), tx3.TxOut[0].Value) + + tx4 := <-lnd.TxPublishChannel + require.Len(t, tx4.TxIn, 2) + require.Len(t, tx4.TxOut, 1) + require.Equal(t, int64(4993740), tx4.TxOut[0].Value) +} + // TestSweepBatcherBatchCreation tests that sweep requests enter the expected // batch based on their timeout distance. func TestSweepBatcherBatchCreation(t *testing.T) { @@ -4212,6 +5559,41 @@ func TestFeeRateGrows(t *testing.T) { runTests(t, testFeeRateGrows) } +// TestCPFP tests CPFP mode. This test doesn't use loopdb. +func TestCPFP(t *testing.T) { + logger := btclog.NewBackend(os.Stdout).Logger("SWEEP") + logger.SetLevel(btclog.LevelTrace) + UseLogger(logger) + + t.Run("input1_offline_then_input2", func(t *testing.T) { + testCPFP_input1_offline_then_input2(t, NewStoreMock()) + }) + + t.Run("two_inputs_one_goes_offline", func(t *testing.T) { + testCPFP_two_inputs_one_goes_offline(t, NewStoreMock()) + }) + + t.Run("cpfp_previous_version", func(t *testing.T) { + testCPFP_cpfp_previous_version(t, NewStoreMock()) + }) + + t.Run("no_cpfp_if_all_online", func(t *testing.T) { + testCPFP_no_cpfp_if_all_online(t, NewStoreMock()) + }) + + t.Run("first_publish_fails", func(t *testing.T) { + testCPFP_first_publish_fails(t, NewStoreMock()) + }) + + t.Run("cpfp_publishing_fails", func(t *testing.T) { + testCPFP_cpfp_publishing_fails(t, NewStoreMock()) + }) + + t.Run("cpfp_and_regular_sweeps", func(t *testing.T) { + runTests(t, testCPFP_cpfp_and_regular_sweeps) + }) +} + // testBatcherStore is BatcherStore used in tests. type testBatcherStore interface { BatcherStore