@@ -182,7 +182,8 @@ func TestSweepBatcherBatchCreation(t *testing.T) {
182182 <- lnd .RegisterSpendChannel
183183
184184 require .Eventually (t , func () bool {
185- // Verify that each batch has the correct number of sweeps in it.
185+ // Verify that each batch has the correct number of sweeps
186+ // in it.
186187 for _ , batch := range batcher .batches {
187188 switch batch .primarySweepID {
188189 case sweepReq1 .SwapHash :
@@ -481,7 +482,9 @@ func TestSweepBatcherSweepReentry(t *testing.T) {
481482 },
482483 TxOut : []* wire.TxOut {
483484 {
484- Value : int64 (sweepReq1 .Value .ToUnit (btcutil .AmountSatoshi )),
485+ Value : int64 (sweepReq1 .Value .ToUnit (
486+ btcutil .AmountSatoshi ,
487+ )),
485488 PkScript : []byte {3 , 2 , 1 },
486489 },
487490 },
@@ -683,7 +686,8 @@ func TestSweepBatcherNonWalletAddr(t *testing.T) {
683686 <- lnd .RegisterSpendChannel
684687
685688 require .Eventually (t , func () bool {
686- // Verify that each batch has the correct number of sweeps in it.
689+ // Verify that each batch has the correct number of sweeps
690+ // in it.
687691 for _ , batch := range batcher .batches {
688692 switch batch .primarySweepID {
689693 case sweepReq1 .SwapHash :
@@ -1111,3 +1115,191 @@ func TestRestoringEmptyBatch(t *testing.T) {
11111115
11121116 checkBatcherError (t , runErr )
11131117}
1118+
1119+ type loopStoreMock struct {
1120+ loops map [lntypes.Hash ]* loopdb.LoopOut
1121+ mu sync.Mutex
1122+ }
1123+
1124+ func newLoopStoreMock () * loopStoreMock {
1125+ return & loopStoreMock {
1126+ loops : make (map [lntypes.Hash ]* loopdb.LoopOut ),
1127+ }
1128+ }
1129+
1130+ func (s * loopStoreMock ) FetchLoopOutSwap (ctx context.Context ,
1131+ hash lntypes.Hash ) (* loopdb.LoopOut , error ) {
1132+
1133+ s .mu .Lock ()
1134+ defer s .mu .Unlock ()
1135+
1136+ out , has := s .loops [hash ]
1137+ if ! has {
1138+ return nil , errors .New ("loop not found" )
1139+ }
1140+
1141+ return out , nil
1142+ }
1143+
1144+ func (s * loopStoreMock ) putLoopOutSwap (hash lntypes.Hash , out * loopdb.LoopOut ) {
1145+ s .mu .Lock ()
1146+ defer s .mu .Unlock ()
1147+
1148+ s .loops [hash ] = out
1149+ }
1150+
1151+ // TestHandleSweepTwice tests that handing the same sweep twice must not
1152+ // add it to different batches.
1153+ func TestHandleSweepTwice (t * testing.T ) {
1154+ defer test .Guard (t )()
1155+
1156+ lnd := test .NewMockLnd ()
1157+ ctx , cancel := context .WithCancel (context .Background ())
1158+
1159+ store := newLoopStoreMock ()
1160+
1161+ batcherStore := NewStoreMock ()
1162+
1163+ batcher := NewBatcher (lnd .WalletKit , lnd .ChainNotifier , lnd .Signer ,
1164+ testMuSig2SignSweep , nil , lnd .ChainParams , batcherStore , store )
1165+
1166+ var wg sync.WaitGroup
1167+ wg .Add (1 )
1168+
1169+ var runErr error
1170+ go func () {
1171+ defer wg .Done ()
1172+ runErr = batcher .Run (ctx )
1173+ }()
1174+
1175+ // Wait for the batcher to be initialized.
1176+ <- batcher .initDone
1177+
1178+ const shortCltv = 111
1179+ const longCltv = 111 + defaultMaxTimeoutDistance + 6
1180+
1181+ // Create two sweep requests with CltvExpiry distant from each other
1182+ // to go assigned to separate batches.
1183+ sweepReq1 := SweepRequest {
1184+ SwapHash : lntypes.Hash {1 , 1 , 1 },
1185+ Value : 111 ,
1186+ Outpoint : wire.OutPoint {
1187+ Hash : chainhash.Hash {1 , 1 },
1188+ Index : 1 ,
1189+ },
1190+ Notifier : & dummyNotifier ,
1191+ }
1192+
1193+ loopOut1 := & loopdb.LoopOut {
1194+ Loop : loopdb.Loop {
1195+ Hash : lntypes.Hash {1 , 1 , 1 },
1196+ },
1197+ Contract : & loopdb.LoopOutContract {
1198+ SwapContract : loopdb.SwapContract {
1199+ CltvExpiry : shortCltv ,
1200+ AmountRequested : 111 ,
1201+ },
1202+ SwapInvoice : swapInvoice ,
1203+ },
1204+ }
1205+
1206+ sweepReq2 := SweepRequest {
1207+ SwapHash : lntypes.Hash {2 , 2 , 2 },
1208+ Value : 222 ,
1209+ Outpoint : wire.OutPoint {
1210+ Hash : chainhash.Hash {2 , 2 },
1211+ Index : 2 ,
1212+ },
1213+ Notifier : & dummyNotifier ,
1214+ }
1215+
1216+ loopOut2 := & loopdb.LoopOut {
1217+ Loop : loopdb.Loop {
1218+ Hash : lntypes.Hash {2 , 2 , 2 },
1219+ },
1220+ Contract : & loopdb.LoopOutContract {
1221+ SwapContract : loopdb.SwapContract {
1222+ CltvExpiry : longCltv ,
1223+ AmountRequested : 222 ,
1224+ },
1225+ SwapInvoice : swapInvoice ,
1226+ },
1227+ }
1228+
1229+ store .putLoopOutSwap (sweepReq1 .SwapHash , loopOut1 )
1230+ store .putLoopOutSwap (sweepReq2 .SwapHash , loopOut2 )
1231+
1232+ // Deliver sweep request to batcher.
1233+ require .NoError (t , batcher .AddSweep (& sweepReq1 ))
1234+
1235+ // Since two batches were created we check that it registered for its
1236+ // primary sweep's spend.
1237+ <- lnd .RegisterSpendChannel
1238+
1239+ // Deliver the second sweep. It will go to a separate batch,
1240+ // since CltvExpiry values are distant enough.
1241+ require .NoError (t , batcher .AddSweep (& sweepReq2 ))
1242+ <- lnd .RegisterSpendChannel
1243+
1244+ // Once batcher receives sweep request it will eventually spin up
1245+ // batches.
1246+ require .Eventually (t , func () bool {
1247+ // Make sure that the sweep was stored and we have exactly one
1248+ // active batch.
1249+ return batcherStore .AssertSweepStored (sweepReq1 .SwapHash ) &&
1250+ batcherStore .AssertSweepStored (sweepReq2 .SwapHash ) &&
1251+ len (batcher .batches ) == 2
1252+ }, test .Timeout , eventuallyCheckFrequency )
1253+
1254+ // Change the second sweep so that it can be added to the first batch.
1255+ // Change CltvExpiry.
1256+ loopOut2 = & loopdb.LoopOut {
1257+ Loop : loopdb.Loop {
1258+ Hash : lntypes.Hash {2 , 2 , 2 },
1259+ },
1260+ Contract : & loopdb.LoopOutContract {
1261+ SwapContract : loopdb.SwapContract {
1262+ CltvExpiry : shortCltv ,
1263+ AmountRequested : 222 ,
1264+ },
1265+ SwapInvoice : swapInvoice ,
1266+ },
1267+ }
1268+ store .putLoopOutSwap (sweepReq2 .SwapHash , loopOut2 )
1269+
1270+ // Re-add the second sweep. It is expected to stay in second batch,
1271+ // not added to both batches.
1272+ require .NoError (t , batcher .AddSweep (& sweepReq2 ))
1273+
1274+ require .Eventually (t , func () bool {
1275+ // Make sure there are two batches.
1276+ batches := batcher .batches
1277+ if len (batches ) != 2 {
1278+ return false
1279+ }
1280+
1281+ // Make sure the second batch has the second sweep.
1282+ sweep2 , has := batches [1 ].sweeps [sweepReq2 .SwapHash ]
1283+ if ! has {
1284+ return false
1285+ }
1286+
1287+ // Make sure the second sweep's timeout has been updated.
1288+ if sweep2 .timeout != shortCltv {
1289+ return false
1290+ }
1291+
1292+ return true
1293+ }, test .Timeout , eventuallyCheckFrequency )
1294+
1295+ // Make sure each batch has one sweep. If the second sweep was added to
1296+ // both batches, the following check won't pass.
1297+ require .Equal (t , 1 , len (batcher .batches [0 ].sweeps ))
1298+ require .Equal (t , 1 , len (batcher .batches [1 ].sweeps ))
1299+
1300+ // Now make it quit by canceling the context.
1301+ cancel ()
1302+ wg .Wait ()
1303+
1304+ checkBatcherError (t , runErr )
1305+ }
0 commit comments