Skip to content

Commit 2b135ff

Browse files
authored
feat(websocket): change websocket lib to nhooyr.io/websocket (#815)
Fixes #713, #543, and #664.
1 parent f9ddeb1 commit 2b135ff

File tree

4 files changed

+66
-50
lines changed

4 files changed

+66
-50
lines changed

Gopkg.lock

Lines changed: 32 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Gopkg.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ required = [
1515
version = "1.1.0"
1616

1717
[[constraint]]
18-
name = "github.com/gorilla/websocket"
19-
version = "1.2.0"
18+
name = "nhooyr.io/websocket"
19+
version = "1.8.6"
2020

2121
[[constraint]]
2222
branch = "master"

go/grpcweb/websocket_wrapper.go

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ import (
1313
"time"
1414

1515
"github.com/desertbit/timer"
16-
"github.com/gorilla/websocket"
1716
"golang.org/x/net/http2"
17+
"nhooyr.io/websocket"
1818
)
1919

2020
type webSocketResponseWriter struct {
@@ -24,40 +24,34 @@ type webSocketResponseWriter struct {
2424
flushedHeaders http.Header
2525
timeOutInterval time.Duration
2626
timer *timer.Timer
27+
context context.Context
2728
}
2829

29-
func newWebSocketResponseWriter(wsConn *websocket.Conn) *webSocketResponseWriter {
30+
func newWebSocketResponseWriter(ctx context.Context, wsConn *websocket.Conn) *webSocketResponseWriter {
3031
return &webSocketResponseWriter{
3132
writtenHeaders: false,
3233
headers: make(http.Header),
3334
flushedHeaders: make(http.Header),
3435
wsConn: wsConn,
36+
context: ctx,
3537
}
3638
}
3739

3840
func (w *webSocketResponseWriter) enablePing(timeOutInterval time.Duration) {
3941
w.timeOutInterval = timeOutInterval
4042
w.timer = timer.NewTimer(w.timeOutInterval)
41-
dispose := make(chan bool)
42-
w.wsConn.SetCloseHandler(func(code int, text string) error {
43-
close(dispose)
44-
return nil
45-
})
46-
go w.ping(dispose)
43+
go w.ping()
4744
}
4845

49-
func (w *webSocketResponseWriter) ping(dispose chan bool) {
50-
if dispose == nil {
51-
return
52-
}
46+
func (w *webSocketResponseWriter) ping() {
5347
defer w.timer.Stop()
5448
for {
5549
select {
56-
case <-dispose:
50+
case <-w.context.Done():
5751
return
5852
case <-w.timer.C:
5953
w.timer.Reset(w.timeOutInterval)
60-
w.wsConn.WriteMessage(websocket.PingMessage, []byte{})
54+
w.wsConn.Ping(w.context)
6155
}
6256
}
6357
}
@@ -73,16 +67,16 @@ func (w *webSocketResponseWriter) Write(b []byte) (int, error) {
7367
if w.timeOutInterval > time.Second && w.timer != nil {
7468
w.timer.Reset(w.timeOutInterval)
7569
}
76-
return len(b), w.wsConn.WriteMessage(websocket.BinaryMessage, b)
70+
return len(b), w.wsConn.Write(w.context, websocket.MessageBinary, b)
7771
}
7872

7973
func (w *webSocketResponseWriter) writeHeaderFrame(headers http.Header) {
8074
headerBuffer := new(bytes.Buffer)
8175
headers.Write(headerBuffer)
8276
headerGrpcDataHeader := []byte{1 << 7, 0, 0, 0, 0} // MSB=1 indicates this is a header data frame.
8377
binary.BigEndian.PutUint32(headerGrpcDataHeader[1:5], uint32(headerBuffer.Len()))
84-
w.wsConn.WriteMessage(websocket.BinaryMessage, headerGrpcDataHeader)
85-
w.wsConn.WriteMessage(websocket.BinaryMessage, headerBuffer.Bytes())
78+
w.wsConn.Write(w.context, websocket.MessageBinary, headerGrpcDataHeader)
79+
w.wsConn.Write(w.context, websocket.MessageBinary, headerBuffer.Bytes())
8680
}
8781

8882
func (w *webSocketResponseWriter) copyFlushedHeaders() {
@@ -127,12 +121,13 @@ type webSocketWrappedReader struct {
127121
respWriter *webSocketResponseWriter
128122
remainingBuffer []byte
129123
remainingError error
124+
context context.Context
130125
cancel context.CancelFunc
131126
}
132127

133128
func (w *webSocketWrappedReader) Close() error {
134129
w.respWriter.FlushTrailers()
135-
return w.wsConn.Close()
130+
return w.wsConn.Close(websocket.StatusNormalClosure, "request body closed")
136131
}
137132

138133
// First byte of a binary WebSocket frame is used for control flow:
@@ -167,15 +162,15 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) {
167162
}
168163

169164
// Read a whole frame from the WebSocket connection
170-
messageType, framePayload, err := w.wsConn.ReadMessage()
171-
if err == io.EOF || messageType == -1 {
165+
messageType, framePayload, err := w.wsConn.Read(w.context)
166+
if err == io.EOF || messageType == 0 {
172167
// The client has closed the connection. Indicate to the response writer that it should close
173168
w.cancel()
174169
return 0, io.EOF
175170
}
176171

177172
// Only Binary frames are valid
178-
if messageType != websocket.BinaryMessage {
173+
if messageType != websocket.MessageBinary {
179174
return 0, errors.New("websocket frame was not a binary frame")
180175
}
181176

@@ -211,12 +206,13 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) {
211206
return len(p), nil
212207
}
213208

214-
func newWebsocketWrappedReader(wsConn *websocket.Conn, respWriter *webSocketResponseWriter, cancel context.CancelFunc) *webSocketWrappedReader {
209+
func newWebsocketWrappedReader(ctx context.Context, wsConn *websocket.Conn, respWriter *webSocketResponseWriter, cancel context.CancelFunc) *webSocketWrappedReader {
215210
return &webSocketWrappedReader{
216211
wsConn: wsConn,
217212
respWriter: respWriter,
218213
remainingBuffer: nil,
219214
remainingError: nil,
215+
context: ctx,
220216
cancel: cancel,
221217
}
222218
}

go/grpcweb/wrapper.go

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ import (
1111
"strings"
1212
"time"
1313

14-
"github.com/gorilla/websocket"
1514
"github.com/rs/cors"
1615
"google.golang.org/grpc"
1716
"google.golang.org/grpc/grpclog"
17+
"nhooyr.io/websocket"
1818
)
1919

2020
var (
@@ -147,18 +147,15 @@ func (w *WrappedGrpcServer) HandleGrpcWebRequest(resp http.ResponseWriter, req *
147147
intResp.finishRequest(req)
148148
}
149149

150-
var websocketUpgrader = websocket.Upgrader{
151-
ReadBufferSize: 1024,
152-
WriteBufferSize: 1024,
153-
CheckOrigin: func(r *http.Request) bool { return true },
154-
Subprotocols: []string{"grpc-websockets"},
155-
}
156-
157150
// HandleGrpcWebsocketRequest takes a HTTP request that is assumed to be a gRPC-Websocket request and wraps it with a
158151
// compatibility layer to transform it to a standard gRPC request for the wrapped gRPC server and transforms the
159152
// response to comply with the gRPC-Web protocol.
160153
func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter, req *http.Request) {
161-
wsConn, err := websocketUpgrader.Upgrade(resp, req, nil)
154+
155+
wsConn, err := websocket.Accept(resp, req, &websocket.AcceptOptions{
156+
InsecureSkipVerify: true, // managed by ServeHTTP
157+
Subprotocols: []string{"grpc-websockets"},
158+
})
162159
if err != nil {
163160
grpclog.Errorf("Unable to upgrade websocket request: %v", err)
164161
return
@@ -170,13 +167,16 @@ func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter,
170167
}
171168
}
172169

173-
messageType, readBytes, err := wsConn.ReadMessage()
170+
ctx, cancelFunc := context.WithCancel(req.Context())
171+
defer cancelFunc()
172+
173+
messageType, readBytes, err := wsConn.Read(ctx)
174174
if err != nil {
175-
grpclog.Errorf("Unable to read first websocket message: %v", err)
175+
grpclog.Errorf("Unable to read first websocket message: %v %v %v", messageType, readBytes, err)
176176
return
177177
}
178178

179-
if messageType != websocket.BinaryMessage {
179+
if messageType != websocket.MessageBinary {
180180
grpclog.Errorf("First websocket message is non-binary")
181181
return
182182
}
@@ -187,14 +187,11 @@ func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter,
187187
return
188188
}
189189

190-
ctx, cancelFunc := context.WithCancel(req.Context())
191-
defer cancelFunc()
192-
193-
respWriter := newWebSocketResponseWriter(wsConn)
190+
respWriter := newWebSocketResponseWriter(ctx, wsConn)
194191
if w.opts.websocketPingInterval >= time.Second {
195192
respWriter.enablePing(w.opts.websocketPingInterval)
196193
}
197-
wrappedReader := newWebsocketWrappedReader(wsConn, respWriter, cancelFunc)
194+
wrappedReader := newWebsocketWrappedReader(ctx, wsConn, respWriter, cancelFunc)
198195

199196
for name, values := range wsHeaders {
200197
headers[name] = values

0 commit comments

Comments
 (0)