@@ -12,7 +12,7 @@ import (
12
12
13
13
"golang.org/x/xerrors"
14
14
15
- "nhooyr.io/websocket/internal/test/cmp "
15
+ "nhooyr.io/websocket/internal/test/assert "
16
16
)
17
17
18
18
func TestAccept (t * testing.T ) {
@@ -25,9 +25,7 @@ func TestAccept(t *testing.T) {
25
25
r := httptest .NewRequest ("GET" , "/" , nil )
26
26
27
27
_ , err := Accept (w , r , nil )
28
- if ! cmp .ErrorContains (err , "protocol violation" ) {
29
- t .Fatal (err )
30
- }
28
+ assert .Contains (t , err , "protocol violation" )
31
29
})
32
30
33
31
t .Run ("badOrigin" , func (t * testing.T ) {
@@ -42,9 +40,7 @@ func TestAccept(t *testing.T) {
42
40
r .Header .Set ("Origin" , "harhar.com" )
43
41
44
42
_ , err := Accept (w , r , nil )
45
- if ! cmp .ErrorContains (err , `request Origin "harhar.com" is not authorized for Host` ) {
46
- t .Fatal (err )
47
- }
43
+ assert .Contains (t , err , `request Origin "harhar.com" is not authorized for Host` )
48
44
})
49
45
50
46
t .Run ("badCompression" , func (t * testing.T ) {
@@ -61,9 +57,7 @@ func TestAccept(t *testing.T) {
61
57
r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
62
58
63
59
_ , err := Accept (w , r , nil )
64
- if ! cmp .ErrorContains (err , `unsupported permessage-deflate parameter` ) {
65
- t .Fatal (err )
66
- }
60
+ assert .Contains (t , err , `unsupported permessage-deflate parameter` )
67
61
})
68
62
69
63
t .Run ("requireHttpHijacker" , func (t * testing.T ) {
@@ -77,9 +71,7 @@ func TestAccept(t *testing.T) {
77
71
r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
78
72
79
73
_ , err := Accept (w , r , nil )
80
- if ! cmp .ErrorContains (err , `http.ResponseWriter does not implement http.Hijacker` ) {
81
- t .Fatal (err )
82
- }
74
+ assert .Contains (t , err , `http.ResponseWriter does not implement http.Hijacker` )
83
75
})
84
76
85
77
t .Run ("badHijack" , func (t * testing.T ) {
@@ -99,9 +91,7 @@ func TestAccept(t *testing.T) {
99
91
r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
100
92
101
93
_ , err := Accept (w , r , nil )
102
- if ! cmp .ErrorContains (err , `failed to hijack connection` ) {
103
- t .Fatal (err )
104
- }
94
+ assert .Contains (t , err , `failed to hijack connection` )
105
95
})
106
96
}
107
97
@@ -193,8 +183,10 @@ func Test_verifyClientHandshake(t *testing.T) {
193
183
}
194
184
195
185
_ , err := verifyClientRequest (httptest .NewRecorder (), r )
196
- if tc .success != (err == nil ) {
197
- t .Fatalf ("unexpected error value: %v" , err )
186
+ if tc .success {
187
+ assert .Success (t , err )
188
+ } else {
189
+ assert .Error (t , err )
198
190
}
199
191
})
200
192
}
@@ -244,9 +236,7 @@ func Test_selectSubprotocol(t *testing.T) {
244
236
r .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (tc .clientProtocols , "," ))
245
237
246
238
negotiated := selectSubprotocol (r , tc .serverProtocols )
247
- if ! cmp .Equal (tc .negotiated , negotiated ) {
248
- t .Fatalf ("unexpected negotiated: %v" , cmp .Diff (tc .negotiated , negotiated ))
249
- }
239
+ assert .Equal (t , "negotiated" , tc .negotiated , negotiated )
250
240
})
251
241
}
252
242
}
@@ -300,8 +290,10 @@ func Test_authenticateOrigin(t *testing.T) {
300
290
r .Header .Set ("Origin" , tc .origin )
301
291
302
292
err := authenticateOrigin (r )
303
- if tc .success != (err == nil ) {
304
- t .Fatalf ("unexpected error value: %v" , err )
293
+ if tc .success {
294
+ assert .Success (t , err )
295
+ } else {
296
+ assert .Error (t , err )
305
297
}
306
298
})
307
299
}
@@ -373,21 +365,13 @@ func Test_acceptCompression(t *testing.T) {
373
365
w := httptest .NewRecorder ()
374
366
copts , err := acceptCompression (r , w , tc .mode )
375
367
if tc .error {
376
- if err == nil {
377
- t .Fatalf ("expected error: %v" , copts )
378
- }
368
+ assert .Error (t , err )
379
369
return
380
370
}
381
371
382
- if err != nil {
383
- t .Fatal (err )
384
- }
385
- if ! cmp .Equal (tc .expCopts , copts ) {
386
- t .Fatalf ("unexpected compression options: %v" , cmp .Diff (tc .expCopts , copts ))
387
- }
388
- if ! cmp .Equal (tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" )) {
389
- t .Fatalf ("unexpected respHeader: %v" , cmp .Diff (tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" )))
390
- }
372
+ assert .Success (t , err )
373
+ assert .Equal (t , "compression options" , tc .expCopts , copts )
374
+ assert .Equal (t , "Sec-WebSocket-Extensions" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
391
375
})
392
376
}
393
377
}
0 commit comments