Skip to content

Commit d7b580d

Browse files
authored
fix: protect against send on closed channel in protocols (#783)
1 parent b8c2661 commit d7b580d

File tree

4 files changed

+60
-0
lines changed

4 files changed

+60
-0
lines changed

protocol/blockfetch/client.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ func (c *Client) handleStartBatch() error {
206206
"role", "client",
207207
"connection_id", c.callbackContext.ConnectionId.String(),
208208
)
209+
// Check for shutdown
210+
select {
211+
case <-c.Protocol.DoneChan():
212+
return protocol.ProtocolShuttingDownError
213+
default:
214+
}
209215
c.startBatchResultChan <- nil
210216
return nil
211217
}
@@ -218,6 +224,12 @@ func (c *Client) handleNoBlocks() error {
218224
"role", "client",
219225
"connection_id", c.callbackContext.ConnectionId.String(),
220226
)
227+
// Check for shutdown
228+
select {
229+
case <-c.Protocol.DoneChan():
230+
return protocol.ProtocolShuttingDownError
231+
default:
232+
}
221233
err := fmt.Errorf("block(s) not found")
222234
c.startBatchResultChan <- err
223235
return nil
@@ -244,6 +256,12 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error {
244256
if err != nil {
245257
return err
246258
}
259+
// Check for shutdown
260+
select {
261+
case <-c.Protocol.DoneChan():
262+
return protocol.ProtocolShuttingDownError
263+
default:
264+
}
247265
// We use the callback when requesting ranges and the internal channel for a single block
248266
if c.blockUseCallback {
249267
if err := c.config.BlockFunc(c.callbackContext, wrappedBlock.Type, blk); err != nil {

protocol/localstatequery/client.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,12 @@ func (c *Client) handleAcquired() error {
849849
"role", "client",
850850
"connection_id", c.callbackContext.ConnectionId.String(),
851851
)
852+
// Check for shutdown
853+
select {
854+
case <-c.Protocol.DoneChan():
855+
return protocol.ProtocolShuttingDownError
856+
default:
857+
}
852858
c.acquired = true
853859
c.acquireResultChan <- nil
854860
c.currentEra = -1
@@ -863,6 +869,12 @@ func (c *Client) handleFailure(msg protocol.Message) error {
863869
"role", "client",
864870
"connection_id", c.callbackContext.ConnectionId.String(),
865871
)
872+
// Check for shutdown
873+
select {
874+
case <-c.Protocol.DoneChan():
875+
return protocol.ProtocolShuttingDownError
876+
default:
877+
}
866878
msgFailure := msg.(*MsgFailure)
867879
switch msgFailure.Failure {
868880
case AcquireFailurePointTooOld:
@@ -883,6 +895,12 @@ func (c *Client) handleResult(msg protocol.Message) error {
883895
"role", "client",
884896
"connection_id", c.callbackContext.ConnectionId.String(),
885897
)
898+
// Check for shutdown
899+
select {
900+
case <-c.Protocol.DoneChan():
901+
return protocol.ProtocolShuttingDownError
902+
default:
903+
}
886904
msgResult := msg.(*MsgResult)
887905
c.queryResultChan <- msgResult.Result
888906
return nil

protocol/localtxsubmission/client.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ func (c *Client) handleAcceptTx() error {
155155
"role", "client",
156156
"connection_id", c.callbackContext.ConnectionId.String(),
157157
)
158+
// Check for shutdown
159+
select {
160+
case <-c.Protocol.DoneChan():
161+
return protocol.ProtocolShuttingDownError
162+
default:
163+
}
158164
c.submitResultChan <- nil
159165
return nil
160166
}
@@ -167,6 +173,12 @@ func (c *Client) handleRejectTx(msg protocol.Message) error {
167173
"role", "client",
168174
"connection_id", c.callbackContext.ConnectionId.String(),
169175
)
176+
// Check for shutdown
177+
select {
178+
case <-c.Protocol.DoneChan():
179+
return protocol.ProtocolShuttingDownError
180+
default:
181+
}
170182
msgRejectTx := msg.(*MsgRejectTx)
171183
rejectErr, err := ledger.NewTxSubmitErrorFromCbor(msgRejectTx.Reason)
172184
if err != nil {

protocol/txsubmission/server.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ func (s *Server) handleReplyTxIds(msg protocol.Message) error {
163163
"role", "server",
164164
"connection_id", s.callbackContext.ConnectionId.String(),
165165
)
166+
// Check for shutdown
167+
select {
168+
case <-s.Protocol.DoneChan():
169+
return protocol.ProtocolShuttingDownError
170+
default:
171+
}
166172
msgReplyTxIds := msg.(*MsgReplyTxIds)
167173
s.requestTxIdsResultChan <- msgReplyTxIds.TxIds
168174
return nil
@@ -176,6 +182,12 @@ func (s *Server) handleReplyTxs(msg protocol.Message) error {
176182
"role", "server",
177183
"connection_id", s.callbackContext.ConnectionId.String(),
178184
)
185+
// Check for shutdown
186+
select {
187+
case <-s.Protocol.DoneChan():
188+
return protocol.ProtocolShuttingDownError
189+
default:
190+
}
179191
msgReplyTxs := msg.(*MsgReplyTxs)
180192
s.requestTxsResultChan <- msgReplyTxs.Txs
181193
return nil

0 commit comments

Comments
 (0)