Skip to content

Commit 48fda66

Browse files
authored
Merge pull request #48 from nhooyr/fixes
Improvements
2 parents bac4153 + 0c05d25 commit 48fda66

File tree

9 files changed

+466
-148
lines changed

9 files changed

+466
-148
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# websocket
22

33
[![GoDoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket)
4-
[![Codecov](https://img.shields.io/codecov/c/github/nhooyr/websocket.svg)](https://codecov.io/gh/nhooyr/websocket)
54
[![GitHub release](https://img.shields.io/github/release-pre/nhooyr/websocket.svg)](https://github.com/nhooyr/websocket/releases)
65

76
websocket is a minimal and idiomatic WebSocket library for Go.
@@ -23,6 +22,7 @@ go get nhooyr.io/websocket
2322
- net/http is used for WebSocket dials and upgrades
2423
- Passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
2524
- JSON helpers
25+
- Thoroughly testing
2626

2727
## Roadmap
2828

@@ -112,7 +112,9 @@ c.Close(websocket.StatusNormalClosure, "")
112112

113113
While I believe nhooyr/websocket has a better API than existing libraries,
114114
both gorilla/websocket and gobwas/ws were extremely useful in implementing the
115-
WebSocket protocol correctly so big thanks to the authors of both.
115+
WebSocket protocol correctly so big thanks to the authors of both. In particular,
116+
I made sure to go through the issue tracker of gorilla/websocket to make sure
117+
I implemented details correctly.
116118

117119
### gorilla/websocket
118120

accept.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ func (o acceptSubprotocols) acceptOption() {}
2323

2424
// AcceptSubprotocols list the subprotocols that Accept will negotiate with a client.
2525
// The first protocol that a client supports will be negotiated.
26-
// Pass "" as a subprotocol if you would like to allow the default protocol along with
27-
// specific subprotocols.
26+
// The empty protocol will always be negotiated as per RFC 6455. If you would like to
27+
// reject it, close the connection is c.Subprotocol() == "".
2828
func AcceptSubprotocols(subprotocols ...string) AcceptOption {
2929
return acceptSubprotocols(subprotocols)
3030
}
@@ -42,7 +42,7 @@ func (o acceptOrigins) acceptOption() {}
4242
// See https://stackoverflow.com/a/37837709/4283659
4343
// You can use a * for wildcards.
4444
func AcceptOrigins(origins ...string) AcceptOption {
45-
return AcceptOrigins(origins...)
45+
return acceptOrigins(origins)
4646
}
4747

4848
// Accept accepts a WebSocket handshake from a client and upgrades the
@@ -118,7 +118,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
118118

119119
netConn, brw, err := hj.Hijack()
120120
if err != nil {
121-
err = xerrors.Errorf("websocket: failed to hijack connection: %v", err)
121+
err = xerrors.Errorf("websocket: failed to hijack connection: %w", err)
122122
http.Error(w, err.Error(), http.StatusInternalServerError)
123123
return nil, err
124124
}
@@ -135,7 +135,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
135135
}
136136

137137
func selectSubprotocol(w http.ResponseWriter, r *http.Request, subprotocols []string) {
138-
clientSubprotocols := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), "\n")
138+
clientSubprotocols := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",")
139139
for _, sp := range subprotocols {
140140
for _, cp := range clientSubprotocols {
141141
if sp == strings.TrimSpace(cp) {
@@ -165,12 +165,12 @@ func authenticateOrigin(r *http.Request, origins []string) error {
165165
}
166166
u, err := url.Parse(origin)
167167
if err != nil {
168-
return xerrors.Errorf("failed to parse Origin header %q: %v", origin, err)
168+
return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err)
169169
}
170170
for _, o := range origins {
171-
if u.Host == o {
171+
if strings.EqualFold(u.Host, o) {
172172
return nil
173173
}
174174
}
175-
return xerrors.New("request origin is not authorized")
175+
return xerrors.Errorf("request origin %q is not authorized", r.Header.Get("Origin"))
176176
}

datatype.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package websocket
22

33
// DataType represents the Opcode of a WebSocket data frame.
4-
//go:generate go run golang.org/x/tools/cmd/stringer -type=DataType
54
type DataType int
65

6+
//go:generate go run golang.org/x/tools/cmd/stringer -type=DataType
7+
78
// DataType constants.
89
const (
910
Text DataType = DataType(opText)

dial.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,19 @@ func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.R
7474

7575
parsedURL, err := url.Parse(u)
7676
if err != nil {
77-
return nil, nil, xerrors.Errorf("failed to parse websocket url: %v", err)
77+
return nil, nil, xerrors.Errorf("failed to parse websocket url: %w", err)
7878
}
7979

8080
switch parsedURL.Scheme {
81-
case "ws", "http":
81+
case "ws":
8282
parsedURL.Scheme = "http"
83-
case "wss", "https":
83+
case "wss":
8484
parsedURL.Scheme = "https"
8585
default:
8686
return nil, nil, xerrors.Errorf("unknown scheme in url: %q", parsedURL.Scheme)
8787
}
8888

89-
req, _ := http.NewRequest("GET", u, nil)
89+
req, _ := http.NewRequest("GET", parsedURL.String(), nil)
9090
req = req.WithContext(ctx)
9191
req.Header = header
9292
req.Header.Set("Connection", "Upgrade")
@@ -113,7 +113,7 @@ func Dial(ctx context.Context, u string, opts ...DialOption) (_ *Conn, _ *http.R
113113
}()
114114

115115
if resp.StatusCode != http.StatusSwitchingProtocols {
116-
return nil, resp, xerrors.Errorf("websocket: expected status code %v but got %v", http.StatusSwitchingProtocols)
116+
return nil, resp, xerrors.Errorf("websocket: expected status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
117117
}
118118

119119
if !httpguts.HeaderValuesContainsToken(resp.Header["Connection"], "Upgrade") {

json.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@ import (
1212
func ReadJSON(ctx context.Context, c *Conn, v interface{}) error {
1313
typ, r, err := c.ReadMessage(ctx)
1414
if err != nil {
15-
return xerrors.Errorf("failed to read json: %v", err)
15+
return xerrors.Errorf("failed to read json: %w", err)
1616
}
1717

1818
if typ != websocket.TextFrame {
1919
return xerrors.Errorf("unexpected frame type for json (expected TextFrame): %v", typ)
2020
}
2121

22+
r.Limit(131072)
23+
r.SetContext(ctx)
24+
2225
d := json.NewDecoder(r)
2326
err = d.Decode(v)
2427
if err != nil {
25-
return xerrors.Errorf("failed to read json: %v", err)
28+
return xerrors.Errorf("failed to read json: %w", err)
2629
}
2730
return nil
2831
}
@@ -31,14 +34,16 @@ func ReadJSON(ctx context.Context, c *Conn, v interface{}) error {
3134
func WriteJSON(ctx context.Context, c *Conn, v interface{}) error {
3235
w := c.MessageWriter(websocket.TextFrame)
3336
w.SetContext(ctx)
37+
3438
e := json.NewEncoder(w)
3539
err := e.Encode(v)
3640
if err != nil {
37-
return xerrors.Errorf("failed to write json: %v", err)
41+
return xerrors.Errorf("failed to write json: %w", err)
3842
}
43+
3944
err = w.Close()
4045
if err != nil {
41-
return xerrors.Errorf("failed to write json: %v", err)
46+
return xerrors.Errorf("failed to write json: %w", err)
4247
}
4348
return nil
4449
}

opcode.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
package websocket
22

33
// opcode represents a WebSocket Opcode.
4-
//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode
54
type opcode int
65

6+
//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode
7+
78
// opcode constants.
89
const (
910
opContinuation opcode = iota
1011
opText
1112
opBinary
1213
// 3 - 7 are reserved for further non-control frames.
13-
opClose opcode = 8 + iota - 3
14+
_
15+
_
16+
_
17+
_
18+
_
19+
opClose
1420
opPing
1521
opPong
1622
// 11-16 are reserved for further control frames.

statuscode.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,19 @@ import (
1111
)
1212

1313
// StatusCode represents a WebSocket status code.
14-
//go:generate go run golang.org/x/tools/cmd/stringer -type=StatusCode
1514
type StatusCode int
1615

16+
//go:generate go run golang.org/x/tools/cmd/stringer -type=StatusCode
17+
1718
// These codes were retrieved from:
1819
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
1920
const (
2021
StatusNormalClosure StatusCode = 1000 + iota
2122
StatusGoingAway
2223
StatusProtocolError
2324
StatusUnsupportedData
24-
// 1004 is reserved.
25-
StatusNoStatusRcvd StatusCode = 1005 + iota - 4
25+
_ // 1004 is reserved.
26+
StatusNoStatusRcvd
2627
StatusAbnormalClosure
2728
StatusInvalidFramePayloadData
2829
StatusPolicyViolation

websocket.go

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (c *Conn) getCloseErr() error {
5858

5959
func (c *Conn) close(err error) {
6060
if err != nil {
61-
err = xerrors.Errorf("websocket: connection broken: %v", err)
61+
err = xerrors.Errorf("websocket: connection broken: %w", err)
6262
}
6363

6464
c.closeOnce.Do(func() {
@@ -102,20 +102,20 @@ func (c *Conn) writeFrame(h header, p []byte) {
102102
b2 := marshalHeader(h)
103103
_, err := c.bw.Write(b2)
104104
if err != nil {
105-
c.close(xerrors.Errorf("failed to write to connection: %v", err))
105+
c.close(xerrors.Errorf("failed to write to connection: %w", err))
106106
return
107107
}
108108

109109
_, err = c.bw.Write(p)
110110
if err != nil {
111-
c.close(xerrors.Errorf("failed to write to connection: %v", err))
111+
c.close(xerrors.Errorf("failed to write to connection: %w", err))
112112
return
113113
}
114114

115115
if h.opcode.controlOp() {
116116
err := c.bw.Flush()
117117
if err != nil {
118-
c.close(xerrors.Errorf("failed to write to connection: %v", err))
118+
c.close(xerrors.Errorf("failed to write to connection: %w", err))
119119
return
120120
}
121121
}
@@ -139,7 +139,11 @@ messageLoop:
139139
masked: c.client,
140140
}
141141
c.writeFrame(h, control.payload)
142-
c.writeDone <- struct{}{}
142+
select {
143+
case <-c.closed:
144+
return
145+
case c.writeDone <- struct{}{}:
146+
}
143147
continue
144148
}
145149

@@ -176,7 +180,7 @@ messageLoop:
176180
if !ok {
177181
err := c.bw.Flush()
178182
if err != nil {
179-
c.close(xerrors.Errorf("failed to write to connection: %v", err))
183+
c.close(xerrors.Errorf("failed to write to connection: %w", err))
180184
return
181185
}
182186
}
@@ -210,7 +214,7 @@ func (c *Conn) handleControl(h header) {
210214
b := make([]byte, h.payloadLength)
211215
_, err := io.ReadFull(c.br, b)
212216
if err != nil {
213-
c.close(xerrors.Errorf("failed to read control frame payload: %v", err))
217+
c.close(xerrors.Errorf("failed to read control frame payload: %w", err))
214218
return
215219
}
216220

@@ -226,7 +230,7 @@ func (c *Conn) handleControl(h header) {
226230
if len(b) > 0 {
227231
code, reason, err := parseClosePayload(b)
228232
if err != nil {
229-
c.close(xerrors.Errorf("read invalid close payload: %v", err))
233+
c.close(xerrors.Errorf("read invalid close payload: %w", err))
230234
return
231235
}
232236
c.Close(code, reason)
@@ -247,7 +251,7 @@ func (c *Conn) readLoop() {
247251
for {
248252
h, err := readHeader(c.br)
249253
if err != nil {
250-
c.close(xerrors.Errorf("failed to read header: %v", err))
254+
c.close(xerrors.Errorf("failed to read header: %w", err))
251255
return
252256
}
253257

@@ -280,6 +284,7 @@ func (c *Conn) readLoop() {
280284
return
281285
}
282286
default:
287+
// TODO send back protocol violation message or figure out what RFC wants.
283288
c.close(xerrors.Errorf("unexpected opcode in header: %#v", h))
284289
return
285290
}
@@ -298,7 +303,7 @@ func (c *Conn) readLoop() {
298303

299304
_, err = io.ReadFull(c.br, b)
300305
if err != nil {
301-
c.close(xerrors.Errorf("failed to read from connection: %v", err))
306+
c.close(xerrors.Errorf("failed to read from connection: %w", err))
302307
return
303308
}
304309
left -= int64(len(b))
@@ -355,14 +360,14 @@ func (c *Conn) MessageWriter(dataType DataType) *MessageWriter {
355360
func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) {
356361
select {
357362
case <-c.closed:
358-
return 0, nil, xerrors.Errorf("failed to read message: %v", c.getCloseErr())
363+
return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr())
359364
case opcode := <-c.read:
360365
return DataType(opcode), &MessageReader{
361366
ctx: context.Background(),
362367
c: c,
363368
}, nil
364369
case <-ctx.Done():
365-
return 0, nil, xerrors.Errorf("failed to read message: %v", ctx.Err())
370+
return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err())
366371
}
367372
}
368373

@@ -481,8 +486,14 @@ func (w *MessageWriter) Close() error {
481486
}
482487
}
483488
close(w.c.writeBytes)
484-
<-w.c.writeDone
485-
return nil
489+
select {
490+
case <-w.c.closed:
491+
return w.c.getCloseErr()
492+
case <-w.ctx.Done():
493+
return w.ctx.Err()
494+
case <-w.c.writeDone:
495+
return nil
496+
}
486497
}
487498

488499
// MessageReader enables reading a data frame from the WebSocket connection.
@@ -501,6 +512,14 @@ func (r *MessageReader) SetContext(ctx context.Context) {
501512
}
502513

503514
// Limit limits the number of bytes read by the reader.
515+
//
516+
// Why not use io.LimitReader? io.LimitReader returns a io.EOF
517+
// after the limit bytes which means its not possible to tell
518+
// whether the message has been read or a limit has been hit.
519+
// This results in unclear error and log messages.
520+
// This function will cause the connection to be closed if the limit is hit
521+
// with a close reason explaining the error and also an error
522+
// indicating the limit was hit.
504523
func (r *MessageReader) Limit(bytes int) {
505524
r.limit = bytes
506525
}

0 commit comments

Comments
 (0)