3
3
package websocket
4
4
5
5
import (
6
+ "bufio"
7
+ "errors"
8
+ "net"
9
+ "net/http"
6
10
"net/http/httptest"
7
11
"strings"
8
12
"testing"
@@ -23,6 +27,38 @@ func TestAccept(t *testing.T) {
23
27
assert .ErrorContains (t , "Accept" , err , "protocol violation" )
24
28
})
25
29
30
+ t .Run ("badOrigin" , func (t * testing.T ) {
31
+ t .Parallel ()
32
+
33
+ w := httptest .NewRecorder ()
34
+ r := httptest .NewRequest ("GET" , "/" , nil )
35
+ r .Header .Set ("Connection" , "Upgrade" )
36
+ r .Header .Set ("Upgrade" , "websocket" )
37
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
38
+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
39
+ r .Header .Set ("Origin" , "harhar.com" )
40
+
41
+ _ , err := Accept (w , r , nil )
42
+ assert .ErrorContains (t , "Accept" , err , "request Origin \" harhar.com\" is not authorized for Host" )
43
+ })
44
+
45
+ t .Run ("badCompression" , func (t * testing.T ) {
46
+ t .Parallel ()
47
+
48
+ w := mockHijacker {
49
+ ResponseWriter : httptest .NewRecorder (),
50
+ }
51
+ r := httptest .NewRequest ("GET" , "/" , nil )
52
+ r .Header .Set ("Connection" , "Upgrade" )
53
+ r .Header .Set ("Upgrade" , "websocket" )
54
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
55
+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
56
+ r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
57
+
58
+ _ , err := Accept (w , r , nil )
59
+ assert .ErrorContains (t , "Accept" , err , "unsupported permessage-deflate parameter" )
60
+ })
61
+
26
62
t .Run ("requireHttpHijacker" , func (t * testing.T ) {
27
63
t .Parallel ()
28
64
@@ -36,6 +72,26 @@ func TestAccept(t *testing.T) {
36
72
_ , err := Accept (w , r , nil )
37
73
assert .ErrorContains (t , "Accept" , err , "http.ResponseWriter does not implement http.Hijacker" )
38
74
})
75
+
76
+ t .Run ("badHijack" , func (t * testing.T ) {
77
+ t .Parallel ()
78
+
79
+ w := mockHijacker {
80
+ ResponseWriter : httptest .NewRecorder (),
81
+ hijack : func () (conn net.Conn , writer * bufio.ReadWriter , err error ) {
82
+ return nil , nil , errors .New ("haha" )
83
+ },
84
+ }
85
+
86
+ r := httptest .NewRequest ("GET" , "/" , nil )
87
+ r .Header .Set ("Connection" , "Upgrade" )
88
+ r .Header .Set ("Upgrade" , "websocket" )
89
+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
90
+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
91
+
92
+ _ , err := Accept (w , r , nil )
93
+ assert .ErrorContains (t , "Accept" , err , "failed to hijack connection" )
94
+ })
39
95
}
40
96
41
97
func Test_verifyClientHandshake (t * testing.T ) {
@@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) {
243
299
}
244
300
245
301
func Test_acceptCompression (t * testing.T ) {
302
+ t .Parallel ()
303
+
304
+ testCases := []struct {
305
+ name string
306
+ mode CompressionMode
307
+ reqSecWebSocketExtensions string
308
+ respSecWebSocketExtensions string
309
+ expCopts * compressionOptions
310
+ error bool
311
+ }{
312
+ {
313
+ name : "disabled" ,
314
+ mode : CompressionDisabled ,
315
+ expCopts : nil ,
316
+ },
317
+ {
318
+ name : "noClientSupport" ,
319
+ mode : CompressionNoContextTakeover ,
320
+ expCopts : nil ,
321
+ },
322
+ {
323
+ name : "permessage-deflate" ,
324
+ mode : CompressionNoContextTakeover ,
325
+ reqSecWebSocketExtensions : "permessage-deflate; client_max_window_bits" ,
326
+ respSecWebSocketExtensions : "permessage-deflate; client_no_context_takeover; server_no_context_takeover" ,
327
+ expCopts : & compressionOptions {
328
+ clientNoContextTakeover : true ,
329
+ serverNoContextTakeover : true ,
330
+ },
331
+ },
332
+ {
333
+ name : "permessage-deflate/error" ,
334
+ mode : CompressionNoContextTakeover ,
335
+ reqSecWebSocketExtensions : "permessage-deflate; meow" ,
336
+ error : true ,
337
+ },
338
+ {
339
+ name : "x-webkit-deflate-frame" ,
340
+ mode : CompressionNoContextTakeover ,
341
+ reqSecWebSocketExtensions : "x-webkit-deflate-frame; no_context_takeover" ,
342
+ respSecWebSocketExtensions : "x-webkit-deflate-frame; no_context_takeover" ,
343
+ expCopts : & compressionOptions {
344
+ clientNoContextTakeover : true ,
345
+ serverNoContextTakeover : true ,
346
+ },
347
+ },
348
+ {
349
+ name : "x-webkit-deflate/error" ,
350
+ mode : CompressionNoContextTakeover ,
351
+ reqSecWebSocketExtensions : "x-webkit-deflate-frame; max_window_bits" ,
352
+ error : true ,
353
+ },
354
+ }
355
+
356
+ for _ , tc := range testCases {
357
+ tc := tc
358
+ t .Run (tc .name , func (t * testing.T ) {
359
+ t .Parallel ()
360
+
361
+ r := httptest .NewRequest (http .MethodGet , "/" , nil )
362
+ r .Header .Set ("Sec-WebSocket-Extensions" , tc .reqSecWebSocketExtensions )
363
+
364
+ w := httptest .NewRecorder ()
365
+ copts , err := acceptCompression (r , w , tc .mode )
366
+ if tc .error {
367
+ assert .Error (t , "acceptCompression" , err )
368
+ return
369
+ }
370
+
371
+ assert .Success (t , "acceptCompression" , err )
372
+ assert .Equal (t , "compresssionOpts" , tc .expCopts , copts )
373
+ assert .Equal (t , "respHeader" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
374
+ })
375
+ }
376
+ }
377
+
378
+ type mockHijacker struct {
379
+ http.ResponseWriter
380
+ hijack func () (net.Conn , * bufio.ReadWriter , error )
381
+ }
382
+
383
+ var _ http.Hijacker = mockHijacker {}
246
384
385
+ func (mj mockHijacker ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
386
+ return mj .hijack ()
247
387
}
0 commit comments