Skip to content

Commit 22dd2e8

Browse files
authored
Merge pull request #759 from starius/sweepbatcher-avoid-adding-to-two-batches
sweepbatcher: exit early in handleSweep
2 parents a135eb8 + 4258b95 commit 22dd2e8

File tree

5 files changed

+219
-17
lines changed

5 files changed

+219
-17
lines changed

sweepbatcher/log.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ func init() {
1717
UseLogger(build.NewSubLogger("SWEEP", nil))
1818
}
1919

20-
// batchPrefixLogger returns a logger that prefixes all log messages with the ID.
20+
// batchPrefixLogger returns a logger that prefixes all log messages with
21+
// the ID.
2122
func batchPrefixLogger(batchID string) btclog.Logger {
2223
return build.NewPrefixLog(fmt.Sprintf("[Batch %s]", batchID), log)
2324
}

sweepbatcher/store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ func NewSQLStore(db BaseDB, network *chaincfg.Params) *SQLStore {
8282

8383
// FetchUnconfirmedSweepBatches fetches all the batches from the database that
8484
// are not in a confirmed state.
85-
func (s *SQLStore) FetchUnconfirmedSweepBatches(ctx context.Context) ([]*dbBatch,
86-
error) {
85+
func (s *SQLStore) FetchUnconfirmedSweepBatches(ctx context.Context) (
86+
[]*dbBatch, error) {
8787

8888
var batches []*dbBatch
8989

sweepbatcher/sweep_batch.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -592,9 +592,11 @@ func (b *batch) publishBatch(ctx context.Context) (btcutil.Amount, error) {
592592
batchTx.LockTime = uint32(b.currentHeight)
593593

594594
var (
595-
batchAmt btcutil.Amount
596-
prevOuts = make([]*wire.TxOut, 0, len(b.sweeps))
597-
signDescs = make([]*lndclient.SignDescriptor, 0, len(b.sweeps))
595+
batchAmt btcutil.Amount
596+
prevOuts = make([]*wire.TxOut, 0, len(b.sweeps))
597+
signDescs = make(
598+
[]*lndclient.SignDescriptor, 0, len(b.sweeps),
599+
)
598600
sweeps = make([]sweep, 0, len(b.sweeps))
599601
fee btcutil.Amount
600602
inputCounter int

sweepbatcher/sweep_batcher.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,12 @@ func (b *Batcher) handleSweep(ctx context.Context, sweep *sweep,
328328

329329
if !accepted {
330330
return fmt.Errorf("existing sweep %x was not "+
331-
"accepted by batch %d", sweep.swapHash[:6],
332-
batch.id)
331+
"accepted by batch %d",
332+
sweep.swapHash[:6], batch.id)
333333
}
334+
335+
// The sweep was updated in the batch, our job is done.
336+
return nil
334337
}
335338
}
336339

@@ -461,6 +464,8 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error {
461464
FeeRate: batch.rbfCache.FeeRate,
462465
}
463466

467+
logger := batchPrefixLogger(fmt.Sprintf("%d", batch.id))
468+
464469
batchKit := batchKit{
465470
id: batch.id,
466471
batchTxid: batch.batchTxid,
@@ -477,7 +482,7 @@ func (b *Batcher) spinUpBatchFromDB(ctx context.Context, batch *batch) error {
477482
verifySchnorrSig: b.VerifySchnorrSig,
478483
purger: b.AddSweep,
479484
store: b.store,
480-
log: batchPrefixLogger(fmt.Sprintf("%d", batch.id)),
485+
log: logger,
481486
quit: b.quit,
482487
}
483488

@@ -598,15 +603,17 @@ func (b *Batcher) monitorSpendAndNotify(ctx context.Context, sweep *sweep,
598603
totalSwept,
599604
)
600605

606+
onChainFeePortion := getFeePortionPaidBySweep(
607+
spendTx, feePortionPerSweep,
608+
roundingDifference, sweep,
609+
)
610+
601611
// Notify the requester of the spend
602612
// with the spend details, including the fee
603613
// portion for this particular sweep.
604614
spendDetail := &SpendDetail{
605-
Tx: spendTx,
606-
OnChainFeePortion: getFeePortionPaidBySweep( // nolint:lll
607-
spendTx, feePortionPerSweep,
608-
roundingDifference, sweep,
609-
),
615+
Tx: spendTx,
616+
OnChainFeePortion: onChainFeePortion,
610617
}
611618

612619
select {

sweepbatcher/sweep_batcher_test.go

Lines changed: 195 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ func TestSweepBatcherBatchCreation(t *testing.T) {
182182
<-lnd.RegisterSpendChannel
183183

184184
require.Eventually(t, func() bool {
185-
// Verify that each batch has the correct number of sweeps in it.
185+
// Verify that each batch has the correct number of sweeps
186+
// in it.
186187
for _, batch := range batcher.batches {
187188
switch batch.primarySweepID {
188189
case sweepReq1.SwapHash:
@@ -481,7 +482,9 @@ func TestSweepBatcherSweepReentry(t *testing.T) {
481482
},
482483
TxOut: []*wire.TxOut{
483484
{
484-
Value: int64(sweepReq1.Value.ToUnit(btcutil.AmountSatoshi)),
485+
Value: int64(sweepReq1.Value.ToUnit(
486+
btcutil.AmountSatoshi,
487+
)),
485488
PkScript: []byte{3, 2, 1},
486489
},
487490
},
@@ -683,7 +686,8 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) {
683686
<-lnd.RegisterSpendChannel
684687

685688
require.Eventually(t, func() bool {
686-
// Verify that each batch has the correct number of sweeps in it.
689+
// Verify that each batch has the correct number of sweeps
690+
// in it.
687691
for _, batch := range batcher.batches {
688692
switch batch.primarySweepID {
689693
case sweepReq1.SwapHash:
@@ -1111,3 +1115,191 @@ func TestRestoringEmptyBatch(t *testing.T) {
11111115

11121116
checkBatcherError(t, runErr)
11131117
}
1118+
1119+
type loopStoreMock struct {
1120+
loops map[lntypes.Hash]*loopdb.LoopOut
1121+
mu sync.Mutex
1122+
}
1123+
1124+
func newLoopStoreMock() *loopStoreMock {
1125+
return &loopStoreMock{
1126+
loops: make(map[lntypes.Hash]*loopdb.LoopOut),
1127+
}
1128+
}
1129+
1130+
func (s *loopStoreMock) FetchLoopOutSwap(ctx context.Context,
1131+
hash lntypes.Hash) (*loopdb.LoopOut, error) {
1132+
1133+
s.mu.Lock()
1134+
defer s.mu.Unlock()
1135+
1136+
out, has := s.loops[hash]
1137+
if !has {
1138+
return nil, errors.New("loop not found")
1139+
}
1140+
1141+
return out, nil
1142+
}
1143+
1144+
func (s *loopStoreMock) putLoopOutSwap(hash lntypes.Hash, out *loopdb.LoopOut) {
1145+
s.mu.Lock()
1146+
defer s.mu.Unlock()
1147+
1148+
s.loops[hash] = out
1149+
}
1150+
1151+
// TestHandleSweepTwice tests that handing the same sweep twice must not
1152+
// add it to different batches.
1153+
func TestHandleSweepTwice(t *testing.T) {
1154+
defer test.Guard(t)()
1155+
1156+
lnd := test.NewMockLnd()
1157+
ctx, cancel := context.WithCancel(context.Background())
1158+
1159+
store := newLoopStoreMock()
1160+
1161+
batcherStore := NewStoreMock()
1162+
1163+
batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer,
1164+
testMuSig2SignSweep, nil, lnd.ChainParams, batcherStore, store)
1165+
1166+
var wg sync.WaitGroup
1167+
wg.Add(1)
1168+
1169+
var runErr error
1170+
go func() {
1171+
defer wg.Done()
1172+
runErr = batcher.Run(ctx)
1173+
}()
1174+
1175+
// Wait for the batcher to be initialized.
1176+
<-batcher.initDone
1177+
1178+
const shortCltv = 111
1179+
const longCltv = 111 + defaultMaxTimeoutDistance + 6
1180+
1181+
// Create two sweep requests with CltvExpiry distant from each other
1182+
// to go assigned to separate batches.
1183+
sweepReq1 := SweepRequest{
1184+
SwapHash: lntypes.Hash{1, 1, 1},
1185+
Value: 111,
1186+
Outpoint: wire.OutPoint{
1187+
Hash: chainhash.Hash{1, 1},
1188+
Index: 1,
1189+
},
1190+
Notifier: &dummyNotifier,
1191+
}
1192+
1193+
loopOut1 := &loopdb.LoopOut{
1194+
Loop: loopdb.Loop{
1195+
Hash: lntypes.Hash{1, 1, 1},
1196+
},
1197+
Contract: &loopdb.LoopOutContract{
1198+
SwapContract: loopdb.SwapContract{
1199+
CltvExpiry: shortCltv,
1200+
AmountRequested: 111,
1201+
},
1202+
SwapInvoice: swapInvoice,
1203+
},
1204+
}
1205+
1206+
sweepReq2 := SweepRequest{
1207+
SwapHash: lntypes.Hash{2, 2, 2},
1208+
Value: 222,
1209+
Outpoint: wire.OutPoint{
1210+
Hash: chainhash.Hash{2, 2},
1211+
Index: 2,
1212+
},
1213+
Notifier: &dummyNotifier,
1214+
}
1215+
1216+
loopOut2 := &loopdb.LoopOut{
1217+
Loop: loopdb.Loop{
1218+
Hash: lntypes.Hash{2, 2, 2},
1219+
},
1220+
Contract: &loopdb.LoopOutContract{
1221+
SwapContract: loopdb.SwapContract{
1222+
CltvExpiry: longCltv,
1223+
AmountRequested: 222,
1224+
},
1225+
SwapInvoice: swapInvoice,
1226+
},
1227+
}
1228+
1229+
store.putLoopOutSwap(sweepReq1.SwapHash, loopOut1)
1230+
store.putLoopOutSwap(sweepReq2.SwapHash, loopOut2)
1231+
1232+
// Deliver sweep request to batcher.
1233+
require.NoError(t, batcher.AddSweep(&sweepReq1))
1234+
1235+
// Since two batches were created we check that it registered for its
1236+
// primary sweep's spend.
1237+
<-lnd.RegisterSpendChannel
1238+
1239+
// Deliver the second sweep. It will go to a separate batch,
1240+
// since CltvExpiry values are distant enough.
1241+
require.NoError(t, batcher.AddSweep(&sweepReq2))
1242+
<-lnd.RegisterSpendChannel
1243+
1244+
// Once batcher receives sweep request it will eventually spin up
1245+
// batches.
1246+
require.Eventually(t, func() bool {
1247+
// Make sure that the sweep was stored and we have exactly one
1248+
// active batch.
1249+
return batcherStore.AssertSweepStored(sweepReq1.SwapHash) &&
1250+
batcherStore.AssertSweepStored(sweepReq2.SwapHash) &&
1251+
len(batcher.batches) == 2
1252+
}, test.Timeout, eventuallyCheckFrequency)
1253+
1254+
// Change the second sweep so that it can be added to the first batch.
1255+
// Change CltvExpiry.
1256+
loopOut2 = &loopdb.LoopOut{
1257+
Loop: loopdb.Loop{
1258+
Hash: lntypes.Hash{2, 2, 2},
1259+
},
1260+
Contract: &loopdb.LoopOutContract{
1261+
SwapContract: loopdb.SwapContract{
1262+
CltvExpiry: shortCltv,
1263+
AmountRequested: 222,
1264+
},
1265+
SwapInvoice: swapInvoice,
1266+
},
1267+
}
1268+
store.putLoopOutSwap(sweepReq2.SwapHash, loopOut2)
1269+
1270+
// Re-add the second sweep. It is expected to stay in second batch,
1271+
// not added to both batches.
1272+
require.NoError(t, batcher.AddSweep(&sweepReq2))
1273+
1274+
require.Eventually(t, func() bool {
1275+
// Make sure there are two batches.
1276+
batches := batcher.batches
1277+
if len(batches) != 2 {
1278+
return false
1279+
}
1280+
1281+
// Make sure the second batch has the second sweep.
1282+
sweep2, has := batches[1].sweeps[sweepReq2.SwapHash]
1283+
if !has {
1284+
return false
1285+
}
1286+
1287+
// Make sure the second sweep's timeout has been updated.
1288+
if sweep2.timeout != shortCltv {
1289+
return false
1290+
}
1291+
1292+
return true
1293+
}, test.Timeout, eventuallyCheckFrequency)
1294+
1295+
// Make sure each batch has one sweep. If the second sweep was added to
1296+
// both batches, the following check won't pass.
1297+
require.Equal(t, 1, len(batcher.batches[0].sweeps))
1298+
require.Equal(t, 1, len(batcher.batches[1].sweeps))
1299+
1300+
// Now make it quit by canceling the context.
1301+
cancel()
1302+
wg.Wait()
1303+
1304+
checkBatcherError(t, runErr)
1305+
}

0 commit comments

Comments
 (0)