Skip to content

Commit db01127

Browse files
chungthuangnmldiegues
authored andcommitted
TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown
1 parent 1ff5fd3 commit db01127

File tree

6 files changed

+212
-70
lines changed

6 files changed

+212
-70
lines changed

connection/connection_test.go

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"math/rand"
78
"net/http"
89
"net/url"
910
"testing"
1011
"time"
1112

12-
"github.com/gobwas/ws/wsutil"
1313
"github.com/rs/zerolog"
1414
"github.com/stretchr/testify/assert"
1515

1616
"github.com/cloudflare/cloudflared/ingress"
17+
"github.com/cloudflare/cloudflared/websocket"
1718
)
1819

1920
const (
@@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP(
5051
isWebsocket bool,
5152
) error {
5253
if isWebsocket {
53-
return wsEndpoint(w, req)
54+
switch req.URL.Path {
55+
case "/ws/echo":
56+
return wsEchoEndpoint(w, req)
57+
case "/ws/flaky":
58+
return wsFlakyEndpoint(w, req)
59+
default:
60+
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
61+
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
62+
}
5463
}
5564
switch req.URL.Path {
5665
case "/ok":
@@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP(
7887
return nil
7988
}
8089

81-
type nowriter struct {
82-
io.Reader
90+
type echoPipe struct {
91+
reader *io.PipeReader
92+
writer *io.PipeWriter
93+
}
94+
95+
func (ep *echoPipe) Read(p []byte) (int, error) {
96+
return ep.reader.Read(p)
8397
}
8498

85-
func (nowriter) Write(p []byte) (int, error) {
86-
return 0, fmt.Errorf("Writer not implemented")
99+
func (ep *echoPipe) Write(p []byte) (int, error) {
100+
return ep.writer.Write(p)
87101
}
88102

89-
func wsEndpoint(w ResponseWriter, r *http.Request) error {
103+
// A mock origin that echos data by streaming like a tcpOverWSConnection
104+
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
105+
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
90106
resp := &http.Response{
91107
StatusCode: http.StatusSwitchingProtocols,
92108
}
93-
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
94-
clientReader := nowriter{r.Body}
109+
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
110+
return err
111+
}
112+
wsCtx, cancel := context.WithCancel(r.Context())
113+
readPipe, writePipe := io.Pipe()
114+
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
95115
go func() {
96-
for {
97-
data, err := wsutil.ReadClientText(clientReader)
98-
if err != nil {
99-
return
100-
}
101-
if err := wsutil.WriteServerText(w, data); err != nil {
102-
return
103-
}
116+
select {
117+
case <-wsCtx.Done():
118+
case <-r.Context().Done():
104119
}
120+
readPipe.Close()
121+
writePipe.Close()
105122
}()
106-
<-r.Context().Done()
123+
124+
originConn := &echoPipe{reader: readPipe, writer: writePipe}
125+
websocket.Stream(wsConn, originConn, &log)
126+
cancel()
127+
wsConn.Close()
128+
return nil
129+
}
130+
131+
type flakyConn struct {
132+
closeAt time.Time
133+
}
134+
135+
func (fc *flakyConn) Read(p []byte) (int, error) {
136+
if time.Now().After(fc.closeAt) {
137+
return 0, io.EOF
138+
}
139+
n := copy(p, "Read from flaky connection")
140+
return n, nil
141+
}
142+
143+
func (fc *flakyConn) Write(p []byte) (int, error) {
144+
if time.Now().After(fc.closeAt) {
145+
return 0, fmt.Errorf("flaky connection closed")
146+
}
147+
return len(p), nil
148+
}
149+
150+
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
151+
resp := &http.Response{
152+
StatusCode: http.StatusSwitchingProtocols,
153+
}
154+
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
155+
return err
156+
}
157+
wsCtx, cancel := context.WithCancel(r.Context())
158+
159+
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
160+
161+
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
162+
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
163+
websocket.Stream(wsConn, originConn, &log)
164+
cancel()
165+
wsConn.Close()
107166
return nil
108167
}
109168

connection/h2mux_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) {
147147
headers := []h2mux.Header{
148148
{
149149
Name: ":path",
150-
Value: "/ws",
150+
Value: "/ws/echo",
151151
},
152152
{
153153
Name: "connection",
@@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) {
167167
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
168168

169169
data := []byte("test websocket")
170-
err = wsutil.WriteClientText(writePipe, data)
170+
err = wsutil.WriteClientBinary(writePipe, data)
171171
require.NoError(t, err)
172172

173-
respBody, err := wsutil.ReadServerText(stream)
173+
respBody, err := wsutil.ReadServerBinary(stream)
174174
require.NoError(t, err)
175175
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
176176

connection/http2_test.go

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package connection
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"io/ioutil"
@@ -27,7 +28,7 @@ var (
2728
)
2829

2930
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
30-
edgeConn, originConn := net.Pipe()
31+
edgeConn, cfdConn := net.Pipe()
3132
var connIndex = uint8(0)
3233
log := zerolog.Nop()
3334
obs := NewObserver(&log, &log, false)
@@ -41,7 +42,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
4142
1*time.Second,
4243
)
4344
return NewHTTP2Connection(
44-
originConn,
45+
cfdConn,
46+
// OriginProxy is set in testConfig
4547
testConfig,
4648
&pogs.ConnectionOptions{},
4749
obs,
@@ -166,6 +168,8 @@ type wsRespWriter struct {
166168
*httptest.ResponseRecorder
167169
readPipe *io.PipeReader
168170
writePipe *io.PipeWriter
171+
closed bool
172+
panicked bool
169173
}
170174

171175
func newWSRespWriter() *wsRespWriter {
@@ -174,46 +178,59 @@ func newWSRespWriter() *wsRespWriter {
174178
httptest.NewRecorder(),
175179
readPipe,
176180
writePipe,
181+
false,
182+
false,
177183
}
178184
}
179185

186+
type nowriter struct {
187+
io.Reader
188+
}
189+
190+
func (nowriter) Write(_ []byte) (int, error) {
191+
return 0, fmt.Errorf("writer not implemented")
192+
}
193+
180194
func (w *wsRespWriter) RespBody() io.ReadWriter {
181195
return nowriter{w.readPipe}
182196
}
183197

184198
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
199+
if w.closed {
200+
w.panicked = true
201+
return 0, errors.New("wsRespWriter panicked")
202+
}
185203
return w.writePipe.Write(data)
186204
}
187205

206+
func (w *wsRespWriter) close() {
207+
w.closed = true
208+
}
209+
188210
func TestServeWS(t *testing.T) {
189211
http2Conn, _ := newTestHTTP2Connection()
190212

191213
ctx, cancel := context.WithCancel(context.Background())
192-
var wg sync.WaitGroup
193-
wg.Add(1)
194-
go func() {
195-
defer wg.Done()
196-
http2Conn.Serve(ctx)
197-
}()
198214

199215
respWriter := newWSRespWriter()
200216
readPipe, writePipe := io.Pipe()
201217

202-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
218+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
203219
require.NoError(t, err)
204220
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
205221

206-
wg.Add(1)
222+
serveDone := make(chan struct{})
207223
go func() {
208-
defer wg.Done()
224+
defer close(serveDone)
209225
http2Conn.ServeHTTP(respWriter, req)
226+
respWriter.close()
210227
}()
211228

212229
data := []byte("test websocket")
213-
err = wsutil.WriteClientText(writePipe, data)
230+
err = wsutil.WriteClientBinary(writePipe, data)
214231
require.NoError(t, err)
215232

216-
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
233+
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
217234
require.NoError(t, err)
218235
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
219236

@@ -223,7 +240,65 @@ func TestServeWS(t *testing.T) {
223240
require.Equal(t, http.StatusOK, resp.StatusCode)
224241
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
225242

243+
<-serveDone
244+
require.False(t, respWriter.panicked)
245+
}
246+
247+
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
248+
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
249+
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
250+
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
251+
252+
ctx, cancel := context.WithCancel(context.Background())
253+
var wg sync.WaitGroup
254+
255+
serverDone := make(chan struct{})
256+
go func() {
257+
defer close(serverDone)
258+
cfdHTTP2Conn.Serve(ctx)
259+
}()
260+
261+
edgeTransport := http2.Transport{}
262+
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
263+
require.NoError(t, err)
264+
message := []byte(t.Name())
265+
266+
for i := 0; i < 100; i++ {
267+
wg.Add(1)
268+
go func() {
269+
defer wg.Done()
270+
readPipe, writePipe := io.Pipe()
271+
reqCtx, reqCancel := context.WithCancel(ctx)
272+
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
273+
require.NoError(t, err)
274+
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
275+
276+
resp, err := edgeHTTP2Conn.RoundTrip(req)
277+
require.NoError(t, err)
278+
// http2RespWriter should rewrite status 101 to 200
279+
require.Equal(t, http.StatusOK, resp.StatusCode)
280+
281+
wg.Add(1)
282+
go func() {
283+
defer wg.Done()
284+
for {
285+
select {
286+
case <-reqCtx.Done():
287+
return
288+
default:
289+
}
290+
_ = wsutil.WriteClientBinary(writePipe, message)
291+
}
292+
}()
293+
294+
time.Sleep(time.Millisecond * 100)
295+
reqCancel()
296+
}()
297+
}
298+
226299
wg.Wait()
300+
cancel()
301+
<-serverDone
227302
}
228303

229304
func TestServeControlStream(t *testing.T) {

connection/quic_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestQUICServer(t *testing.T) {
4747

4848
// This is simply a sample websocket frame message.
4949
wsBuf := &bytes.Buffer{}
50-
wsutil.WriteClientText(wsBuf, []byte("Hello"))
50+
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
5151

5252
var tests = []struct {
5353
desc string
@@ -104,7 +104,7 @@ func TestQUICServer(t *testing.T) {
104104
},
105105
{
106106
desc: "test ws proxy",
107-
dest: "/ok",
107+
dest: "/ws/echo",
108108
connectionType: quicpogs.ConnectionTypeWebsocket,
109109
metadata: []quicpogs.Metadata{
110110
{
@@ -125,7 +125,7 @@ func TestQUICServer(t *testing.T) {
125125
},
126126
},
127127
message: wsBuf.Bytes(),
128-
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
128+
expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
129129
},
130130
{
131131
desc: "test tcp proxy",
@@ -233,7 +233,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
233233
}
234234

235235
if isWebsocket {
236-
return wsEndpoint(w, r)
236+
return wsEchoEndpoint(w, r)
237237
}
238238
switch r.URL.Path {
239239
case "/ok":

ingress/origin_connection.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
5353
wc.streamHandler(wsConn, wc.conn, log)
5454
cancel()
5555
// Makes sure wsConn stops sending ping before terminating the stream
56-
wsConn.WaitForShutdown()
56+
wsConn.Close()
5757
}
5858

5959
func (wc *tcpOverWSConnection) Close() {
@@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
7373
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
7474
cancel()
7575
// Makes sure wsConn stops sending ping before terminating the stream
76-
wsConn.WaitForShutdown()
76+
wsConn.Close()
7777
}
7878

7979
func (sp *socksProxyOverWSConnection) Close() {

0 commit comments

Comments
 (0)