Skip to content

Commit f685c8d

Browse files
committed
Improve speed and add a benchmark
1 parent 696af24 commit f685c8d

File tree

7 files changed

+282
-164
lines changed

7 files changed

+282
-164
lines changed

accept.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ func AcceptInsecureOrigin() AcceptOption {
5353

5454
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
5555
if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") {
56-
err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection"))
56+
err := xerrors.Errorf("websocket: protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
5757
http.Error(w, err.Error(), http.StatusBadRequest)
5858
return err
5959
}
6060

6161
if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") {
62-
err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade"))
62+
err := xerrors.Errorf("websocket: protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
6363
http.Error(w, err.Error(), http.StatusBadRequest)
6464
return err
6565
}
6666

6767
if r.Method != "GET" {
68-
err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method)
68+
err := xerrors.Errorf("websocket: protocol violation: handshake request method %q is not GET", r.Method)
6969
http.Error(w, err.Error(), http.StatusBadRequest)
7070
return err
7171
}
@@ -88,7 +88,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
8888
// Accept accepts a WebSocket handshake from a client and upgrades the
8989
// the connection to WebSocket.
9090
// Accept will reject the handshake if the Origin is not the same as the Host unless
91-
// InsecureAcceptOrigin is passed.
91+
// the AcceptInsecureOrigin option is passed.
9292
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
9393
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
9494
var subprotocols []string

bench_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package websocket_test
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"nhooyr.io/websocket"
8+
"strings"
9+
"testing"
10+
"time"
11+
)
12+
13+
func BenchmarkConn(b *testing.B) {
14+
b.StopTimer()
15+
16+
s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
17+
c, err := websocket.Accept(w, r,
18+
websocket.AcceptSubprotocols("echo"),
19+
)
20+
if err != nil {
21+
b.Logf("server handshake failed: %+v", err)
22+
return
23+
}
24+
echoLoop(r.Context(), c)
25+
}))
26+
defer closeFn()
27+
28+
wsURL := strings.Replace(s.URL, "http", "ws", 1)
29+
30+
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
31+
defer cancel()
32+
33+
c, _, err := websocket.Dial(ctx, wsURL)
34+
if err != nil {
35+
b.Fatalf("failed to dial: %v", err)
36+
}
37+
defer c.Close(websocket.StatusInternalError, "")
38+
39+
msg := strings.Repeat("2", 4096*16)
40+
buf := make([]byte, len(msg))
41+
b.SetBytes(int64(len(msg)))
42+
b.StartTimer()
43+
for i := 0; i < b.N; i++ {
44+
w, err := c.Write(ctx, websocket.MessageText)
45+
if err != nil {
46+
b.Fatal(err)
47+
}
48+
49+
_, err = io.WriteString(w, msg)
50+
if err != nil {
51+
b.Fatal(err)
52+
}
53+
54+
err = w.Close()
55+
if err != nil {
56+
b.Fatal(err)
57+
}
58+
59+
_, r, err := c.Read(ctx)
60+
if err != nil {
61+
b.Fatal(err, b.N)
62+
}
63+
64+
_, err = io.ReadFull(r, buf)
65+
if err != nil {
66+
b.Fatal(err)
67+
}
68+
69+
// TODO jank
70+
_, err = r.Read(nil)
71+
if err != io.EOF {
72+
b.Fatalf("wtf %q", err)
73+
}
74+
}
75+
b.StopTimer()
76+
c.Close(websocket.StatusNormalClosure, "")
77+
}

example_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ func ExampleAccept_echo() {
3434
if err != nil {
3535
return err
3636
}
37-
3837
r = io.LimitReader(r, 32768)
3938

40-
w := c.Write(ctx, typ)
39+
w, err := c.Write(ctx, typ)
40+
if err != nil {
41+
return err
42+
}
43+
4144
_, err = io.Copy(w, r)
4245
if err != nil {
4346
return err

json.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func (jc JSONConn) Read(ctx context.Context, v interface{}) error {
2222
return nil
2323
}
2424

25-
func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
25+
func (jc JSONConn) read(ctx context.Context, v interface{}) error {
2626
typ, r, err := jc.Conn.Read(ctx)
2727
if err != nil {
2828
return err
@@ -53,10 +53,13 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error {
5353
}
5454

5555
func (jc JSONConn) write(ctx context.Context, v interface{}) error {
56-
w := jc.Conn.Write(ctx, MessageText)
56+
w, err := jc.Conn.Write(ctx, MessageText)
57+
if err != nil {
58+
return xerrors.Errorf("failed to get message writer: %w", err)
59+
}
5760

5861
e := json.NewEncoder(w)
59-
err := e.Encode(v)
62+
err = e.Encode(v)
6063
if err != nil {
6164
return xerrors.Errorf("failed to encode json: %w", err)
6265
}

statuscode.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"math/bits"
8-
"unicode/utf8"
98

109
"golang.org/x/xerrors"
1110
)
@@ -54,6 +53,12 @@ func (ce CloseError) Error() string {
5453
}
5554

5655
func parseClosePayload(p []byte) (CloseError, error) {
56+
if len(p) == 0 {
57+
return CloseError{
58+
Code: StatusNoStatusRcvd,
59+
}, nil
60+
}
61+
5762
if len(p) < 2 {
5863
return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
5964
}
@@ -63,9 +68,6 @@ func parseClosePayload(p []byte) (CloseError, error) {
6368
Reason: string(p[2:]),
6469
}
6570

66-
if !utf8.ValidString(ce.Reason) {
67-
return CloseError{}, xerrors.Errorf("invalid utf-8: %q", ce.Reason)
68-
}
6971
if !validWireCloseCode(ce.Code) {
7072
return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code)
7173
}

0 commit comments

Comments
 (0)