Skip to content

Commit ebae7a7

Browse files
committed
TUN-5494: Send a RPC with terminate reason to edge if the session is closed locally
1 parent 70e675f commit ebae7a7

File tree

12 files changed

+562
-296
lines changed

12 files changed

+562
-296
lines changed

connection/quic.go

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er
168168
return rpcStream.Serve(q, q.logger)
169169
}
170170

171+
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
171172
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration) error {
172173
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
173174
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
@@ -178,22 +179,60 @@ func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.
178179
}
179180
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
180181
if err != nil {
181-
q.logger.Err(err).Msgf("Failed to register udp session %s", sessionID)
182+
q.logger.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session")
182183
return err
183184
}
184-
go func() {
185-
defer q.sessionManager.UnregisterSession(q.session.Context(), sessionID)
186-
if err := session.Serve(q.session.Context(), closeAfterIdleHint); err != nil {
187-
q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).Msg("session terminated")
188-
}
189-
}()
185+
186+
go q.serveUDPSession(session, closeAfterIdleHint)
187+
190188
q.logger.Debug().Msgf("Registered session %v, %v, %v", sessionID, dstIP, dstPort)
191189
return nil
192190
}
193191

194-
func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID) error {
195-
q.sessionManager.UnregisterSession(ctx, sessionID)
196-
return nil
192+
func (q *QUICConnection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) {
193+
ctx := q.session.Context()
194+
closedByRemote, err := session.Serve(ctx, closeAfterIdleHint)
195+
// If session is terminated by remote, then we know it has been unregistered from session manager and edge
196+
if !closedByRemote {
197+
if err != nil {
198+
q.closeUDPSession(ctx, session.ID, err.Error())
199+
} else {
200+
q.closeUDPSession(ctx, session.ID, "terminated without error")
201+
}
202+
q.logger.Debug().Err(err).Str("sessionID", session.ID.String()).Msg("session terminated")
203+
return
204+
}
205+
q.logger.Debug().Err(err).Msg("Session terminated by edge")
206+
}
207+
208+
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
209+
func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
210+
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
211+
stream, err := q.session.OpenStream()
212+
if err != nil {
213+
// Log this at debug because this is not an error if session was closed due to lost connection
214+
// with edge
215+
q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).
216+
Msgf("Failed to open quic stream to unregister udp session with edge")
217+
return
218+
}
219+
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.logger)
220+
if err != nil {
221+
// Log this at debug because this is not an error if session was closed due to lost connection
222+
// with edge
223+
q.logger.Err(err).Str("sessionID", sessionID.String()).
224+
Msgf("Failed to open rpc stream to unregister udp session with edge")
225+
return
226+
}
227+
if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil {
228+
q.logger.Err(err).Str("sessionID", sessionID.String()).
229+
Msgf("Failed to unregister udp session with edge")
230+
}
231+
}
232+
233+
// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion
234+
func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
235+
return q.sessionManager.UnregisterSession(ctx, sessionID, message, true)
197236
}
198237

199238
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to

connection/quic_test.go

Lines changed: 169 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,39 @@ import (
1717
"os"
1818
"sync"
1919
"testing"
20+
"time"
2021

2122
"github.com/gobwas/ws/wsutil"
23+
"github.com/google/uuid"
2224
"github.com/lucas-clemente/quic-go"
2325
"github.com/pkg/errors"
2426
"github.com/rs/zerolog"
2527
"github.com/stretchr/testify/assert"
2628
"github.com/stretchr/testify/require"
2729

30+
"github.com/cloudflare/cloudflared/datagramsession"
2831
quicpogs "github.com/cloudflare/cloudflared/quic"
2932
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
3033
)
3134

32-
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
33-
// It also serves as a demonstration for communication with the QUIC connection started by a cloudflared.
34-
func TestQUICServer(t *testing.T) {
35-
quicConfig := &quic.Config{
35+
var (
36+
testTLSServerConfig = generateTLSConfig()
37+
testQUICConfig = &quic.Config{
3638
KeepAlive: true,
3739
EnableDatagrams: true,
3840
}
41+
)
3942

40-
// Setup test.
41-
log := zerolog.New(os.Stdout)
42-
43+
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
44+
// It also serves as a demonstration for communication with the QUIC connection started by a cloudflared.
45+
func TestQUICServer(t *testing.T) {
4346
// Start a UDP Listener for QUIC.
4447
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
4548
require.NoError(t, err)
4649
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
4750
require.NoError(t, err)
4851
defer udpListener.Close()
4952

50-
// Create a simple tls config.
51-
tlsConfig := generateTLSConfig()
52-
53-
// Create a client config
54-
tlsClientConfig := &tls.Config{
55-
InsecureSkipVerify: true,
56-
NextProtos: []string{"argotunnel"},
57-
}
58-
59-
// Start a mock httpProxy
60-
originProxy := &mockOriginProxyWithRequest{}
61-
6253
// This is simply a sample websocket frame message.
6354
wsBuf := &bytes.Buffer{}
6455
wsutil.WriteClientText(wsBuf, []byte("Hello"))
@@ -158,25 +149,13 @@ func TestQUICServer(t *testing.T) {
158149
go func() {
159150
defer wg.Done()
160151
quicServer(
161-
t, udpListener, tlsConfig, quicConfig,
152+
t, udpListener, testTLSServerConfig, testQUICConfig,
162153
test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse,
163154
)
164155
}()
165156

166-
controlStream := fakeControlStream{}
167-
168-
qC, err := NewQUICConnection(
169-
ctx,
170-
quicConfig,
171-
udpListener.LocalAddr(),
172-
tlsClientConfig,
173-
originProxy,
174-
&tunnelpogs.ConnectionOptions{},
175-
controlStream,
176-
NewObserver(&log, &log, false),
177-
)
178-
require.NoError(t, err)
179-
go qC.Serve(ctx)
157+
qc := testQUICConnection(ctx, udpListener.LocalAddr(), t)
158+
go qc.Serve(ctx)
180159

181160
wg.Wait()
182161
cancel()
@@ -531,3 +510,159 @@ func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWri
531510
io.Copy(rwa, rwa)
532511
return nil
533512
}
513+
514+
func TestServeUDPSession(t *testing.T) {
515+
// Start a UDP Listener for QUIC.
516+
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
517+
require.NoError(t, err)
518+
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
519+
require.NoError(t, err)
520+
defer udpListener.Close()
521+
522+
ctx, cancel := context.WithCancel(context.Background())
523+
524+
// Establish QUIC connection with edge
525+
edgeQUICSessionChan := make(chan quic.Session)
526+
go func() {
527+
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
528+
require.NoError(t, err)
529+
530+
edgeQUICSession, err := earlyListener.Accept(ctx)
531+
require.NoError(t, err)
532+
edgeQUICSessionChan <- edgeQUICSession
533+
}()
534+
535+
qc := testQUICConnection(ctx, udpListener.LocalAddr(), t)
536+
go qc.Serve(ctx)
537+
538+
edgeQUICSession := <-edgeQUICSessionChan
539+
serveSession(ctx, qc, edgeQUICSession, closedByOrigin, io.EOF.Error(), t)
540+
serveSession(ctx, qc, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t)
541+
serveSession(ctx, qc, edgeQUICSession, closedByRemote, "eyeball closed connection", t)
542+
cancel()
543+
}
544+
545+
func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Session, closeType closeReason, expectedReason string, t *testing.T) {
546+
var (
547+
payload = []byte(t.Name())
548+
)
549+
sessionID := uuid.New()
550+
cfdConn, originConn := net.Pipe()
551+
// Registers and run a new session
552+
session, err := qc.sessionManager.RegisterSession(ctx, sessionID, cfdConn)
553+
require.NoError(t, err)
554+
555+
sessionDone := make(chan struct{})
556+
go func() {
557+
qc.serveUDPSession(session, time.Millisecond*50)
558+
close(sessionDone)
559+
}()
560+
561+
// Send a message to the quic session on edge side, it should be deumx to this datagram session
562+
muxedPayload, err := quicpogs.SuffixSessionID(sessionID, payload)
563+
require.NoError(t, err)
564+
err = edgeQUICSession.SendMessage(muxedPayload)
565+
require.NoError(t, err)
566+
567+
readBuffer := make([]byte, len(payload)+1)
568+
n, err := originConn.Read(readBuffer)
569+
require.NoError(t, err)
570+
require.Equal(t, len(payload), n)
571+
require.True(t, bytes.Equal(payload, readBuffer[:n]))
572+
573+
// Close connection to terminate session
574+
switch closeType {
575+
case closedByOrigin:
576+
originConn.Close()
577+
case closedByRemote:
578+
err = qc.UnregisterUdpSession(ctx, sessionID, expectedReason)
579+
require.NoError(t, err)
580+
case closedByTimeout:
581+
}
582+
583+
if closeType != closedByRemote {
584+
// Session was not closed by remote, so closeUDPSession should be invoked to unregister from remote
585+
unregisterFromEdgeChan := make(chan struct{})
586+
rpcServer := &mockSessionRPCServer{
587+
sessionID: sessionID,
588+
unregisterReason: expectedReason,
589+
calledUnregisterChan: unregisterFromEdgeChan,
590+
}
591+
go runMockSessionRPCServer(ctx, edgeQUICSession, rpcServer, t)
592+
593+
<-unregisterFromEdgeChan
594+
}
595+
596+
<-sessionDone
597+
}
598+
599+
type closeReason uint8
600+
601+
const (
602+
closedByOrigin closeReason = iota
603+
closedByRemote
604+
closedByTimeout
605+
)
606+
607+
func runMockSessionRPCServer(ctx context.Context, session quic.Session, rpcServer *mockSessionRPCServer, t *testing.T) {
608+
stream, err := session.AcceptStream(ctx)
609+
require.NoError(t, err)
610+
611+
if stream.StreamID() == 0 {
612+
// Skip the first stream, it's the control stream of the QUIC connection
613+
stream, err = session.AcceptStream(ctx)
614+
require.NoError(t, err)
615+
}
616+
protocol, err := quicpogs.DetermineProtocol(stream)
617+
assert.NoError(t, err)
618+
rpcServerStream, err := quicpogs.NewRPCServerStream(stream, protocol)
619+
assert.NoError(t, err)
620+
621+
log := zerolog.New(os.Stdout)
622+
err = rpcServerStream.Serve(rpcServer, &log)
623+
assert.NoError(t, err)
624+
}
625+
626+
type mockSessionRPCServer struct {
627+
sessionID uuid.UUID
628+
unregisterReason string
629+
calledUnregisterChan chan struct{}
630+
}
631+
632+
func (s mockSessionRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
633+
return fmt.Errorf("mockSessionRPCServer doesn't implement RegisterUdpSession")
634+
}
635+
636+
func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, reason string) error {
637+
if s.sessionID != sessionID {
638+
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
639+
}
640+
if s.unregisterReason != reason {
641+
return fmt.Errorf("expect unregister reason %s, got %s", s.unregisterReason, reason)
642+
}
643+
close(s.calledUnregisterChan)
644+
fmt.Println("unregister from edge")
645+
return nil
646+
}
647+
648+
func testQUICConnection(ctx context.Context, udpListenerAddr net.Addr, t *testing.T) *QUICConnection {
649+
tlsClientConfig := &tls.Config{
650+
InsecureSkipVerify: true,
651+
NextProtos: []string{"argotunnel"},
652+
}
653+
// Start a mock httpProxy
654+
originProxy := &mockOriginProxyWithRequest{}
655+
log := zerolog.New(os.Stdout)
656+
qc, err := NewQUICConnection(
657+
ctx,
658+
testQUICConfig,
659+
udpListenerAddr,
660+
tlsClientConfig,
661+
originProxy,
662+
&tunnelpogs.ConnectionOptions{},
663+
fakeControlStream{},
664+
NewObserver(&log, &log, false),
665+
)
666+
require.NoError(t, err)
667+
return qc
668+
}

datagramsession/event.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package datagramsession
22

33
import (
4+
"fmt"
45
"io"
56

67
"github.com/google/uuid"
@@ -24,6 +25,22 @@ func newRegisterSessionEvent(sessionID uuid.UUID, originProxy io.ReadWriteCloser
2425
// unregisterSessionEvent is an event to stop tracking and terminate the session.
2526
type unregisterSessionEvent struct {
2627
sessionID uuid.UUID
28+
err *errClosedSession
29+
}
30+
31+
// ClosedSessionError represent a condition that closes the session other than I/O
32+
// I/O error is not included, because the side that closes the session is ambiguous.
33+
type errClosedSession struct {
34+
message string
35+
byRemote bool
36+
}
37+
38+
func (sc *errClosedSession) Error() string {
39+
if sc.byRemote {
40+
return fmt.Sprintf("session closed by remote due to %s", sc.message)
41+
} else {
42+
return fmt.Sprintf("session closed by local due to %s", sc.message)
43+
}
2744
}
2845

2946
// newDatagram is an event when transport receives new datagram

datagramsession/manager.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type Manager interface {
2020
// RegisterSession starts tracking a session. Caller is responsible for starting the session
2121
RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error)
2222
// UnregisterSession stops tracking the session and terminates it
23-
UnregisterSession(ctx context.Context, sessionID uuid.UUID) error
23+
UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error
2424
}
2525

2626
type manager struct {
@@ -100,8 +100,14 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes
100100
registration.resultChan <- session
101101
}
102102

103-
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID) error {
104-
event := &unregisterSessionEvent{sessionID: sessionID}
103+
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
104+
event := &unregisterSessionEvent{
105+
sessionID: sessionID,
106+
err: &errClosedSession{
107+
message: message,
108+
byRemote: byRemote,
109+
},
110+
}
105111
select {
106112
case <-ctx.Done():
107113
return ctx.Err()
@@ -114,7 +120,7 @@ func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
114120
session, ok := m.sessions[unregistration.sessionID]
115121
if ok {
116122
delete(m.sessions, unregistration.sessionID)
117-
session.close()
123+
session.close(unregistration.err)
118124
}
119125
}
120126

0 commit comments

Comments
 (0)