Skip to content

Commit 8136da0

Browse files
Add integration tests for data channel and control channel
cr: https://code.amazon.com/reviews/CR-116169644
1 parent 0958b67 commit 8136da0

File tree

3 files changed

+693
-0
lines changed

3 files changed

+693
-0
lines changed
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"). You may not
4+
// use this file except in compliance with the License. A copy of the
5+
// License is located at
6+
//
7+
// http://aws.amazon.com/apache2.0/
8+
//
9+
// or in the "license" file accompanying this file. This file is distributed
10+
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
11+
// either express or implied. See the License for the specific language governing
12+
// permissions and limitations under the License.
13+
14+
//go:build integration
15+
// +build integration
16+
17+
// controlchannel package implement control communicator for web socket connection.
18+
package controlchannel
19+
20+
import (
21+
"fmt"
22+
"math"
23+
"math/rand"
24+
"net/http"
25+
"net/http/httptest"
26+
"net/url"
27+
"os"
28+
"runtime"
29+
"testing"
30+
"time"
31+
32+
"github.com/aws/amazon-ssm-agent/agent/contracts"
33+
mgsConfig "github.com/aws/amazon-ssm-agent/agent/session/config"
34+
mgsContracts "github.com/aws/amazon-ssm-agent/agent/session/contracts"
35+
"github.com/aws/amazon-ssm-agent/agent/session/retry"
36+
"github.com/aws/amazon-ssm-agent/agent/session/service"
37+
serviceMock "github.com/aws/amazon-ssm-agent/agent/session/service/mocks"
38+
"github.com/aws/amazon-ssm-agent/agent/ssmconnectionchannel"
39+
"github.com/gorilla/websocket"
40+
"github.com/stretchr/testify/assert"
41+
"github.com/stretchr/testify/mock"
42+
)
43+
44+
var (
45+
wsUpgrader = &websocket.Upgrader{ReadBufferSize: 2048, WriteBufferSize: 2084}
46+
)
47+
48+
func TestMain(m *testing.M) {
49+
resetConnectionChannel()
50+
code := m.Run()
51+
os.Exit(code)
52+
}
53+
54+
func TestOpenControlChannel_MultiThread(t *testing.T) {
55+
httpErrorHandler := func(hw http.ResponseWriter, request *http.Request) {
56+
httpConn, err := wsUpgrader.Upgrade(hw, request, nil)
57+
if err != nil {
58+
http.Error(hw, fmt.Sprintf("no upgrade: %v", err), http.StatusGatewayTimeout)
59+
panic("Connection should be successful. Should not enter here.")
60+
}
61+
62+
for {
63+
_, _, err = httpConn.ReadMessage()
64+
if err != nil {
65+
return
66+
}
67+
// close connection to simulate any issue on service side
68+
httpConn.Close()
69+
}
70+
}
71+
72+
// launch local HTTP Server
73+
srv := httptest.NewServer(http.HandlerFunc(httpErrorHandler))
74+
defer srv.Close()
75+
u, _ := url.Parse(srv.URL)
76+
u.Scheme = "ws"
77+
78+
controlChannel := getControlChannel()
79+
messageChan := make(chan mgsContracts.AgentMessage)
80+
mockEventLog.On("SendAuditMessage")
81+
var ableToOpenMGSConnection uint32
82+
createControlChannelOutput := service.CreateControlChannelOutput{TokenValue: &token}
83+
mockService = &serviceMock.Service{}
84+
mockService.On("CreateControlChannel", mock.Anything, mock.Anything, mock.AnythingOfType("string")).Return(&createControlChannelOutput, nil)
85+
mockService.On("GetRegion").Return(region)
86+
mockService.On("GetV4Signer").Return(signer)
87+
88+
controlChannel.Initialize(mockContext, mockService, instanceId, messageChan)
89+
var err error
90+
err = controlChannel.SetWebSocket(mockContext, mockService, &ableToOpenMGSConnection)
91+
assert.Nil(t, err, "should not throw error during websocket creation")
92+
93+
// Set local server URL
94+
controlChannel.wsChannel.SetUrl(u.String())
95+
96+
// Get number of go-routines running
97+
initialGRNumber := runtime.NumGoroutine()
98+
stop := make(chan bool)
99+
startConnectionChannelReader(stop, contracts.MGS)
100+
// start the control channel Open
101+
err = controlChannel.Open(mockContext, &ableToOpenMGSConnection)
102+
defer controlChannel.Close(mockLog)
103+
104+
// sleep for 1 minute to wait for the goroutines to run
105+
time.Sleep(60 * time.Second)
106+
assert.Nil(t, err, "should not throw error during channel open")
107+
108+
controlChannel.AuditLogScheduler.ScheduleAuditEvents()
109+
stop <- true
110+
111+
completedGRNumber := runtime.NumGoroutine()
112+
assert.True(t, initialGRNumber+5 >= completedGRNumber) // tests run in parallel at times hence adding some buffer
113+
}
114+
115+
func startConnectionChannelReader(stop chan bool, expectedStatus contracts.SSMConnectionChannel) {
116+
go func() {
117+
for {
118+
select {
119+
case status := <-ssmconnectionchannel.GetMDSSwitchChannel():
120+
if expectedStatus == contracts.MDS {
121+
if status {
122+
break
123+
}
124+
panic("should not reach this spot for MGS")
125+
}
126+
if expectedStatus == contracts.MGS {
127+
if !status {
128+
break
129+
}
130+
panic("should not reach this spot for MGS")
131+
}
132+
break
133+
case <-stop:
134+
return
135+
}
136+
}
137+
}()
138+
}
139+
140+
func TestOpenControlChannel_OpenControlChannelError(t *testing.T) {
141+
httpErrorHandler := func(hw http.ResponseWriter, request *http.Request) {
142+
http.Error(hw, fmt.Sprintf("no upgrade: %v", fmt.Errorf("test")), http.StatusGatewayTimeout)
143+
return
144+
}
145+
146+
// launch local HTTP Server
147+
srv := httptest.NewServer(http.HandlerFunc(httpErrorHandler))
148+
defer srv.Close()
149+
u, _ := url.Parse(srv.URL)
150+
u.Scheme = "ws"
151+
152+
controlChannel := getControlChannel()
153+
messageChan := make(chan mgsContracts.AgentMessage)
154+
mockEventLog.On("SendAuditMessage")
155+
var ableToOpenMGSConnection uint32
156+
createControlChannelOutput := service.CreateControlChannelOutput{TokenValue: &token}
157+
mockService = &serviceMock.Service{}
158+
mockService.On("CreateControlChannel", mock.Anything, mock.Anything, mock.AnythingOfType("string")).Return(&createControlChannelOutput, nil)
159+
mockService.On("GetRegion").Return(region)
160+
mockService.On("GetV4Signer").Return(signer)
161+
162+
stop := make(chan bool)
163+
startConnectionChannelReader(stop, contracts.MDS)
164+
165+
controlChannel.Initialize(mockContext, mockService, instanceId, messageChan)
166+
var err error
167+
err = controlChannel.SetWebSocket(mockContext, mockService, &ableToOpenMGSConnection)
168+
assert.Nil(t, err, "should not throw error during websocket creation")
169+
170+
// Set local server URL
171+
controlChannel.wsChannel.SetUrl(u.String())
172+
defer controlChannel.wsChannel.Close(mockContext.Log())
173+
174+
// Get number of go-routines running
175+
initialGRNumber := runtime.NumGoroutine()
176+
177+
// start control channel Open
178+
err = controlChannel.Open(mockContext, &ableToOpenMGSConnection)
179+
180+
defer controlChannel.Close(mockLog)
181+
assert.NotNil(t, err, "should throw error during channel open")
182+
183+
time.Sleep(10 * time.Second)
184+
completedGRNumber := runtime.NumGoroutine()
185+
stop <- true
186+
187+
// tests run in parallel at times hence adding some buffer
188+
assert.True(t, initialGRNumber+5 >= completedGRNumber)
189+
}
190+
191+
func TestOpenControlChannel_CreateControlChannelError(t *testing.T) {
192+
controlChannel := getControlChannel()
193+
messageChan := make(chan mgsContracts.AgentMessage)
194+
mockEventLog.On("SendAuditMessage")
195+
var ableToOpenMGSConnection uint32
196+
createControlChannelOutput := service.CreateControlChannelOutput{TokenValue: &token}
197+
mockService = &serviceMock.Service{}
198+
mockService.On("CreateControlChannel", mock.Anything, mock.Anything, mock.AnythingOfType("string")).Return(&createControlChannelOutput, fmt.Errorf("throw error"))
199+
mockService.On("GetRegion").Return(region)
200+
mockService.On("GetV4Signer").Return(signer)
201+
stop := make(chan bool)
202+
startConnectionChannelReader(stop, contracts.MGS)
203+
// Get number of go-routines running
204+
initialGRNumber := runtime.NumGoroutine()
205+
controlChannel.Initialize(mockContext, mockService, instanceId, messageChan)
206+
var err error
207+
err = controlChannel.SetWebSocket(mockContext, mockService, &ableToOpenMGSConnection)
208+
defer controlChannel.Close(mockLog)
209+
assert.Contains(t, err.Error(), "throw error")
210+
time.Sleep(2 * time.Second)
211+
completedGRNumber := runtime.NumGoroutine()
212+
stop <- true
213+
// tests run in parallel at times hence adding some buffer
214+
assert.True(t, initialGRNumber+5 >= completedGRNumber)
215+
}
216+
217+
func TestOpenControlChannel_CreateControlChannelError_RetryCount(t *testing.T) {
218+
httpTempHandler := func(hw http.ResponseWriter, request *http.Request) {
219+
httpConn, err := wsUpgrader.Upgrade(hw, request, nil)
220+
if err != nil {
221+
http.Error(hw, fmt.Sprintf("no upgrade: %v", err), http.StatusGatewayTimeout)
222+
panic("Connection should be successful. Should not enter here.")
223+
}
224+
for {
225+
_, _, err = httpConn.ReadMessage()
226+
if err != nil {
227+
return
228+
}
229+
}
230+
}
231+
// launch local HTTP Server
232+
srv := httptest.NewServer(http.HandlerFunc(httpTempHandler))
233+
u, _ := url.Parse(srv.URL)
234+
u.Scheme = "ws"
235+
defer srv.Close()
236+
237+
controlChannel := getControlChannel()
238+
messageChan := make(chan mgsContracts.AgentMessage)
239+
240+
// Set local server URL
241+
mockEventLog.On("SendAuditMessage")
242+
243+
var ableToOpenMGSConnection uint32
244+
createControlChannelOutput := service.CreateControlChannelOutput{TokenValue: &token}
245+
246+
mockService = &serviceMock.Service{}
247+
mockService.On("CreateControlChannel", mock.Anything, mock.Anything, mock.AnythingOfType("string")).Return(&createControlChannelOutput, fmt.Errorf("test")).Times(3)
248+
mockService.On("CreateControlChannel", mock.Anything, mock.Anything, mock.AnythingOfType("string")).Return(&createControlChannelOutput, nil)
249+
mockService.On("GetRegion").Return(region)
250+
mockService.On("GetV4Signer").Return(signer)
251+
252+
counter := 0
253+
// Get number of go-routines running
254+
initialGRNumber := runtime.NumGoroutine()
255+
256+
startTime := time.Now()
257+
258+
stop := make(chan bool)
259+
startConnectionChannelReader(stop, contracts.MGS)
260+
261+
// copied over from MGSInteractor
262+
retryer := retry.ExponentialRetryer{
263+
CallableFunc: func() (channel interface{}, err error) {
264+
counter++
265+
controlChannel = getControlChannel()
266+
controlChannel.Initialize(mockContext, mockService, instanceId, messageChan)
267+
if err = controlChannel.SetWebSocket(mockContext, mockService, &ableToOpenMGSConnection); err != nil {
268+
return nil, err
269+
}
270+
271+
controlChannel.wsChannel.SetUrl(u.String())
272+
if err = controlChannel.Open(mockContext, &ableToOpenMGSConnection); err != nil {
273+
return nil, err
274+
}
275+
276+
controlChannel.AuditLogScheduler.ScheduleAuditEvents()
277+
return controlChannel, nil
278+
},
279+
GeometricRatio: mgsConfig.RetryGeometricRatio,
280+
JitterRatio: mgsConfig.RetryJitterRatio,
281+
InitialDelayInMilli: rand.Intn(mgsConfig.ControlChannelRetryInitialDelayMillis) + mgsConfig.ControlChannelRetryInitialDelayMillis,
282+
MaxDelayInMilli: mgsConfig.ControlChannelRetryMaxIntervalMillis,
283+
MaxAttempts: 30,
284+
}
285+
retryer.Init()
286+
_, err1 := retryer.Call()
287+
288+
stop <- true
289+
assert.Nil(t, err1)
290+
time.Sleep(10 * time.Second)
291+
completedGRNumber := runtime.NumGoroutine()
292+
assert.True(t, math.Abs(startTime.Sub(time.Now()).Seconds()) > 50)
293+
294+
// tests run in parallel at times hence adding some buffer
295+
assert.True(t, initialGRNumber+5 >= completedGRNumber)
296+
assert.Equal(t, 4, counter)
297+
}
298+
299+
func TestOpenControlChannel_OpenControlChannelError_RetryCount(t *testing.T) {
300+
httpTempHandler := func(hw http.ResponseWriter, request *http.Request) {
301+
http.Error(hw, fmt.Sprintf("no upgrade: %v", fmt.Errorf("err1")), http.StatusGatewayTimeout)
302+
}
303+
// launch local HTTP Server
304+
srv := httptest.NewServer(http.HandlerFunc(httpTempHandler))
305+
u, _ := url.Parse(srv.URL)
306+
u.Scheme = "ws"
307+
defer srv.Close()
308+
309+
controlChannel := getControlChannel()
310+
messageChan := make(chan mgsContracts.AgentMessage)
311+
312+
// Set local server URL
313+
mockEventLog.On("SendAuditMessage")
314+
315+
var ableToOpenMGSConnection uint32
316+
createControlChannelOutput := service.CreateControlChannelOutput{TokenValue: &token}
317+
318+
mockService = &serviceMock.Service{}
319+
mockService.On("CreateControlChannel", mock.Anything, mock.Anything, mock.AnythingOfType("string")).Return(&createControlChannelOutput, nil).Times(4)
320+
mockService.On("GetRegion").Return(region)
321+
mockService.On("GetV4Signer").Return(signer)
322+
323+
counter := 0
324+
// Get number of go-routines running
325+
initialGRNumber := runtime.NumGoroutine()
326+
327+
startTime := time.Now()
328+
stop := make(chan bool)
329+
startConnectionChannelReader(stop, contracts.MDS)
330+
// copied over from MGSInteractor
331+
retryer := retry.ExponentialRetryer{
332+
CallableFunc: func() (channel interface{}, err error) {
333+
counter++
334+
controlChannel = getControlChannel()
335+
controlChannel.Initialize(mockContext, mockService, instanceId, messageChan)
336+
if err = controlChannel.SetWebSocket(mockContext, mockService, &ableToOpenMGSConnection); err != nil {
337+
return nil, err
338+
}
339+
340+
controlChannel.wsChannel.SetUrl(u.String())
341+
if err = controlChannel.Open(mockContext, &ableToOpenMGSConnection); err != nil {
342+
return nil, err
343+
}
344+
345+
controlChannel.AuditLogScheduler.ScheduleAuditEvents()
346+
return controlChannel, nil
347+
},
348+
GeometricRatio: mgsConfig.RetryGeometricRatio,
349+
JitterRatio: mgsConfig.RetryJitterRatio,
350+
InitialDelayInMilli: rand.Intn(mgsConfig.ControlChannelRetryInitialDelayMillis) + mgsConfig.ControlChannelRetryInitialDelayMillis,
351+
MaxDelayInMilli: mgsConfig.ControlChannelRetryMaxIntervalMillis,
352+
MaxAttempts: 3,
353+
}
354+
retryer.Init()
355+
_, err1 := retryer.Call()
356+
357+
assert.NotNil(t, err1)
358+
stop <- true
359+
time.Sleep(10 * time.Second)
360+
completedGRNumber := runtime.NumGoroutine()
361+
assert.True(t, math.Abs(startTime.Sub(time.Now()).Seconds()) > 50)
362+
363+
// tests run in parallel at times hence adding some buffer
364+
assert.True(t, initialGRNumber+5 >= completedGRNumber)
365+
assert.Equal(t, 4, counter)
366+
mockService.AssertExpectations(t)
367+
}

agent/session/controlchannel/controlchannel_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ func TestSetWebSocket(t *testing.T) {
8181
initializeMocks()
8282
controlChannel := getControlChannel()
8383
createControlChannelOutput := service.CreateControlChannelOutput{TokenValue: &token}
84+
mockService = &serviceMock.Service{}
8485
mockService.On("CreateControlChannel", mock.Anything, mock.Anything, mock.AnythingOfType("string")).Return(&createControlChannelOutput, nil)
8586
mockService.On("GetRegion").Return(region)
8687
mockService.On("GetV4Signer").Return(signer)

0 commit comments

Comments
 (0)