diff --git a/client.go b/client.go index 04fdafee..4d07aa8b 100644 --- a/client.go +++ b/client.go @@ -6,6 +6,7 @@ package websocket import ( "bytes" + "compress/flate" "context" "crypto/tls" "errors" @@ -100,16 +101,19 @@ type Dialer struct { // Subprotocols specifies the client's requested subprotocols. Subprotocols []string - // EnableCompression specifies if the client should attempt to negotiate - // per message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. Currently only "no context - // takeover" modes are supported. + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). EnableCompression bool // Jar specifies the cookie jar. // If Jar is nil, cookies are not sent in requests and ignored // in responses. Jar http.CookieJar + + // AllowClientContextTakeover specifies whether the server will negotiate client context + // takeover for per message compression. Context takeover improves compression at the + // the cost of using more memory. + AllowClientContextTakeover bool } // Dial creates a new client connection by calling DialContext with a background context. @@ -235,8 +239,11 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } } - if d.EnableCompression { - req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} + switch { + case d.EnableCompression && d.AllowClientContextTakeover: + req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_max_window_bits=15; client_max_window_bits=15") + case d.EnableCompression: + req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") } if d.HandshakeTimeout != 0 { @@ -408,13 +415,24 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h if ext[""] != "permessage-deflate" { continue } - _, snct := ext["server_no_context_takeover"] - _, cnct := ext["client_no_context_takeover"] - if !snct || !cnct { - return nil, resp, errInvalidCompression + + _, cmwb := ext["client_max_window_bits"] + _, smwb := ext["server_max_window_bits"] + + switch { + case cmwb && smwb: + var wf contextTakeoverWriterFactory + conn.newCompressionWriter = wf.newCompressionWriter + + var rf contextTakeoverReaderFactory + fr := flate.NewReader(nil) + rf.fr = fr + conn.newDecompressionReader = rf.newDeCompressionReader + default: + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover } - conn.newCompressionWriter = compressNoContextTakeover - conn.newDecompressionReader = decompressNoContextTakeover + break } diff --git a/client_server_test.go b/client_server_test.go index a47df488..73f0e330 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -45,7 +45,14 @@ var cstDialer = Dialer{ HandshakeTimeout: 30 * time.Second, } -type cstHandler struct{ *testing.T } +type cstHandlerConfig struct { + contextTakeover bool +} + +type cstHandler struct { + *testing.T + cstHandlerConfig +} type cstServer struct { *httptest.Server @@ -59,17 +66,17 @@ const ( cstRequestURI = cstPath + "?" + cstRawQuery ) -func newServer(t *testing.T) *cstServer { +func newServer(t *testing.T, c cstHandlerConfig) *cstServer { var s cstServer - s.Server = httptest.NewServer(cstHandler{t}) + s.Server = httptest.NewServer(cstHandler{t, c}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s } -func newTLSServer(t *testing.T) *cstServer { +func newTLSServer(t *testing.T, c cstHandlerConfig) *cstServer { var s cstServer - s.Server = httptest.NewTLSServer(cstHandler{t}) + s.Server = httptest.NewTLSServer(cstHandler{t, c}) s.Server.URL += cstRequestURI s.URL = makeWsProto(s.Server.URL) return &s @@ -92,6 +99,9 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "bad protocol", http.StatusBadRequest) return } + if t.contextTakeover { + cstUpgrader.AllowServerContextTakeover = true + } ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) if err != nil { t.Logf("Upgrade: %v", err) @@ -122,6 +132,28 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.Logf("Close: %v", err) return } + + // for multipleSendRecv when context takeover. + if t.contextTakeover { + op, rd, err := ws.NextReader() + if err != nil { + t.Logf("NextReader: %v", err) + return + } + wr, err := ws.NextWriter(op) + if err != nil { + t.Logf("NextWriter: %v", err) + return + } + if _, err = io.Copy(wr, rd); err != nil { + t.Logf("NextWriter: %v", err) + return + } + if err := wr.Close(); err != nil { + t.Logf("Close: %v", err) + return + } + } } func makeWsProto(s string) string { @@ -148,9 +180,30 @@ func sendRecv(t *testing.T, ws *Conn) { } } +func multipleSendRecv(t *testing.T, ws *Conn) { + for _, message := range []string{"Hello World", "Can you read message?"} { + if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("WriteMessage: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + if string(p) != message { + t.Fatalf("message=%s, want %s", p, message) + } + } +} + func TestProxyDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() surl, _ := url.Parse(s.Server.URL) @@ -187,7 +240,7 @@ func TestProxyDial(t *testing.T) { } func TestProxyAuthorizationDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() surl, _ := url.Parse(s.Server.URL) @@ -227,7 +280,7 @@ func TestProxyAuthorizationDial(t *testing.T) { } func TestDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, _, err := cstDialer.Dial(s.URL, nil) @@ -239,7 +292,7 @@ func TestDial(t *testing.T) { } func TestDialCookieJar(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() jar, _ := cookiejar.New(nil) @@ -301,7 +354,7 @@ func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool { } func TestDialTLS(t *testing.T) { - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -315,7 +368,7 @@ func TestDialTLS(t *testing.T) { } func TestDialTimeout(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -371,7 +424,7 @@ func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } func TestHandshakeTimeout(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -387,7 +440,7 @@ func TestHandshakeTimeout(t *testing.T) { } func TestHandshakeTimeoutInContext(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -408,7 +461,7 @@ func TestHandshakeTimeoutInContext(t *testing.T) { } func TestDialBadScheme(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, _, err := cstDialer.Dial(s.Server.URL, nil) @@ -419,7 +472,7 @@ func TestDialBadScheme(t *testing.T) { } func TestDialBadOrigin(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) @@ -436,7 +489,7 @@ func TestDialBadOrigin(t *testing.T) { } func TestDialBadHeader(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() for _, k := range []string{"Upgrade", @@ -500,7 +553,7 @@ func TestDialExtraTokensInRespHeaders(t *testing.T) { } func TestHandshake(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) @@ -751,7 +804,7 @@ func TestHost(t *testing.T) { } func TestDialCompression(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() dialer := cstDialer @@ -764,8 +817,24 @@ func TestDialCompression(t *testing.T) { sendRecv(t, ws) } +func TestDialCompressionOfContextTakeover(t *testing.T) { + s := newServer(t, cstHandlerConfig{true}) + defer s.Close() + + dialer := cstDialer + dialer.EnableCompression = true + dialer.AllowClientContextTakeover = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + multipleSendRecv(t, ws) +} + func TestSocksProxyDial(t *testing.T) { - s := newServer(t) + s := newServer(t, cstHandlerConfig{}) defer s.Close() proxyListener, err := net.Listen("tcp", "127.0.0.1:0") @@ -868,7 +937,7 @@ func TestTracingDialWithContext(t *testing.T) { } ctx := httptrace.WithClientTrace(context.Background(), trace) - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer @@ -907,7 +976,7 @@ func TestEmptyTracingDialWithContext(t *testing.T) { trace := &httptrace.ClientTrace{} ctx := httptrace.WithClientTrace(context.Background(), trace) - s := newTLSServer(t) + s := newTLSServer(t, cstHandlerConfig{}) defer s.Close() d := cstDialer diff --git a/compression.go b/compression.go index 813ffb1e..882b46b1 100644 --- a/compression.go +++ b/compression.go @@ -5,17 +5,24 @@ package websocket import ( - "compress/flate" "errors" "io" "strings" "sync" + + "compress/flate" ) const ( minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 maxCompressionLevel = flate.BestCompression defaultCompressionLevel = 1 + + tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" ) var ( @@ -26,12 +33,6 @@ var ( ) func decompressNoContextTakeover(r io.Reader) io.ReadCloser { - const tail = - // Add four bytes as specified in RFC - "\x00\x00\xff\xff" + - // Add final block to squelch unexpected EOF error from flate reader. - "\x01\x00\x00\xff\xff" - fr, _ := flateReaderPool.Get().(io.ReadCloser) fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) return &flateReadWrapper{fr} @@ -112,6 +113,8 @@ func (w *flateWriteWrapper) Close() error { if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { return errors.New("websocket: internal error, unexpected bytes at end of flate stream") } + w.tw.p = [4]byte{} + w.tw.n = 0 err2 := w.tw.w.Close() if err1 != nil { return err1 @@ -146,3 +149,105 @@ func (r *flateReadWrapper) Close() error { r.fr = nil return err } + +type ( + contextTakeoverWriterFactory struct { + fw *flate.Writer + tw truncWriter + } + + flateTakeoverWriteWrapper struct { + f *contextTakeoverWriterFactory + } +) + +func (wf *contextTakeoverWriterFactory) newCompressionWriter(w io.WriteCloser, level int) io.WriteCloser { + // Set writer on first write. + // In order to guarantee the consistency of compression with the client, + // do not reassign later. + if wf.fw == nil { + wf.fw, _ = flate.NewWriter(&wf.tw, level) + } + + wf.tw.w = w + wf.tw.n = 0 + return &flateTakeoverWriteWrapper{wf} +} + +func (w *flateTakeoverWriteWrapper) Write(p []byte) (int, error) { + if w.f == nil { + return 0, errWriteClosed + } + return w.f.fw.Write(p) +} + +func (w *flateTakeoverWriteWrapper) Close() error { + if w.f == nil { + return errWriteClosed + } + f := w.f + w.f = nil + err1 := f.fw.Flush() + if f.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := f.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +// modules for compression context takeover +type ( + contextTakeoverReaderFactory struct { + fr io.ReadCloser + + // this window is used in compress/flate.decompressor. + // since there is no interface for updating the dictionary in the structure, + // window is rewritten with this structure. + // although there is a Reset(), it becomes initialization of a dictionary. + window []byte + } + + flateTakeoverReadWrapper struct { + f *contextTakeoverReaderFactory + } +) + +func (f *contextTakeoverReaderFactory) newDeCompressionReader(r io.Reader) io.ReadCloser { + f.fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), f.window) + return &flateTakeoverReadWrapper{f} +} + +func (r *flateTakeoverReadWrapper) Read(p []byte) (int, error) { + if r.f.fr == nil { + return 0, io.ErrClosedPipe + } + + n, err := r.f.fr.Read(p) + + // add window + r.f.window = append(r.f.window, p[:n]...) + if len(r.f.window) > maxWindowBits { + offset := len(r.f.window) - maxWindowBits + r.f.window = r.f.window[offset:] + } + + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + + return n, err +} + +func (r *flateTakeoverReadWrapper) Close() error { + if r.f.fr == nil { + return io.ErrClosedPipe + } + err := r.f.fr.Close() + return err +} diff --git a/compression_test.go b/compression_test.go index 8a26b30f..9a6f1343 100644 --- a/compression_test.go +++ b/compression_test.go @@ -65,6 +65,20 @@ func BenchmarkWriteWithCompression(b *testing.B) { b.ReportAllocs() } +func BenchmarkWriteWithCompressionOfContextTakeover(b *testing.B) { + w := ioutil.Discard + c := newTestConn(nil, w, false) + messages := textMessages(100) + c.enableWriteCompression = true + var f contextTakeoverWriterFactory + c.newCompressionWriter = f.newCompressionWriter + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + func TestValidCompressionLevel(t *testing.T) { c := newTestConn(nil, nil, false) for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { diff --git a/conn.go b/conn.go index 5161ef81..ae4e77c0 100644 --- a/conn.go +++ b/conn.go @@ -39,6 +39,8 @@ const ( continuationFrame = 0 noFrame = -1 + + maxWindowBits = 1 << 15 ) // Close codes defined in RFC 6455, section 11.7. diff --git a/conn_test.go b/conn_test.go index 06e51849..fd4acc4b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -7,6 +7,7 @@ package websocket import ( "bufio" "bytes" + "compress/flate" "errors" "fmt" "io" @@ -87,20 +88,50 @@ func TestFraming(t *testing.T) { }}, } - for _, compress := range []bool{false, true} { + compressConditions := []struct { + compress bool + contextTakeover bool + }{ + { + compress: false, + contextTakeover: false, + }, + { + compress: true, + contextTakeover: false, + }, + { + compress: true, + contextTakeover: true, + }, + } + + for _, compressCondition := range compressConditions { for _, isServer := range []bool{true, false} { for _, chunker := range readChunkers { var connBuf bytes.Buffer wc := newTestConn(nil, &connBuf, isServer) rc := newTestConn(chunker.f(&connBuf), nil, !isServer) - if compress { + switch { + case compressCondition.compress && compressCondition.contextTakeover: + + var wf contextTakeoverWriterFactory + wf.fw, _ = flate.NewWriter(&wf.tw, defaultCompressionLevel) + wc.newCompressionWriter = wf.newCompressionWriter + + var rf contextTakeoverReaderFactory + fr := flate.NewReader(nil) + rf.fr = fr + rc.newDecompressionReader = rf.newDeCompressionReader + + case compressCondition.compress: wc.newCompressionWriter = compressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover } for _, n := range frameSizes { for _, writer := range writers { - name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) + name := fmt.Sprintf("z:%v, c:%v, s:%v, r:%s, n:%d w:%s", compressCondition.compress, compressCondition.contextTakeover, isServer, chunker.name, n, writer.name) w, err := wc.NextWriter(TextMessage) if err != nil { diff --git a/doc.go b/doc.go index 8db0cef9..9cdc8a0a 100644 --- a/doc.go +++ b/doc.go @@ -218,10 +218,17 @@ // // conn.EnableWriteCompression(false) // -// Currently this package does not support compression with "context takeover". +// Currently this package supports compression with "context takeover". // This means that messages must be compressed and decompressed in isolation, // without retaining sliding window or dictionary state across messages. For // more details refer to RFC 7692. // +// If you want to use it, please do as follows. +// +// var upgrader = websocket.Upgrader{ +// EnableCompression: true, +// AllowServerContextTakeover: true, +// } +// // Use of compression is experimental and may result in decreased performance. package websocket diff --git a/server.go b/server.go index bb335974..028f55ff 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ package websocket import ( "bufio" + "compress/flate" "errors" "io" "net/http" @@ -68,10 +69,13 @@ type Upgrader struct { CheckOrigin func(r *http.Request) bool // EnableCompression specify if the server should attempt to negotiate per - // message compression (RFC 7692). Setting this value to true does not - // guarantee that compression will be supported. Currently only "no context - // takeover" modes are supported. + // message compression (RFC 7692). EnableCompression bool + + // AllowServerContextTakeover specifies whether the server will negotiate server context + // takeover for per message compression. Context takeover improves compression at the + // cost of using more memory. + AllowServerContextTakeover bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -161,14 +165,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade subprotocol := u.selectSubprotocol(r, responseHeader) // Negotiate PMCE - var compress bool + var ( + compress bool + contextTakeover bool + ) if u.EnableCompression { for _, ext := range parseExtensions(r.Header) { - if ext[""] != "permessage-deflate" { - continue + // map[string]string{"":"permessage-deflate", "client_max_window_bits":""} + // detect context-takeover from client_max_window_bits + if ext[""] == "permessage-deflate" { + compress = true + } + + if _, ok := ext["client_max_window_bits"]; ok { + contextTakeover = true } - compress = true - break } } @@ -205,8 +216,19 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c.subprotocol = subprotocol if compress { - c.newCompressionWriter = compressNoContextTakeover - c.newDecompressionReader = decompressNoContextTakeover + switch { + case contextTakeover && u.AllowServerContextTakeover: + var wf contextTakeoverWriterFactory + c.newCompressionWriter = wf.newCompressionWriter + + var rf contextTakeoverReaderFactory + fr := flate.NewReader(nil) + rf.fr = fr + c.newDecompressionReader = rf.newDeCompressionReader + default: + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } } // Use larger of hijacked buffer and connection write buffer for header. @@ -225,7 +247,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade p = append(p, "\r\n"...) } if compress { - p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + switch { + case contextTakeover && u.AllowServerContextTakeover: + p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_max_window_bits=15; client_max_window_bits=15\r\n"...) + default: + p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" {