Skip to content

Commit aaf4b45

Browse files
committed
Up test coverage of accept.go to 100%
1 parent 8c87970 commit aaf4b45

File tree

6 files changed

+164
-12
lines changed

6 files changed

+164
-12
lines changed

accept.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
9292
w.Header().Set("Sec-WebSocket-Protocol", subproto)
9393
}
9494

95-
copts, err := acceptCompression(r, w, opts.CompressionMode)
95+
copts, err := acceptCompression(r, w, opts.CompressionOptions.Mode)
9696
if err != nil {
9797
return nil, err
9898
}
@@ -201,7 +201,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
201201
case "server_no_context_takeover":
202202
copts.serverNoContextTakeover = true
203203
continue
204-
case "client_max_window_bits", "server-max-window-bits":
204+
}
205+
206+
if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") {
205207
continue
206208
}
207209

accept_test.go

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
package websocket
44

55
import (
6+
"bufio"
7+
"errors"
8+
"net"
9+
"net/http"
610
"net/http/httptest"
711
"strings"
812
"testing"
@@ -23,6 +27,38 @@ func TestAccept(t *testing.T) {
2327
assert.ErrorContains(t, "Accept", err, "protocol violation")
2428
})
2529

30+
t.Run("badOrigin", func(t *testing.T) {
31+
t.Parallel()
32+
33+
w := httptest.NewRecorder()
34+
r := httptest.NewRequest("GET", "/", nil)
35+
r.Header.Set("Connection", "Upgrade")
36+
r.Header.Set("Upgrade", "websocket")
37+
r.Header.Set("Sec-WebSocket-Version", "13")
38+
r.Header.Set("Sec-WebSocket-Key", "meow123")
39+
r.Header.Set("Origin", "harhar.com")
40+
41+
_, err := Accept(w, r, nil)
42+
assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host")
43+
})
44+
45+
t.Run("badCompression", func(t *testing.T) {
46+
t.Parallel()
47+
48+
w := mockHijacker{
49+
ResponseWriter: httptest.NewRecorder(),
50+
}
51+
r := httptest.NewRequest("GET", "/", nil)
52+
r.Header.Set("Connection", "Upgrade")
53+
r.Header.Set("Upgrade", "websocket")
54+
r.Header.Set("Sec-WebSocket-Version", "13")
55+
r.Header.Set("Sec-WebSocket-Key", "meow123")
56+
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
57+
58+
_, err := Accept(w, r, nil)
59+
assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter")
60+
})
61+
2662
t.Run("requireHttpHijacker", func(t *testing.T) {
2763
t.Parallel()
2864

@@ -36,6 +72,26 @@ func TestAccept(t *testing.T) {
3672
_, err := Accept(w, r, nil)
3773
assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker")
3874
})
75+
76+
t.Run("badHijack", func(t *testing.T) {
77+
t.Parallel()
78+
79+
w := mockHijacker{
80+
ResponseWriter: httptest.NewRecorder(),
81+
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
82+
return nil, nil, errors.New("haha")
83+
},
84+
}
85+
86+
r := httptest.NewRequest("GET", "/", nil)
87+
r.Header.Set("Connection", "Upgrade")
88+
r.Header.Set("Upgrade", "websocket")
89+
r.Header.Set("Sec-WebSocket-Version", "13")
90+
r.Header.Set("Sec-WebSocket-Key", "meow123")
91+
92+
_, err := Accept(w, r, nil)
93+
assert.ErrorContains(t, "Accept", err, "failed to hijack connection")
94+
})
3995
}
4096

4197
func Test_verifyClientHandshake(t *testing.T) {
@@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) {
243299
}
244300

245301
func Test_acceptCompression(t *testing.T) {
302+
t.Parallel()
303+
304+
testCases := []struct {
305+
name string
306+
mode CompressionMode
307+
reqSecWebSocketExtensions string
308+
respSecWebSocketExtensions string
309+
expCopts *compressionOptions
310+
error bool
311+
}{
312+
{
313+
name: "disabled",
314+
mode: CompressionDisabled,
315+
expCopts: nil,
316+
},
317+
{
318+
name: "noClientSupport",
319+
mode: CompressionNoContextTakeover,
320+
expCopts: nil,
321+
},
322+
{
323+
name: "permessage-deflate",
324+
mode: CompressionNoContextTakeover,
325+
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
326+
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
327+
expCopts: &compressionOptions{
328+
clientNoContextTakeover: true,
329+
serverNoContextTakeover: true,
330+
},
331+
},
332+
{
333+
name: "permessage-deflate/error",
334+
mode: CompressionNoContextTakeover,
335+
reqSecWebSocketExtensions: "permessage-deflate; meow",
336+
error: true,
337+
},
338+
{
339+
name: "x-webkit-deflate-frame",
340+
mode: CompressionNoContextTakeover,
341+
reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
342+
respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
343+
expCopts: &compressionOptions{
344+
clientNoContextTakeover: true,
345+
serverNoContextTakeover: true,
346+
},
347+
},
348+
{
349+
name: "x-webkit-deflate/error",
350+
mode: CompressionNoContextTakeover,
351+
reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
352+
error: true,
353+
},
354+
}
355+
356+
for _, tc := range testCases {
357+
tc := tc
358+
t.Run(tc.name, func(t *testing.T) {
359+
t.Parallel()
360+
361+
r := httptest.NewRequest(http.MethodGet, "/", nil)
362+
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
363+
364+
w := httptest.NewRecorder()
365+
copts, err := acceptCompression(r, w, tc.mode)
366+
if tc.error {
367+
assert.Error(t, "acceptCompression", err)
368+
return
369+
}
370+
371+
assert.Success(t, "acceptCompression", err)
372+
assert.Equal(t, "compresssionOpts", tc.expCopts, copts)
373+
assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
374+
})
375+
}
376+
}
377+
378+
type mockHijacker struct {
379+
http.ResponseWriter
380+
hijack func() (net.Conn, *bufio.ReadWriter, error)
381+
}
382+
383+
var _ http.Hijacker = mockHijacker{}
246384

385+
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
386+
return mj.hijack()
247387
}

compress.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,22 @@ import (
99
"sync"
1010
)
1111

12+
// CompressionOptions represents the available deflate extension options.
13+
// See https://tools.ietf.org/html/rfc7692
1214
type CompressionOptions struct {
1315
// Mode controls the compression mode.
16+
//
17+
// See docs on CompressionMode.
1418
Mode CompressionMode
1519

1620
// Threshold controls the minimum size of a message before compression is applied.
21+
//
22+
// Defaults to 512 bytes for CompressionNoContextTakeover and 256 bytes
23+
// for CompressionContextTakeover.
1724
Threshold int
1825
}
1926

20-
// CompressionMode controls the modes available RFC 7692's deflate extension.
27+
// CompressionMode represents the modes available to the deflate extension.
2128
// See https://tools.ietf.org/html/rfc7692
2229
//
2330
// A compatibility layer is implemented for the older deflate-frame extension used
@@ -31,7 +38,7 @@ const (
3138
// for every message. This applies to both server and client side.
3239
//
3340
// This means less efficient compression as the sliding window from previous messages
34-
// will not be used but the memory overhead will be much lower if the connections
41+
// will not be used but the memory overhead will be lower if the connections
3542
// are long lived and seldom used.
3643
//
3744
// The message will only be compressed if greater than 512 bytes.
@@ -40,8 +47,7 @@ const (
4047
// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
4148
// This enables reusing the sliding window from previous messages.
4249
// As most WebSocket protocols are repetitive, this can be very efficient.
43-
//
44-
// The message will only be compressed if greater than 128 bytes.
50+
// It carries an overhead of 64 kB for every connection compared to CompressionNoContextTakeover.
4551
//
4652
// If the peer negotiates NoContextTakeover on the client or server side, it will be
4753
// used instead as this is required by the RFC.

conn_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ func TestConn(t *testing.T) {
2626
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
2727
Subprotocols: []string{"echo"},
2828
InsecureSkipVerify: true,
29-
CompressionMode: websocket.CompressionNoContextTakeover,
29+
CompressionOptions: websocket.CompressionOptions{
30+
Mode: websocket.CompressionNoContextTakeover,
31+
},
3032
})
3133
assert.Success(t, "accept", err)
3234
defer c.Close(websocket.StatusInternalError, "")
@@ -42,8 +44,10 @@ func TestConn(t *testing.T) {
4244
defer cancel()
4345

4446
opts := &websocket.DialOptions{
45-
Subprotocols: []string{"echo"},
46-
CompressionMode: websocket.CompressionNoContextTakeover,
47+
Subprotocols: []string{"echo"},
48+
CompressionOptions: websocket.CompressionOptions{
49+
Mode: websocket.CompressionNoContextTakeover,
50+
},
4751
}
4852
opts.HTTPClient = s.Client()
4953

dial.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
136136
if len(opts.Subprotocols) > 0 {
137137
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
138138
}
139-
if opts.CompressionMode != CompressionDisabled {
140-
copts := opts.CompressionMode.opts()
139+
if opts.CompressionOptions.Mode != CompressionDisabled {
140+
copts := opts.CompressionOptions.Mode.opts()
141141
copts.setHeader(req.Header)
142142
}
143143

write.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func newMsgWriter(c *Conn) *msgWriter {
6464

6565
func (mw *msgWriter) ensureFlateWriter() {
6666
if mw.flateWriter == nil {
67-
mw.flateWriter = getFlateWriter(mw.trimWriter)
67+
mw.flateWriter = getFlateWriter(mw.trimWriter, nil)
6868
}
6969
}
7070

0 commit comments

Comments
 (0)