Skip to content

Commit 573d410

Browse files
committed
Revert "TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown"
This reverts commit f8fbbcd.
1 parent f6f1030 commit 573d410

File tree

6 files changed

+52
-195
lines changed

6 files changed

+52
-195
lines changed

connection/connection_test.go

Lines changed: 18 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@ import (
44
"context"
55
"fmt"
66
"io"
7-
"math/rand"
87
"net/http"
98
"net/url"
109
"testing"
1110
"time"
1211

12+
"github.com/gobwas/ws/wsutil"
1313
"github.com/rs/zerolog"
1414
"github.com/stretchr/testify/assert"
1515

1616
"github.com/cloudflare/cloudflared/ingress"
17-
"github.com/cloudflare/cloudflared/websocket"
1817
)
1918

2019
const (
@@ -51,15 +50,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
5150
isWebsocket bool,
5251
) error {
5352
if isWebsocket {
54-
switch req.URL.Path {
55-
case "/ws/echo":
56-
return wsEchoEndpoint(w, req)
57-
case "/ws/flaky":
58-
return wsFlakyEndpoint(w, req)
59-
default:
60-
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
61-
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
62-
}
53+
return wsEndpoint(w, req)
6354
}
6455
switch req.URL.Path {
6556
case "/ok":
@@ -87,82 +78,32 @@ func (moc *mockOriginProxy) ProxyTCP(
8778
return nil
8879
}
8980

90-
type echoPipe struct {
91-
reader *io.PipeReader
92-
writer *io.PipeWriter
93-
}
94-
95-
func (ep *echoPipe) Read(p []byte) (int, error) {
96-
return ep.reader.Read(p)
81+
type nowriter struct {
82+
io.Reader
9783
}
9884

99-
func (ep *echoPipe) Write(p []byte) (int, error) {
100-
return ep.writer.Write(p)
85+
func (nowriter) Write(p []byte) (int, error) {
86+
return 0, fmt.Errorf("Writer not implemented")
10187
}
10288

103-
// A mock origin that echos data by streaming like a tcpOverWSConnection
104-
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
105-
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
89+
func wsEndpoint(w ResponseWriter, r *http.Request) error {
10690
resp := &http.Response{
10791
StatusCode: http.StatusSwitchingProtocols,
10892
}
109-
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
110-
return err
111-
}
112-
wsCtx, cancel := context.WithCancel(r.Context())
113-
readPipe, writePipe := io.Pipe()
114-
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
93+
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
94+
clientReader := nowriter{r.Body}
11595
go func() {
116-
select {
117-
case <-wsCtx.Done():
118-
case <-r.Context().Done():
96+
for {
97+
data, err := wsutil.ReadClientText(clientReader)
98+
if err != nil {
99+
return
100+
}
101+
if err := wsutil.WriteServerText(w, data); err != nil {
102+
return
103+
}
119104
}
120-
readPipe.Close()
121-
writePipe.Close()
122105
}()
123-
124-
originConn := &echoPipe{reader: readPipe, writer: writePipe}
125-
websocket.Stream(wsConn, originConn, &log)
126-
cancel()
127-
wsConn.Close()
128-
return nil
129-
}
130-
131-
type flakyConn struct {
132-
closeAt time.Time
133-
}
134-
135-
func (fc *flakyConn) Read(p []byte) (int, error) {
136-
if time.Now().After(fc.closeAt) {
137-
return 0, io.EOF
138-
}
139-
n := copy(p, []byte("Read from flaky connection"))
140-
return n, nil
141-
}
142-
143-
func (fc *flakyConn) Write(p []byte) (int, error) {
144-
if time.Now().After(fc.closeAt) {
145-
return 0, fmt.Errorf("Flaky connection closed")
146-
}
147-
return len(p), nil
148-
}
149-
150-
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
151-
resp := &http.Response{
152-
StatusCode: http.StatusSwitchingProtocols,
153-
}
154-
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
155-
return err
156-
}
157-
wsCtx, cancel := context.WithCancel(r.Context())
158-
159-
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
160-
161-
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
162-
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
163-
websocket.Stream(wsConn, originConn, &log)
164-
cancel()
165-
wsConn.Close()
106+
<-r.Context().Done()
166107
return nil
167108
}
168109

connection/h2mux_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func TestServeStreamWS(t *testing.T) {
147147
headers := []h2mux.Header{
148148
{
149149
Name: ":path",
150-
Value: "/ws/echo",
150+
Value: "/ws",
151151
},
152152
{
153153
Name: "connection",
@@ -167,10 +167,10 @@ func TestServeStreamWS(t *testing.T) {
167167
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
168168

169169
data := []byte("test websocket")
170-
err = wsutil.WriteClientBinary(writePipe, data)
170+
err = wsutil.WriteClientText(writePipe, data)
171171
require.NoError(t, err)
172172

173-
respBody, err := wsutil.ReadServerBinary(stream)
173+
respBody, err := wsutil.ReadServerText(stream)
174174
require.NoError(t, err)
175175
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
176176

connection/http2_test.go

Lines changed: 13 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ var (
2727
)
2828

2929
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
30-
edgeConn, cfdConn := net.Pipe()
30+
edgeConn, originConn := net.Pipe()
3131
var connIndex = uint8(0)
3232
log := zerolog.Nop()
3333
obs := NewObserver(&log, &log, false)
@@ -41,8 +41,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
4141
1*time.Second,
4242
)
4343
return NewHTTP2Connection(
44-
cfdConn,
45-
// OriginProxy is set in testConfig
44+
originConn,
4645
testConfig,
4746
&pogs.ConnectionOptions{},
4847
obs,
@@ -167,7 +166,6 @@ type wsRespWriter struct {
167166
*httptest.ResponseRecorder
168167
readPipe *io.PipeReader
169168
writePipe *io.PipeWriter
170-
closed bool
171169
}
172170

173171
func newWSRespWriter() *wsRespWriter {
@@ -176,58 +174,46 @@ func newWSRespWriter() *wsRespWriter {
176174
httptest.NewRecorder(),
177175
readPipe,
178176
writePipe,
179-
false,
180177
}
181178
}
182179

183-
type nowriter struct {
184-
io.Reader
185-
}
186-
187-
func (nowriter) Write(p []byte) (int, error) {
188-
return 0, fmt.Errorf("Writer not implemented")
189-
}
190-
191180
func (w *wsRespWriter) RespBody() io.ReadWriter {
192181
return nowriter{w.readPipe}
193182
}
194183

195184
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
196-
if w.closed {
197-
// Simulate writing to http2 ResponseWriter after ServeHTTP has returned
198-
panic("Write to closed ResponseWriter")
199-
}
200185
return w.writePipe.Write(data)
201186
}
202187

203-
func (w *wsRespWriter) close() {
204-
w.closed = true
205-
}
206-
207188
func TestServeWS(t *testing.T) {
208189
http2Conn, _ := newTestHTTP2Connection()
209190

210191
ctx, cancel := context.WithCancel(context.Background())
192+
var wg sync.WaitGroup
193+
wg.Add(1)
194+
go func() {
195+
defer wg.Done()
196+
http2Conn.Serve(ctx)
197+
}()
211198

212199
respWriter := newWSRespWriter()
213200
readPipe, writePipe := io.Pipe()
214201

215-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
202+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
216203
require.NoError(t, err)
217204
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
218205

219-
serveDone := make(chan struct{})
206+
wg.Add(1)
220207
go func() {
221-
defer close(serveDone)
208+
defer wg.Done()
222209
http2Conn.ServeHTTP(respWriter, req)
223-
respWriter.close()
224210
}()
225211

226212
data := []byte("test websocket")
227-
err = wsutil.WriteClientBinary(writePipe, data)
213+
err = wsutil.WriteClientText(writePipe, data)
228214
require.NoError(t, err)
229215

230-
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
216+
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
231217
require.NoError(t, err)
232218
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
233219

@@ -237,64 +223,7 @@ func TestServeWS(t *testing.T) {
237223
require.Equal(t, http.StatusOK, resp.StatusCode)
238224
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
239225

240-
<-serveDone
241-
}
242-
243-
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
244-
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
245-
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
246-
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
247-
248-
ctx, cancel := context.WithCancel(context.Background())
249-
var wg sync.WaitGroup
250-
251-
serverDone := make(chan struct{})
252-
go func() {
253-
defer close(serverDone)
254-
cfdHTTP2Conn.Serve(ctx)
255-
}()
256-
257-
edgeTransport := http2.Transport{}
258-
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
259-
require.NoError(t, err)
260-
message := []byte(t.Name())
261-
262-
for i := 0; i < 100; i++ {
263-
wg.Add(1)
264-
go func() {
265-
defer wg.Done()
266-
readPipe, writePipe := io.Pipe()
267-
reqCtx, reqCancel := context.WithCancel(ctx)
268-
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
269-
require.NoError(t, err)
270-
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
271-
272-
resp, err := edgeHTTP2Conn.RoundTrip(req)
273-
require.NoError(t, err)
274-
// http2RespWriter should rewrite status 101 to 200
275-
require.Equal(t, http.StatusOK, resp.StatusCode)
276-
277-
wg.Add(1)
278-
go func() {
279-
defer wg.Done()
280-
for {
281-
select {
282-
case <-reqCtx.Done():
283-
return
284-
default:
285-
}
286-
_ = wsutil.WriteClientBinary(writePipe, message)
287-
}
288-
}()
289-
290-
time.Sleep(time.Millisecond * 100)
291-
reqCancel()
292-
}()
293-
}
294-
295226
wg.Wait()
296-
cancel()
297-
<-serverDone
298227
}
299228

300229
func TestServeControlStream(t *testing.T) {

connection/quic_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func TestQUICServer(t *testing.T) {
6161

6262
// This is simply a sample websocket frame message.
6363
wsBuf := &bytes.Buffer{}
64-
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
64+
wsutil.WriteClientText(wsBuf, []byte("Hello"))
6565

6666
var tests = []struct {
6767
desc string
@@ -118,7 +118,7 @@ func TestQUICServer(t *testing.T) {
118118
},
119119
{
120120
desc: "test ws proxy",
121-
dest: "/ws/echo",
121+
dest: "/ok",
122122
connectionType: quicpogs.ConnectionTypeWebsocket,
123123
metadata: []quicpogs.Metadata{
124124
quicpogs.Metadata{
@@ -139,7 +139,7 @@ func TestQUICServer(t *testing.T) {
139139
},
140140
},
141141
message: wsBuf.Bytes(),
142-
expectedResponse: []byte{0x82, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
142+
expectedResponse: []byte{0x81, 0x5, 0x48, 0x65, 0x6c, 0x6c, 0x6f},
143143
},
144144
{
145145
desc: "test tcp proxy",
@@ -278,7 +278,7 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
278278
}
279279

280280
if isWebsocket {
281-
return wsEchoEndpoint(w, r)
281+
return wsEndpoint(w, r)
282282
}
283283
switch r.URL.Path {
284284
case "/ok":

ingress/origin_connection.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri
5353
wc.streamHandler(wsConn, wc.conn, log)
5454
cancel()
5555
// Makes sure wsConn stops sending ping before terminating the stream
56-
wsConn.Close()
56+
wsConn.WaitForShutdown()
5757
}
5858

5959
func (wc *tcpOverWSConnection) Close() {
@@ -73,7 +73,7 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
7373
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
7474
cancel()
7575
// Makes sure wsConn stops sending ping before terminating the stream
76-
wsConn.Close()
76+
wsConn.WaitForShutdown()
7777
}
7878

7979
func (sp *socksProxyOverWSConnection) Close() {

0 commit comments

Comments
 (0)