Skip to content

Commit 8bfe111

Browse files
committed
TUN-8861: Add session limiter to TCP session manager
## Summary In order to make cloudflared behavior more predictable and prevent an exhaustion of resources, we have decided to add session limits that can be configured by the user. This commit adds the session limiter to the HTTP/TCP handling path. For now the limiter is set to run only in unlimited mode.
1 parent bf4954e commit 8bfe111

File tree

12 files changed

+275
-102
lines changed

12 files changed

+275
-102
lines changed

connection/connection_test.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@ package connection
22

33
import (
44
"context"
5+
"crypto/rand"
56
"fmt"
67
"io"
7-
"math/rand"
8+
"math/big"
89
"net/http"
910
"time"
1011

12+
pkgerrors "github.com/pkg/errors"
1113
"github.com/rs/zerolog"
1214

15+
cfdsession "github.com/cloudflare/cloudflared/session"
16+
1317
"github.com/cloudflare/cloudflared/stream"
1418
"github.com/cloudflare/cloudflared/tracing"
1519
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@@ -77,7 +81,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
7781
return wsFlakyEndpoint(w, req)
7882
default:
7983
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
80-
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
84+
return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
8185
}
8286
}
8387
switch req.URL.Path {
@@ -95,14 +99,17 @@ func (moc *mockOriginProxy) ProxyHTTP(
9599
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
96100
}
97101
return nil
98-
99102
}
100103

101104
func (moc *mockOriginProxy) ProxyTCP(
102105
ctx context.Context,
103106
rwa ReadWriteAcker,
104107
r *TCPRequest,
105108
) error {
109+
if r.CfTraceID == "flow-rate-limited" {
110+
return pkgerrors.Wrap(cfdsession.ErrTooManyActiveSessions, "tcp flow rate limited")
111+
}
112+
106113
return nil
107114
}
108115

@@ -178,7 +185,8 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
178185

179186
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
180187

181-
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
188+
rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
189+
closedAfter := time.Millisecond * time.Duration(rInt.Int64())
182190
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
183191
stream.Pipe(wsConn, originConn, &log)
184192
cancel()

connection/header.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ var (
2222

2323
var (
2424
// pre-generate possible values for res
25-
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
26-
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
25+
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false)
26+
responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true)
27+
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
2728
)
2829

2930
// HTTPHeader is a custom header struct that expects only ever one value for the header.
@@ -34,11 +35,12 @@ type HTTPHeader struct {
3435
}
3536

3637
type responseMetaHeader struct {
37-
Source string `json:"src"`
38+
Source string `json:"src"`
39+
FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
3840
}
3941

40-
func mustInitRespMetaHeader(src string) string {
41-
header, err := json.Marshal(responseMetaHeader{Source: src})
42+
func mustInitRespMetaHeader(src string, flowRateLimited bool) string {
43+
header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited})
4244
if err != nil {
4345
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
4446
}
@@ -112,7 +114,7 @@ func SerializeHeaders(h1Headers http.Header) string {
112114
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
113115
const unableToDeserializeErr = "Unable to deserialize headers"
114116

115-
var deserialized []HTTPHeader
117+
deserialized := make([]HTTPHeader, 0)
116118
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
117119
if len(serializedPair) == 0 {
118120
continue

connection/http2.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
"github.com/rs/zerolog"
1717
"golang.org/x/net/http2"
1818

19+
cfdsession "github.com/cloudflare/cloudflared/session"
20+
1921
"github.com/cloudflare/cloudflared/tracing"
2022
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
2123
)
@@ -156,7 +158,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
156158
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
157159

158160
// WriteErrorResponse will return false if status was already written. we need to abort handler.
159-
if !respWriter.WriteErrorResponse() {
161+
if !respWriter.WriteErrorResponse(requestErr) {
160162
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
161163
panic(http.ErrAbortHandler)
162164
}
@@ -209,8 +211,9 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
209211
w: w,
210212
log: log,
211213
}
212-
respWriter.WriteErrorResponse()
213-
return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
214+
err := fmt.Errorf("%T doesn't implement http.Flusher", w)
215+
respWriter.WriteErrorResponse(err)
216+
return nil, err
214217
}
215218

216219
return &http2RespWriter{
@@ -295,7 +298,7 @@ func (rp *http2RespWriter) WriteHeader(status int) {
295298
rp.log.Warn().Msg("WriteHeader after hijack")
296299
return
297300
}
298-
rp.WriteRespHeaders(status, rp.respHeaders)
301+
_ = rp.WriteRespHeaders(status, rp.respHeaders)
299302
}
300303

301304
func (rp *http2RespWriter) hijacked() bool {
@@ -328,12 +331,16 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
328331
return conn, readWriter, nil
329332
}
330333

331-
func (rp *http2RespWriter) WriteErrorResponse() bool {
334+
func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
332335
if rp.statusWritten {
333336
return false
334337
}
335338

336-
rp.setResponseMetaHeader(responseMetaHeaderCfd)
339+
if errors.Is(err, cfdsession.ErrTooManyActiveSessions) {
340+
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
341+
} else {
342+
rp.setResponseMetaHeader(responseMetaHeaderCfd)
343+
}
337344
rp.w.WriteHeader(http.StatusBadGateway)
338345
rp.statusWritten = true
339346

connection/http2_test.go

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020
"github.com/stretchr/testify/require"
2121
"golang.org/x/net/http2"
2222

23+
"github.com/cloudflare/cloudflared/tracing"
24+
2325
"github.com/cloudflare/cloudflared/tunnelrpc"
2426
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
2527
)
@@ -65,31 +67,30 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
6567
wg.Add(1)
6668
go func() {
6769
defer wg.Done()
68-
http2Conn.Serve(ctx)
70+
_ = http2Conn.Serve(ctx)
6971
}()
7072

7173
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
7274
require.NoError(t, err)
7375

74-
endpoint := fmt.Sprintf("http://localhost:8080/ok")
7576
reqBody := []byte(`{
7677
"version": 2,
7778
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
7879
`)
7980
reader := bytes.NewReader(reqBody)
80-
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
81+
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
8182
require.NoError(t, err)
8283
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
8384

8485
resp, err := edgeHTTP2Conn.RoundTrip(req)
8586
require.NoError(t, err)
8687
require.Equal(t, http.StatusOK, resp.StatusCode)
8788
bdy, err := io.ReadAll(resp.Body)
89+
defer resp.Body.Close()
8890
require.NoError(t, err)
8991
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
9092
cancel()
9193
wg.Wait()
92-
9394
}
9495

9596
func TestServeHTTP(t *testing.T) {
@@ -134,7 +135,7 @@ func TestServeHTTP(t *testing.T) {
134135
wg.Add(1)
135136
go func() {
136137
defer wg.Done()
137-
http2Conn.Serve(ctx)
138+
_ = http2Conn.Serve(ctx)
138139
}()
139140

140141
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
@@ -153,6 +154,7 @@ func TestServeHTTP(t *testing.T) {
153154
require.NoError(t, err)
154155
require.Equal(t, test.expectedBody, respBody)
155156
}
157+
_ = resp.Body.Close()
156158
if test.isProxyError {
157159
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
158160
} else {
@@ -281,10 +283,11 @@ func TestServeWS(t *testing.T) {
281283

282284
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
283285
require.NoError(t, err)
284-
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
286+
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))
285287

286288
cancel()
287289
resp := respWriter.Result()
290+
defer resp.Body.Close()
288291
// http2RespWriter should rewrite status 101 to 200
289292
require.Equal(t, http.StatusOK, resp.StatusCode)
290293
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
@@ -304,7 +307,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
304307
serverDone := make(chan struct{})
305308
go func() {
306309
defer close(serverDone)
307-
cfdHTTP2Conn.Serve(ctx)
310+
_ = cfdHTTP2Conn.Serve(ctx)
308311
}()
309312

310313
edgeTransport := http2.Transport{}
@@ -319,13 +322,16 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
319322
readPipe, writePipe := io.Pipe()
320323
reqCtx, reqCancel := context.WithCancel(ctx)
321324
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
322-
require.NoError(t, err)
325+
assert.NoError(t, err)
326+
323327
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
324328

325329
resp, err := edgeHTTP2Conn.RoundTrip(req)
326-
require.NoError(t, err)
330+
assert.NoError(t, err)
331+
_ = resp.Body.Close()
332+
327333
// http2RespWriter should rewrite status 101 to 200
328-
require.Equal(t, http.StatusOK, resp.StatusCode)
334+
assert.Equal(t, http.StatusOK, resp.StatusCode)
329335

330336
wg.Add(1)
331337
go func() {
@@ -378,7 +384,7 @@ func TestServeControlStream(t *testing.T) {
378384
wg.Add(1)
379385
go func() {
380386
defer wg.Done()
381-
http2Conn.Serve(ctx)
387+
_ = http2Conn.Serve(ctx)
382388
}()
383389

384390
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@@ -391,7 +397,8 @@ func TestServeControlStream(t *testing.T) {
391397
wg.Add(1)
392398
go func() {
393399
defer wg.Done()
394-
edgeHTTP2Conn.RoundTrip(req)
400+
// nolint: bodyclose
401+
_, _ = edgeHTTP2Conn.RoundTrip(req)
395402
}()
396403

397404
<-rpcClientFactory.registered
@@ -431,7 +438,7 @@ func TestFailRegistration(t *testing.T) {
431438
wg.Add(1)
432439
go func() {
433440
defer wg.Done()
434-
http2Conn.Serve(ctx)
441+
_ = http2Conn.Serve(ctx)
435442
}()
436443

437444
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@@ -442,9 +449,10 @@ func TestFailRegistration(t *testing.T) {
442449
require.NoError(t, err)
443450
resp, err := edgeHTTP2Conn.RoundTrip(req)
444451
require.NoError(t, err)
452+
defer resp.Body.Close()
445453
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
446454

447-
assert.NotNil(t, http2Conn.controlStreamErr)
455+
require.Error(t, http2Conn.controlStreamErr)
448456
cancel()
449457
wg.Wait()
450458
}
@@ -481,7 +489,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
481489
wg.Add(1)
482490
go func() {
483491
defer wg.Done()
484-
http2Conn.Serve(ctx)
492+
_ = http2Conn.Serve(ctx)
485493
}()
486494

487495
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@@ -494,6 +502,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
494502
wg.Add(1)
495503
go func() {
496504
defer wg.Done()
505+
// nolint: bodyclose
497506
_, _ = edgeHTTP2Conn.RoundTrip(req)
498507
}()
499508

@@ -524,6 +533,36 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
524533
})
525534
}
526535

536+
func TestServeTCP_RateLimited(t *testing.T) {
537+
ctx, cancel := context.WithCancel(context.Background())
538+
http2Conn, edgeConn := newTestHTTP2Connection()
539+
540+
var wg sync.WaitGroup
541+
wg.Add(1)
542+
go func() {
543+
defer wg.Done()
544+
_ = http2Conn.Serve(ctx)
545+
}()
546+
547+
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
548+
require.NoError(t, err)
549+
550+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
551+
require.NoError(t, err)
552+
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
553+
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
554+
555+
resp, err := edgeHTTP2Conn.RoundTrip(req)
556+
require.NoError(t, err)
557+
defer resp.Body.Close()
558+
559+
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
560+
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
561+
562+
cancel()
563+
wg.Wait()
564+
}
565+
527566
func benchmarkServeHTTP(b *testing.B, test testRequest) {
528567
http2Conn, edgeConn := newTestHTTP2Connection()
529568

@@ -532,7 +571,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
532571
wg.Add(1)
533572
go func() {
534573
defer wg.Done()
535-
http2Conn.Serve(ctx)
574+
_ = http2Conn.Serve(ctx)
536575
}()
537576

538577
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)

0 commit comments

Comments
 (0)