Skip to content

Commit 543169c

Browse files
committed
TUN-3490: Make sure OriginClient implementation doesn't write after Proxy return
1 parent d576951 commit 543169c

File tree

4 files changed

+79
-7
lines changed

4 files changed

+79
-7
lines changed

connection/http2.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
190190
}
191191

192192
func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
193+
defer func() {
194+
// Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
195+
// Register a recover routine just in case.
196+
if r := recover(); r != nil {
197+
println("Recover from http2 response writer panic, error", r)
198+
}
199+
}()
193200
n, err = rp.w.Write(p)
194201
if err == nil && rp.shouldFlush {
195202
rp.flusher.Flush()

ingress/origin_service.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,20 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig) (*http.Tra
318318

319319
return &httpTransport, nil
320320
}
321+
322+
// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
323+
type MockOriginService struct {
324+
Transport http.RoundTripper
325+
}
326+
327+
func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) {
328+
return mos.Transport.RoundTrip(req)
329+
}
330+
331+
func (mos MockOriginService) String() string {
332+
return "MockOriginService"
333+
}
334+
335+
func (mos MockOriginService) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
336+
return nil
337+
}

origin/proxy.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"fmt"
77
"io"
8+
"net"
89
"net/http"
910
"strconv"
1011
"strings"
@@ -124,19 +125,31 @@ func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request,
124125
}
125126

126127
serveCtx, cancel := context.WithCancel(req.Context())
127-
defer cancel()
128+
connClosedChan := make(chan struct{})
128129
go func() {
130+
// serveCtx is done if req is cancelled, or streamWebsocket returns
129131
<-serveCtx.Done()
130132
conn.Close()
133+
close(connClosedChan)
131134
}()
132-
err = w.WriteRespHeaders(resp)
133-
if err != nil {
134-
return nil, errors.Wrap(err, "Error writing response header")
135-
}
135+
136136
// Copy to/from stream to the undelying connection. Use the underlying
137137
// connection because cloudflared doesn't operate on the message themselves
138-
websocket.Stream(conn.UnderlyingConn(), w)
139-
return resp, nil
138+
err = c.streamWebsocket(w, conn.UnderlyingConn(), resp)
139+
cancel()
140+
141+
// We need to make sure conn is closed before returning, otherwise we might write to conn after Proxy returns
142+
<-connClosedChan
143+
return resp, err
144+
}
145+
146+
func (c *client) streamWebsocket(w connection.ResponseWriter, conn net.Conn, resp *http.Response) error {
147+
err := w.WriteRespHeaders(resp)
148+
if err != nil {
149+
return errors.Wrap(err, "Error writing websocket response header")
150+
}
151+
websocket.Stream(conn, w)
152+
return nil
140153
}
141154

142155
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {

origin/proxy_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
4949

5050
func (w *mockHTTPRespWriter) WriteErrorResponse() {
5151
w.WriteHeader(http.StatusBadGateway)
52+
w.Write([]byte("http response error"))
5253
}
5354

5455
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
@@ -315,3 +316,37 @@ func (ma mockAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
315316
w.WriteHeader(http.StatusCreated)
316317
w.Write([]byte("Created"))
317318
}
319+
320+
type errorOriginTransport struct{}
321+
322+
func (errorOriginTransport) RoundTrip(*http.Request) (*http.Response, error) {
323+
return nil, fmt.Errorf("Proxy error")
324+
}
325+
326+
func TestProxyError(t *testing.T) {
327+
ingress := ingress.Ingress{
328+
Rules: []ingress.Rule{
329+
{
330+
Hostname: "*",
331+
Path: nil,
332+
Service: ingress.MockOriginService{
333+
Transport: errorOriginTransport{},
334+
},
335+
},
336+
},
337+
}
338+
339+
logger, err := logger.New()
340+
require.NoError(t, err)
341+
342+
client := NewClient(ingress, testTags, logger)
343+
344+
respWriter := newMockHTTPRespWriter()
345+
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
346+
assert.NoError(t, err)
347+
348+
err = client.Proxy(respWriter, req, false)
349+
assert.Error(t, err)
350+
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
351+
assert.Equal(t, "http response error", respWriter.Body.String())
352+
}

0 commit comments

Comments
 (0)