Skip to content

Commit d07d24e

Browse files
committed
TUN-5695: Define RPC method to update configuration
1 parent 0ab6867 commit d07d24e

File tree

8 files changed

+923
-223
lines changed

8 files changed

+923
-223
lines changed

connection/quic.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
192192
}
193193

194194
func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
195-
return rpcStream.Serve(q, q.logger)
195+
return rpcStream.Serve(q, q, q.logger)
196196
}
197197

198198
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
@@ -260,6 +260,11 @@ func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uui
260260
return q.sessionManager.UnregisterSession(ctx, sessionID, message, true)
261261
}
262262

263+
// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration
264+
func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) {
265+
return nil, fmt.Errorf("TODO: TUN-5698")
266+
}
267+
263268
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
264269
// the client.
265270
type streamReadWriteAcker struct {

connection/quic_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -560,12 +560,12 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.
560560
if closeType != closedByRemote {
561561
// Session was not closed by remote, so closeUDPSession should be invoked to unregister from remote
562562
unregisterFromEdgeChan := make(chan struct{})
563-
rpcServer := &mockSessionRPCServer{
563+
sessionRPCServer := &mockSessionRPCServer{
564564
sessionID: sessionID,
565565
unregisterReason: expectedReason,
566566
calledUnregisterChan: unregisterFromEdgeChan,
567567
}
568-
go runMockSessionRPCServer(ctx, edgeQUICSession, rpcServer, t)
568+
go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t)
569569

570570
<-unregisterFromEdgeChan
571571
}
@@ -581,7 +581,7 @@ const (
581581
closedByTimeout
582582
)
583583

584-
func runMockSessionRPCServer(ctx context.Context, session quic.Session, rpcServer *mockSessionRPCServer, t *testing.T) {
584+
func runRPCServer(ctx context.Context, session quic.Session, sessionRPCServer tunnelpogs.SessionManager, configRPCServer tunnelpogs.ConfigurationManager, t *testing.T) {
585585
stream, err := session.AcceptStream(ctx)
586586
require.NoError(t, err)
587587

@@ -596,7 +596,7 @@ func runMockSessionRPCServer(ctx context.Context, session quic.Session, rpcServe
596596
assert.NoError(t, err)
597597

598598
log := zerolog.New(os.Stdout)
599-
err = rpcServerStream.Serve(rpcServer, &log)
599+
err = rpcServerStream.Serve(sessionRPCServer, configRPCServer, &log)
600600
assert.NoError(t, err)
601601
}
602602

@@ -618,7 +618,6 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI
618618
return fmt.Errorf("expect unregister reason %s, got %s", s.unregisterReason, reason)
619619
}
620620
close(s.calledUnregisterChan)
621-
fmt.Println("unregister from edge")
622621
return nil
623622
}
624623

quic/quic_protocol.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, err
125125
return nil, err
126126
}
127127
if signature != DataStreamProtocolSignature {
128-
return nil, fmt.Errorf("Wrong protocol signature %v", signature)
128+
return nil, fmt.Errorf("wrong protocol signature %v", signature)
129129
}
130130

131131
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
@@ -157,13 +157,13 @@ func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) (
157157
return &RPCServerStream{stream}, nil
158158
}
159159

160-
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error {
160+
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, configManager tunnelpogs.ConfigurationManager, logger *zerolog.Logger) error {
161161
// RPC logs are very robust, create a new logger that only logs error to reduce noise
162162
rpcLogger := logger.Level(zerolog.ErrorLevel)
163163
rpcTransport := tunnelrpc.NewTransportLogger(&rpcLogger, rpc.StreamTransport(s))
164164
defer rpcTransport.Close()
165165

166-
main := tunnelpogs.SessionManager_ServerToClient(sessionManager)
166+
main := tunnelpogs.CloudflaredServer_ServerToClient(sessionManager, configManager)
167167
rpcConn := rpc.NewConn(
168168
rpcTransport,
169169
rpc.MainInterface(main.Client),
@@ -223,7 +223,7 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error {
223223

224224
// RPCClientStream is a stream to call methods of SessionManager
225225
type RPCClientStream struct {
226-
client tunnelpogs.SessionManager_PogsClient
226+
client tunnelpogs.CloudflaredServer_PogsClient
227227
transport rpc.Transport
228228
}
229229

@@ -241,7 +241,7 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
241241
tunnelrpc.ConnLog(logger),
242242
)
243243
return &RPCClientStream{
244-
client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
244+
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
245245
transport: transport,
246246
}, nil
247247
}
@@ -258,6 +258,10 @@ func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID
258258
return rcs.client.UnregisterUdpSession(ctx, sessionID, message)
259259
}
260260

261+
func (rcs *RPCClientStream) UpdateConfiguration(ctx context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) {
262+
return rcs.client.UpdateConfiguration(ctx, version, config)
263+
}
264+
261265
func (rcs *RPCClientStream) Close() {
262266
_ = rcs.client.Close()
263267
_ = rcs.transport.Close()

quic/quic_protocol_test.go

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"github.com/rs/zerolog"
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
17+
18+
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
1719
)
1820

1921
const (
@@ -108,14 +110,10 @@ func TestConnectResponseMeta(t *testing.T) {
108110
}
109111

110112
func TestRegisterUdpSession(t *testing.T) {
111-
clientReader, serverWriter := io.Pipe()
112-
serverReader, clientWriter := io.Pipe()
113-
114-
clientStream := mockRPCStream{clientReader, clientWriter}
115-
serverStream := mockRPCStream{serverReader, serverWriter}
113+
clientStream, serverStream := newMockRPCStreams()
116114

117115
unregisterMessage := "closed by eyeball"
118-
rpcServer := mockRPCServer{
116+
sessionRPCServer := mockSessionRPCServer{
119117
sessionID: uuid.New(),
120118
dstIP: net.IP{172, 16, 0, 1},
121119
dstPort: 8000,
@@ -129,7 +127,7 @@ func TestRegisterUdpSession(t *testing.T) {
129127
assert.NoError(t, err)
130128
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
131129
assert.NoError(t, err)
132-
err = rpcServerStream.Serve(rpcServer, &logger)
130+
err = rpcServerStream.Serve(sessionRPCServer, nil, &logger)
133131
assert.NoError(t, err)
134132

135133
serverStream.Close()
@@ -139,12 +137,12 @@ func TestRegisterUdpSession(t *testing.T) {
139137
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
140138
assert.NoError(t, err)
141139

142-
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
140+
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), sessionRPCServer.sessionID, sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
143141

144142
// Different sessionID, the RPC server should reject the registraion
145-
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
143+
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
146144

147-
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), rpcServer.sessionID, unregisterMessage))
145+
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), sessionRPCServer.sessionID, unregisterMessage))
148146

149147
// Different sessionID, the RPC server should reject the unregistraion
150148
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
@@ -153,15 +151,56 @@ func TestRegisterUdpSession(t *testing.T) {
153151
<-sessionRegisteredChan
154152
}
155153

156-
type mockRPCServer struct {
154+
func TestManageConfiguration(t *testing.T) {
155+
var (
156+
version int32 = 168
157+
config = []byte(t.Name())
158+
)
159+
clientStream, serverStream := newMockRPCStreams()
160+
161+
configRPCServer := mockConfigRPCServer{
162+
version: version,
163+
config: config,
164+
}
165+
166+
logger := zerolog.Nop()
167+
updatedChan := make(chan struct{})
168+
go func() {
169+
protocol, err := DetermineProtocol(serverStream)
170+
assert.NoError(t, err)
171+
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
172+
assert.NoError(t, err)
173+
err = rpcServerStream.Serve(nil, configRPCServer, &logger)
174+
assert.NoError(t, err)
175+
176+
serverStream.Close()
177+
close(updatedChan)
178+
}()
179+
180+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
181+
defer cancel()
182+
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger)
183+
assert.NoError(t, err)
184+
185+
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
186+
assert.NoError(t, err)
187+
188+
require.Equal(t, version, result.LastAppliedVersion)
189+
require.NoError(t, result.Err)
190+
191+
rpcClientStream.Close()
192+
<-updatedChan
193+
}
194+
195+
type mockSessionRPCServer struct {
157196
sessionID uuid.UUID
158197
dstIP net.IP
159198
dstPort uint16
160199
closeIdleAfter time.Duration
161200
unregisterMessage string
162201
}
163202

164-
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
203+
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
165204
if s.sessionID != sessionID {
166205
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
167206
}
@@ -177,7 +216,7 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU
177216
return nil
178217
}
179218

180-
func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
219+
func (s mockSessionRPCServer) UnregisterUdpSession(_ context.Context, sessionID uuid.UUID, message string) error {
181220
if s.sessionID != sessionID {
182221
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
183222
}
@@ -187,11 +226,35 @@ func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.
187226
return nil
188227
}
189228

229+
type mockConfigRPCServer struct {
230+
version int32
231+
config []byte
232+
}
233+
234+
func (s mockConfigRPCServer) UpdateConfiguration(_ context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) {
235+
if s.version != version {
236+
return nil, fmt.Errorf("expect version %d, got %d", s.version, version)
237+
}
238+
if !bytes.Equal(s.config, config) {
239+
return nil, fmt.Errorf("expect config %v, got %v", s.config, config)
240+
}
241+
return &tunnelpogs.UpdateConfigurationResponse{LastAppliedVersion: version}, nil
242+
}
243+
190244
type mockRPCStream struct {
191245
io.ReadCloser
192246
io.WriteCloser
193247
}
194248

249+
func newMockRPCStreams() (client io.ReadWriteCloser, server io.ReadWriteCloser) {
250+
clientReader, serverWriter := io.Pipe()
251+
serverReader, clientWriter := io.Pipe()
252+
253+
client = mockRPCStream{clientReader, clientWriter}
254+
server = mockRPCStream{serverReader, serverWriter}
255+
return
256+
}
257+
195258
func (s mockRPCStream) Close() error {
196259
_ = s.ReadCloser.Close()
197260
_ = s.WriteCloser.Close()

tunnelrpc/pogs/cloudflaredrpc.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package pogs
2+
3+
import (
4+
"github.com/cloudflare/cloudflared/tunnelrpc"
5+
capnp "zombiezen.com/go/capnproto2"
6+
"zombiezen.com/go/capnproto2/rpc"
7+
)
8+
9+
type CloudflaredServer interface {
10+
SessionManager
11+
ConfigurationManager
12+
}
13+
14+
type CloudflaredServer_PogsImpl struct {
15+
SessionManager_PogsImpl
16+
ConfigurationManager_PogsImpl
17+
}
18+
19+
func CloudflaredServer_ServerToClient(s SessionManager, c ConfigurationManager) tunnelrpc.CloudflaredServer {
20+
return tunnelrpc.CloudflaredServer_ServerToClient(CloudflaredServer_PogsImpl{
21+
SessionManager_PogsImpl: SessionManager_PogsImpl{s},
22+
ConfigurationManager_PogsImpl: ConfigurationManager_PogsImpl{c},
23+
})
24+
}
25+
26+
type CloudflaredServer_PogsClient struct {
27+
SessionManager_PogsClient
28+
ConfigurationManager_PogsClient
29+
Client capnp.Client
30+
Conn *rpc.Conn
31+
}
32+
33+
func NewCloudflaredServer_PogsClient(client capnp.Client, conn *rpc.Conn) CloudflaredServer_PogsClient {
34+
sessionManagerClient := SessionManager_PogsClient{
35+
Client: client,
36+
Conn: conn,
37+
}
38+
configManagerClient := ConfigurationManager_PogsClient{
39+
Client: client,
40+
Conn: conn,
41+
}
42+
return CloudflaredServer_PogsClient{
43+
SessionManager_PogsClient: sessionManagerClient,
44+
ConfigurationManager_PogsClient: configManagerClient,
45+
Client: client,
46+
Conn: conn,
47+
}
48+
}
49+
50+
func (c CloudflaredServer_PogsClient) Close() error {
51+
c.Client.Close()
52+
return c.Conn.Close()
53+
}

0 commit comments

Comments
 (0)