Skip to content

Commit 26e239c

Browse files
committed
loop+test: enhance epoch subscription for multiple subscribers
1 parent b43fa11 commit 26e239c

File tree

4 files changed

+62
-27
lines changed

4 files changed

+62
-27
lines changed

test/chainnotifier_mock.go

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,31 +73,40 @@ func (c *mockChainNotifier) RegisterBlockEpochNtfn(ctx context.Context) (
7373
chan int32, chan error, error) {
7474

7575
blockErrorChan := make(chan error, 1)
76-
blockEpochChan := make(chan int32)
76+
blockEpochChan := make(chan int32, 1)
77+
78+
c.lnd.lock.Lock()
79+
c.lnd.blockHeightListeners = append(
80+
c.lnd.blockHeightListeners, blockEpochChan,
81+
)
82+
c.lnd.lock.Unlock()
7783

7884
c.wg.Add(1)
7985
go func() {
8086
defer c.wg.Done()
87+
defer func() {
88+
c.lnd.lock.Lock()
89+
defer c.lnd.lock.Unlock()
90+
for i := 0; i < len(c.lnd.blockHeightListeners); i++ {
91+
if c.lnd.blockHeightListeners[i] == blockEpochChan {
92+
c.lnd.blockHeightListeners = append(
93+
c.lnd.blockHeightListeners[:i],
94+
c.lnd.blockHeightListeners[i+1:]...,
95+
)
96+
break
97+
}
98+
}
99+
}()
81100

82101
// Send initial block height
102+
c.lnd.lock.Lock()
83103
select {
84104
case blockEpochChan <- c.lnd.Height:
85105
case <-ctx.Done():
86-
return
87106
}
107+
c.lnd.lock.Unlock()
88108

89-
for {
90-
select {
91-
case m := <-c.lnd.epochChannel:
92-
select {
93-
case blockEpochChan <- m:
94-
case <-ctx.Done():
95-
return
96-
}
97-
case <-ctx.Done():
98-
return
99-
}
100-
}
109+
<-ctx.Done()
101110
}()
102111

103112
return blockEpochChan, blockErrorChan, nil

test/context.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,9 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx,
259259
func (ctx *Context) NotifyServerHeight(height int32) {
260260
require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height))
261261
}
262+
263+
func (ctx *Context) AssertEpochListeners(numListeners int32) {
264+
require.Eventually(ctx.T, func() bool {
265+
return ctx.Lnd.EpochSubscribers() == numListeners
266+
}, Timeout, time.Millisecond*250)
267+
}

test/lnd_services_mock.go

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"errors"
66
"sync"
7-
"time"
87

98
"github.com/btcsuite/btcd/chaincfg"
109
"github.com/btcsuite/btcd/wire"
@@ -63,13 +62,13 @@ func NewMockLnd() *LndMockServices {
6362

6463
SignOutputRawChannel: make(chan SignOutputRawRequest),
6564

66-
FailInvoiceChannel: make(chan lntypes.Hash, 2),
67-
epochChannel: make(chan int32),
68-
Height: testStartingHeight,
69-
NodePubkey: testNodePubkey,
70-
Signature: testSignature,
71-
SignatureMsg: testSignatureMsg,
72-
Invoices: make(map[lntypes.Hash]*lndclient.Invoice),
65+
FailInvoiceChannel: make(chan lntypes.Hash, 2),
66+
blockHeightListeners: make([]chan int32, 0),
67+
Height: testStartingHeight,
68+
NodePubkey: testNodePubkey,
69+
Signature: testSignature,
70+
SignatureMsg: testSignatureMsg,
71+
Invoices: make(map[lntypes.Hash]*lndclient.Invoice),
7372
}
7473

7574
lightningClient.lnd = &lnd
@@ -139,7 +138,7 @@ type LndMockServices struct {
139138
SendOutputsChannel chan wire.MsgTx
140139
SettleInvoiceChannel chan lntypes.Preimage
141140
FailInvoiceChannel chan lntypes.Hash
142-
epochChannel chan int32
141+
blockHeightListeners []chan int32
143142

144143
ConfChannel chan *chainntnfs.TxConfirmation
145144
RegisterConfChannel chan *ConfRegistration
@@ -177,15 +176,28 @@ type LndMockServices struct {
177176
lock sync.Mutex
178177
}
179178

179+
// EpochSubscribers returns the number of subscribers to block epoch
180+
// notifications.
181+
func (s *LndMockServices) EpochSubscribers() int32 {
182+
s.lock.Lock()
183+
defer s.lock.Unlock()
184+
185+
return int32(len(s.blockHeightListeners))
186+
}
187+
180188
// NotifyHeight notifies a new block height.
181189
func (s *LndMockServices) NotifyHeight(height int32) error {
190+
s.lock.Lock()
191+
defer s.lock.Unlock()
182192
s.Height = height
183193

184-
select {
185-
case s.epochChannel <- height:
186-
case <-time.After(Timeout):
187-
return ErrTimeout
194+
for _, listener := range s.blockHeightListeners {
195+
lis := listener
196+
go func() {
197+
lis <- height
198+
}()
188199
}
200+
189201
return nil
190202
}
191203

testcontext_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,11 @@ func (ctx *testContext) assertPreimagePush(preimage lntypes.Preimage) {
250250
ctx.Context.T.Fatalf("preimage not pushed")
251251
}
252252
}
253+
254+
func (ctx *testContext) AssertEpochListeners(numListeners int32) {
255+
ctx.Context.T.Helper()
256+
257+
require.Eventually(ctx.Context.T, func() bool {
258+
return ctx.Lnd.EpochSubscribers() == numListeners
259+
}, test.Timeout, time.Millisecond*250)
260+
}

0 commit comments

Comments
 (0)