Skip to content

Commit 519e970

Browse files
committed
Cookie and CloseError unit tests
1 parent 5f76559 commit 519e970

File tree

9 files changed

+171
-62
lines changed

9 files changed

+171
-62
lines changed

accept.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
112112
hj, ok := w.(http.Hijacker)
113113
if !ok {
114114
err = xerrors.New("websocket: response writer does not implement http.Hijacker")
115-
http.Error(w, err.Error(), http.StatusInternalServerError)
115+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
116116
return nil, err
117117
}
118118

@@ -131,7 +131,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
131131
netConn, brw, err := hj.Hijack()
132132
if err != nil {
133133
err = xerrors.Errorf("websocket: failed to hijack connection: %w", err)
134-
http.Error(w, err.Error(), http.StatusInternalServerError)
134+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
135135
return nil, err
136136
}
137137

example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func ExampleAccept() {
7171
log.Printf("server handshake failed: %v", err)
7272
return
7373
}
74-
defer c.Close(websocket.StatusInternalError, "")
74+
defer c.Close(websocket.StatusInternalError, "") // TODO returning internal is incorect if its a timeout error.
7575

7676
jc := websocket.JSONConn{
7777
Conn: c,

header.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ type header struct {
3030
maskKey [4]byte
3131
}
3232

33-
// TODO bitwise helpers
34-
3533
// bytes returns the bytes of the header.
3634
// See https://tools.ietf.org/html/rfc6455#section-5.2
3735
func marshalHeader(h header) []byte {

header_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func randBool() bool {
2020
func TestHeader(t *testing.T) {
2121
t.Parallel()
2222

23-
t.Run("negative", func(t *testing.T) {
23+
t.Run("readNegativeLength", func(t *testing.T) {
2424
t.Parallel()
2525

2626
b := marshalHeader(header{

statuscode.go

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ const (
2424
StatusUnsupportedData
2525
_ // 1004 is reserved.
2626
StatusNoStatusRcvd
27-
StatusAbnormalClosure
27+
// statusAbnormalClosure is unexported because it isn't necessary, at least until WASM.
28+
// The error returned will indicate whether the connection was closed or not or what happened.
29+
// It only makes sense for browser clients.
30+
statusAbnormalClosure
2831
StatusInvalidFramePayloadData
2932
StatusPolicyViolation
3033
StatusMessageTooBig
@@ -33,7 +36,10 @@ const (
3336
StatusServiceRestart
3437
StatusTryAgainLater
3538
StatusBadGateway
36-
StatusTLSHandshake
39+
// statusTLSHandshake is unexported because we just return
40+
// handshake error in dial. We do not return a conn
41+
// so there is nothing to use this on. At least until WASM.
42+
statusTLSHandshake
3743
)
3844

3945
// CloseError represents an error from a WebSocket close frame.
@@ -43,68 +49,63 @@ type CloseError struct {
4349
Reason string
4450
}
4551

46-
func (e CloseError) Error() string {
47-
return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", e.Code, e.Reason)
52+
func (ce CloseError) Error() string {
53+
return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", ce.Code, ce.Reason)
4854
}
4955

50-
func parseClosePayload(p []byte) (code StatusCode, reason string, err error) {
56+
func parseClosePayload(p []byte) (CloseError, error) {
5157
if len(p) < 2 {
52-
return 0, "", fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
58+
return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
5359
}
5460

55-
code = StatusCode(binary.BigEndian.Uint16(p))
56-
reason = string(p[2:])
61+
ce := CloseError{
62+
Code: StatusCode(binary.BigEndian.Uint16(p)),
63+
Reason: string(p[2:]),
64+
}
5765

58-
if !utf8.ValidString(reason) {
59-
return 0, "", xerrors.Errorf("invalid utf-8: %q", reason)
66+
if !utf8.ValidString(ce.Reason) {
67+
return CloseError{}, xerrors.Errorf("invalid utf-8: %q", ce.Reason)
6068
}
61-
if !validCloseCode(code) {
62-
return 0, "", xerrors.Errorf("invalid code %v", code)
69+
if !validWireCloseCode(ce.Code) {
70+
return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code)
6371
}
6472

65-
return code, reason, nil
73+
return ce, nil
6674
}
6775

6876
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
6977
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
70-
var validReceivedCloseCodes = map[StatusCode]bool{
71-
StatusNormalClosure: true,
72-
StatusGoingAway: true,
73-
StatusProtocolError: true,
74-
StatusUnsupportedData: true,
75-
StatusNoStatusRcvd: false,
76-
// TODO use
77-
StatusAbnormalClosure: false,
78-
StatusInvalidFramePayloadData: true,
79-
StatusPolicyViolation: true,
80-
StatusMessageTooBig: true,
81-
StatusMandatoryExtension: true,
82-
StatusInternalError: true,
83-
StatusServiceRestart: true,
84-
StatusTryAgainLater: true,
85-
StatusTLSHandshake: false,
86-
}
78+
func validWireCloseCode(code StatusCode) bool {
79+
if code >= StatusNormalClosure && code <= statusTLSHandshake {
80+
switch code {
81+
case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake:
82+
return false
83+
default:
84+
return true
85+
}
86+
}
87+
if code >= 3000 && code <= 4999 {
88+
return true
89+
}
8790

88-
func validCloseCode(code StatusCode) bool {
89-
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
91+
return false
9092
}
9193

9294
const maxControlFramePayload = 125
9395

94-
// TODO make method on CloseError
95-
func closePayload(code StatusCode, reason string) ([]byte, error) {
96-
if len(reason) > maxControlFramePayload-2 {
97-
return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, reason, len(reason))
96+
func (ce CloseError) bytes() ([]byte, error) {
97+
if len(ce.Reason) > maxControlFramePayload-2 {
98+
return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason))
9899
}
99-
if bits.Len(uint(code)) > 16 {
100+
if bits.Len(uint(ce.Code)) > 16 {
100101
return nil, errors.New("status code is larger than 2 bytes")
101102
}
102-
if !validCloseCode(code) {
103-
return nil, fmt.Errorf("status code %v cannot be set", code)
103+
if !validWireCloseCode(ce.Code) {
104+
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
104105
}
105106

106-
buf := make([]byte, 2+len(reason))
107-
binary.BigEndian.PutUint16(buf[:], uint16(code))
108-
copy(buf[2:], reason)
107+
buf := make([]byte, 2+len(ce.Reason))
108+
binary.BigEndian.PutUint16(buf[:], uint16(ce.Code))
109+
copy(buf[2:], ce.Reason)
109110
return buf, nil
110111
}

statuscode_string.go

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

statuscode_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package websocket
2+
3+
import (
4+
"math"
5+
"strings"
6+
"testing"
7+
)
8+
9+
func TestCloseError(t *testing.T) {
10+
t.Parallel()
11+
12+
// Other parts of close error are tested by websocket_test.go right now
13+
// with the autobahn tests.
14+
15+
testCases := []struct {
16+
name string
17+
ce CloseError
18+
success bool
19+
}{
20+
{
21+
name: "normal",
22+
ce: CloseError{
23+
Code: StatusNormalClosure,
24+
Reason: strings.Repeat("x", maxControlFramePayload-2),
25+
},
26+
success: true,
27+
},
28+
{
29+
name: "bigReason",
30+
ce: CloseError{
31+
Code: StatusNormalClosure,
32+
Reason: strings.Repeat("x", maxControlFramePayload-1),
33+
},
34+
success: false,
35+
},
36+
{
37+
name: "bigCode",
38+
ce: CloseError{
39+
Code: math.MaxUint16,
40+
Reason: strings.Repeat("x", maxControlFramePayload-2),
41+
},
42+
success: false,
43+
},
44+
}
45+
46+
for _, tc := range testCases {
47+
tc := tc
48+
t.Run(tc.name, func(t *testing.T) {
49+
t.Parallel()
50+
51+
_, err := tc.ce.bytes()
52+
if (err == nil) != tc.success {
53+
t.Fatalf("unexpected error value: %v", err)
54+
}
55+
})
56+
}
57+
}

websocket.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ type control struct {
2323
type Conn struct {
2424
subprotocol string
2525
br *bufio.Reader
26-
// TODO Cannot use bufio writer because for compression we need to know how much is buffered and compress it if large.
26+
// TODO switch to []byte for write buffering because for messages larger than buffers, there will always be 3 writes. One for the frame, one for the message, one for the fin.
27+
// Also will help for compression.
2728
bw *bufio.Writer
2829
closer io.Closer
2930
client bool
@@ -225,12 +226,12 @@ func (c *Conn) handleControl(h header) {
225226
case opPong:
226227
case opClose:
227228
if len(b) > 0 {
228-
code, reason, err := parseClosePayload(b)
229+
ce, err := parseClosePayload(b)
229230
if err != nil {
230231
c.close(xerrors.Errorf("read invalid close payload: %w", err))
231232
return
232233
}
233-
c.Close(code, reason)
234+
c.Close(ce.Code, ce.Reason)
234235
} else {
235236
c.writeClose(nil, CloseError{
236237
Code: StatusNoStatusRcvd,
@@ -279,8 +280,7 @@ func (c *Conn) readLoop() {
279280
return
280281
}
281282
default:
282-
// TODO send back protocol violation message or figure out what RFC wants.
283-
c.close(xerrors.Errorf("unexpected opcode in header: %#v", h))
283+
c.Close(StatusProtocolError, fmt.Sprintf("unknown opcode %v", h.opcode))
284284
return
285285
}
286286

@@ -338,18 +338,23 @@ func (c *Conn) writePong(p []byte) error {
338338
// Close closes the WebSocket connection with the given status code and reason.
339339
// It will write a WebSocket close frame with a timeout of 5 seconds.
340340
func (c *Conn) Close(code StatusCode, reason string) error {
341+
ce := CloseError{
342+
Code: code,
343+
Reason: reason,
344+
}
345+
341346
// This function also will not wait for a close frame from the peer like the RFC
342347
// wants because that makes no sense and I don't think anyone actually follows that.
343348
// Definitely worth seeing what popular browsers do later.
344-
p, err := closePayload(code, reason)
349+
p, err := ce.bytes()
345350
if err != nil {
346-
p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code))
351+
ce = CloseError{
352+
Code: StatusInternalError,
353+
}
354+
p, _ = ce.bytes()
347355
}
348356

349-
cerr := c.writeClose(p, CloseError{
350-
Code: code,
351-
Reason: reason,
352-
})
357+
cerr := c.writeClose(p, ce)
353358
if err != nil {
354359
return err
355360
}

websocket_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import (
77
"io"
88
"io/ioutil"
99
"net/http"
10+
"net/http/cookiejar"
1011
"net/http/httptest"
12+
"net/url"
1113
"os"
1214
"os/exec"
1315
"reflect"
@@ -216,6 +218,52 @@ func TestHandshake(t *testing.T) {
216218
return nil
217219
},
218220
},
221+
{
222+
name: "cookies",
223+
server: func(w http.ResponseWriter, r *http.Request) error {
224+
cookie, err := r.Cookie("mycookie")
225+
if err != nil {
226+
return xerrors.Errorf("request is missing mycookie: %w", err)
227+
}
228+
if cookie.Value != "myvalue" {
229+
return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value)
230+
}
231+
c, err := websocket.Accept(w, r)
232+
if err != nil {
233+
return err
234+
}
235+
c.Close(websocket.StatusInternalError, "")
236+
return nil
237+
},
238+
client: func(ctx context.Context, u string) error {
239+
jar, err := cookiejar.New(nil)
240+
if err != nil {
241+
return xerrors.Errorf("failed to create cookie jar: %w", err)
242+
}
243+
parsedURL, err := url.Parse(u)
244+
if err != nil {
245+
return xerrors.Errorf("failed to parse url: %w", err)
246+
}
247+
parsedURL.Scheme = "http"
248+
jar.SetCookies(parsedURL, []*http.Cookie{
249+
{
250+
Name: "mycookie",
251+
Value: "myvalue",
252+
},
253+
})
254+
hc := &http.Client{
255+
Jar: jar,
256+
}
257+
c, _, err := websocket.Dial(ctx, u,
258+
websocket.DialHTTPClient(hc),
259+
)
260+
if err != nil {
261+
return err
262+
}
263+
c.Close(websocket.StatusInternalError, "")
264+
return nil
265+
},
266+
},
219267
}
220268

221269
for _, tc := range testCases {

0 commit comments

Comments
 (0)