Skip to content

Commit 76badfa

Browse files
committed
TUN-8236: Add write timeout to quic and tcp connections
## Summary To prevent bad eyeballs and severs to be able to exhaust the quic control flows we are adding the possibility of having a timeout for a write operation to be acknowledged. This will prevent hanging connections from exhausting the quic control flows, creating a DDoS.
1 parent 56aeb6b commit 76badfa

File tree

18 files changed

+146
-54
lines changed

18 files changed

+146
-54
lines changed

cmd/cloudflared/tunnel/cmd.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ const (
8181
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
8282
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
8383

84+
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
85+
writeStreamTimeout = "write-stream-timeout"
86+
8487
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
8588
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
8689
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
@@ -696,6 +699,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
696699
Value: 5 * time.Second,
697700
Hidden: true,
698701
}),
702+
altsrc.NewDurationFlag(&cli.DurationFlag{
703+
Name: writeStreamTimeout,
704+
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
705+
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
706+
Value: 0 * time.Second,
707+
Hidden: true,
708+
}),
699709
altsrc.NewBoolFlag(&cli.BoolFlag{
700710
Name: quicDisablePathMTUDiscovery,
701711
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},

cmd/cloudflared/tunnel/configuration.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ func prepareTunnelConfig(
247247
FeatureSelector: featureSelector,
248248
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
249249
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
250+
WriteStreamTimeout: c.Duration(writeStreamTimeout),
250251
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
251252
}
252253
packetConfig, err := newPacketConfig(c, log)
@@ -259,6 +260,7 @@ func prepareTunnelConfig(
259260
Ingress: &ingressRules,
260261
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
261262
ConfigurationFlags: parseConfigFlags(c),
263+
WriteTimeout: c.Duration(writeStreamTimeout),
262264
}
263265
return tunnelConfig, orchestratorConfig, nil
264266
}

connection/quic.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ type QUICConnection struct {
6666
connIndex uint8
6767

6868
udpUnregisterTimeout time.Duration
69+
streamWriteTimeout time.Duration
6970
}
7071

7172
// NewQUICConnection returns a new instance of QUICConnection.
@@ -82,6 +83,7 @@ func NewQUICConnection(
8283
logger *zerolog.Logger,
8384
packetRouterConfig *ingress.GlobalRouterConfig,
8485
udpUnregisterTimeout time.Duration,
86+
streamWriteTimeout time.Duration,
8587
) (*QUICConnection, error) {
8688
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
8789
if err != nil {
@@ -117,6 +119,7 @@ func NewQUICConnection(
117119
connOptions: connOptions,
118120
connIndex: connIndex,
119121
udpUnregisterTimeout: udpUnregisterTimeout,
122+
streamWriteTimeout: streamWriteTimeout,
120123
}, nil
121124
}
122125

@@ -195,7 +198,7 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
195198

196199
func (q *QUICConnection) runStream(quicStream quic.Stream) {
197200
ctx := quicStream.Context()
198-
stream := quicpogs.NewSafeStreamCloser(quicStream)
201+
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
199202
defer stream.Close()
200203

201204
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
@@ -373,7 +376,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
373376
return
374377
}
375378

376-
stream := quicpogs.NewSafeStreamCloser(quicStream)
379+
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
377380
defer stream.Close()
378381
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
379382
if err != nil {

connection/quic_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ var (
3535
KeepAlivePeriod: 5 * time.Second,
3636
EnableDatagrams: true,
3737
}
38+
defaultQUICTimeout = 30 * time.Second
3839
)
3940

4041
var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
@@ -197,7 +198,7 @@ func quicServer(
197198

198199
quicStream, err := session.OpenStreamSync(context.Background())
199200
require.NoError(t, err)
200-
stream := quicpogs.NewSafeStreamCloser(quicStream)
201+
stream := quicpogs.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
201202

202203
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
203204
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
@@ -726,6 +727,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
726727
&log,
727728
nil,
728729
5*time.Second,
730+
0*time.Second,
729731
)
730732
require.NoError(t, err)
731733
return qc

ingress/constants_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package ingress
2+
3+
import "github.com/cloudflare/cloudflared/logger"
4+
5+
var (
6+
TestLogger = logger.Create(nil)
7+
)

ingress/origin_connection.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"io"
66
"net"
7+
"time"
78

89
"github.com/rs/zerolog"
910

@@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
3132

3233
// tcpConnection is an OriginConnection that directly streams to raw TCP.
3334
type tcpConnection struct {
34-
conn net.Conn
35+
net.Conn
36+
writeTimeout time.Duration
37+
logger *zerolog.Logger
3538
}
3639

37-
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
38-
stream.Pipe(tunnelConn, tc.conn, log)
40+
func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
41+
stream.Pipe(tunnelConn, tc, tc.logger)
42+
}
43+
44+
func (tc *tcpConnection) Write(b []byte) (int, error) {
45+
if tc.writeTimeout > 0 {
46+
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
47+
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
48+
}
49+
}
50+
51+
nBytes, err := tc.Conn.Write(b)
52+
if err != nil {
53+
tc.logger.Err(err).Msg("Error writing to the TCP connection")
54+
}
55+
56+
return nBytes, err
3957
}
4058

4159
func (tc *tcpConnection) Close() {
42-
tc.conn.Close()
60+
tc.Conn.Close()
4361
}
4462

4563
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.

ingress/origin_connection_test.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"golang.org/x/net/proxy"
2020
"golang.org/x/sync/errgroup"
2121

22-
"github.com/cloudflare/cloudflared/logger"
2322
"github.com/cloudflare/cloudflared/socks"
2423
"github.com/cloudflare/cloudflared/stream"
2524
"github.com/cloudflare/cloudflared/websocket"
@@ -31,15 +30,15 @@ const (
3130
)
3231

3332
var (
34-
testLogger = logger.Create(nil)
3533
testMessage = []byte("TestStreamOriginConnection")
3634
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
3735
)
3836

3937
func TestStreamTCPConnection(t *testing.T) {
4038
cfdConn, originConn := net.Pipe()
4139
tcpConn := tcpConnection{
42-
conn: cfdConn,
40+
Conn: cfdConn,
41+
writeTimeout: 30 * time.Second,
4342
}
4443

4544
eyeballConn, edgeConn := net.Pipe()
@@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) {
6665
return nil
6766
})
6867

69-
tcpConn.Stream(ctx, edgeConn, testLogger)
68+
tcpConn.Stream(ctx, edgeConn, TestLogger)
7069
require.NoError(t, errGroup.Wait())
7170
}
7271

@@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
9392
return nil
9493
})
9594

96-
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
95+
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
9796
require.NoError(t, errGroup.Wait())
9897
}
9998

@@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
147146

148147
errGroup, ctx := errgroup.WithContext(ctx)
149148
errGroup.Go(func() error {
150-
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
149+
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
151150
return nil
152151
})
153152

@@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
159158
require.NoError(t, err)
160159
defer wsForwarderInConn.Close()
161160

162-
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
161+
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
163162
return nil
164163
})
165164

@@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
209208
originConn.Close()
210209
}()
211210
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
212-
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
211+
tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
213212
})
214213
server := httptest.NewServer(handler)
215214
defer server.Close()

ingress/origin_proxy.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
8+
"github.com/rs/zerolog"
79
)
810

911
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@@ -14,7 +16,7 @@ type HTTPOriginProxy interface {
1416

1517
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
1618
type StreamBasedOriginProxy interface {
17-
EstablishConnection(ctx context.Context, dest string) (OriginConnection, error)
19+
EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error)
1820
}
1921

2022
// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
@@ -62,19 +64,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
6264
return resp, nil
6365
}
6466

65-
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
67+
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
6668
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
6769
if err != nil {
6870
return nil, err
6971
}
7072

7173
originConn := &tcpConnection{
72-
conn: conn,
74+
Conn: conn,
75+
writeTimeout: o.writeTimeout,
76+
logger: logger,
7377
}
7478
return originConn, nil
7579
}
7680

77-
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
81+
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) {
7882
var err error
7983
if !o.isBastion {
8084
dest = o.dest
@@ -92,6 +96,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string)
9296

9397
}
9498

95-
func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) {
99+
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
96100
return o.conn, nil
97101
}

ingress/origin_proxy_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
3636
require.NoError(t, err)
3737

3838
// Origin not listening for new connection, should return an error
39-
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String())
39+
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
4040
require.Error(t, err)
4141
}
4242

@@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
8787
t.Run(test.testCase, func(t *testing.T) {
8888
if test.expectErr {
8989
bastionHost, _ := carrier.ResolveBastionDest(test.req)
90-
_, err := test.service.EstablishConnection(context.Background(), bastionHost)
90+
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
9191
assert.Error(t, err)
9292
}
9393
})
@@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
9999
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
100100
// Origin not listening for new connection, should return an error
101101
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
102-
_, err := service.EstablishConnection(context.Background(), bastionHost)
102+
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
103103
assert.Error(t, err)
104104
}
105105
}
@@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
132132
url: originURL,
133133
}
134134
shutdownC := make(chan struct{})
135-
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
135+
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
136136

137137
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
138138
require.NoError(t, err)
@@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
167167
url: originURL,
168168
}
169169
shutdownC := make(chan struct{})
170-
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
170+
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
171171

172172
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
173173
protos := []string{"https", "http", "dne"}

ingress/origin_service.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,17 @@ func (o httpService) MarshalJSON() ([]byte, error) {
9494
// rawTCPService dials TCP to the destination specified by the client
9595
// It's used by warp routing
9696
type rawTCPService struct {
97-
name string
98-
dialer net.Dialer
97+
name string
98+
dialer net.Dialer
99+
writeTimeout time.Duration
100+
logger *zerolog.Logger
99101
}
100102

101103
func (o *rawTCPService) String() string {
102104
return o.name
103105
}
104106

105-
func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
107+
func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
106108
return nil
107109
}
108110

@@ -285,13 +287,14 @@ type WarpRoutingService struct {
285287
Proxy StreamBasedOriginProxy
286288
}
287289

288-
func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService {
290+
func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
289291
svc := &rawTCPService{
290292
name: ServiceWarpRouting,
291293
dialer: net.Dialer{
292294
Timeout: config.ConnectTimeout.Duration,
293295
KeepAlive: config.TCPKeepAlive.Duration,
294296
},
297+
writeTimeout: writeTimeout,
295298
}
296299

297300
return &WarpRoutingService{Proxy: svc}

0 commit comments

Comments
 (0)