Skip to content

Commit 5b740c2

Browse files
authored
Read Limit Fix (#537)
This fix addresses a potential denial-of-service (DoS) vector that can cause an integer overflow in the presence of malicious WebSocket frames. The fix adds additional checks against the remaining bytes on a connection, as well as a test to prevent regression. Credit to Max Justicz (https://justi.cz/) for discovering and reporting this, as well as providing a robust PoC and review. * build: go.mod to go1.12 * bugfix: fix DoS vector caused by readLimit bypass * test: update TestReadLimit sub-test * bugfix: payload length 127 should read bytes as uint64 * bugfix: defend against readLength overflows
1 parent 7e9819d commit 5b740c2

File tree

3 files changed

+138
-37
lines changed

3 files changed

+138
-37
lines changed

conn.go

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,12 @@ type Conn struct {
260260
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
261261

262262
// Read fields
263-
reader io.ReadCloser // the current reader returned to the application
264-
readErr error
265-
br *bufio.Reader
266-
readRemaining int64 // bytes remaining in current frame.
263+
reader io.ReadCloser // the current reader returned to the application
264+
readErr error
265+
br *bufio.Reader
266+
// bytes remaining in current frame.
267+
// set setReadRemaining to safely update this value and prevent overflow
268+
readRemaining int64
267269
readFinal bool // true the current message has more frames.
268270
readLength int64 // Message size.
269271
readLimit int64 // Maximum message size.
@@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
320322
return c
321323
}
322324

325+
// setReadRemaining tracks the number of bytes remaining on the connection. If n
326+
// overflows, an ErrReadLimit is returned.
327+
func (c *Conn) setReadRemaining(n int64) error {
328+
if n < 0 {
329+
return ErrReadLimit
330+
}
331+
332+
c.readRemaining = n
333+
return nil
334+
}
335+
323336
// Subprotocol returns the negotiated protocol for the connection.
324337
func (c *Conn) Subprotocol() string {
325338
return c.subprotocol
@@ -790,7 +803,7 @@ func (c *Conn) advanceFrame() (int, error) {
790803
final := p[0]&finalBit != 0
791804
frameType := int(p[0] & 0xf)
792805
mask := p[1]&maskBit != 0
793-
c.readRemaining = int64(p[1] & 0x7f)
806+
c.setReadRemaining(int64(p[1] & 0x7f))
794807

795808
c.readDecompress = false
796809
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
@@ -824,21 +837,37 @@ func (c *Conn) advanceFrame() (int, error) {
824837
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
825838
}
826839

827-
// 3. Read and parse frame length.
840+
// 3. Read and parse frame length as per
841+
// https://tools.ietf.org/html/rfc6455#section-5.2
842+
//
843+
// The length of the "Payload data", in bytes: if 0-125, that is the payload
844+
// length.
845+
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
846+
// integer are the payload length.
847+
// - If 127, the following 8 bytes interpreted as
848+
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
849+
// payload length. Multibyte length quantities are expressed in network byte
850+
// order.
828851

829852
switch c.readRemaining {
830853
case 126:
831854
p, err := c.read(2)
832855
if err != nil {
833856
return noFrame, err
834857
}
835-
c.readRemaining = int64(binary.BigEndian.Uint16(p))
858+
859+
if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
860+
return noFrame, err
861+
}
836862
case 127:
837863
p, err := c.read(8)
838864
if err != nil {
839865
return noFrame, err
840866
}
841-
c.readRemaining = int64(binary.BigEndian.Uint64(p))
867+
868+
if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
869+
return noFrame, err
870+
}
842871
}
843872

844873
// 4. Handle frame masking.
@@ -861,6 +890,12 @@ func (c *Conn) advanceFrame() (int, error) {
861890
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
862891

863892
c.readLength += c.readRemaining
893+
// Don't allow readLength to overflow in the presence of a large readRemaining
894+
// counter.
895+
if c.readLength < 0 {
896+
return noFrame, ErrReadLimit
897+
}
898+
864899
if c.readLimit > 0 && c.readLength > c.readLimit {
865900
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
866901
return noFrame, ErrReadLimit
@@ -874,7 +909,7 @@ func (c *Conn) advanceFrame() (int, error) {
874909
var payload []byte
875910
if c.readRemaining > 0 {
876911
payload, err = c.read(int(c.readRemaining))
877-
c.readRemaining = 0
912+
c.setReadRemaining(0)
878913
if err != nil {
879914
return noFrame, err
880915
}
@@ -947,6 +982,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
947982
c.readErr = hideTempErr(err)
948983
break
949984
}
985+
950986
if frameType == TextMessage || frameType == BinaryMessage {
951987
c.messageReader = &messageReader{c}
952988
c.reader = c.messageReader
@@ -987,7 +1023,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
9871023
if c.isServer {
9881024
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
9891025
}
990-
c.readRemaining -= int64(n)
1026+
rem := c.readRemaining
1027+
rem -= int64(n)
1028+
c.setReadRemaining(rem)
9911029
if c.readRemaining > 0 && c.readErr == io.EOF {
9921030
c.readErr = errUnexpectedEOF
9931031
}

conn_test.go

Lines changed: 88 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
5555
}
5656

5757
func TestFraming(t *testing.T) {
58-
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
58+
frameSizes := []int{
59+
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
60+
// 65536, 65537
61+
}
5962
var readChunkers = []struct {
6063
name string
6164
f func(io.Reader) io.Reader
@@ -120,6 +123,8 @@ func TestFraming(t *testing.T) {
120123
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
121124
continue
122125
}
126+
127+
t.Logf("frame size: %d", n)
123128
rbuf, err := ioutil.ReadAll(r)
124129
if err != nil {
125130
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
@@ -458,37 +463,93 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
458463
}
459464

460465
func TestReadLimit(t *testing.T) {
466+
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
467+
const readLimit = 512
468+
message := make([]byte, readLimit+1)
461469

462-
const readLimit = 512
463-
message := make([]byte, readLimit+1)
470+
var b1, b2 bytes.Buffer
471+
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
472+
rc := newTestConn(&b1, &b2, true)
473+
rc.SetReadLimit(readLimit)
464474

465-
var b1, b2 bytes.Buffer
466-
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
467-
rc := newTestConn(&b1, &b2, true)
468-
rc.SetReadLimit(readLimit)
475+
// Send message at the limit with interleaved pong.
476+
w, _ := wc.NextWriter(BinaryMessage)
477+
w.Write(message[:readLimit-1])
478+
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
479+
w.Write(message[:1])
480+
w.Close()
469481

470-
// Send message at the limit with interleaved pong.
471-
w, _ := wc.NextWriter(BinaryMessage)
472-
w.Write(message[:readLimit-1])
473-
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
474-
w.Write(message[:1])
475-
w.Close()
482+
// Send message larger than the limit.
483+
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
476484

477-
// Send message larger than the limit.
478-
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
485+
op, _, err := rc.NextReader()
486+
if op != BinaryMessage || err != nil {
487+
t.Fatalf("1: NextReader() returned %d, %v", op, err)
488+
}
489+
op, r, err := rc.NextReader()
490+
if op != BinaryMessage || err != nil {
491+
t.Fatalf("2: NextReader() returned %d, %v", op, err)
492+
}
493+
_, err = io.Copy(ioutil.Discard, r)
494+
if err != ErrReadLimit {
495+
t.Fatalf("io.Copy() returned %v", err)
496+
}
497+
})
479498

480-
op, _, err := rc.NextReader()
481-
if op != BinaryMessage || err != nil {
482-
t.Fatalf("1: NextReader() returned %d, %v", op, err)
483-
}
484-
op, r, err := rc.NextReader()
485-
if op != BinaryMessage || err != nil {
486-
t.Fatalf("2: NextReader() returned %d, %v", op, err)
487-
}
488-
_, err = io.Copy(ioutil.Discard, r)
489-
if err != ErrReadLimit {
490-
t.Fatalf("io.Copy() returned %v", err)
491-
}
499+
t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
500+
const readLimit = 1
501+
502+
var b1, b2 bytes.Buffer
503+
rc := newTestConn(&b1, &b2, true)
504+
rc.SetReadLimit(readLimit)
505+
506+
// First, send a non-final binary message
507+
b1.Write([]byte("\x02\x81"))
508+
509+
// Mask key
510+
b1.Write([]byte("\x00\x00\x00\x00"))
511+
512+
// First payload
513+
b1.Write([]byte("A"))
514+
515+
// Next, send a negative-length, non-final continuation frame
516+
b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
517+
518+
// Mask key
519+
b1.Write([]byte("\x00\x00\x00\x00"))
520+
521+
// Next, send a too long, final continuation frame
522+
b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
523+
524+
// Mask key
525+
b1.Write([]byte("\x00\x00\x00\x00"))
526+
527+
// Too-long payload
528+
b1.Write([]byte("BCDEF"))
529+
530+
op, r, err := rc.NextReader()
531+
if op != BinaryMessage || err != nil {
532+
t.Fatalf("1: NextReader() returned %d, %v", op, err)
533+
}
534+
535+
var buf [10]byte
536+
var read int
537+
n, err := r.Read(buf[:])
538+
if err != nil && err != ErrReadLimit {
539+
t.Fatalf("unexpected error testing read limit: %v", err)
540+
}
541+
read += n
542+
543+
n, err = r.Read(buf[:])
544+
if err != nil && err != ErrReadLimit {
545+
t.Fatalf("unexpected error testing read limit: %v", err)
546+
}
547+
read += n
548+
549+
if err == nil && read > readLimit {
550+
t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
551+
}
552+
})
492553
}
493554

494555
func TestAddrs(t *testing.T) {

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
module github.com/gorilla/websocket
2+
3+
go 1.12

0 commit comments

Comments
 (0)