Skip to content

Commit 6238fd9

Browse files
committed
TUN-5141: Make sure websocket pinger returns before streaming returns
1 parent f985ed5 commit 6238fd9

File tree

4 files changed

+95
-13
lines changed

4 files changed

+95
-13
lines changed

connection/http2.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"math"
88
"net"
99
"net/http"
10+
"runtime/debug"
1011
"strings"
1112
"sync"
1213

@@ -100,7 +101,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
100101
connType := determineHTTP2Type(r)
101102
handleMissingRequestParts(connType, r)
102103

103-
respWriter, err := newHTTP2RespWriter(r, w, connType)
104+
respWriter, err := NewHTTP2RespWriter(r, w, connType)
104105
if err != nil {
105106
c.observer.log.Error().Msg(err.Error())
106107
return
@@ -159,7 +160,7 @@ type http2RespWriter struct {
159160
shouldFlush bool
160161
}
161162

162-
func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) {
163+
func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) {
163164
flusher, isFlusher := w.(http.Flusher)
164165
if !isFlusher {
165166
respWriter := &http2RespWriter{
@@ -231,7 +232,7 @@ func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
231232
// Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
232233
// Register a recover routine just in case.
233234
if r := recover(); r != nil {
234-
println("Recover from http2 response writer panic, error", r)
235+
println(fmt.Sprintf("Recover from http2 response writer panic, error %s", debug.Stack()))
235236
}
236237
}()
237238
n, err = rp.w.Write(p)

ingress/origin_connection.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ type tcpOverWSConnection struct {
4848
}
4949

5050
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
51-
wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log)
51+
wsCtx, cancel := context.WithCancel(ctx)
52+
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
53+
wc.streamHandler(wsConn, wc.conn, log)
54+
cancel()
55+
// Makes sure wsConn stops sending ping before terminating the stream
56+
wsConn.WaitForShutdown()
5257
}
5358

5459
func (wc *tcpOverWSConnection) Close() {
@@ -63,7 +68,12 @@ type socksProxyOverWSConnection struct {
6368
}
6469

6570
func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
66-
socks.StreamNetHandler(websocket.NewConn(ctx, tunnelConn, log), sp.accessPolicy, log)
71+
wsCtx, cancel := context.WithCancel(ctx)
72+
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
73+
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
74+
cancel()
75+
// Makes sure wsConn stops sending ping before terminating the stream
76+
wsConn.WaitForShutdown()
6777
}
6878

6979
func (sp *socksProxyOverWSConnection) Close() {

ingress/origin_connection_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"golang.org/x/net/proxy"
2020
"golang.org/x/sync/errgroup"
2121

22+
"github.com/cloudflare/cloudflared/connection"
2223
"github.com/cloudflare/cloudflared/logger"
2324
"github.com/cloudflare/cloudflared/socks"
2425
"github.com/cloudflare/cloudflared/websocket"
@@ -189,6 +190,53 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
189190
}
190191
}
191192

193+
func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
194+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
195+
eyeballConn, err := connection.NewHTTP2RespWriter(r, w, connection.TypeWebsocket)
196+
assert.NoError(t, err)
197+
198+
cfdConn, originConn := net.Pipe()
199+
tcpOverWSConn := tcpOverWSConnection{
200+
conn: cfdConn,
201+
streamHandler: DefaultStreamHandler,
202+
}
203+
go func() {
204+
time.Sleep(time.Millisecond * 10)
205+
// Simulate losing connection to origin
206+
originConn.Close()
207+
}()
208+
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
209+
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
210+
})
211+
server := httptest.NewServer(handler)
212+
defer server.Close()
213+
client := server.Client()
214+
215+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
216+
defer cancel()
217+
218+
errGroup, ctx := errgroup.WithContext(ctx)
219+
for i := 0; i < 50; i++ {
220+
eyeballConn, edgeConn := net.Pipe()
221+
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, server.URL, edgeConn)
222+
assert.NoError(t, err)
223+
224+
resp, err := client.Transport.RoundTrip(req)
225+
assert.NoError(t, err)
226+
assert.Equal(t, resp.StatusCode, http.StatusOK)
227+
228+
errGroup.Go(func() error {
229+
for {
230+
if err := wsutil.WriteClientBinary(eyeballConn, testMessage); err != nil {
231+
return nil
232+
}
233+
}
234+
})
235+
}
236+
237+
assert.NoError(t, errGroup.Wait())
238+
}
239+
192240
type wsEyeball struct {
193241
conn net.Conn
194242
}

websocket/connection.go

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@ const (
1818
writeWait = 10 * time.Second
1919

2020
// Time allowed to read the next pong message from the peer.
21-
pongWait = 60 * time.Second
21+
defaultPongWait = 60 * time.Second
2222

2323
// Send pings to peer with this period. Must be less than pongWait.
24-
pingPeriod = (pongWait * 9) / 10
24+
defaultPingPeriod = (defaultPongWait * 9) / 10
25+
26+
PingPeriodContextKey = PingPeriodContext("pingPeriod")
2527
)
2628

29+
type PingPeriodContext string
30+
2731
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
2832
// This is still used by access carrier
2933
type GorillaConn struct {
@@ -77,7 +81,7 @@ func (c *GorillaConn) SetDeadline(t time.Time) error {
7781

7882
// pinger simulates the websocket connection to keep it alive
7983
func (c *GorillaConn) pinger(ctx context.Context) {
80-
ticker := time.NewTicker(pingPeriod)
84+
ticker := time.NewTicker(defaultPingPeriod)
8185
defer ticker.Stop()
8286
for {
8387
select {
@@ -94,12 +98,15 @@ func (c *GorillaConn) pinger(ctx context.Context) {
9498
type Conn struct {
9599
rw io.ReadWriter
96100
log *zerolog.Logger
101+
// closed is a channel to indicate if Conn has been fully terminated
102+
shutdownC chan struct{}
97103
}
98104

99105
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
100106
c := &Conn{
101-
rw: rw,
102-
log: log,
107+
rw: rw,
108+
log: log,
109+
shutdownC: make(chan struct{}),
103110
}
104111
go c.pinger(ctx)
105112
return c
@@ -123,23 +130,39 @@ func (c *Conn) Write(p []byte) (int, error) {
123130
}
124131

125132
func (c *Conn) pinger(ctx context.Context) {
133+
defer close(c.shutdownC)
126134
pongMessge := wsutil.Message{
127135
OpCode: gobwas.OpPong,
128136
Payload: []byte{},
129137
}
130-
ticker := time.NewTicker(pingPeriod)
138+
139+
ticker := time.NewTicker(c.pingPeriod(ctx))
131140
defer ticker.Stop()
132141
for {
133142
select {
134143
case <-ticker.C:
135144
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
136-
c.log.Err(err).Msgf("failed to write ping message")
145+
c.log.Debug().Err(err).Msgf("failed to write ping message")
137146
}
138147
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
139-
c.log.Err(err).Msgf("failed to write pong message")
148+
c.log.Debug().Err(err).Msgf("failed to write pong message")
140149
}
141150
case <-ctx.Done():
142151
return
143152
}
144153
}
145154
}
155+
156+
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
157+
if val := ctx.Value(PingPeriodContextKey); val != nil {
158+
if period, ok := val.(time.Duration); ok {
159+
return period
160+
}
161+
}
162+
return defaultPingPeriod
163+
}
164+
165+
// Close waits for pinger to terminate
166+
func (c *Conn) WaitForShutdown() {
167+
<-c.shutdownC
168+
}

0 commit comments

Comments
 (0)