diff --git a/conn.go b/conn.go index 9562ffd4..d1c92c97 100644 --- a/conn.go +++ b/conn.go @@ -252,6 +252,8 @@ type Conn struct { writer io.WriteCloser // the current writer returned to the application isWriting bool // for best-effort concurrent write detection + disableClientMask bool + writeErrMu sync.Mutex writeErr error @@ -315,6 +317,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufSize: writeBufferSize, enableWriteCompression: true, compressionLevel: defaultCompressionLevel, + disableClientMask: false, } c.SetCloseHandler(nil) c.SetPingHandler(nil) @@ -432,7 +435,12 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er if c.isServer { buf = append(buf, data...) } else { - key := newMaskKey() + var key [4]byte + if c.disableClientMask { + key = [4]byte{0, 0, 0, 0} + } else { + key = newMaskKey() + } buf = append(buf, key[:]...) buf = append(buf, data...) maskBytes(key, 0, buf[6:]) @@ -610,7 +618,12 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { } if !c.isServer { - key := newMaskKey() + var key [4]byte + if c.disableClientMask { + key = [4]byte{0, 0, 0, 0} + } else { + key = newMaskKey() + } copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) if len(extra) > 0 { @@ -743,9 +756,10 @@ func (w *messageWriter) Close() error { // WritePreparedMessage writes prepared message into connection. func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { frameType, frameData, err := pm.frame(prepareKey{ - isServer: c.isServer, - compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), - compressionLevel: c.compressionLevel, + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + disableClientMask: c.disableClientMask, }) if err != nil { return err @@ -1230,6 +1244,20 @@ func (c *Conn) SetCompressionLevel(level int) error { return nil } +// SetDisableClientMask configures WebSocket payload masking behavior for client-mode frames. +// When enabled (true), implements protocol-allowed optimization +// by generating zero-value mask keys ([4]byte{0,0,0,0}), effectively omitting XOR operations +// while maintaining formal protocol compliance. +// +// Security Advisory: +// - Safe to enable ONLY when using secure transport layers (TLS 1.2+/SSL) +// - May expose vulnerabilities to network intermediaries when unprotected +// +// Default: false (masking enabled) - Maintains protocol compliance for plaintext connections +func (c *Conn) SetDisableClientMask(value bool) { + c.disableClientMask = value +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. // An empty message is returned for code CloseNoStatusReceived. func FormatCloseMessage(closeCode int, text string) []byte { diff --git a/mask.go b/mask.go index d0742bf2..91df11f9 100644 --- a/mask.go +++ b/mask.go @@ -13,7 +13,7 @@ const wordSize = int(unsafe.Sizeof(uintptr(0))) func maskBytes(key [4]byte, pos int, b []byte) int { // Mask one byte at a time for small buffers. - if len(b) < 2*wordSize { + if len(b) <= 2*wordSize { for i := range b { b[i] ^= key[pos&3] pos++ @@ -21,6 +21,9 @@ func maskBytes(key [4]byte, pos int, b []byte) int { return pos & 3 } + if key == [4]byte{} { + return (pos + len(b)) & 3 + } // Mask one byte at a time to word boundary. if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { n = wordSize - n @@ -32,16 +35,27 @@ func maskBytes(key [4]byte, pos int, b []byte) int { } // Create aligned word size key. + var kw uintptr var k [wordSize]byte - for i := range k { - k[i] = key[(pos+i)&3] + if wordSize == 8 { + k[0] = key[(pos+0)&3] + k[1] = key[(pos+1)&3] + k[2] = key[(pos+2)&3] + k[3] = key[(pos+3)&3] + kw = *(*uintptr)(unsafe.Pointer(&k)) + kw = (kw << 32) | kw + } else { + for i := range k { + k[i] = key[(pos+i)&3] + } + kw = *(*uintptr)(unsafe.Pointer(&k)) } - kw := *(*uintptr)(unsafe.Pointer(&k)) // Mask one word at a time. n := (len(b) / wordSize) * wordSize + p0 := unsafe.Pointer(&b[0]) for i := 0; i < n; i += wordSize { - *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + *(*uintptr)(unsafe.Pointer(uintptr(p0) + uintptr(i))) ^= kw } // Mask one byte at a time for remaining bytes. diff --git a/mask_test.go b/mask_test.go index 6389f436..b8ae00d5 100644 --- a/mask_test.go +++ b/mask_test.go @@ -7,8 +7,11 @@ package websocket import ( + "bytes" "fmt" + "math/rand" "testing" + "unsafe" ) func maskBytesByByte(key [4]byte, pos int, b []byte) int { @@ -28,6 +31,49 @@ func notzero(b []byte) int { return -1 } +func maskBytesV1(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} + func TestMaskBytes(t *testing.T) { key := [4]byte{1, 2, 3, 4} for size := 1; size <= 1024; size++ { @@ -44,8 +90,39 @@ func TestMaskBytes(t *testing.T) { } } +func TestMaskBytesWithRandomMessage(t *testing.T) { + keys := [][4]byte{ + {1, 2, 3, 4}, + {0, 0, 0, 0}, + } + for _, key := range keys { + for size := 1; size <= 1024; size++ { + for align := 0; align < wordSize; align++ { + for pos := 0; pos < 4; pos++ { + byteMessage := make([]byte, size+align)[align:] + for i := 0; i < len(byteMessage); i++ { + byteMessage[i] = uint8(rand.Uint32()) + } + byteMessageCopy := make([]byte, len(byteMessage)) + copy(byteMessageCopy, byteMessage) + posBytes := maskBytes(key, pos, byteMessage) + posBytesByByte := maskBytesByByte(key, pos, byteMessageCopy) + if posBytes != posBytesByByte { + t.Errorf("keys:%v, size:%d, align:%d, pos:%d", key, size, align, pos) + return + } + if !bytes.Equal(byteMessage, byteMessageCopy) { + t.Errorf("keys:%v, size:%d, align:%d, pos:%d", key, size, align, pos) + return + } + } + } + } + } +} + func BenchmarkMaskBytes(b *testing.B) { - for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { + for _, size := range []int{2, 4, 8, 16, 32, 512, 1024, 1048576} { b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { for _, align := range []int{wordSize / 2} { b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { @@ -54,11 +131,13 @@ func BenchmarkMaskBytes(b *testing.B) { fn func(key [4]byte, pos int, b []byte) int }{ {"byte", maskBytesByByte}, + {"wordV1", maskBytesV1}, {"word", maskBytes}, } { b.Run(fn.name, func(b *testing.B) { key := newMaskKey() data := make([]byte, size+align)[align:] + b.ResetTimer() for i := 0; i < b.N; i++ { fn.fn(key, 0, data) } diff --git a/prepared.go b/prepared.go index c854225e..a7e93290 100644 --- a/prepared.go +++ b/prepared.go @@ -25,9 +25,10 @@ type PreparedMessage struct { // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. type prepareKey struct { - isServer bool - compress bool - compressionLevel int + isServer bool + compress bool + compressionLevel int + disableClientMask bool } // preparedFrame contains data in wire representation. @@ -83,6 +84,7 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { compressionLevel: key.compressionLevel, enableWriteCompression: true, writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + disableClientMask: key.disableClientMask, } if key.compress { c.newCompressionWriter = compressNoContextTakeover