diff --git a/loopdb/postgres.go b/loopdb/postgres.go index 86d8e6c6b..60483e934 100644 --- a/loopdb/postgres.go +++ b/loopdb/postgres.go @@ -64,6 +64,21 @@ type PostgresStore struct { *BaseDB } +// In migration of sweeps table from outpoint_txid and outpoint_index to +// outpoint we need to reverse the order of bytes in outpoint_txid and to +// convert it to hex. This is done differently in sqlite and postgres. +// +// Changes from sqlite to postgres: +// - substr(blob, ...) -> get_byte(blob, index) +// - group_concat -> string_agg +// - 1-based indexing (32+1-i) -> 0-based (32 - i) +// - to_hex() + lpad(..., 2, '0') ensures each byte is two-digit hex +const ( + txidSqlite = "group_concat(hex(substr(outpoint_txid,32+1-i,1)),'')" + txidPostgres = "string_agg(lpad(to_hex(get_byte(outpoint_txid, " + + "32 - i)), 2, '0'), '')" +) + // NewPostgresStore creates a new store that is backed by a Postgres database // backend. func NewPostgresStore(cfg *PostgresConfig, @@ -93,6 +108,7 @@ func NewPostgresStore(cfg *PostgresConfig, postgresFS := newReplacerFS(sqlSchemas, map[string]string{ "BLOB": "BYTEA", "INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY", + txidSqlite: txidPostgres, }) err = applyMigrations( diff --git a/loopdb/sqlc/batch.sql.go b/loopdb/sqlc/batch.sql.go index 656eb6dec..9c8aedc1c 100644 --- a/loopdb/sqlc/batch.sql.go +++ b/loopdb/sqlc/batch.sql.go @@ -35,7 +35,7 @@ func (q *Queries) DropBatch(ctx context.Context, id int32) error { const getBatchSweeps = `-- name: GetBatchSweeps :many SELECT - id, swap_hash, batch_id, outpoint_txid, outpoint_index, amt, completed + id, swap_hash, batch_id, outpoint, amt, completed FROM sweeps WHERE @@ -57,8 +57,7 @@ func (q *Queries) GetBatchSweeps(ctx context.Context, batchID int32) ([]Sweep, e &i.ID, &i.SwapHash, &i.BatchID, - &i.OutpointTxid, - &i.OutpointIndex, + &i.Outpoint, &i.Amt, &i.Completed, ); err != nil { @@ -101,11 +100,11 @@ FROM JOIN sweeps ON sweep_batches.id = sweeps.batch_id WHERE - sweeps.swap_hash = $1 + sweeps.outpoint = $1 ` -func (q *Queries) GetParentBatch(ctx context.Context, swapHash []byte) (SweepBatch, error) { - row := q.db.QueryRowContext(ctx, getParentBatch, swapHash) +func (q *Queries) GetParentBatch(ctx context.Context, outpoint string) (SweepBatch, error) { + row := q.db.QueryRowContext(ctx, getParentBatch, outpoint) var i SweepBatch err := row.Scan( &i.ID, @@ -125,11 +124,11 @@ SELECT FROM (SELECT false AS false_value) AS f LEFT JOIN - sweeps s ON s.swap_hash = $1 + sweeps s ON s.outpoint = $1 ` -func (q *Queries) GetSweepStatus(ctx context.Context, swapHash []byte) (bool, error) { - row := q.db.QueryRowContext(ctx, getSweepStatus, swapHash) +func (q *Queries) GetSweepStatus(ctx context.Context, outpoint string) (bool, error) { + row := q.db.QueryRowContext(ctx, getSweepStatus, outpoint) var completed bool err := row.Scan(&completed) return completed, err @@ -251,8 +250,7 @@ const upsertSweep = `-- name: UpsertSweep :exec INSERT INTO sweeps ( swap_hash, batch_id, - outpoint_txid, - outpoint_index, + outpoint, amt, completed ) VALUES ( @@ -260,31 +258,25 @@ INSERT INTO sweeps ( $2, $3, $4, - $5, - $6 -) ON CONFLICT (swap_hash) DO UPDATE SET + $5 +) ON CONFLICT (outpoint) DO UPDATE SET batch_id = $2, - outpoint_txid = $3, - outpoint_index = $4, - amt = $5, - completed = $6 + completed = $5 ` type UpsertSweepParams struct { - SwapHash []byte - BatchID int32 - OutpointTxid []byte - OutpointIndex int32 - Amt int64 - Completed bool + SwapHash []byte + BatchID int32 + Outpoint string + Amt int64 + Completed bool } func (q *Queries) UpsertSweep(ctx context.Context, arg UpsertSweepParams) error { _, err := q.db.ExecContext(ctx, upsertSweep, arg.SwapHash, arg.BatchID, - arg.OutpointTxid, - arg.OutpointIndex, + arg.Outpoint, arg.Amt, arg.Completed, ) diff --git a/loopdb/sqlc/migrations/000013_batcher_key_outpoint.down.sql b/loopdb/sqlc/migrations/000013_batcher_key_outpoint.down.sql new file mode 100644 index 000000000..327fb9cec --- /dev/null +++ b/loopdb/sqlc/migrations/000013_batcher_key_outpoint.down.sql @@ -0,0 +1,3 @@ +-- We kept old table as sweeps_old. Use it. +ALTER TABLE sweeps RENAME TO sweeps_new; +ALTER TABLE sweeps_old RENAME TO sweeps; diff --git a/loopdb/sqlc/migrations/000013_batcher_key_outpoint.up.sql b/loopdb/sqlc/migrations/000013_batcher_key_outpoint.up.sql new file mode 100644 index 000000000..36e279603 --- /dev/null +++ b/loopdb/sqlc/migrations/000013_batcher_key_outpoint.up.sql @@ -0,0 +1,67 @@ +-- We want to make column swap_hash non-unique and to use the outpoint as a key. +-- We can't make a column non-unique or remove it in sqlite, so work around. +-- See https://stackoverflow.com/a/42013422 + +-- We also made outpoint a single point replacing columns outpoint_txid and +-- outpoint_index. + +-- sweeps stores the individual sweeps that are part of a batch. +CREATE TABLE sweeps2 ( + -- id is the autoincrementing primary key. + id INTEGER PRIMARY KEY, + + -- swap_hash is the hash of the swap that is being swept. + swap_hash BLOB NOT NULL, + + -- batch_id is the id of the batch this swap is part of. + batch_id INTEGER NOT NULL, + + -- outpoint is the UTXO id of the output being swept ("txid:index"). + outpoint TEXT NOT NULL UNIQUE, + + -- amt is the amount of the output being swept. + amt BIGINT NOT NULL, + + -- completed indicates whether the sweep has been completed. + completed BOOLEAN NOT NULL DEFAULT FALSE, + + -- Foreign key constraint to ensure that we reference an existing batch + -- id. + FOREIGN KEY (batch_id) REFERENCES sweep_batches(id), + + -- Foreign key constraint to ensure that swap_hash references an + -- existing swap. + FOREIGN KEY (swap_hash) REFERENCES swaps(swap_hash) +); + +-- Copy all the data from sweeps to sweeps2. +-- Explanation: +-- - seq(i) goes from 1 to 32 +-- - substr(outpoint_txid, 32+1-i, 1) indexes BLOB bytes in reverse order +-- (SQLite uses 1-based indexing) +-- - hex(...) gives uppercase by default, so wrapped in lower(...) +-- - group_concat(..., '') combines all hex digits +-- - concatenated with ':' || CAST(outpoint_index AS TEXT) for full outpoint. +WITH RECURSIVE seq(i) AS ( + SELECT 1 + UNION ALL + SELECT i + 1 FROM seq WHERE i < 32 +) +INSERT INTO sweeps2 ( + id, swap_hash, batch_id, outpoint, amt, completed +) +SELECT + id, + swap_hash, + batch_id, + ( + SELECT lower(group_concat(hex(substr(outpoint_txid,32+1-i,1)),'')) + FROM seq + ) || ':' || CAST(outpoint_index AS TEXT), + amt, + completed +FROM sweeps; + +-- Rename tables. +ALTER TABLE sweeps RENAME TO sweeps_old; +ALTER TABLE sweeps2 RENAME TO sweeps; diff --git a/loopdb/sqlc/models.go b/loopdb/sqlc/models.go index 88aca93fb..0f2a166b7 100644 --- a/loopdb/sqlc/models.go +++ b/loopdb/sqlc/models.go @@ -180,13 +180,12 @@ type SwapUpdate struct { } type Sweep struct { - ID int32 - SwapHash []byte - BatchID int32 - OutpointTxid []byte - OutpointIndex int32 - Amt int64 - Completed bool + ID int32 + SwapHash []byte + BatchID int32 + Outpoint string + Amt int64 + Completed bool } type SweepBatch struct { @@ -198,3 +197,13 @@ type SweepBatch struct { LastRbfSatPerKw sql.NullInt32 MaxTimeoutDistance int32 } + +type SweepsOld struct { + ID int32 + SwapHash []byte + BatchID int32 + OutpointTxid []byte + OutpointIndex int32 + Amt int64 + Completed bool +} diff --git a/loopdb/sqlc/querier.go b/loopdb/sqlc/querier.go index a805196c6..d5283b868 100644 --- a/loopdb/sqlc/querier.go +++ b/loopdb/sqlc/querier.go @@ -32,7 +32,7 @@ type Querier interface { GetLoopOutSwap(ctx context.Context, swapHash []byte) (GetLoopOutSwapRow, error) GetLoopOutSwaps(ctx context.Context) ([]GetLoopOutSwapsRow, error) GetMigration(ctx context.Context, migrationID string) (MigrationTracker, error) - GetParentBatch(ctx context.Context, swapHash []byte) (SweepBatch, error) + GetParentBatch(ctx context.Context, outpoint string) (SweepBatch, error) GetReservation(ctx context.Context, reservationID []byte) (Reservation, error) GetReservationUpdates(ctx context.Context, reservationID []byte) ([]ReservationUpdate, error) GetReservations(ctx context.Context) ([]Reservation, error) @@ -40,7 +40,7 @@ type Querier interface { GetStaticAddressLoopInSwap(ctx context.Context, swapHash []byte) (GetStaticAddressLoopInSwapRow, error) GetStaticAddressLoopInSwapsByStates(ctx context.Context, dollar_1 sql.NullString) ([]GetStaticAddressLoopInSwapsByStatesRow, error) GetSwapUpdates(ctx context.Context, swapHash []byte) ([]SwapUpdate, error) - GetSweepStatus(ctx context.Context, swapHash []byte) (bool, error) + GetSweepStatus(ctx context.Context, outpoint string) (bool, error) GetUnconfirmedBatches(ctx context.Context) ([]SweepBatch, error) InsertBatch(ctx context.Context, arg InsertBatchParams) (int32, error) InsertDepositUpdate(ctx context.Context, arg InsertDepositUpdateParams) error diff --git a/loopdb/sqlc/queries/batch.sql b/loopdb/sqlc/queries/batch.sql index 03fad60d6..b02241273 100644 --- a/loopdb/sqlc/queries/batch.sql +++ b/loopdb/sqlc/queries/batch.sql @@ -47,8 +47,7 @@ WHERE INSERT INTO sweeps ( swap_hash, batch_id, - outpoint_txid, - outpoint_index, + outpoint, amt, completed ) VALUES ( @@ -56,14 +55,10 @@ INSERT INTO sweeps ( $2, $3, $4, - $5, - $6 -) ON CONFLICT (swap_hash) DO UPDATE SET + $5 +) ON CONFLICT (outpoint) DO UPDATE SET batch_id = $2, - outpoint_txid = $3, - outpoint_index = $4, - amt = $5, - completed = $6; + completed = $5; -- name: GetParentBatch :one SELECT @@ -73,7 +68,7 @@ FROM JOIN sweeps ON sweep_batches.id = sweeps.batch_id WHERE - sweeps.swap_hash = $1; + sweeps.outpoint = $1; -- name: GetBatchSweptAmount :one SELECT @@ -101,4 +96,4 @@ SELECT FROM (SELECT false AS false_value) AS f LEFT JOIN - sweeps s ON s.swap_hash = $1; + sweeps s ON s.outpoint = $1; diff --git a/sweepbatcher/greedy_batch_selection.go b/sweepbatcher/greedy_batch_selection.go index 163e0c3a3..1df9591f1 100644 --- a/sweepbatcher/greedy_batch_selection.go +++ b/sweepbatcher/greedy_batch_selection.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" sweeppkg "github.com/lightninglabs/loop/sweep" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -108,8 +109,8 @@ func estimateSweepFeeIncrement(s *sweep) (feeDetails, feeDetails, error) { rbfCache: rbfCache{ FeeRate: s.minFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - s.swapHash: *s, + sweeps: map[wire.OutPoint]sweep{ + s.outpoint: *s, }, } @@ -120,9 +121,13 @@ func estimateSweepFeeIncrement(s *sweep) (feeDetails, feeDetails, error) { } // Add the same sweep again to measure weight increments. - swapHash2 := s.swapHash - swapHash2[0]++ - batch.sweeps[swapHash2] = *s + outpoint2 := s.outpoint + outpoint2.Hash[0]++ + if _, has := batch.sweeps[outpoint2]; has { + return feeDetails{}, feeDetails{}, fmt.Errorf("dummy outpoint "+ + "%s is present in the batch", outpoint2) + } + batch.sweeps[outpoint2] = *s // Estimate weight of a batch with two sweeps. fd2, err := estimateBatchWeight(batch) diff --git a/sweepbatcher/greedy_batch_selection_test.go b/sweepbatcher/greedy_batch_selection_test.go index 959dfcf3a..d5aa58a49 100644 --- a/sweepbatcher/greedy_batch_selection_test.go +++ b/sweepbatcher/greedy_batch_selection_test.go @@ -5,6 +5,8 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/loop/swap" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lntypes" @@ -206,8 +208,14 @@ func TestEstimateSweepFeeIncrement(t *testing.T) { // for batches. func TestEstimateBatchWeight(t *testing.T) { // Useful variables reused in test cases. - swapHash1 := lntypes.Hash{1, 1, 1} - swapHash2 := lntypes.Hash{2, 2, 2} + outpoint1 := wire.OutPoint{ + Hash: chainhash.Hash{1, 1, 1}, + Index: 1, + } + outpoint2 := wire.OutPoint{ + Hash: chainhash.Hash{2, 2, 2}, + Index: 2, + } se2 := testHtlcV2SuccessEstimator se3 := testHtlcV3SuccessEstimator trAddr := (*btcutil.AddressTaproot)(nil) @@ -224,8 +232,8 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: lowFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se3, }, }, @@ -244,11 +252,11 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: lowFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se3, }, - swapHash2: { + outpoint2: { htlcSuccessEstimator: se3, }, }, @@ -267,11 +275,11 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: lowFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se2, }, - swapHash2: { + outpoint2: { htlcSuccessEstimator: se3, }, }, @@ -290,12 +298,12 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: lowFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se2, nonCoopHint: true, }, - swapHash2: { + outpoint2: { htlcSuccessEstimator: se3, nonCoopHint: true, }, @@ -315,8 +323,8 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: highFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se3, }, }, @@ -335,11 +343,11 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: lowFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se3, }, - swapHash2: { + outpoint2: { htlcSuccessEstimator: se3, nonCoopHint: true, }, @@ -359,11 +367,11 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: lowFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se3, }, - swapHash2: { + outpoint2: { htlcSuccessEstimator: se3, coopFailed: true, }, @@ -383,8 +391,8 @@ func TestEstimateBatchWeight(t *testing.T) { rbfCache: rbfCache{ FeeRate: lowFeeRate, }, - sweeps: map[lntypes.Hash]sweep{ - swapHash1: { + sweeps: map[wire.OutPoint]sweep{ + outpoint1: { htlcSuccessEstimator: se3, isExternalAddr: true, destAddr: trAddr, diff --git a/sweepbatcher/store.go b/sweepbatcher/store.go index 510015cc3..01b9e74a9 100644 --- a/sweepbatcher/store.go +++ b/sweepbatcher/store.go @@ -29,11 +29,11 @@ type Querier interface { GetBatchSweptAmount(ctx context.Context, batchID int32) (int64, error) // GetSweepStatus returns true if the sweep has been completed. - GetSweepStatus(ctx context.Context, swapHash []byte) (bool, error) + GetSweepStatus(ctx context.Context, outpoint string) (bool, error) // GetParentBatch fetches the parent batch of a completed sweep. - GetParentBatch(ctx context.Context, swapHash []byte) (sqlc.SweepBatch, - error) + GetParentBatch(ctx context.Context, + outpoint string) (sqlc.SweepBatch, error) // GetUnconfirmedBatches fetches all the batches from the // database that are not in a confirmed state. @@ -185,10 +185,10 @@ func (s *SQLStore) TotalSweptAmount(ctx context.Context, id int32) ( } // GetParentBatch fetches the parent batch of a completed sweep. -func (s *SQLStore) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( +func (s *SQLStore) GetParentBatch(ctx context.Context, outpoint wire.OutPoint) ( *dbBatch, error) { - batch, err := s.baseDb.GetParentBatch(ctx, swapHash[:]) + batch, err := s.baseDb.GetParentBatch(ctx, outpoint.String()) if err != nil { return nil, err } @@ -203,10 +203,10 @@ func (s *SQLStore) UpsertSweep(ctx context.Context, sweep *dbSweep) error { } // GetSweepStatus returns true if the sweep has been completed. -func (s *SQLStore) GetSweepStatus(ctx context.Context, swapHash lntypes.Hash) ( +func (s *SQLStore) GetSweepStatus(ctx context.Context, outpoint wire.OutPoint) ( bool, error) { - return s.baseDb.GetSweepStatus(ctx, swapHash[:]) + return s.baseDb.GetSweepStatus(ctx, outpoint.String()) } type dbBatch struct { @@ -355,15 +355,12 @@ func (s *SQLStore) convertSweepRow(row sqlc.Sweep) (dbSweep, error) { sweep.SwapHash = swapHash - hash, err := chainhash.NewHash(row.OutpointTxid) + outpoint, err := wire.NewOutPointFromString(row.Outpoint) if err != nil { return sweep, err } - sweep.Outpoint = wire.OutPoint{ - Hash: *hash, - Index: uint32(row.OutpointIndex), - } + sweep.Outpoint = *outpoint return sweep, nil } @@ -371,11 +368,10 @@ func (s *SQLStore) convertSweepRow(row sqlc.Sweep) (dbSweep, error) { // sweepToUpsertArgs converts a Sweep struct to the arguments needed to insert. func sweepToUpsertArgs(sweep dbSweep) sqlc.UpsertSweepParams { return sqlc.UpsertSweepParams{ - SwapHash: sweep.SwapHash[:], - BatchID: sweep.BatchID, - OutpointTxid: sweep.Outpoint.Hash[:], - OutpointIndex: int32(sweep.Outpoint.Index), - Amt: int64(sweep.Amount), - Completed: sweep.Completed, + SwapHash: sweep.SwapHash[:], + BatchID: sweep.BatchID, + Outpoint: sweep.Outpoint.String(), + Amt: int64(sweep.Amount), + Completed: sweep.Completed, } } diff --git a/sweepbatcher/store_mock.go b/sweepbatcher/store_mock.go index 5cbd1871c..47b27ee86 100644 --- a/sweepbatcher/store_mock.go +++ b/sweepbatcher/store_mock.go @@ -7,13 +7,13 @@ import ( "sync" "github.com/btcsuite/btcd/btcutil" - "github.com/lightningnetwork/lnd/lntypes" + "github.com/btcsuite/btcd/wire" ) // StoreMock implements a mock client swap store. type StoreMock struct { batches map[int32]dbBatch - sweeps map[lntypes.Hash]dbSweep + sweeps map[wire.OutPoint]dbSweep mu sync.Mutex } @@ -21,7 +21,7 @@ type StoreMock struct { func NewStoreMock() *StoreMock { return &StoreMock{ batches: make(map[int32]dbBatch), - sweeps: make(map[lntypes.Hash]dbSweep), + sweeps: make(map[wire.OutPoint]dbSweep), } } @@ -122,19 +122,19 @@ func (s *StoreMock) UpsertSweep(ctx context.Context, sweep *dbSweep) error { s.mu.Lock() defer s.mu.Unlock() - s.sweeps[sweep.SwapHash] = *sweep + s.sweeps[sweep.Outpoint] = *sweep return nil } // GetSweepStatus returns the status of a sweep. func (s *StoreMock) GetSweepStatus(ctx context.Context, - swapHash lntypes.Hash) (bool, error) { + outpoint wire.OutPoint) (bool, error) { s.mu.Lock() defer s.mu.Unlock() - sweep, ok := s.sweeps[swapHash] + sweep, ok := s.sweeps[outpoint] if !ok { return false, nil } @@ -148,23 +148,23 @@ func (s *StoreMock) Close() error { } // AssertSweepStored asserts that a sweep is stored. -func (s *StoreMock) AssertSweepStored(id lntypes.Hash) bool { +func (s *StoreMock) AssertSweepStored(outpoint wire.OutPoint) bool { s.mu.Lock() defer s.mu.Unlock() - _, ok := s.sweeps[id] + _, ok := s.sweeps[outpoint] return ok } // GetParentBatch returns the parent batch of a swap. -func (s *StoreMock) GetParentBatch(ctx context.Context, swapHash lntypes.Hash) ( - *dbBatch, error) { +func (s *StoreMock) GetParentBatch(ctx context.Context, + outpoint wire.OutPoint) (*dbBatch, error) { s.mu.Lock() defer s.mu.Unlock() for _, sweep := range s.sweeps { - if sweep.SwapHash == swapHash { + if sweep.Outpoint == outpoint { batch, ok := s.batches[sweep.BatchID] if !ok { return nil, errors.New("batch not found") diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index fa036fffe..49045781f 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -60,6 +60,7 @@ var ( // sweep stores any data related to sweeping a specific outpoint. type sweep struct { // swapHash is the hash of the swap that the sweep belongs to. + // Multiple sweeps may belong to the same swap. swapHash lntypes.Hash // outpoint is the outpoint being swept. @@ -188,6 +189,9 @@ type rbfCache struct { SkipNextBump bool } +// zeroSweepID is default value for sweep.primarySweepID and batchKit.primaryID. +var zeroSweepID wire.OutPoint + // batch is a collection of sweeps that are published together. type batch struct { // id is the primary identifier of this batch. @@ -196,11 +200,11 @@ type batch struct { // state is the current state of the batch. state batchState - // primarySweepID is the swap hash of the primary sweep in the batch. - primarySweepID lntypes.Hash + // primarySweepID is the outpoint of the primary sweep in the batch. + primarySweepID wire.OutPoint // sweeps store the sweeps that this batch currently contains. - sweeps map[lntypes.Hash]sweep + sweeps map[wire.OutPoint]sweep // currentHeight is the current block height. currentHeight int32 @@ -310,8 +314,8 @@ type batchKit struct { batchTxid *chainhash.Hash batchPkScript []byte state batchState - primaryID lntypes.Hash - sweeps map[lntypes.Hash]sweep + primaryID wire.OutPoint + sweeps map[wire.OutPoint]sweep rbfCache rbfCache returnChan chan SweepRequest wallet lndclient.WalletKitClient @@ -354,7 +358,7 @@ func NewBatch(cfg batchConfig, bk batchKit) *batch { // never been persisted, so it needs to be assigned a new ID. id: -1, state: Open, - sweeps: make(map[lntypes.Hash]sweep), + sweeps: make(map[wire.OutPoint]sweep), spendChan: make(chan *chainntnfs.SpendDetail), confChan: make(chan *chainntnfs.TxConfirmation, 1), reorgChan: make(chan struct{}, 1), @@ -389,7 +393,7 @@ func NewBatchFromDB(cfg batchConfig, bk batchKit) (*batch, error) { // Assign batchConfTarget to primary sweep's confTarget. for _, sweep := range bk.sweeps { - if sweep.swapHash == bk.primaryID { + if sweep.outpoint == bk.primaryID { cfg.batchConfTarget = sweep.confTarget break } @@ -479,7 +483,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // Before we run through the acceptance checks, let's just see if this // sweep is already in our batch. In that case, just update the sweep. - oldSweep, ok := b.sweeps[sweep.swapHash] + oldSweep, ok := b.sweeps[sweep.outpoint] if ok { // Preserve coopFailed value not to forget about cooperative // spending failure in this sweep. @@ -490,11 +494,11 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // to sweep again, a new sweep notifier will be created by the // swap. By re-assigning to the batch's sweep we make sure that // everything, including the notifier, is up to date. - b.sweeps[sweep.swapHash] = tmp + b.sweeps[sweep.outpoint] = tmp // If this is the primary sweep, we also need to update the // batch's confirmation target and fee rate. - if b.primarySweepID == sweep.swapHash { + if b.primarySweepID == sweep.outpoint { b.cfg.batchConfTarget = sweep.confTarget b.rbfCache.SkipNextBump = true } @@ -571,8 +575,8 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // If this is the first sweep being added to the batch, make it the // primary sweep. - if b.primarySweepID == lntypes.ZeroHash { - b.primarySweepID = sweep.swapHash + if b.primarySweepID == zeroSweepID { + b.primarySweepID = sweep.outpoint b.cfg.batchConfTarget = sweep.confTarget b.rbfCache.FeeRate = sweep.minFeeRate b.rbfCache.SkipNextBump = true @@ -587,7 +591,7 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // Add the sweep to the batch's sweeps. b.Infof("adding sweep %x", sweep.swapHash[:6]) - b.sweeps[sweep.swapHash] = *sweep + b.sweeps[sweep.outpoint] = *sweep // Update FeeRate. Max(sweep.minFeeRate) for all the sweeps of // the batch is the basis for fee bumps. @@ -599,15 +603,16 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { return true, b.persistSweep(ctx, *sweep, false) } -// sweepExists returns true if the batch contains the sweep with the given hash. -func (b *batch) sweepExists(hash lntypes.Hash) bool { +// sweepExists returns true if the batch contains the sweep with the given +// outpoint. +func (b *batch) sweepExists(outpoint wire.OutPoint) bool { done, err := b.scheduleNextCall() defer done() if err != nil { return false } - _, ok := b.sweeps[hash] + _, ok := b.sweeps[outpoint] return ok } @@ -664,7 +669,7 @@ func (b *batch) Run(ctx context.Context) error { // If a primary sweep exists we immediately start monitoring for its // spend. - if b.primarySweepID != lntypes.ZeroHash { + if b.primarySweepID != zeroSweepID { sweep := b.sweeps[b.primarySweepID] err := b.monitorSpend(runCtx, sweep) if err != nil { @@ -694,8 +699,8 @@ func (b *batch) Run(ctx context.Context) error { // completes. timerChan := clock.TickAfter(b.cfg.batchPublishDelay) - b.Infof("started, primary %x, total sweeps %v", - b.primarySweepID[0:6], len(b.sweeps)) + b.Infof("started, primary %s, total sweeps %d", + b.primarySweepID, len(b.sweeps)) for { select { @@ -1152,7 +1157,7 @@ func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, // places we store the sweep. sweep.coopFailed = true sweeps[i] = sweep - b.sweeps[sweep.swapHash] = sweep + b.sweeps[sweep.outpoint] = sweep // Update newCoopFailures to know if we need // another attempt of cooperative signing. @@ -1700,7 +1705,7 @@ func (b *batch) handleSpend(ctx context.Context, spendTx *wire.MsgTx) error { // batch and feed it back to the batcher. if !found { newSweep := sweep - delete(b.sweeps, sweep.swapHash) + delete(b.sweeps, sweep.outpoint) purgeList = append(purgeList, SweepRequest{ SwapHash: newSweep.swapHash, Outpoint: newSweep.outpoint, diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index 4c7965127..d71387dc0 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -79,11 +79,12 @@ type BatcherStore interface { UpsertSweep(ctx context.Context, sweep *dbSweep) error // GetSweepStatus returns the completed status of the sweep. - GetSweepStatus(ctx context.Context, swapHash lntypes.Hash) (bool, error) + GetSweepStatus(ctx context.Context, + outpoint wire.OutPoint) (bool, error) // GetParentBatch returns the parent batch of a (completed) sweep. - GetParentBatch(ctx context.Context, swapHash lntypes.Hash) (*dbBatch, - error) + GetParentBatch(ctx context.Context, + outpoint wire.OutPoint) (*dbBatch, error) // TotalSweptAmount returns the total amount swept by a (confirmed) // batch. @@ -135,8 +136,10 @@ type SweepInfo struct { // SweepFetcher is used to get details of a sweep. type SweepFetcher interface { - // FetchSweep returns details of the sweep with the given hash. - FetchSweep(ctx context.Context, hash lntypes.Hash) (*SweepInfo, error) + // FetchSweep returns details of the sweep with the given hash or + // outpoint. The outpoint is used if hash is not unique. + FetchSweep(ctx context.Context, hash lntypes.Hash, + outpoint wire.OutPoint) (*SweepInfo, error) } // MuSig2SignSweep is a function that can be used to sign a sweep transaction @@ -611,7 +614,7 @@ func (b *Batcher) testRunInEventLoop(ctx context.Context, handler func()) { func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, notifier *SpendNotifier) error { - completed, err := b.store.GetSweepStatus(ctx, sweep.swapHash) + completed, err := b.store.GetSweepStatus(ctx, sweep.outpoint) if err != nil { return err } @@ -626,7 +629,7 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, // Verify that the parent batch is confirmed. Note that a batch // is only considered confirmed after it has received three // on-chain confirmations to prevent issues caused by reorgs. - parentBatch, err := b.store.GetParentBatch(ctx, sweep.swapHash) + parentBatch, err := b.store.GetParentBatch(ctx, sweep.outpoint) if err != nil { errorf("unable to get parent batch for sweep %x:"+ " %v", sweep.swapHash[:6], err) @@ -656,7 +659,7 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep, // Check if the sweep is already in a batch. If that is the case, we // provide the sweep to that batch and return. for _, batch := range b.batches { - if batch.sweepExists(sweep.swapHash) { + if batch.sweepExists(sweep.outpoint) { accepted, err := batch.addSweep(ctx, sweep) if err != nil && !errors.Is(err, ErrBatchShuttingDown) { return err @@ -803,7 +806,7 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error { primarySweep := dbSweeps[0] - sweeps := make(map[lntypes.Hash]sweep) + sweeps := make(map[wire.OutPoint]sweep) // Collect feeRate from sweeps and stored batch. feeRate := batch.rbfCache.FeeRate @@ -814,7 +817,7 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error { return err } - sweeps[sweep.swapHash] = *sweep + sweeps[sweep.outpoint] = *sweep // Set minFeeRate to max(sweep.minFeeRate) for all sweeps. if feeRate < sweep.minFeeRate { @@ -834,7 +837,7 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error { batchKit.batchTxid = batch.batchTxid batchKit.batchPkScript = batch.batchPkScript batchKit.state = batch.state - batchKit.primaryID = primarySweep.SwapHash + batchKit.primaryID = primarySweep.Outpoint batchKit.sweeps = sweeps batchKit.rbfCache = rbfCache batchKit.log = logger @@ -1031,9 +1034,10 @@ type SwapStoreWrapper struct { } // FetchSweep returns details of the sweep with the given hash. +// In LoopOut case, swap hashes are unique. // Implements SweepFetcher interface. func (f *SwapStoreWrapper) FetchSweep(ctx context.Context, - swapHash lntypes.Hash) (*SweepInfo, error) { + swapHash lntypes.Hash, _ wire.OutPoint) (*SweepInfo, error) { swap, err := f.swapStore.FetchLoopOutSwap(ctx, swapHash) if err != nil { @@ -1094,7 +1098,7 @@ func (b *Batcher) fetchSweep(ctx context.Context, func (b *Batcher) loadSweep(ctx context.Context, swapHash lntypes.Hash, outpoint wire.OutPoint, value btcutil.Amount) (*sweep, error) { - s, err := b.sweepStore.FetchSweep(ctx, swapHash) + s, err := b.sweepStore.FetchSweep(ctx, swapHash, outpoint) if err != nil { return nil, fmt.Errorf("failed to fetch sweep data for %x: %w", swapHash[:6], err) diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index f4b297647..1cf6cd702 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -163,9 +163,9 @@ func (b *batch) snapshot(ctx context.Context) *batch { var snapshot *batch b.testRunInEventLoop(ctx, func() { // Deep copy sweeps. - sweeps := make(map[lntypes.Hash]sweep, len(b.sweeps)) - for h, s := range b.sweeps { - sweeps[h] = s + sweeps := make(map[wire.OutPoint]sweep, len(b.sweeps)) + for o, s := range b.sweeps { + sweeps[o] = s } // Deep copy cfg. @@ -360,12 +360,12 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, for _, batch := range batches { batch := batch.snapshot(ctx) switch batch.primarySweepID { - case sweepReq1.SwapHash: + case sweepReq1.Outpoint: if len(batch.sweeps) != 2 { return false } - case sweepReq3.SwapHash: + case sweepReq3.Outpoint: if len(batch.sweeps) != 1 { return false } @@ -376,9 +376,9 @@ func testSweepBatcherBatchCreation(t *testing.T, store testStore, }, test.Timeout, eventuallyCheckFrequency) // Check that all sweeps were stored. - require.True(t, batcherStore.AssertSweepStored(sweepReq1.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq2.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq3.SwapHash)) + require.True(t, batcherStore.AssertSweepStored(sweepReq1.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq2.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq3.Outpoint)) } // testFeeBumping tests that sweep is RBFed with slightly higher fee rate after @@ -565,7 +565,7 @@ func testTxLabeler(t *testing.T, store testStore, var wantLabel string for _, btch := range getBatches(ctx, batcher) { btch := btch.snapshot(ctx) - if btch.primarySweepID == sweepReq1.SwapHash { + if btch.primarySweepID == sweepReq1.Outpoint { wantLabel = fmt.Sprintf( "BatchOutSweepSuccess -- %d", btch.id, ) @@ -794,7 +794,7 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, batch := &batch{} for _, btch := range getBatches(ctx, batcher) { btch.testRunInEventLoop(ctx, func() { - if btch.primarySweepID == sweepReq1.SwapHash { + if btch.primarySweepID == sweepReq1.Outpoint { batch = btch } }) @@ -806,7 +806,7 @@ func testSweepBatcherSimpleLifecycle(t *testing.T, store testStore, }, test.Timeout, eventuallyCheckFrequency) // The primary sweep id should be that of the first inserted sweep. - require.Equal(t, batch.primarySweepID, sweepReq1.SwapHash) + require.Equal(t, batch.primarySweepID, sweepReq1.Outpoint) // Wait for tx to be published. <-lnd.TxPublishChannel @@ -1057,7 +1057,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // batch. require.Eventually(t, func() bool { // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + if !batcherStore.AssertSweepStored(sweepReq.Outpoint) { return false } @@ -1135,7 +1135,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { // Wait for batch to load. require.Eventually(t, func() bool { // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + if !batcherStore.AssertSweepStored(sweepReq.Outpoint) { return false } @@ -1243,7 +1243,7 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { Value: 111, Outpoint: wire.OutPoint{ Hash: chainhash.Hash{2, 2}, - Index: 1, + Index: 2, }, Notifier: &dummyNotifier, } @@ -1323,8 +1323,8 @@ func testDelays(t *testing.T, store testStore, batcherStore testBatcherStore) { SwapHash: lntypes.Hash{3, 3, 3}, Value: 111, Outpoint: wire.OutPoint{ - Hash: chainhash.Hash{2, 2}, - Index: 1, + Hash: chainhash.Hash{3, 3}, + Index: 3, }, Notifier: &dummyNotifier, } @@ -1440,14 +1440,16 @@ func testMaxSweepsPerBatch(t *testing.T, store testStore, preimage := lntypes.Preimage{2, byte(i % 256), byte(i / 256)} swapHash := preimage.Hash() + outpoint := wire.OutPoint{ + Hash: chainhash.Hash{byte(i + 1)}, + Index: uint32(i + 1), + } + // Create a sweep request. sweepReq := SweepRequest{ SwapHash: swapHash, Value: 111, - Outpoint: wire.OutPoint{ - Hash: chainhash.Hash{1, 1}, - Index: 1, - }, + Outpoint: outpoint, Notifier: &dummyNotifier, } @@ -1681,7 +1683,7 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, b := &batch{} for _, btch := range getBatches(ctx, batcher) { btch.testRunInEventLoop(ctx, func() { - if btch.primarySweepID == sweepReq1.SwapHash { + if btch.primarySweepID == sweepReq1.Outpoint { b = btch } }) @@ -1694,7 +1696,7 @@ func testSweepBatcherSweepReentry(t *testing.T, store testStore, // Verify that the batch has a primary sweep id that matches the first // inserted sweep, sweep1. - require.Equal(t, b.primarySweepID, sweepReq1.SwapHash) + require.Equal(t, b.primarySweepID, sweepReq1.Outpoint) // Create the spending tx. In order to simulate an older version of the // batch transaction being confirmed, we only insert the primary sweep's @@ -1962,17 +1964,17 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, for _, batch := range batches { batch := batch.snapshot(ctx) switch batch.primarySweepID { - case sweepReq1.SwapHash: + case sweepReq1.Outpoint: if len(batch.sweeps) != 1 { return false } - case sweepReq2.SwapHash: + case sweepReq2.Outpoint: if len(batch.sweeps) != 1 { return false } - case sweepReq3.SwapHash: + case sweepReq3.Outpoint: if len(batch.sweeps) != 1 { return false } @@ -1983,9 +1985,9 @@ func testSweepBatcherNonWalletAddr(t *testing.T, store testStore, }, test.Timeout, eventuallyCheckFrequency) // Check that all sweeps were stored. - require.True(t, batcherStore.AssertSweepStored(sweepReq1.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq2.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq3.SwapHash)) + require.True(t, batcherStore.AssertSweepStored(sweepReq1.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq2.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq3.Outpoint)) } // testSweepBatcherComposite tests that sweep requests that sweep to both wallet @@ -2301,22 +2303,22 @@ func testSweepBatcherComposite(t *testing.T, store testStore, for _, batch := range batches { batch := batch.snapshot(ctx) switch batch.primarySweepID { - case sweepReq1.SwapHash: + case sweepReq1.Outpoint: if len(batch.sweeps) != 2 { return false } - case sweepReq3.SwapHash: + case sweepReq3.Outpoint: if len(batch.sweeps) != 1 { return false } - case sweepReq4.SwapHash: + case sweepReq4.Outpoint: if len(batch.sweeps) != 2 { return false } - case sweepReq6.SwapHash: + case sweepReq6.Outpoint: if len(batch.sweeps) != 1 { return false } @@ -2327,12 +2329,12 @@ func testSweepBatcherComposite(t *testing.T, store testStore, }, test.Timeout, eventuallyCheckFrequency) // Check that all sweeps were stored. - require.True(t, batcherStore.AssertSweepStored(sweepReq1.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq2.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq3.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq4.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq5.SwapHash)) - require.True(t, batcherStore.AssertSweepStored(sweepReq6.SwapHash)) + require.True(t, batcherStore.AssertSweepStored(sweepReq1.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq2.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq3.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq4.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq5.Outpoint)) + require.True(t, batcherStore.AssertSweepStored(sweepReq6.Outpoint)) } // makeTestTx creates a test transaction with a single output of the given @@ -2456,7 +2458,7 @@ func testRestoringEmptyBatch(t *testing.T, store testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + if !batcherStore.AssertSweepStored(sweepReq.Outpoint) { return false } @@ -2676,10 +2678,10 @@ func testHandleSweepTwice(t *testing.T, backend testStore, require.Eventually(t, func() bool { // Make sure that the sweep was stored and we have exactly one // active batch. - if !batcherStore.AssertSweepStored(sweepReq1.SwapHash) { + if !batcherStore.AssertSweepStored(sweepReq1.Outpoint) { return false } - if !batcherStore.AssertSweepStored(sweepReq2.SwapHash) { + if !batcherStore.AssertSweepStored(sweepReq2.Outpoint) { return false } @@ -2730,7 +2732,7 @@ func testHandleSweepTwice(t *testing.T, backend testStore, snapshot := secondBatch.snapshot(ctx) // Make sure the second batch has the second sweep. - sweep2, has := snapshot.sweeps[sweepReq2.SwapHash] + sweep2, has := snapshot.sweeps[sweepReq2.Outpoint] if !has { return false } @@ -2833,7 +2835,7 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, // batch. require.Eventually(t, func() bool { // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + if !batcherStore.AssertSweepStored(sweepReq.Outpoint) { return false } @@ -2889,7 +2891,7 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, // Wait for batch to load. require.Eventually(t, func() bool { // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(sweepReq.SwapHash) { + if !batcherStore.AssertSweepStored(sweepReq.Outpoint) { return false } @@ -2915,24 +2917,24 @@ func testRestoringPreservesConfTarget(t *testing.T, store testStore, } type sweepFetcherMock struct { - store map[lntypes.Hash]*SweepInfo + store map[wire.OutPoint]*SweepInfo mu sync.Mutex } -func (f *sweepFetcherMock) setSweep(hash lntypes.Hash, info *SweepInfo) { +func (f *sweepFetcherMock) setSweep(outpoint wire.OutPoint, info *SweepInfo) { f.mu.Lock() defer f.mu.Unlock() - f.store[hash] = info + f.store[outpoint] = info } -func (f *sweepFetcherMock) FetchSweep(ctx context.Context, hash lntypes.Hash) ( - *SweepInfo, error) { +func (f *sweepFetcherMock) FetchSweep(ctx context.Context, _ lntypes.Hash, + outpoint wire.OutPoint) (*SweepInfo, error) { f.mu.Lock() defer f.mu.Unlock() - return f.store[hash], nil + return f.store[outpoint], nil } // testSweepFetcher tests providing custom sweep fetcher to Batcher. @@ -2986,12 +2988,6 @@ func testSweepFetcher(t *testing.T, store testStore, DestAddr: destAddr, } - sweepFetcher := &sweepFetcherMock{ - store: map[lntypes.Hash]*SweepInfo{ - swapHash: sweepInfo, - }, - } - // Create a sweep request. sweepReq := SweepRequest{ SwapHash: swapHash, @@ -3003,6 +2999,12 @@ func testSweepFetcher(t *testing.T, store testStore, Notifier: &dummyNotifier, } + sweepFetcher := &sweepFetcherMock{ + store: map[wire.OutPoint]*SweepInfo{ + sweepReq.Outpoint: sweepInfo, + }, + } + // Create a swap in the DB. It is needed to satisfy SQL constraints in // case of SQL test. The data is not actually used, since we pass sweep // fetcher, so put different conf target to make sure it is not used. @@ -3045,7 +3047,7 @@ func testSweepFetcher(t *testing.T, store testStore, // batch. require.Eventually(t, func() bool { // Make sure that the sweep was stored - if !batcherStore.AssertSweepStored(swapHash) { + if !batcherStore.AssertSweepStored(sweepReq.Outpoint) { return false } @@ -3286,7 +3288,7 @@ func testWithMixedBatch(t *testing.T, store testStore, // Use sweepFetcher to provide NonCoopHint for swapHash1. sweepFetcher := &sweepFetcherMock{ - store: map[lntypes.Hash]*SweepInfo{}, + store: map[wire.OutPoint]*SweepInfo{}, } // Create 3 sweeps: @@ -3350,6 +3352,11 @@ func testWithMixedBatch(t *testing.T, store testStore, // Create 3 swaps and 3 sweeps. for i, swapHash := range swapHashes { + outpoint := wire.OutPoint{ + Hash: chainhash.Hash{byte(i + 1)}, + Index: uint32(i + 1), + } + // Publish a block to trigger republishing. err = lnd.NotifyHeight(601 + int32(i)) require.NoError(t, err) @@ -3395,16 +3402,13 @@ func testWithMixedBatch(t *testing.T, store testStore, if i == 0 { sweepInfo.NonCoopHint = true } - sweepFetcher.setSweep(swapHash, sweepInfo) + sweepFetcher.setSweep(outpoint, sweepInfo) // Create sweep request. sweepReq := SweepRequest{ SwapHash: swapHash, Value: 1_000_000, - Outpoint: wire.OutPoint{ - Hash: chainhash.Hash{1, 1}, - Index: 1, - }, + Outpoint: outpoint, Notifier: &dummyNotifier, } require.NoError(t, batcher.AddSweep(&sweepReq)) @@ -3492,7 +3496,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, // Use sweepFetcher to provide NonCoopHint for swapHash1. sweepFetcher := &sweepFetcherMock{ - store: map[lntypes.Hash]*SweepInfo{}, + store: map[wire.OutPoint]*SweepInfo{}, } // Swap hashes must match the preimages, for non-cooperative spending @@ -3523,6 +3527,11 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, // Create swaps and sweeps. for i, swapHash := range swapHashes { + outpoint := wire.OutPoint{ + Hash: chainhash.Hash{byte(i + 1)}, + Index: uint32(i + 1), + } + // Put a swap into store to satisfy SQL constraints. swap := &loopdb.LoopOutContract{ SwapContract: loopdb.SwapContract{ @@ -3549,7 +3558,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, ) require.NoError(t, err) - sweepFetcher.setSweep(swapHash, &SweepInfo{ + sweepFetcher.setSweep(outpoint, &SweepInfo{ Preimage: preimages[i], NonCoopHint: nonCoopHints[i], @@ -3567,10 +3576,7 @@ func testWithMixedBatchCustom(t *testing.T, store testStore, sweepReq := SweepRequest{ SwapHash: swapHash, Value: 1_000_000, - Outpoint: wire.OutPoint{ - Hash: chainhash.Hash{1, 1}, - Index: 1, - }, + Outpoint: outpoint, Notifier: &dummyNotifier, } require.NoError(t, batcher.AddSweep(&sweepReq)) @@ -4148,13 +4154,13 @@ type testBatcherStore interface { BatcherStore // AssertSweepStored asserts that a sweep is stored. - AssertSweepStored(id lntypes.Hash) bool + AssertSweepStored(outpoint wire.OutPoint) bool } type loopdbBatcherStore struct { BatcherStore - sweepsSet map[lntypes.Hash]struct{} + sweepsSet map[wire.OutPoint]struct{} mu sync.Mutex } @@ -4169,17 +4175,17 @@ func (s *loopdbBatcherStore) UpsertSweep(ctx context.Context, err := s.BatcherStore.UpsertSweep(ctx, sweep) if err == nil { - s.sweepsSet[sweep.SwapHash] = struct{}{} + s.sweepsSet[sweep.Outpoint] = struct{}{} } return err } // AssertSweepStored asserts that a sweep is stored. -func (s *loopdbBatcherStore) AssertSweepStored(id lntypes.Hash) bool { +func (s *loopdbBatcherStore) AssertSweepStored(outpoint wire.OutPoint) bool { s.mu.Lock() defer s.mu.Unlock() - _, has := s.sweepsSet[id] + _, has := s.sweepsSet[outpoint] return has } @@ -4255,7 +4261,7 @@ func runTests(t *testing.T, testFn func(t *testing.T, store testStore, testStore := newLoopdbStore(t, sqlDB) testBatcherStore := &loopdbBatcherStore{ BatcherStore: batcherStore, - sweepsSet: make(map[lntypes.Hash]struct{}), + sweepsSet: make(map[wire.OutPoint]struct{}), } testFn(t, testStore, testBatcherStore) })