@@ -12,8 +12,10 @@ import (
1212)
1313
1414type mockChainNotifier struct {
15- lnd * LndMockServices
16- wg sync.WaitGroup
15+ sync.Mutex
16+ lnd * LndMockServices
17+ confRegistrations []* ConfRegistration
18+ wg sync.WaitGroup
1719}
1820
1921// SpendRegistration contains registration details.
@@ -29,6 +31,7 @@ type ConfRegistration struct {
2931 PkScript []byte
3032 HeightHint int32
3133 NumConfs int32
34+ ConfChan chan * chainntnfs.TxConfirmation
3235}
3336
3437func (c * mockChainNotifier ) RegisterSpendNtfn (ctx context.Context ,
@@ -103,7 +106,18 @@ func (c *mockChainNotifier) RegisterConfirmationsNtfn(ctx context.Context,
103106 txid * chainhash.Hash , pkScript []byte , numConfs , heightHint int32 ) (
104107 chan * chainntnfs.TxConfirmation , chan error , error ) {
105108
106- confChan := make (chan * chainntnfs.TxConfirmation , 1 )
109+ reg := & ConfRegistration {
110+ PkScript : pkScript ,
111+ TxID : txid ,
112+ HeightHint : heightHint ,
113+ NumConfs : numConfs ,
114+ ConfChan : make (chan * chainntnfs.TxConfirmation , 1 ),
115+ }
116+
117+ c .Lock ()
118+ c .confRegistrations = append (c .confRegistrations , reg )
119+ c .Unlock ()
120+
107121 errChan := make (chan error , 1 )
108122
109123 c .wg .Add (1 )
@@ -112,26 +126,35 @@ func (c *mockChainNotifier) RegisterConfirmationsNtfn(ctx context.Context,
112126
113127 select {
114128 case m := <- c .lnd .ConfChannel :
115- if bytes .Equal (m .Tx .TxOut [0 ].PkScript , pkScript ) {
116- select {
117- case confChan <- m :
118- case <- ctx .Done ():
129+ c .Lock ()
130+ for i := 0 ; i < len (c .confRegistrations ); i ++ {
131+ r := c .confRegistrations [i ]
132+
133+ // Whichever conf notifier catches the confirmation
134+ // will forward it to all matching subscibers.
135+ if bytes .Equal (m .Tx .TxOut [0 ].PkScript , r .PkScript ) {
136+ // Unregister the "notifier".
137+ c .confRegistrations = append (
138+ c .confRegistrations [:i ], c .confRegistrations [i + 1 :]... ,
139+ )
140+ i --
141+
142+ select {
143+ case r .ConfChan <- m :
144+ case <- ctx .Done ():
145+ }
119146 }
120147 }
148+ c .Unlock ()
121149 case <- ctx .Done ():
122150 }
123151 }()
124152
125153 select {
126- case c .lnd .RegisterConfChannel <- & ConfRegistration {
127- PkScript : pkScript ,
128- TxID : txid ,
129- HeightHint : heightHint ,
130- NumConfs : numConfs ,
131- }:
154+ case c .lnd .RegisterConfChannel <- reg :
132155 case <- time .After (Timeout ):
133156 return nil , nil , ErrTimeout
134157 }
135158
136- return confChan , errChan , nil
159+ return reg . ConfChan , errChan , nil
137160}
0 commit comments