Skip to content

Commit ed2bac0

Browse files
committed
TUN-5621: Correctly manage QUIC stream closing
Until this PR, we were naively closing the quic.Stream whenever the callstack for handling the request (HTTP or TCP) finished. However, our proxy handler may still be reading or writing from the quic.Stream at that point, because we return the callstack if either side finishes, but not necessarily both. This is a problem for quic-go library because quic.Stream#Close cannot be called concurrently with quic.Stream#Write Furthermore, we also noticed that quic.Stream#Close does nothing to do receiving stream (since, underneath, quic.Stream has 2 streams, 1 for each direction), thus leaking memory, as explained in: quic-go/quic-go#3322 This PR addresses both problems by wrapping the quic.Stream that is passed down to the proxying logic and handle all these concerns.
1 parent e09dcf6 commit ed2bac0

File tree

7 files changed

+244
-45
lines changed

7 files changed

+244
-45
lines changed

connection/quic.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream q
122122
func (q *QUICConnection) acceptStream(ctx context.Context) error {
123123
defer q.Close()
124124
for {
125-
stream, err := q.session.AcceptStream(ctx)
125+
quicStream, err := q.session.AcceptStream(ctx)
126126
if err != nil {
127127
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
128128
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
@@ -131,7 +131,9 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
131131
return fmt.Errorf("failed to accept QUIC stream: %w", err)
132132
}
133133
go func() {
134+
stream := quicpogs.NewSafeStreamCloser(quicStream)
134135
defer stream.Close()
136+
135137
if err = q.handleStream(stream); err != nil {
136138
q.logger.Err(err).Msg("Failed to handle QUIC stream")
137139
}
@@ -144,7 +146,7 @@ func (q *QUICConnection) Close() {
144146
q.session.CloseWithError(0, "")
145147
}
146148

147-
func (q *QUICConnection) handleStream(stream quic.Stream) error {
149+
func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
148150
signature, err := quicpogs.DetermineProtocol(stream)
149151
if err != nil {
150152
return err

connection/quic_test.go

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,9 @@ package connection
33
import (
44
"bytes"
55
"context"
6-
"crypto/rand"
7-
"crypto/rsa"
86
"crypto/tls"
9-
"crypto/x509"
10-
"encoding/pem"
117
"fmt"
128
"io"
13-
"math/big"
149
"net"
1510
"net/http"
1611
"net/url"
@@ -33,7 +28,7 @@ import (
3328
)
3429

3530
var (
36-
testTLSServerConfig = generateTLSConfig()
31+
testTLSServerConfig = quicpogs.GenerateTLSConfig()
3732
testQUICConfig = &quic.Config{
3833
KeepAlive: true,
3934
EnableDatagrams: true,
@@ -84,7 +79,7 @@ func TestQUICServer(t *testing.T) {
8479
},
8580
{
8681
desc: "test http body request streaming",
87-
dest: "/echo_body",
82+
dest: "/slow_echo_body",
8883
connectionType: quicpogs.ConnectionTypeHTTP,
8984
metadata: []quicpogs.Metadata{
9085
{
@@ -195,8 +190,9 @@ func quicServer(
195190
session, err := earlyListener.Accept(ctx)
196191
require.NoError(t, err)
197192

198-
stream, err := session.OpenStreamSync(context.Background())
193+
quicStream, err := session.OpenStreamSync(context.Background())
199194
require.NoError(t, err)
195+
stream := quicpogs.NewSafeStreamCloser(quicStream)
200196

201197
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
202198
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
@@ -207,42 +203,20 @@ func quicServer(
207203

208204
if message != nil {
209205
// ALPN successful. Write data.
210-
_, err := stream.Write([]byte(message))
206+
_, err := stream.Write(message)
211207
require.NoError(t, err)
212208
}
213209

214210
response := make([]byte, len(expectedResponse))
215-
stream.Read(response)
216-
require.NoError(t, err)
211+
_, err = stream.Read(response)
212+
if err != io.EOF {
213+
require.NoError(t, err)
214+
}
217215

218216
// For now it is an echo server. Verify if the same data is returned.
219217
assert.Equal(t, expectedResponse, response)
220218
}
221219

222-
// Setup a bare-bones TLS config for the server
223-
func generateTLSConfig() *tls.Config {
224-
key, err := rsa.GenerateKey(rand.Reader, 1024)
225-
if err != nil {
226-
panic(err)
227-
}
228-
template := x509.Certificate{SerialNumber: big.NewInt(1)}
229-
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
230-
if err != nil {
231-
panic(err)
232-
}
233-
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
234-
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
235-
236-
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
237-
if err != nil {
238-
panic(err)
239-
}
240-
return &tls.Config{
241-
Certificates: []tls.Certificate{tlsCert},
242-
NextProtos: []string{"argotunnel"},
243-
}
244-
}
245-
246220
type mockOriginProxyWithRequest struct{}
247221

248222
func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Request, isWebsocket bool) error {
@@ -264,6 +238,9 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
264238
switch r.URL.Path {
265239
case "/ok":
266240
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
241+
case "/slow_echo_body":
242+
time.Sleep(5)
243+
fallthrough
267244
case "/echo_body":
268245
resp := &http.Response{
269246
StatusCode: http.StatusOK,

origin/tunnel.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ const (
3131
dialTimeout = 15 * time.Second
3232
FeatureSerializedHeaders = "serialized_headers"
3333
FeatureQuickReconnects = "quick_reconnects"
34-
quicHandshakeIdleTimeout = 5 * time.Second
35-
quicMaxIdleTimeout = 15 * time.Second
3634
)
3735

3836
type TunnelConfig struct {
@@ -523,8 +521,8 @@ func ServeQUIC(
523521
) (err error, recoverable bool) {
524522
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
525523
quicConfig := &quic.Config{
526-
HandshakeIdleTimeout: quicHandshakeIdleTimeout,
527-
MaxIdleTimeout: quicMaxIdleTimeout,
524+
HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout,
525+
MaxIdleTimeout: quicpogs.MaxIdleTimeout,
528526
MaxIncomingStreams: connection.MaxConcurrentStreams,
529527
MaxIncomingUniStreams: connection.MaxConcurrentStreams,
530528
KeepAlive: true,

quic/quic_protocol.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ import (
1717
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
1818
)
1919

20-
// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does
21-
// not write data before writing the metadata.
20+
// ProtocolSignature defines the first 6 bytes of the stream, which is used to distinguish the type of stream. It
21+
// ensures whoever performs a handshake does not write data before writing the metadata.
2222
type ProtocolSignature [6]byte
2323

2424
var (
@@ -29,12 +29,15 @@ var (
2929
RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
3030
)
3131

32-
const protocolVersionLength = 2
33-
3432
type protocolVersion string
3533

3634
const (
3735
protocolV1 protocolVersion = "01"
36+
37+
protocolVersionLength = 2
38+
39+
HandshakeIdleTimeout = 5 * time.Second
40+
MaxIdleTimeout = 15 * time.Second
3841
)
3942

4043
// RequestServerStream is a stream to serve requests

quic/safe_stream.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package quic
2+
3+
import (
4+
"sync"
5+
"time"
6+
7+
"github.com/lucas-clemente/quic-go"
8+
)
9+
10+
type SafeStreamCloser struct {
11+
lock sync.Mutex
12+
stream quic.Stream
13+
}
14+
15+
func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser {
16+
return &SafeStreamCloser{
17+
stream: stream,
18+
}
19+
}
20+
21+
func (s *SafeStreamCloser) Read(p []byte) (n int, err error) {
22+
return s.stream.Read(p)
23+
}
24+
25+
func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
26+
s.lock.Lock()
27+
defer s.lock.Unlock()
28+
return s.stream.Write(p)
29+
}
30+
31+
func (s *SafeStreamCloser) Close() error {
32+
// Make sure a possible writer does not block the lock forever. We need it, so we can close the writer
33+
// side of the stream safely.
34+
_ = s.stream.SetWriteDeadline(time.Now())
35+
36+
// This lock is eventually acquired despite Write also acquiring it, because we set a deadline to writes.
37+
s.lock.Lock()
38+
defer s.lock.Unlock()
39+
40+
// We have to clean up the receiving stream ourselves since the Close in the bottom does not handle that.
41+
s.stream.CancelRead(0)
42+
return s.stream.Close()
43+
}

quic/safe_stream_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
package quic
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"io"
7+
"net"
8+
"sync"
9+
"testing"
10+
11+
"github.com/lucas-clemente/quic-go"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
var (
16+
testTLSServerConfig = GenerateTLSConfig()
17+
testQUICConfig = &quic.Config{
18+
KeepAlive: true,
19+
EnableDatagrams: true,
20+
}
21+
exchanges = 1000
22+
msgsPerExchange = 10
23+
testMsg = "Ok message"
24+
)
25+
26+
func TestSafeStreamClose(t *testing.T) {
27+
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
28+
require.NoError(t, err)
29+
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
30+
require.NoError(t, err)
31+
defer udpListener.Close()
32+
33+
var serverReady sync.WaitGroup
34+
serverReady.Add(1)
35+
36+
var done sync.WaitGroup
37+
done.Add(1)
38+
go func() {
39+
defer done.Done()
40+
quicServer(t, &serverReady, udpListener)
41+
}()
42+
43+
done.Add(1)
44+
go func() {
45+
serverReady.Wait()
46+
defer done.Done()
47+
quicClient(t, udpListener.LocalAddr())
48+
}()
49+
50+
done.Wait()
51+
}
52+
53+
func quicClient(t *testing.T, addr net.Addr) {
54+
tlsConf := &tls.Config{
55+
InsecureSkipVerify: true,
56+
NextProtos: []string{"argotunnel"},
57+
}
58+
session, err := quic.DialAddr(addr.String(), tlsConf, testQUICConfig)
59+
require.NoError(t, err)
60+
61+
var wg sync.WaitGroup
62+
for exchange := 0; exchange < exchanges; exchange++ {
63+
quicStream, err := session.AcceptStream(context.Background())
64+
require.NoError(t, err)
65+
wg.Add(1)
66+
67+
go func(iter int) {
68+
defer wg.Done()
69+
70+
stream := NewSafeStreamCloser(quicStream)
71+
defer stream.Close()
72+
73+
// Do a bunch of round trips over this stream that should work.
74+
for msg := 0; msg < msgsPerExchange; msg++ {
75+
clientRoundTrip(t, stream, true)
76+
}
77+
// And one that won't work necessarily, but shouldn't break other streams in the session.
78+
if iter%2 == 0 {
79+
clientRoundTrip(t, stream, false)
80+
}
81+
}(exchange)
82+
}
83+
84+
wg.Wait()
85+
}
86+
87+
func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn) {
88+
ctx, cancel := context.WithCancel(context.Background())
89+
defer cancel()
90+
91+
earlyListener, err := quic.Listen(conn, testTLSServerConfig, testQUICConfig)
92+
require.NoError(t, err)
93+
94+
serverReady.Done()
95+
session, err := earlyListener.Accept(ctx)
96+
require.NoError(t, err)
97+
98+
var wg sync.WaitGroup
99+
for exchange := 0; exchange < exchanges; exchange++ {
100+
quicStream, err := session.OpenStreamSync(context.Background())
101+
require.NoError(t, err)
102+
wg.Add(1)
103+
104+
go func(iter int) {
105+
defer wg.Done()
106+
107+
stream := NewSafeStreamCloser(quicStream)
108+
defer stream.Close()
109+
110+
// Do a bunch of round trips over this stream that should work.
111+
for msg := 0; msg < msgsPerExchange; msg++ {
112+
serverRoundTrip(t, stream, true)
113+
}
114+
// And one that won't work necessarily, but shouldn't break other streams in the session.
115+
if iter%2 == 1 {
116+
serverRoundTrip(t, stream, false)
117+
}
118+
}(exchange)
119+
}
120+
121+
wg.Wait()
122+
}
123+
124+
func clientRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) {
125+
response := make([]byte, len(testMsg))
126+
_, err := stream.Read(response)
127+
if !mustWork {
128+
return
129+
}
130+
if err != io.EOF {
131+
require.NoError(t, err)
132+
}
133+
require.Equal(t, testMsg, string(response))
134+
}
135+
136+
func serverRoundTrip(t *testing.T, stream io.ReadWriteCloser, mustWork bool) {
137+
_, err := stream.Write([]byte(testMsg))
138+
if !mustWork {
139+
return
140+
}
141+
require.NoError(t, err)
142+
}

0 commit comments

Comments
 (0)