Skip to content

Commit 2c9b736

Browse files
committed
TUN-3427: Define a struct that only implements RegistrationServer in tunnelpogs
1 parent 8e8513e commit 2c9b736

File tree

9 files changed

+186
-145
lines changed

9 files changed

+186
-145
lines changed

connection/rpc.go

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,40 @@ package connection
22

33
import (
44
"context"
5-
"fmt"
6-
"time"
5+
"io"
76

87
rpc "zombiezen.com/go/capnproto2/rpc"
98

10-
"github.com/cloudflare/cloudflared/h2mux"
119
"github.com/cloudflare/cloudflared/logger"
1210
"github.com/cloudflare/cloudflared/tunnelrpc"
1311
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
1412
)
1513

16-
// NewRPCClient creates and returns a new RPC client, which will communicate
14+
// NewTunnelRPCClient creates and returns a new RPC client, which will communicate
1715
// using a stream on the given muxer
18-
func NewRPCClient(
16+
func NewTunnelRPCClient(
1917
ctx context.Context,
20-
muxer *h2mux.Muxer,
18+
stream io.ReadWriteCloser,
2119
logger logger.Service,
22-
openStreamTimeout time.Duration,
2320
) (client tunnelpogs.TunnelServer_PogsClient, err error) {
24-
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
25-
defer openStreamCancel()
26-
stream, err := muxer.OpenRPCStream(openStreamCtx)
27-
if err != nil {
28-
return
29-
}
30-
31-
if !isRPCStreamResponse(stream.Headers) {
32-
stream.Close()
33-
err = fmt.Errorf("rpc: bad response headers: %v", stream.Headers)
34-
return
35-
}
36-
3721
conn := rpc.NewConn(
3822
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
3923
tunnelrpc.ConnLog(logger),
4024
)
41-
client = tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
25+
registrationClient := tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
26+
client = tunnelpogs.TunnelServer_PogsClient{RegistrationServer_PogsClient: registrationClient, Client: conn.Bootstrap(ctx), Conn: conn}
4227
return client, nil
4328
}
4429

45-
func isRPCStreamResponse(headers []h2mux.Header) bool {
46-
return len(headers) == 1 &&
47-
headers[0].Name == ":status" &&
48-
headers[0].Value == "200"
30+
func NewRegistrationRPCClient(
31+
ctx context.Context,
32+
stream io.ReadWriteCloser,
33+
logger logger.Service,
34+
) (client tunnelpogs.RegistrationServer_PogsClient, err error) {
35+
conn := rpc.NewConn(
36+
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
37+
tunnelrpc.ConnLog(logger),
38+
)
39+
client = tunnelpogs.RegistrationServer_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn}
40+
return client, nil
4941
}

h2mux/error.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ var (
1919
ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol}
2020
ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol}
2121
ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol}
22+
ErrNotRPCStream = MuxerProtocolError{"2004 not RPC stream", http2.ErrCodeProtocol}
2223

2324
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
2425
ErrStreamRequestConnectionClosed = MuxerApplicationError{"3001 connection closed while opening stream"}

h2mux/h2mux.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ const (
2222
defaultTimeout time.Duration = 5 * time.Second
2323
defaultRetries uint64 = 5
2424
defaultWriteBufferMaxLen int = 1024 * 1024 // 1mb
25-
writeBufferInitialSize int = 16 * 1024 // 16KB
25+
writeBufferInitialSize int = 16 * 1024 // 16KB
2626

2727
SettingMuxerMagic http2.SettingID = 0x42db
2828
MuxerMagicOrigin uint32 = 0xa2e43c8b
@@ -441,11 +441,17 @@ func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader
441441
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
442442
stream := m.NewStream(RPCHeaders())
443443
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
444+
stream.Close()
444445
return nil, err
445446
}
446447
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
448+
stream.Close()
447449
return nil, err
448450
}
451+
if !IsRPCStreamResponse(stream) {
452+
stream.Close()
453+
return nil, ErrNotRPCStream
454+
}
449455
return stream, nil
450456
}
451457

@@ -499,3 +505,10 @@ func (m *Muxer) abort() {
499505
func (m *Muxer) TimerRetries() uint64 {
500506
return m.muxWriter.idleTimer.RetryCount()
501507
}
508+
509+
func IsRPCStreamResponse(stream *MuxedStream) bool {
510+
headers := stream.Headers
511+
return len(headers) == 1 &&
512+
headers[0].Name == ":status" &&
513+
headers[0].Value == "200"
514+
}

origin/reconnect.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"sync"
88
"time"
99

10-
"github.com/cloudflare/cloudflared/connection"
1110
"github.com/cloudflare/cloudflared/h2mux"
1211
"github.com/cloudflare/cloudflared/logger"
1312
"github.com/cloudflare/cloudflared/tunnelrpc"
@@ -164,18 +163,17 @@ func ReconnectTunnel(
164163
}
165164

166165
config.TransportLogger.Debug("initiating RPC stream to reconnect")
167-
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
166+
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect)
168167
if err != nil {
169-
// RPC stream open error
170-
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect)
168+
return err
171169
}
172-
defer tunnelServer.Close()
170+
defer rpcClient.Close()
173171
// Request server info without blocking tunnel registration; must use capnp library directly.
174-
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
172+
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
175173
return nil
176174
})
177175
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
178-
registration := tunnelServer.ReconnectTunnel(
176+
registration := rpcClient.ReconnectTunnel(
179177
ctx,
180178
token,
181179
eventDigest,

origin/supervisor.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,16 +323,16 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
323323
<-muxer.Shutdown()
324324
}()
325325

326-
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger, openStreamTimeout)
326+
rpcClient, err := newTunnelRPCClient(ctx, muxer, s.config, authenticate)
327327
if err != nil {
328328
return nil, err
329329
}
330-
defer tunnelServer.Close()
330+
defer rpcClient.Close()
331331

332332
const arbitraryConnectionID = uint8(0)
333333
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
334334
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
335-
authResponse, err := tunnelServer.Authenticate(
335+
authResponse, err := rpcClient.Authenticate(
336336
ctx,
337337
s.config.OriginCert,
338338
s.config.Hostname,

origin/tunnel.go

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ const (
4444
FeatureQuickReconnects = "quick_reconnects"
4545
)
4646

47-
type registerRPCName string
47+
type rpcName string
4848

4949
const (
50-
register registerRPCName = "register"
51-
reconnect registerRPCName = "reconnect"
50+
register rpcName = "register"
51+
reconnect rpcName = "reconnect"
52+
unregister rpcName = "unregister"
53+
authenticate rpcName = " authenticate"
5254
)
5355

5456
type TunnelConfig struct {
@@ -121,7 +123,7 @@ type clientRegisterTunnelError struct {
121123
cause error
122124
}
123125

124-
func newClientRegisterTunnelError(cause error, counter *prometheus.CounterVec, name registerRPCName) clientRegisterTunnelError {
126+
func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError {
125127
counter.WithLabelValues(cause.Error(), string(name)).Inc()
126128
return clientRegisterTunnelError{cause: cause}
127129
}
@@ -337,7 +339,7 @@ func ServeTunnel(
337339
if config.NamedTunnel != nil {
338340
_ = UnregisterConnection(ctx, handler.muxer, config)
339341
} else {
340-
_ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
342+
_ = UnregisterTunnel(handler.muxer, config)
341343
}
342344
}
343345
handler.muxer.Shutdown()
@@ -417,14 +419,13 @@ func RegisterConnection(
417419
const registerConnection = "registerConnection"
418420

419421
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
420-
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
422+
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, registerConnection)
421423
if err != nil {
422-
// RPC stream open error
423-
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection)
424+
return err
424425
}
425-
defer rpc.Close()
426+
defer rpcClient.Close()
426427

427-
conn, err := rpc.RegisterConnection(
428+
conn, err := rpcClient.RegisterConnection(
428429
ctx,
429430
config.NamedTunnel.Auth,
430431
config.NamedTunnel.ID,
@@ -470,14 +471,14 @@ func UnregisterConnection(
470471
config *TunnelConfig,
471472
) error {
472473
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
473-
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
474+
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
474475
if err != nil {
475476
// RPC stream open error
476-
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
477+
return err
477478
}
478-
defer rpc.Close()
479+
defer rpcClient.Close()
479480

480-
return rpc.UnregisterConnection(ctx)
481+
return rpcClient.UnregisterConnection(ctx)
481482
}
482483

483484
func RegisterTunnel(
@@ -494,18 +495,18 @@ func RegisterTunnel(
494495
if config.TunnelEventChan != nil {
495496
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
496497
}
497-
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
498+
499+
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
498500
if err != nil {
499-
// RPC stream open error
500-
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
501+
return err
501502
}
502-
defer tunnelServer.Close()
503+
defer rpcClient.Close()
503504
// Request server info without blocking tunnel registration; must use capnp library directly.
504-
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
505+
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
505506
return nil
506507
})
507508
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
508-
registration := tunnelServer.RegisterTunnel(
509+
registration := rpcClient.RegisterTunnel(
509510
ctx,
510511
config.OriginCert,
511512
config.Hostname,
@@ -529,7 +530,7 @@ func processRegistrationSuccess(
529530
logger logger.Service,
530531
connectionID uint8,
531532
registration *tunnelpogs.TunnelRegistration,
532-
name registerRPCName,
533+
name rpcName,
533534
credentialManager *reconnectCredentialManager,
534535
) error {
535536
for _, logLine := range registration.LogLines {
@@ -563,7 +564,7 @@ func processRegistrationSuccess(
563564
return nil
564565
}
565566

566-
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error {
567+
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name rpcName) error {
567568
if err.Error() == DuplicateConnectionError {
568569
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
569570
return errDuplicationConnection
@@ -575,18 +576,18 @@ func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics
575576
}
576577
}
577578

578-
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger logger.Service) error {
579-
logger.Debug("initiating RPC stream to unregister")
579+
func UnregisterTunnel(muxer *h2mux.Muxer, config *TunnelConfig) error {
580+
config.TransportLogger.Debug("initiating RPC stream to unregister")
580581
ctx := context.Background()
581-
tunnelServer, err := connection.NewRPCClient(ctx, muxer, logger, openStreamTimeout)
582+
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, unregister)
582583
if err != nil {
583584
// RPC stream open error
584585
return err
585586
}
586-
defer tunnelServer.Close()
587+
defer rpcClient.Close()
587588

588589
// gracePeriod is encoded in int64 using capnproto
589-
return tunnelServer.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
590+
return rpcClient.UnregisterTunnel(ctx, config.GracePeriod.Nanoseconds())
590591
}
591592

592593
func LogServerInfo(
@@ -909,3 +910,18 @@ func findCfRayHeader(h1 *http.Request) string {
909910
func isLBProbeRequest(req *http.Request) bool {
910911
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
911912
}
913+
914+
func newTunnelRPCClient(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, rpcName rpcName) (tunnelpogs.TunnelServer_PogsClient, error) {
915+
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
916+
defer openStreamCancel()
917+
stream, err := muxer.OpenRPCStream(openStreamCtx)
918+
if err != nil {
919+
return tunnelpogs.TunnelServer_PogsClient{}, err
920+
}
921+
rpcClient, err := connection.NewTunnelRPCClient(ctx, stream, config.TransportLogger)
922+
if err != nil {
923+
// RPC stream open error
924+
return tunnelpogs.TunnelServer_PogsClient{}, newRPCError(err, config.Metrics.rpcFail, rpcName)
925+
}
926+
return rpcClient, nil
927+
}

0 commit comments

Comments
 (0)