@@ -310,12 +310,12 @@ type MockChainBridge struct {
310310
311311 NewBlocks chan int32
312312
313- ReqCount int
313+ ReqCount atomic. Int32
314314 ConfReqs map [int ]* chainntnfs.ConfirmationEvent
315315
316316 failFeeEstimates atomic.Bool
317- emptyConf bool
318- errConf bool
317+ errConf atomic. Int32
318+ emptyConf atomic. Int32
319319 confErr chan error
320320}
321321
@@ -334,19 +334,30 @@ func (m *MockChainBridge) FailFeeEstimatesOnce() {
334334 m .failFeeEstimates .Store (true )
335335}
336336
337- func (m * MockChainBridge ) FailConf (enable bool ) {
338- m .errConf = enable
337+ // FailConfOnce updates the ChainBridge such that the next call to
338+ // RegisterConfirmationNtfn will fail by returning an error on the error channel
339+ // returned from RegisterConfirmationNtfn.
340+ func (m * MockChainBridge ) FailConfOnce () {
341+ // Store the incremented request count so we never store 0 as a value.
342+ m .errConf .Store (m .ReqCount .Load () + 1 )
339343}
340- func (m * MockChainBridge ) EmptyConf (enable bool ) {
341- m .emptyConf = enable
344+
345+ // EmptyConfOnce updates the ChainBridge such that the next confirmation event
346+ // sent via SendConfNtfn will have an empty confirmation.
347+ func (m * MockChainBridge ) EmptyConfOnce () {
348+ // Store the incremented request count so we never store 0 as a value.
349+ m .emptyConf .Store (m .ReqCount .Load () + 1 )
342350}
343351
344352func (m * MockChainBridge ) SendConfNtfn (reqNo int , blockHash * chainhash.Hash ,
345353 blockHeight , blockIndex int , block * wire.MsgBlock ,
346354 tx * wire.MsgTx ) {
347355
356+ // Compare to the incremented request count since we incremented it
357+ // when storing the request number.
348358 req := m .ConfReqs [reqNo ]
349- if m .emptyConf {
359+ if m .emptyConf .Load () == int32 (reqNo )+ 1 {
360+ m .emptyConf .Store (0 )
350361 req .Confirmed <- nil
351362 return
352363 }
@@ -371,7 +382,7 @@ func (m *MockChainBridge) RegisterConfirmationsNtfn(ctx context.Context,
371382 }
372383
373384 defer func () {
374- m .ReqCount ++
385+ m .ReqCount . Add ( 1 )
375386 }()
376387
377388 req := & chainntnfs.ConfirmationEvent {
@@ -380,15 +391,18 @@ func (m *MockChainBridge) RegisterConfirmationsNtfn(ctx context.Context,
380391 }
381392 m .confErr = make (chan error , 1 )
382393
383- m .ConfReqs [m .ReqCount ] = req
394+ currentReqCount := m .ReqCount .Load ()
395+ m .ConfReqs [int (currentReqCount )] = req
384396
385397 select {
386- case m .ConfReqSignal <- m . ReqCount :
398+ case m .ConfReqSignal <- int ( currentReqCount ) :
387399 case <- ctx .Done ():
388400 }
389401
390- if m .errConf {
391- m .confErr <- fmt .Errorf ("confirmation error" )
402+ // Compare to the incremented request count since we incremented it
403+ // when storing the request number.
404+ if m .errConf .CompareAndSwap (currentReqCount + 1 , 0 ) {
405+ m .confErr <- fmt .Errorf ("confirmation registration error" )
392406 }
393407
394408 return req , m .confErr , nil
@@ -661,7 +675,7 @@ func (m *MockKeyRing) IsLocalKey(context.Context, keychain.KeyDescriptor) bool {
661675
662676type MockGenSigner struct {
663677 KeyRing * MockKeyRing
664- FailSigning bool
678+ failSigning atomic. Bool
665679}
666680
667681func NewMockGenSigner (keyRing * MockKeyRing ) * MockGenSigner {
@@ -670,11 +684,17 @@ func NewMockGenSigner(keyRing *MockKeyRing) *MockGenSigner {
670684 }
671685}
672686
687+ // FailSigningOnce updates the GenSigner such that the next call to
688+ // SignVirtualTx will fail by returning an error.
689+ func (m * MockGenSigner ) FailSigningOnce () {
690+ m .failSigning .Store (true )
691+ }
692+
673693func (m * MockGenSigner ) SignVirtualTx (signDesc * lndclient.SignDescriptor ,
674694 virtualTx * wire.MsgTx , prevOut * wire.TxOut ) (* schnorr.Signature ,
675695 error ) {
676696
677- if m .FailSigning {
697+ if m .failSigning . CompareAndSwap ( true , false ) {
678698 return nil , fmt .Errorf ("failed to sign virtual tx" )
679699 }
680700
0 commit comments