Skip to content

Commit f475a01

Browse files
committed
Merge pull request #1261 from fjl/p2p-no-writes-at-shutdown
p2p: prevent writes at shutdown time
2 parents 2639033 + 70da79f commit f475a01

File tree

2 files changed

+61
-28
lines changed

2 files changed

+61
-28
lines changed

p2p/peer.go

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -115,41 +115,60 @@ func newPeer(conn *conn, protocols []Protocol) *Peer {
115115
}
116116

117117
func (p *Peer) run() DiscReason {
118-
readErr := make(chan error, 1)
118+
var (
119+
writeStart = make(chan struct{}, 1)
120+
writeErr = make(chan error, 1)
121+
readErr = make(chan error, 1)
122+
reason DiscReason
123+
requested bool
124+
)
119125
p.wg.Add(2)
120126
go p.readLoop(readErr)
121127
go p.pingLoop()
122128

123-
p.startProtocols()
129+
// Start all protocol handlers.
130+
writeStart <- struct{}{}
131+
p.startProtocols(writeStart, writeErr)
124132

125133
// Wait for an error or disconnect.
126-
var (
127-
reason DiscReason
128-
requested bool
129-
)
130-
select {
131-
case err := <-readErr:
132-
if r, ok := err.(DiscReason); ok {
133-
reason = r
134-
} else {
135-
// Note: We rely on protocols to abort if there is a write
136-
// error. It might be more robust to handle them here as well.
137-
glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
138-
reason = DiscNetworkError
134+
loop:
135+
for {
136+
select {
137+
case err := <-writeErr:
138+
// A write finished. Allow the next write to start if
139+
// there was no error.
140+
if err != nil {
141+
glog.V(logger.Detail).Infof("%v: write error: %v\n", p, err)
142+
reason = DiscNetworkError
143+
break loop
144+
}
145+
writeStart <- struct{}{}
146+
case err := <-readErr:
147+
if r, ok := err.(DiscReason); ok {
148+
glog.V(logger.Debug).Infof("%v: remote requested disconnect: %v\n", p, r)
149+
requested = true
150+
reason = r
151+
} else {
152+
glog.V(logger.Detail).Infof("%v: read error: %v\n", p, err)
153+
reason = DiscNetworkError
154+
}
155+
break loop
156+
case err := <-p.protoErr:
157+
reason = discReasonForError(err)
158+
glog.V(logger.Debug).Infof("%v: protocol error: %v (%v)\n", p, err, reason)
159+
break loop
160+
case reason = <-p.disc:
161+
glog.V(logger.Debug).Infof("%v: locally requested disconnect: %v\n", p, reason)
162+
break loop
139163
}
140-
case err := <-p.protoErr:
141-
reason = discReasonForError(err)
142-
case reason = <-p.disc:
143-
requested = true
144164
}
165+
145166
close(p.closed)
146167
p.rw.close(reason)
147168
p.wg.Wait()
148-
149169
if requested {
150170
reason = DiscRequested
151171
}
152-
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
153172
return reason
154173
}
155174

@@ -196,7 +215,6 @@ func (p *Peer) handle(msg Msg) error {
196215
// This is the last message. We don't need to discard or
197216
// check errors because, the connection will be closed after it.
198217
rlp.Decode(msg.Payload, &reason)
199-
glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0])
200218
return reason[0]
201219
case msg.Code < baseProtocolLength:
202220
// ignore other base protocol messages
@@ -247,11 +265,13 @@ outer:
247265
return result
248266
}
249267

250-
func (p *Peer) startProtocols() {
268+
func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
251269
p.wg.Add(len(p.running))
252270
for _, proto := range p.running {
253271
proto := proto
254272
proto.closed = p.closed
273+
proto.wstart = writeStart
274+
proto.werr = writeErr
255275
glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
256276
go func() {
257277
err := proto.Run(p, proto)
@@ -280,18 +300,31 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
280300

281301
type protoRW struct {
282302
Protocol
283-
in chan Msg
284-
closed <-chan struct{}
303+
in chan Msg // receices read messages
304+
closed <-chan struct{} // receives when peer is shutting down
305+
wstart <-chan struct{} // receives when write may start
306+
werr chan<- error // for write results
285307
offset uint64
286308
w MsgWriter
287309
}
288310

289-
func (rw *protoRW) WriteMsg(msg Msg) error {
311+
func (rw *protoRW) WriteMsg(msg Msg) (err error) {
290312
if msg.Code >= rw.Length {
291313
return newPeerError(errInvalidMsgCode, "not handled")
292314
}
293315
msg.Code += rw.offset
294-
return rw.w.WriteMsg(msg)
316+
select {
317+
case <-rw.wstart:
318+
err = rw.w.WriteMsg(msg)
319+
// Report write status back to Peer.run. It will initiate
320+
// shutdown if the error is non-nil and unblock the next write
321+
// otherwise. The calling protocol code should exit for errors
322+
// as well but we don't want to rely on that.
323+
rw.werr <- err
324+
case <-rw.closed:
325+
err = fmt.Errorf("shutting down")
326+
}
327+
return err
295328
}
296329

297330
func (rw *protoRW) ReadMsg() (Msg, error) {

p2p/peer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func TestPeerDisconnect(t *testing.T) {
121121
}
122122
select {
123123
case reason := <-disc:
124-
if reason != DiscQuitting {
124+
if reason != DiscRequested {
125125
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
126126
}
127127
case <-time.After(500 * time.Millisecond):

0 commit comments

Comments
 (0)