11package lnd
22
33import (
4- "net"
5- "sync/atomic"
6- "testing"
7- "time"
8-
9- "github.com/lightningnetwork/lnd/fn/v2"
104 "github.com/lightningnetwork/lnd/htlcswitch"
11- "github.com/lightningnetwork/lnd/lntypes"
125 "github.com/lightningnetwork/lnd/lnwallet"
136 "github.com/lightningnetwork/lnd/lnwire"
14- "github.com/lightningnetwork/lnd/peer"
15- "github.com/stretchr/testify/require"
16- )
17-
18- const (
19- timeout = time .Second * 5
207)
218
229// mockMessageSwitch is a mock implementation of the messageSwitch interface
@@ -36,195 +23,18 @@ func (m *mockMessageSwitch) CircuitModifier() htlcswitch.CircuitModifier {
3623}
3724
3825// RemoveLink currently does nothing.
39- func (m * mockMessageSwitch ) RemoveLink (cid lnwire.ChannelID ) {}
26+ func (m * mockMessageSwitch ) RemoveLink (lnwire.ChannelID ) {}
4027
4128// CreateAndAddLink currently returns a dummy value.
42- func (m * mockMessageSwitch ) CreateAndAddLink (cfg htlcswitch.ChannelLinkConfig ,
43- lnChan * lnwallet.LightningChannel ) error {
29+ func (m * mockMessageSwitch ) CreateAndAddLink (htlcswitch.ChannelLinkConfig ,
30+ * lnwallet.LightningChannel ) error {
4431
4532 return nil
4633}
4734
4835// GetLinksByInterface returns the active links.
49- func (m * mockMessageSwitch ) GetLinksByInterface (pub [33 ]byte ) (
36+ func (m * mockMessageSwitch ) GetLinksByInterface ([33 ]byte ) (
5037 []htlcswitch.ChannelUpdateHandler , error ) {
5138
5239 return m .links , nil
5340}
54-
55- // mockUpdateHandler is a mock implementation of the ChannelUpdateHandler
56- // interface. It is used in mockMessageSwitch's GetLinksByInterface method.
57- type mockUpdateHandler struct {
58- cid lnwire.ChannelID
59- isOutgoingAddBlocked atomic.Bool
60- isIncomingAddBlocked atomic.Bool
61- }
62-
63- // newMockUpdateHandler creates a new mockUpdateHandler.
64- func newMockUpdateHandler (cid lnwire.ChannelID ) * mockUpdateHandler {
65- return & mockUpdateHandler {
66- cid : cid ,
67- }
68- }
69-
70- // HandleChannelUpdate currently does nothing.
71- func (m * mockUpdateHandler ) HandleChannelUpdate (msg lnwire.Message ) {}
72-
73- // ChanID returns the mockUpdateHandler's cid.
74- func (m * mockUpdateHandler ) ChanID () lnwire.ChannelID { return m .cid }
75-
76- // Bandwidth currently returns a dummy value.
77- func (m * mockUpdateHandler ) Bandwidth () lnwire.MilliSatoshi { return 0 }
78-
79- // EligibleToForward currently returns a dummy value.
80- func (m * mockUpdateHandler ) EligibleToForward () bool { return false }
81-
82- // MayAddOutgoingHtlc currently returns nil.
83- func (m * mockUpdateHandler ) MayAddOutgoingHtlc (lnwire.MilliSatoshi ) error { return nil }
84-
85- type mockMessageConn struct {
86- t * testing.T
87-
88- // MessageConn embeds our interface so that the mock does not need to
89- // implement every function. The mock will panic if an unspecified function
90- // is called.
91- peer.MessageConn
92-
93- // writtenMessages is a channel that our mock pushes written messages into.
94- writtenMessages chan []byte
95-
96- readMessages chan []byte
97- curReadMessage []byte
98-
99- // writeRaceDetectingCounter is incremented on any function call
100- // associated with writing to the connection. The race detector will
101- // trigger on this counter if a data race exists.
102- writeRaceDetectingCounter int
103-
104- // readRaceDetectingCounter is incremented on any function call
105- // associated with reading from the connection. The race detector will
106- // trigger on this counter if a data race exists.
107- readRaceDetectingCounter int
108- }
109-
110- func (m * mockUpdateHandler ) EnableAdds (dir htlcswitch.LinkDirection ) bool {
111- if dir == htlcswitch .Outgoing {
112- return m .isOutgoingAddBlocked .Swap (false )
113- }
114-
115- return m .isIncomingAddBlocked .Swap (false )
116- }
117-
118- func (m * mockUpdateHandler ) DisableAdds (dir htlcswitch.LinkDirection ) bool {
119- if dir == htlcswitch .Outgoing {
120- return ! m .isOutgoingAddBlocked .Swap (true )
121- }
122-
123- return ! m .isIncomingAddBlocked .Swap (true )
124- }
125-
126- func (m * mockUpdateHandler ) IsFlushing (dir htlcswitch.LinkDirection ) bool {
127- switch dir {
128- case htlcswitch .Outgoing :
129- return m .isOutgoingAddBlocked .Load ()
130- case htlcswitch .Incoming :
131- return m .isIncomingAddBlocked .Load ()
132- }
133-
134- return false
135- }
136-
137- func (m * mockUpdateHandler ) OnFlushedOnce (hook func ()) {
138- hook ()
139- }
140- func (m * mockUpdateHandler ) OnCommitOnce (
141- _ htlcswitch.LinkDirection , hook func (),
142- ) {
143-
144- hook ()
145- }
146- func (m * mockUpdateHandler ) InitStfu () <- chan fn.Result [lntypes.ChannelParty ] {
147- // TODO(proofofkeags): Implement
148- c := make (chan fn.Result [lntypes.ChannelParty ], 1 )
149-
150- c <- fn.Errf [lntypes.ChannelParty ]("InitStfu not yet implemented" )
151-
152- return c
153- }
154-
155- func newMockConn (t * testing.T , expectedMessages int ) * mockMessageConn {
156- return & mockMessageConn {
157- t : t ,
158- writtenMessages : make (chan []byte , expectedMessages ),
159- readMessages : make (chan []byte , 1 ),
160- }
161- }
162-
163- // SetWriteDeadline mocks setting write deadline for our conn.
164- func (m * mockMessageConn ) SetWriteDeadline (time.Time ) error {
165- m .writeRaceDetectingCounter ++
166- return nil
167- }
168-
169- // Flush mocks a message conn flush.
170- func (m * mockMessageConn ) Flush () (int , error ) {
171- m .writeRaceDetectingCounter ++
172- return 0 , nil
173- }
174-
175- // WriteMessage mocks sending of a message on our connection. It will push
176- // the bytes sent into the mock's writtenMessages channel.
177- func (m * mockMessageConn ) WriteMessage (msg []byte ) error {
178- m .writeRaceDetectingCounter ++
179-
180- msgCopy := make ([]byte , len (msg ))
181- copy (msgCopy , msg )
182-
183- select {
184- case m .writtenMessages <- msgCopy :
185- case <- time .After (timeout ):
186- m .t .Fatalf ("timeout sending message: %v" , msgCopy )
187- }
188-
189- return nil
190- }
191-
192- // assertWrite asserts that our mock as had WriteMessage called with the byte
193- // slice we expect.
194- func (m * mockMessageConn ) assertWrite (expected []byte ) {
195- select {
196- case actual := <- m .writtenMessages :
197- require .Equal (m .t , expected , actual )
198-
199- case <- time .After (timeout ):
200- m .t .Fatalf ("timeout waiting for write: %v" , expected )
201- }
202- }
203-
204- func (m * mockMessageConn ) SetReadDeadline (t time.Time ) error {
205- m .readRaceDetectingCounter ++
206- return nil
207- }
208-
209- func (m * mockMessageConn ) ReadNextHeader () (uint32 , error ) {
210- m .readRaceDetectingCounter ++
211- m .curReadMessage = <- m .readMessages
212- return uint32 (len (m .curReadMessage )), nil
213- }
214-
215- func (m * mockMessageConn ) ReadNextBody (buf []byte ) ([]byte , error ) {
216- m .readRaceDetectingCounter ++
217- return m .curReadMessage , nil
218- }
219-
220- func (m * mockMessageConn ) RemoteAddr () net.Addr {
221- return nil
222- }
223-
224- func (m * mockMessageConn ) LocalAddr () net.Addr {
225- return nil
226- }
227-
228- func (m * mockMessageConn ) Close () error {
229- return nil
230- }
0 commit comments