Skip to content

Commit f8fbbcd

Browse files
committed
TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown
1 parent 2ca4633 commit f8fbbcd

File tree

6 files changed

+195
-52
lines changed

6 files changed

+195
-52
lines changed

connection/connection_test.go

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"math/rand"
78
"net/http"
89
"net/url"
910
"testing"
1011
"time"
1112

12-
"github.com/gobwas/ws/wsutil"
1313
"github.com/rs/zerolog"
1414
"github.com/stretchr/testify/assert"
1515

1616
"github.com/cloudflare/cloudflared/ingress"
17+
"github.com/cloudflare/cloudflared/websocket"
1718
)
1819

1920
const (
@@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP(
5051
isWebsocket bool,
5152
) error {
5253
if isWebsocket {
53-
return wsEndpoint(w, req)
54+
switch req.URL.Path {
55+
case "/ws/echo":
56+
return wsEchoEndpoint(w, req)
57+
case "/ws/flaky":
58+
return wsFlakyEndpoint(w, req)
59+
default:
60+
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
61+
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
62+
}
5463
}
5564
switch req.URL.Path {
5665
case "/ok":
@@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP(
7887
return nil
7988
}
8089

81-
type nowriter struct {
82-
io.Reader
90+
type echoPipe struct {
91+
reader *io.PipeReader
92+
writer *io.PipeWriter
93+
}
94+
95+
func (ep *echoPipe) Read(p []byte) (int, error) {
96+
return ep.reader.Read(p)
8397
}
8498

85-
func (nowriter) Write(p []byte) (int, error) {
86-
return 0, fmt.Errorf("Writer not implemented")
99+
func (ep *echoPipe) Write(p []byte) (int, error) {
100+
return ep.writer.Write(p)
87101
}
88102

89-
func wsEndpoint(w ResponseWriter, r *http.Request) error {
103+
// A mock origin that echos data by streaming like a tcpOverWSConnection
104+
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
105+
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
90106
resp := &http.Response{
91107
StatusCode: http.StatusSwitchingProtocols,
92108
}
93-
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
94-
clientReader := nowriter{r.Body}
109+
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
110+
return err
111+
}
112+
wsCtx, cancel := context.WithCancel(r.Context())
113+
readPipe, writePipe := io.Pipe()
114+
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
95115
go func() {
96-
for {
97-
data, err := wsutil.ReadClientText(clientReader)
98-
if err != nil {
99-
return
100-
}
101-
if err := wsutil.WriteServerText(w, data); err != nil {
102-
return
103-
}
116+
select {
117+
case <-wsCtx.Done():
118+
case <-r.Context().Done():
104119
}
120+
readPipe.Close()
121+
writePipe.Close()
105122
}()
106-
<-r.Context().Done()
123+
124+
originConn := &echoPipe{reader: readPipe, writer: writePipe}
125+
websocket.Stream(wsConn, originConn, &log)
126+
cancel()
127+
wsConn.Close()
128+
return nil
129+
}
130+
131+
type flakyConn struct {
132+
closeAt time.Time
133+
}
134+
135+
func (fc *flakyConn) Read(p []byte) (int, error) {
136+
if time.Now().After(fc.closeAt) {
137+
return 0, io.EOF
138+
}
139+
n := copy(p, []byte("Read from flaky connection"))
140+
return n, nil
141+
}
142+
143+
func (fc *flakyConn) Write(p []byte) (int, error) {
144+
if time.Now().After(fc.closeAt) {
145+
return 0, fmt.Errorf("Flaky connection closed")
146+
}
147+
return len(p), nil
148+
}
149+
150+
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
151+
resp := &http.Response{
152+
StatusCode: http.StatusSwitchingProtocols,
153+
}
154+
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
155+
return err
156+
}
157+
wsCtx, cancel := context.WithCancel(r.Context())
158+
159+
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
160+
161+
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
162+
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
163+
websocket.Stream(wsConn, originConn, &log)
164+
cancel()
165+
wsConn.Close()
107166
return nil
108167
}
109168

connection/h2mux_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) {
147147
headers := []h2mux.Header{
148148
{
149149
Name: ":path",
150-
Value: "/ws",
150+
Value: "/ws/echo",
151151
},
152152
{
153153
Name: "connection",
@@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) {
167167
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
168168

169169
data := []byte("test websocket")
170-
err = wsutil.WriteClientText(writePipe, data)
170+
err = wsutil.WriteClientBinary(writePipe, data)
171171
require.NoError(t, err)
172172

173-
respBody, err := wsutil.ReadServerText(stream)
173+
respBody, err := wsutil.ReadServerBinary(stream)
174174
require.NoError(t, err)
175175
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
176176

connection/http2_test.go

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ var (
2727
)
2828

2929
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
30-
edgeConn, originConn := net.Pipe()
30+
edgeConn, cfdConn := net.Pipe()
3131
var connIndex = uint8(0)
3232
log := zerolog.Nop()
3333
obs := NewObserver(&log, &log, false)
@@ -41,7 +41,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
4141
1*time.Second,
4242
)
4343
return NewHTTP2Connection(
44-
originConn,
44+
cfdConn,
45+
// OriginProxy is set in testConfig
4546
testConfig,
4647
&pogs.ConnectionOptions{},
4748
obs,
@@ -166,6 +167,7 @@ type wsRespWriter struct {
166167
*httptest.ResponseRecorder
167168
readPipe *io.PipeReader
168169
writePipe *io.PipeWriter
170+
closed bool
169171
}
170172

171173
func newWSRespWriter() *wsRespWriter {
@@ -174,46 +176,58 @@ func newWSRespWriter() *wsRespWriter {
174176
httptest.NewRecorder(),
175177
readPipe,
176178
writePipe,
179+
false,
177180
}
178181
}
179182

183+
type nowriter struct {
184+
io.Reader
185+
}
186+
187+
func (nowriter) Write(p []byte) (int, error) {
188+
return 0, fmt.Errorf("Writer not implemented")
189+
}
190+
180191
func (w *wsRespWriter) RespBody() io.ReadWriter {
181192
return nowriter{w.readPipe}
182193
}
183194

184195
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
196+
if w.closed {
197+
// Simulate writing to http2 ResponseWriter after ServeHTTP has returned
198+
panic("Write to closed ResponseWriter")
199+
}
185200
return w.writePipe.Write(data)
186201
}
187202

203+
func (w *wsRespWriter) close() {
204+
w.closed = true
205+
}
206+
188207
func TestServeWS(t *testing.T) {
189208
http2Conn, _ := newTestHTTP2Connection()
190209

191210
ctx, cancel := context.WithCancel(context.Background())
192-
var wg sync.WaitGroup
193-
wg.Add(1)
194-
go func() {
195-
defer wg.Done()
196-
http2Conn.Serve(ctx)
197-
}()
198211

199212
respWriter := newWSRespWriter()
200213
readPipe, writePipe := io.Pipe()
201214

202-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
215+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
203216
require.NoError(t, err)
204217
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
205218

206-
wg.Add(1)
219+
serveDone := make(chan struct{})
207220
go func() {
208-
defer wg.Done()
221+
defer close(serveDone)
209222
http2Conn.ServeHTTP(respWriter, req)
223+
respWriter.close()
210224
}()
211225

212226
data := []byte("test websocket")
213-
err = wsutil.WriteClientText(writePipe, data)
227+
err = wsutil.WriteClientBinary(writePipe, data)
214228
require.NoError(t, err)
215229

216-
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
230+
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
217231
require.NoError(t, err)
218232
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
219233

@@ -223,7 +237,64 @@ func TestServeWS(t *testing.T) {
223237
require.Equal(t, http.StatusOK, resp.StatusCode)
224238
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
225239

240+
<-serveDone
241+
}
242+
243+
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
244+
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
245+
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
246+
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
247+
248+
ctx, cancel := context.WithCancel(context.Background())
249+
var wg sync.WaitGroup
250+
251+
serverDone := make(chan struct{})
252+
go func() {
253+
defer close(serverDone)
254+
cfdHTTP2Conn.Serve(ctx)
255+
}()
256+
257+
edgeTransport := http2.Transport{}
258+
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
259+
require.NoError(t, err)
260+
message := []byte(t.Name())
261+
262+
for i := 0; i < 100; i++ {
263+
wg.Add(1)
264+
go func() {
265+
defer wg.Done()
266+
readPipe, writePipe := io.Pipe()
267+
reqCtx, reqCancel := context.WithCancel(ctx)
268+
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
269+
require.NoError(t, err)
270+
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
271+
272+
resp, err := edgeHTTP2Conn.RoundTrip(req)
273+
require.NoError(t, err)
274+
// http2RespWriter should rewrite status 101 to 200
275+
require.Equal(t, http.StatusOK, resp.StatusCode)
276+
277+
wg.Add(1)
278+
go func() {
279+
defer wg.Done()
280+
for {
281+
select {
282+
case <-reqCtx.Done():
283+
return
284+
default:
285+
}
286+
_ = wsutil.WriteClientBinary(writePipe, message)
287+
}
288+
}()
289+
290+
time.Sleep(time.Millisecond * 100)
291+
reqCancel()
292+
}()
293+
}
294+
226295
wg.Wait()
296+
cancel()
297+
<-serverDone
227298
}
228299

229300
func TestServeControlStream(t *testing.T) {

connection/quic_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestQUICServer(t *testing.T) {
6161

6262
// This is simply a sample websocket frame message.
6363
wsBuf := &bytes.Buffer{}
64-
wsutil.WriteClientText(wsBuf, []byte("Hello"))
64+
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
6565

6666
var tests = []struct {
6767
desc string
@@ -118,7 +118,7 @@ func TestQUICServer(t *testing.T) {
118118
},
119119
{
120120
desc: "test ws proxy",
121-
dest: "/ok",
121+
dest: "/ws/echo",
122122
connectionType: quicpogs.ConnectionTypeWebsocket,
123123
metadata: []quicpogs.Metadata{
124124
quicpogs.Metadata{
@@ -139,7 +139,7 @@ func TestQUICServer(t *testing.T) {
139139
},
140140
},
141141
message: wsBuf.Bytes(),
142-
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
142+
expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
143143
},
144144
{
145145
desc: "test tcp proxy",
@@ -278,7 +278,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
278278
}
279279

280280
if isWebsocket {
281-
return wsEndpoint(w, r)
281+
return wsEchoEndpoint(w, r)
282282
}
283283
switch r.URL.Path {
284284
case "/ok":

ingress/origin_connection.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
5353
wc.streamHandler(wsConn, wc.conn, log)
5454
cancel()
5555
// Makes sure wsConn stops sending ping before terminating the stream
56-
wsConn.WaitForShutdown()
56+
wsConn.Close()
5757
}
5858

5959
func (wc *tcpOverWSConnection) Close() {
@@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
7373
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
7474
cancel()
7575
// Makes sure wsConn stops sending ping before terminating the stream
76-
wsConn.WaitForShutdown()
76+
wsConn.Close()
7777
}
7878

7979
func (sp *socksProxyOverWSConnection) Close() {

0 commit comments

Comments
 (0)