@@ -14,6 +14,8 @@ import (
1414 "github.com/rs/zerolog"
1515 "github.com/stretchr/testify/assert"
1616 "github.com/stretchr/testify/require"
17+
18+ tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
1719)
1820
1921const (
@@ -108,14 +110,10 @@ func TestConnectResponseMeta(t *testing.T) {
108110}
109111
110112func TestRegisterUdpSession (t * testing.T ) {
111- clientReader , serverWriter := io .Pipe ()
112- serverReader , clientWriter := io .Pipe ()
113-
114- clientStream := mockRPCStream {clientReader , clientWriter }
115- serverStream := mockRPCStream {serverReader , serverWriter }
113+ clientStream , serverStream := newMockRPCStreams ()
116114
117115 unregisterMessage := "closed by eyeball"
118- rpcServer := mockRPCServer {
116+ sessionRPCServer := mockSessionRPCServer {
119117 sessionID : uuid .New (),
120118 dstIP : net.IP {172 , 16 , 0 , 1 },
121119 dstPort : 8000 ,
@@ -129,7 +127,7 @@ func TestRegisterUdpSession(t *testing.T) {
129127 assert .NoError (t , err )
130128 rpcServerStream , err := NewRPCServerStream (serverStream , protocol )
131129 assert .NoError (t , err )
132- err = rpcServerStream .Serve (rpcServer , & logger )
130+ err = rpcServerStream .Serve (sessionRPCServer , nil , & logger )
133131 assert .NoError (t , err )
134132
135133 serverStream .Close ()
@@ -139,12 +137,12 @@ func TestRegisterUdpSession(t *testing.T) {
139137 rpcClientStream , err := NewRPCClientStream (context .Background (), clientStream , & logger )
140138 assert .NoError (t , err )
141139
142- assert .NoError (t , rpcClientStream .RegisterUdpSession (context .Background (), rpcServer .sessionID , rpcServer .dstIP , rpcServer .dstPort , testCloseIdleAfterHint ))
140+ assert .NoError (t , rpcClientStream .RegisterUdpSession (context .Background (), sessionRPCServer .sessionID , sessionRPCServer .dstIP , sessionRPCServer .dstPort , testCloseIdleAfterHint ))
143141
144142 // Different sessionID, the RPC server should reject the registraion
145- assert .Error (t , rpcClientStream .RegisterUdpSession (context .Background (), uuid .New (), rpcServer .dstIP , rpcServer .dstPort , testCloseIdleAfterHint ))
143+ assert .Error (t , rpcClientStream .RegisterUdpSession (context .Background (), uuid .New (), sessionRPCServer .dstIP , sessionRPCServer .dstPort , testCloseIdleAfterHint ))
146144
147- assert .NoError (t , rpcClientStream .UnregisterUdpSession (context .Background (), rpcServer .sessionID , unregisterMessage ))
145+ assert .NoError (t , rpcClientStream .UnregisterUdpSession (context .Background (), sessionRPCServer .sessionID , unregisterMessage ))
148146
149147 // Different sessionID, the RPC server should reject the unregistraion
150148 assert .Error (t , rpcClientStream .UnregisterUdpSession (context .Background (), uuid .New (), unregisterMessage ))
@@ -153,15 +151,56 @@ func TestRegisterUdpSession(t *testing.T) {
153151 <- sessionRegisteredChan
154152}
155153
156- type mockRPCServer struct {
154+ func TestManageConfiguration (t * testing.T ) {
155+ var (
156+ version int32 = 168
157+ config = []byte (t .Name ())
158+ )
159+ clientStream , serverStream := newMockRPCStreams ()
160+
161+ configRPCServer := mockConfigRPCServer {
162+ version : version ,
163+ config : config ,
164+ }
165+
166+ logger := zerolog .Nop ()
167+ updatedChan := make (chan struct {})
168+ go func () {
169+ protocol , err := DetermineProtocol (serverStream )
170+ assert .NoError (t , err )
171+ rpcServerStream , err := NewRPCServerStream (serverStream , protocol )
172+ assert .NoError (t , err )
173+ err = rpcServerStream .Serve (nil , configRPCServer , & logger )
174+ assert .NoError (t , err )
175+
176+ serverStream .Close ()
177+ close (updatedChan )
178+ }()
179+
180+ ctx , cancel := context .WithTimeout (context .Background (), time .Second )
181+ defer cancel ()
182+ rpcClientStream , err := NewRPCClientStream (ctx , clientStream , & logger )
183+ assert .NoError (t , err )
184+
185+ result , err := rpcClientStream .UpdateConfiguration (ctx , version , config )
186+ assert .NoError (t , err )
187+
188+ require .Equal (t , version , result .LastAppliedVersion )
189+ require .NoError (t , result .Err )
190+
191+ rpcClientStream .Close ()
192+ <- updatedChan
193+ }
194+
195+ type mockSessionRPCServer struct {
157196 sessionID uuid.UUID
158197 dstIP net.IP
159198 dstPort uint16
160199 closeIdleAfter time.Duration
161200 unregisterMessage string
162201}
163202
164- func (s mockRPCServer ) RegisterUdpSession (ctx context.Context , sessionID uuid.UUID , dstIP net.IP , dstPort uint16 , closeIdleAfter time.Duration ) error {
203+ func (s mockSessionRPCServer ) RegisterUdpSession (_ context.Context , sessionID uuid.UUID , dstIP net.IP , dstPort uint16 , closeIdleAfter time.Duration ) error {
165204 if s .sessionID != sessionID {
166205 return fmt .Errorf ("expect session ID %s, got %s" , s .sessionID , sessionID )
167206 }
@@ -177,7 +216,7 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU
177216 return nil
178217}
179218
180- func (s mockRPCServer ) UnregisterUdpSession (ctx context.Context , sessionID uuid.UUID , message string ) error {
219+ func (s mockSessionRPCServer ) UnregisterUdpSession (_ context.Context , sessionID uuid.UUID , message string ) error {
181220 if s .sessionID != sessionID {
182221 return fmt .Errorf ("expect session ID %s, got %s" , s .sessionID , sessionID )
183222 }
@@ -187,11 +226,35 @@ func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.
187226 return nil
188227}
189228
229+ type mockConfigRPCServer struct {
230+ version int32
231+ config []byte
232+ }
233+
234+ func (s mockConfigRPCServer ) UpdateConfiguration (_ context.Context , version int32 , config []byte ) (* tunnelpogs.UpdateConfigurationResponse , error ) {
235+ if s .version != version {
236+ return nil , fmt .Errorf ("expect version %d, got %d" , s .version , version )
237+ }
238+ if ! bytes .Equal (s .config , config ) {
239+ return nil , fmt .Errorf ("expect config %v, got %v" , s .config , config )
240+ }
241+ return & tunnelpogs.UpdateConfigurationResponse {LastAppliedVersion : version }, nil
242+ }
243+
190244type mockRPCStream struct {
191245 io.ReadCloser
192246 io.WriteCloser
193247}
194248
249+ func newMockRPCStreams () (client io.ReadWriteCloser , server io.ReadWriteCloser ) {
250+ clientReader , serverWriter := io .Pipe ()
251+ serverReader , clientWriter := io .Pipe ()
252+
253+ client = mockRPCStream {clientReader , clientWriter }
254+ server = mockRPCStream {serverReader , serverWriter }
255+ return
256+ }
257+
195258func (s mockRPCStream ) Close () error {
196259 _ = s .ReadCloser .Close ()
197260 _ = s .WriteCloser .Close ()
0 commit comments