Skip to content

Commit c1c2507

Browse files
lorenzo-dev1fjl
andauthored
p2p: fix DiscReason encoding/decoding (#30855)
This fixes an issue where the disconnect message was not wrapped in a list. The specification requires it to be a list like any other message. In order to remain compatible with legacy geth versions, we now accept both encodings when parsing a disconnect message. --------- Co-authored-by: Felix Lange <[email protected]>
1 parent c7e740f commit c1c2507

File tree

4 files changed

+44
-20
lines changed

4 files changed

+44
-20
lines changed

p2p/peer.go

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,7 @@ func (p *Peer) handle(msg Msg) error {
345345
case msg.Code == discMsg:
346346
// This is the last message. We don't need to discard or
347347
// check errors because, the connection will be closed after it.
348-
var m struct{ R DiscReason }
349-
rlp.Decode(msg.Payload, &m)
350-
return m.R
348+
return decodeDisconnectMessage(msg.Payload)
351349
case msg.Code < baseProtocolLength:
352350
// ignore other base protocol messages
353351
return msg.Discard()
@@ -372,6 +370,27 @@ func (p *Peer) handle(msg Msg) error {
372370
return nil
373371
}
374372

373+
// decodeDisconnectMessage decodes the payload of discMsg.
374+
func decodeDisconnectMessage(r io.Reader) (reason DiscReason) {
375+
s := rlp.NewStream(r, 100)
376+
k, _, err := s.Kind()
377+
if err != nil {
378+
return DiscInvalid
379+
}
380+
if k == rlp.List {
381+
s.List()
382+
err = s.Decode(&reason)
383+
} else {
384+
// Legacy path: some implementations, including geth, used to send the disconnect
385+
// reason as a byte array by accident.
386+
err = s.Decode(&reason)
387+
}
388+
if err != nil {
389+
reason = DiscInvalid
390+
}
391+
return reason
392+
}
393+
375394
func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
376395
n := 0
377396
for _, cap := range caps {

p2p/peer_error.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ const (
7070
DiscSelf
7171
DiscReadTimeout
7272
DiscSubprotocolError = DiscReason(0x10)
73+
74+
DiscInvalid = 0xff
7375
)
7476

7577
var discReasonToString = [...]string{
@@ -86,10 +88,11 @@ var discReasonToString = [...]string{
8688
DiscSelf: "connected to self",
8789
DiscReadTimeout: "read timeout",
8890
DiscSubprotocolError: "subprotocol error",
91+
DiscInvalid: "invalid disconnect reason",
8992
}
9093

9194
func (d DiscReason) String() string {
92-
if len(discReasonToString) <= int(d) {
95+
if len(discReasonToString) <= int(d) || discReasonToString[d] == "" {
9396
return fmt.Sprintf("unknown disconnect reason %d", d)
9497
}
9598
return discReasonToString[d]

p2p/transport.go

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,14 @@ func (t *rlpxTransport) close(err error) {
113113
// Tell the remote end why we're disconnecting if possible.
114114
// We only bother doing this if the underlying connection supports
115115
// setting a timeout tough.
116-
if t.conn != nil {
117-
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
118-
deadline := time.Now().Add(discWriteTimeout)
119-
if err := t.conn.SetWriteDeadline(deadline); err == nil {
120-
// Connection supports write deadline.
121-
t.wbuf.Reset()
122-
rlp.Encode(&t.wbuf, []DiscReason{r})
123-
t.conn.Write(discMsg, t.wbuf.Bytes())
124-
}
116+
if reason, ok := err.(DiscReason); ok && reason != DiscNetworkError {
117+
// We do not use the WriteMsg func since we want a custom deadline
118+
deadline := time.Now().Add(discWriteTimeout)
119+
if err := t.conn.SetWriteDeadline(deadline); err == nil {
120+
// Connection supports write deadline.
121+
t.wbuf.Reset()
122+
rlp.Encode(&t.wbuf, []any{reason})
123+
t.conn.Write(discMsg, t.wbuf.Bytes())
125124
}
126125
}
127126
t.conn.Close()
@@ -163,11 +162,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
163162
if msg.Code == discMsg {
164163
// Disconnect before protocol handshake is valid according to the
165164
// spec and we send it ourself if the post-handshake checks fail.
166-
// We can't return the reason directly, though, because it is echoed
167-
// back otherwise. Wrap it in a string instead.
168-
var reason [1]DiscReason
169-
rlp.Decode(msg.Payload, &reason)
170-
return nil, reason[0]
165+
r := decodeDisconnectMessage(msg.Payload)
166+
return nil, r
171167
}
172168
if msg.Code != handshakeMsg {
173169
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)

p2p/transport_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestProtocolHandshake(t *testing.T) {
9797
return
9898
}
9999

100-
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
100+
if err := ExpectMsg(rlpx, discMsg, []any{DiscQuitting}); err != nil {
101101
t.Errorf("error receiving disconnect: %v", err)
102102
}
103103
}()
@@ -112,7 +112,13 @@ func TestProtocolHandshakeErrors(t *testing.T) {
112112
}{
113113
{
114114
code: discMsg,
115-
msg: []DiscReason{DiscQuitting},
115+
msg: []any{DiscQuitting},
116+
err: DiscQuitting,
117+
},
118+
{
119+
// legacy disconnect encoding as byte array
120+
code: discMsg,
121+
msg: []byte{byte(DiscQuitting)},
116122
err: DiscQuitting,
117123
},
118124
{

0 commit comments

Comments
 (0)