Skip to content

Commit 791d82e

Browse files
authored
Merge pull request #56 from nhooyr/better
Fix remaining major bugs and UX issues
2 parents d67546a + fb12459 commit 791d82e

File tree

11 files changed

+426
-215
lines changed

11 files changed

+426
-215
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
coverage.html
22
wstest_reports
3+
websocket.test

accept.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ func AcceptInsecureOrigin() AcceptOption {
5353

5454
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
5555
if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") {
56-
err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection"))
56+
err := xerrors.Errorf("websocket: protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
5757
http.Error(w, err.Error(), http.StatusBadRequest)
5858
return err
5959
}
6060

6161
if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") {
62-
err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade"))
62+
err := xerrors.Errorf("websocket: protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
6363
http.Error(w, err.Error(), http.StatusBadRequest)
6464
return err
6565
}
6666

6767
if r.Method != "GET" {
68-
err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method)
68+
err := xerrors.Errorf("websocket: protocol violation: handshake request method %q is not GET", r.Method)
6969
http.Error(w, err.Error(), http.StatusBadRequest)
7070
return err
7171
}
@@ -88,7 +88,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
8888
// Accept accepts a WebSocket handshake from a client and upgrades the
8989
// the connection to WebSocket.
9090
// Accept will reject the handshake if the Origin is not the same as the Host unless
91-
// InsecureAcceptOrigin is passed.
91+
// the AcceptInsecureOrigin option is passed.
9292
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
9393
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
9494
var subprotocols []string

bench_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package websocket_test
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"strconv"
8+
"strings"
9+
"testing"
10+
"time"
11+
12+
"nhooyr.io/websocket"
13+
)
14+
15+
func BenchmarkConn(b *testing.B) {
16+
b.StopTimer()
17+
18+
s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19+
c, err := websocket.Accept(w, r,
20+
websocket.AcceptSubprotocols("echo"),
21+
)
22+
if err != nil {
23+
b.Logf("server handshake failed: %+v", err)
24+
return
25+
}
26+
echoLoop(r.Context(), c)
27+
}))
28+
defer closeFn()
29+
30+
wsURL := strings.Replace(s.URL, "http", "ws", 1)
31+
32+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
33+
defer cancel()
34+
35+
c, _, err := websocket.Dial(ctx, wsURL)
36+
if err != nil {
37+
b.Fatalf("failed to dial: %v", err)
38+
}
39+
defer c.Close(websocket.StatusInternalError, "")
40+
41+
runN := func(n int) {
42+
b.Run(strconv.Itoa(n), func(b *testing.B) {
43+
msg := []byte(strings.Repeat("2", n))
44+
buf := make([]byte, len(msg))
45+
b.SetBytes(int64(len(msg)))
46+
b.ResetTimer()
47+
for i := 0; i < b.N; i++ {
48+
w, err := c.Write(ctx, websocket.MessageText)
49+
if err != nil {
50+
b.Fatal(err)
51+
}
52+
53+
_, err = w.Write(msg)
54+
if err != nil {
55+
b.Fatal(err)
56+
}
57+
58+
err = w.Close()
59+
if err != nil {
60+
b.Fatal(err)
61+
}
62+
63+
_, r, err := c.Read(ctx)
64+
if err != nil {
65+
b.Fatal(err, b.N)
66+
}
67+
68+
_, err = io.ReadFull(r, buf)
69+
if err != nil {
70+
b.Fatal(err)
71+
}
72+
}
73+
b.StopTimer()
74+
})
75+
}
76+
77+
runN(32)
78+
runN(128)
79+
runN(512)
80+
runN(1024)
81+
runN(4096)
82+
runN(16384)
83+
runN(65536)
84+
runN(131072)
85+
86+
c.Close(websocket.StatusNormalClosure, "")
87+
}

dial_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
)
88

99
func Test_verifyServerHandshake(t *testing.T) {
10+
t.Parallel()
11+
1012
testCases := []struct {
1113
name string
1214
response func(w http.ResponseWriter)

example_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ func ExampleAccept_echo() {
3434
if err != nil {
3535
return err
3636
}
37-
3837
r = io.LimitReader(r, 32768)
3938

40-
w := c.Write(ctx, typ)
39+
w, err := c.Write(ctx, typ)
40+
if err != nil {
41+
return err
42+
}
43+
4144
_, err = io.Copy(w, r)
4245
if err != nil {
4346
return err
@@ -76,7 +79,7 @@ func ExampleAccept() {
7679
log.Printf("server handshake failed: %v", err)
7780
return
7881
}
79-
defer c.Close(websocket.StatusInternalError, "") // TODO returning internal is incorect if its a timeout error.
82+
defer c.Close(websocket.StatusInternalError, "")
8083

8184
jc := websocket.JSONConn{
8285
Conn: c,

header.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/binary"
55
"fmt"
66
"io"
7+
"math"
78

89
"golang.org/x/xerrors"
910
)
@@ -55,7 +56,7 @@ func marshalHeader(h header) []byte {
5556
panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength))
5657
case h.payloadLength <= 125:
5758
b[1] = byte(h.payloadLength)
58-
case h.payloadLength <= 1<<16:
59+
case h.payloadLength <= math.MaxUint16:
5960
b[1] = 126
6061
b = b[:len(b)+2]
6162
binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength))
@@ -105,10 +106,8 @@ func readHeader(r io.Reader) (header, error) {
105106
case payloadLength < 126:
106107
h.payloadLength = int64(payloadLength)
107108
case payloadLength == 126:
108-
h.payloadLength = 126
109109
extra += 2
110110
case payloadLength == 127:
111-
h.payloadLength = 127
112111
extra += 8
113112
}
114113

header_test.go

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package websocket
33
import (
44
"bytes"
55
"math/rand"
6+
"strconv"
67
"testing"
78
"time"
89

@@ -36,10 +37,38 @@ func TestHeader(t *testing.T) {
3637
t.Fatalf("unexpected error value: %+v", err)
3738
}
3839
})
40+
41+
t.Run("lengths", func(t *testing.T) {
42+
t.Parallel()
43+
44+
lengths := []int{
45+
124,
46+
125,
47+
126,
48+
4096,
49+
16384,
50+
65535,
51+
65536,
52+
65537,
53+
131072,
54+
}
55+
56+
for _, n := range lengths {
57+
n := n
58+
t.Run(strconv.Itoa(n), func(t *testing.T) {
59+
t.Parallel()
60+
61+
testHeader(t, header{
62+
payloadLength: int64(n),
63+
})
64+
})
65+
}
66+
})
67+
3968
t.Run("fuzz", func(t *testing.T) {
4069
t.Parallel()
4170

42-
for i := 0; i < 1000; i++ {
71+
for i := 0; i < 10000; i++ {
4372
h := header{
4473
fin: randBool(),
4574
rsv1: randBool(),
@@ -55,20 +84,24 @@ func TestHeader(t *testing.T) {
5584
rand.Read(h.maskKey[:])
5685
}
5786

58-
b := marshalHeader(h)
59-
r := bytes.NewReader(b)
60-
h2, err := readHeader(r)
61-
if err != nil {
62-
t.Logf("header: %#v", h)
63-
t.Logf("bytes: %b", b)
64-
t.Fatalf("failed to read header: %v", err)
65-
}
66-
67-
if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) {
68-
t.Logf("header: %#v", h)
69-
t.Logf("bytes: %b", b)
70-
t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{})))
71-
}
87+
testHeader(t, h)
7288
}
7389
})
7490
}
91+
92+
func testHeader(t *testing.T, h header) {
93+
b := marshalHeader(h)
94+
r := bytes.NewReader(b)
95+
h2, err := readHeader(r)
96+
if err != nil {
97+
t.Logf("header: %#v", h)
98+
t.Logf("bytes: %b", b)
99+
t.Fatalf("failed to read header: %v", err)
100+
}
101+
102+
if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) {
103+
t.Logf("header: %#v", h)
104+
t.Logf("bytes: %b", b)
105+
t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{})))
106+
}
107+
}

json.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func (jc JSONConn) Read(ctx context.Context, v interface{}) error {
2222
return nil
2323
}
2424

25-
func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
25+
func (jc JSONConn) read(ctx context.Context, v interface{}) error {
2626
typ, r, err := jc.Conn.Read(ctx)
2727
if err != nil {
2828
return err
@@ -53,10 +53,13 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error {
5353
}
5454

5555
func (jc JSONConn) write(ctx context.Context, v interface{}) error {
56-
w := jc.Conn.Write(ctx, MessageText)
56+
w, err := jc.Conn.Write(ctx, MessageText)
57+
if err != nil {
58+
return xerrors.Errorf("failed to get message writer: %w", err)
59+
}
5760

5861
e := json.NewEncoder(w)
59-
err := e.Encode(v)
62+
err = e.Encode(v)
6063
if err != nil {
6164
return xerrors.Errorf("failed to encode json: %w", err)
6265
}

statuscode.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"math/bits"
8-
"unicode/utf8"
98

109
"golang.org/x/xerrors"
1110
)
@@ -54,6 +53,12 @@ func (ce CloseError) Error() string {
5453
}
5554

5655
func parseClosePayload(p []byte) (CloseError, error) {
56+
if len(p) == 0 {
57+
return CloseError{
58+
Code: StatusNoStatusRcvd,
59+
}, nil
60+
}
61+
5762
if len(p) < 2 {
5863
return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
5964
}
@@ -63,9 +68,6 @@ func parseClosePayload(p []byte) (CloseError, error) {
6368
Reason: string(p[2:]),
6469
}
6570

66-
if !utf8.ValidString(ce.Reason) {
67-
return CloseError{}, xerrors.Errorf("invalid utf-8: %q", ce.Reason)
68-
}
6971
if !validWireCloseCode(ce.Code) {
7072
return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code)
7173
}

0 commit comments

Comments
 (0)