@@ -14,93 +14,147 @@ import (
1414 "github.com/stretchr/testify/require"
1515)
1616
17- func TestSendCommand (t * testing.T ) {
18- // commandSuccessInstance indicates an instance for which the command should succeed
19- // regardless of whether `waitError` is set.
20- const commandSuccessInstance = "inst-success"
21- cases := []struct {
22- name string
23- sendOutput * ssm.SendCommandOutput
24- sendError error
25- expectedError string
26- expectedOut string
27- waitError error
28- instances []string
29- }{
30- {
31- name : "send success" ,
32- sendOutput : & ssm.SendCommandOutput {
33- Command : & ssm.Command {CommandId : aws .String ("id1" )},
34- },
35- instances : []string {"inst-id-1" },
36- expectedOut : "id1" ,
17+ func TestSendCommandSuccess (t * testing.T ) {
18+ instances := []string {"inst-id-1" , "inst-id-2" }
19+ waitInstanceIDs := []string {}
20+ mockSSM := MockSSM {
21+ SendCommandFn : func (input * ssm.SendCommandInput ) (* ssm.SendCommandOutput , error ) {
22+ assert .Equal (t , "test-doc" , aws .StringValue (input .DocumentName ))
23+ assert .Equal (t , "$DEFAULT" , aws .StringValue (input .DocumentVersion ))
24+ assert .Equal (t , aws .StringSlice (instances ), input .InstanceIds )
25+ return & ssm.SendCommandOutput {Command : & ssm.Command {CommandId : aws .String ("command-id" )}}, nil
3726 },
38- {
39- name : "send fail" ,
40- sendError : errors .New ("failed to send command" ),
41- expectedError : "send command failed" ,
42- instances : []string {"inst-id-1" },
27+ WaitUntilCommandExecutedWithContextFn : func (ctx aws.Context , input * ssm.GetCommandInvocationInput , opts ... request.WaiterOption ) error {
28+ assert .Equal (t , "command-id" , aws .StringValue (input .CommandId ))
29+ waitInstanceIDs = append (waitInstanceIDs , aws .StringValue (input .InstanceId ))
30+ return nil
4331 },
44- {
45- name : "wait single failure" ,
46- waitError : errors .New ("exceeded max attempts" ),
47- sendOutput : & ssm.SendCommandOutput {
48- Command : & ssm.Command {CommandId : aws .String ("" )},
49- },
50- expectedError : "too many failures while awaiting document execution" ,
51- instances : []string {"inst-id-1" },
32+ }
33+ u := updater {ssm : mockSSM }
34+ commandID , err := u .sendCommand (instances , "test-doc" )
35+ require .NoError (t , err )
36+ assert .EqualValues (t , "command-id" , commandID )
37+ assert .Equal (t , instances , waitInstanceIDs )
38+ }
39+
40+ func TestSendCommandErr (t * testing.T ) {
41+ instances := []string {"inst-id-1" , "inst-id-2" }
42+ sendError := errors .New ("failed to send command" )
43+ mockSSM := MockSSM {
44+ SendCommandFn : func (input * ssm.SendCommandInput ) (* ssm.SendCommandOutput , error ) {
45+ assert .Equal (t , "test-doc" , aws .StringValue (input .DocumentName ))
46+ assert .Equal (t , "$DEFAULT" , aws .StringValue (input .DocumentVersion ))
47+ assert .Equal (t , aws .StringSlice (instances ), input .InstanceIds )
48+ return nil , sendError
5249 },
50+ }
51+ u := updater {ssm : mockSSM }
52+ commandID , err := u .sendCommand (instances , "test-doc" )
53+ require .Error (t , err )
54+ assert .Equal (t , "" , commandID )
55+ assert .ErrorIs (t , err , sendError )
56+
57+ }
58+
59+ func TestSendCommandWaitErr (t * testing.T ) {
60+ cases := []struct {
61+ name string
62+ instances []string
63+ }{
5364 {
54- name : "wait one succcess" ,
55- waitError : errors .New ("exceeded max attempts" ),
56- sendOutput : & ssm.SendCommandOutput {
57- Command : & ssm.Command {CommandId : aws .String ("id1" )},
58- },
59- instances : []string {"inst-id-1" , "inst-id-2" , commandSuccessInstance },
60- expectedOut : "id1" ,
65+ name : "wait single failure" ,
66+ instances : []string {"inst-id-1" },
6167 },
6268 {
6369 name : "wait fail all" ,
64- waitError : errors .New ("exceeded max attempts" ),
65- sendOutput : & ssm.SendCommandOutput {
66- Command : & ssm.Command {CommandId : aws .String ("id1" )},
67- },
68- expectedError : "too many failures while awaiting document execution" ,
69- instances : []string {"inst-id-1" , "inst-id-2" , "inst-id-3" },
70+ instances : []string {"inst-id-1" , "inst-id-2" , "inst-id-3" },
7071 },
7172 }
7273 for _ , tc := range cases {
7374 t .Run (tc .name , func (t * testing.T ) {
75+ waitError := errors .New ("exceeded max attempts" )
76+ failedInstanceIDs := []string {}
7477 mockSSM := MockSSM {
7578 SendCommandFn : func (input * ssm.SendCommandInput ) (* ssm.SendCommandOutput , error ) {
7679 assert .Equal (t , "test-doc" , aws .StringValue (input .DocumentName ))
77- assert .Equal (t , "$DEFAULT" , aws .StringValue (input .DocumentVersion ))
7880 assert .Equal (t , aws .StringSlice (tc .instances ), input .InstanceIds )
79- return tc .sendOutput , tc .sendError
81+ return & ssm.SendCommandOutput {
82+ Command : & ssm.Command {CommandId : aws .String ("command-id" )},
83+ }, nil
8084 },
8185 WaitUntilCommandExecutedWithContextFn : func (ctx aws.Context , input * ssm.GetCommandInvocationInput , opts ... request.WaiterOption ) error {
82- if aws .StringValue (input .InstanceId ) == commandSuccessInstance {
83- return nil
84- }
85- return tc .waitError
86+ assert .Equal (t , "command-id" , aws .StringValue (input .CommandId ))
87+ return waitError
88+ },
89+ GetCommandInvocationFn : func (input * ssm.GetCommandInvocationInput ) (* ssm.GetCommandInvocationOutput , error ) {
90+ assert .Equal (t , "command-id" , aws .StringValue (input .CommandId ))
91+ failedInstanceIDs = append (failedInstanceIDs , aws .StringValue (input .InstanceId ))
92+ return & ssm.GetCommandInvocationOutput {}, nil
8693 },
8794 }
8895 u := updater {ssm : mockSSM }
89- actual , err := u .sendCommand (tc .instances , "test-doc" )
90- if tc .expectedOut != "" {
91- require .NoError (t , err )
92- assert .EqualValues (t , tc .expectedOut , actual )
93- } else if tc .sendError != nil {
94- assert .ErrorIs (t , err , tc .sendError )
95- assert .Contains (t , err .Error (), tc .expectedError )
96- } else {
97- assert .ErrorIs (t , err , tc .waitError )
98- assert .Contains (t , err .Error (), tc .expectedError )
99- }
96+ commandID , err := u .sendCommand (tc .instances , "test-doc" )
97+ require .Error (t , err )
98+ assert .ErrorIs (t , err , waitError )
99+ assert .Equal (t , "" , commandID )
100+ assert .Equal (t , tc .instances , failedInstanceIDs , "should match instances for which wait fail" )
100101 })
101102 }
102103}
103104
105+ func TestSendCommandWaitSuccess (t * testing.T ) {
106+ mockSendCommand := func (input * ssm.SendCommandInput ) (* ssm.SendCommandOutput , error ) {
107+ assert .Equal (t , "test-doc" , aws .StringValue (input .DocumentName ))
108+ return & ssm.SendCommandOutput {
109+ Command : & ssm.Command {CommandId : aws .String ("command-id" )},
110+ }, nil
111+ }
112+ t .Run ("wait one success" , func (t * testing.T ) {
113+ // commandSuccessInstance indicates an instance for which the command should succeed
114+ const commandSuccessInstance = "inst-success"
115+ instances := []string {"inst-id-1" , "inst-id-1" , commandSuccessInstance }
116+ expectedFailInstances := []string {"inst-id-1" , "inst-id-1" }
117+ failedInstanceIDs := []string {}
118+ mockSSM := MockSSM {
119+ SendCommandFn : mockSendCommand ,
120+ WaitUntilCommandExecutedWithContextFn : func (ctx aws.Context , input * ssm.GetCommandInvocationInput , opts ... request.WaiterOption ) error {
121+ if aws .StringValue (input .InstanceId ) == commandSuccessInstance {
122+ return nil
123+ }
124+ return errors .New ("exceeded max attempts" )
125+ },
126+ GetCommandInvocationFn : func (input * ssm.GetCommandInvocationInput ) (* ssm.GetCommandInvocationOutput , error ) {
127+ assert .Equal (t , "command-id" , aws .StringValue (input .CommandId ))
128+ failedInstanceIDs = append (failedInstanceIDs , aws .StringValue (input .InstanceId ))
129+ return & ssm.GetCommandInvocationOutput {}, nil
130+ },
131+ }
132+ u := updater {ssm : mockSSM }
133+ commandID , err := u .sendCommand (instances , "test-doc" )
134+ require .NoError (t , err )
135+ assert .Equal (t , "command-id" , commandID )
136+ assert .Equal (t , expectedFailInstances , failedInstanceIDs , "should match instances for which wait fail" )
137+ })
138+ t .Run ("wait all success" , func (t * testing.T ) {
139+ instances := []string {"inst-id-1" , "inst-id-1" }
140+ waitInstanceIDs := []string {}
141+ mockSSM := MockSSM {
142+ SendCommandFn : mockSendCommand ,
143+ WaitUntilCommandExecutedWithContextFn : func (ctx aws.Context , input * ssm.GetCommandInvocationInput , opts ... request.WaiterOption ) error {
144+ assert .Equal (t , "command-id" , aws .StringValue (input .CommandId ))
145+ waitInstanceIDs = append (waitInstanceIDs , aws .StringValue (input .InstanceId ))
146+ return nil
147+ },
148+ }
149+ u := updater {ssm : mockSSM }
150+ commandID , err := u .sendCommand (instances , "test-doc" )
151+ require .NoError (t , err )
152+ assert .Equal (t , "command-id" , commandID )
153+ assert .Equal (t , instances , waitInstanceIDs )
154+ })
155+
156+ }
157+
104158func TestListContainerInstances (t * testing.T ) {
105159 cases := []struct {
106160 name string
@@ -695,6 +749,11 @@ func TestUpdateInstanceErr(t *testing.T) {
695749 assert .Equal (t , "instance-id" , aws .StringValue (input .InstanceId ))
696750 return waitExecErr
697751 },
752+ GetCommandInvocationFn : func (input * ssm.GetCommandInvocationInput ) (* ssm.GetCommandInvocationOutput , error ) {
753+ assert .Equal (t , "command-id" , aws .StringValue (input .CommandId ))
754+ assert .Equal (t , "instance-id" , aws .StringValue (input .InstanceId ))
755+ return & ssm.GetCommandInvocationOutput {}, nil
756+ },
698757 }
699758 u := updater {ssm : mockSSM , checkDocument : "check-document" }
700759 err := u .updateInstance (instance {
@@ -842,6 +901,11 @@ func TestVerifyUpdateErr(t *testing.T) {
842901 assert .Equal (t , "instance-id" , aws .StringValue (input .InstanceId ))
843902 return waitExecErr
844903 },
904+ GetCommandInvocationFn : func (input * ssm.GetCommandInvocationInput ) (* ssm.GetCommandInvocationOutput , error ) {
905+ assert .Equal (t , "command-id" , aws .StringValue (input .CommandId ))
906+ assert .Equal (t , "instance-id" , aws .StringValue (input .InstanceId ))
907+ return & ssm.GetCommandInvocationOutput {}, nil
908+ },
845909 }
846910 u := updater {ssm : mockSSM , checkDocument : "check-document" }
847911 ok , err := u .verifyUpdate (instance {
0 commit comments