@@ -11,6 +11,7 @@ import (
1111 "github.com/golang/mock/gomock"
1212 "github.com/stretchr/testify/assert"
1313 "github.com/stretchr/testify/require"
14+ "golang.org/x/sys/windows/svc"
1415)
1516
1617var errTestFailure = errors .New ("test failure" )
@@ -146,3 +147,205 @@ func TestExecuteCommandTimeout(t *testing.T) {
146147 _ , err := client .ExecuteCommand (context .Background (), "ping" , "-t" , "localhost" )
147148 require .Error (t , err )
148149}
150+
151+ type mockManagedService struct {
152+ queryFuncs []func () (svc.Status , error )
153+ controlFunc func (svc.Cmd ) (svc.Status , error )
154+ startFunc func (args ... string ) error
155+ }
156+
157+ func (m * mockManagedService ) Query () (svc.Status , error ) {
158+ queryFunc := m .queryFuncs [0 ]
159+ m .queryFuncs = m .queryFuncs [1 :]
160+ return queryFunc ()
161+ }
162+
163+ func (m * mockManagedService ) Control (cmd svc.Cmd ) (svc.Status , error ) {
164+ return m .controlFunc (cmd )
165+ }
166+
167+ func (m * mockManagedService ) Start (args ... string ) error {
168+ return m .startFunc (args ... )
169+ }
170+
171+ func TestTryStopServiceFn (t * testing.T ) {
172+ tests := []struct {
173+ name string
174+ queryFuncs []func () (svc.Status , error )
175+ controlFunc func (svc.Cmd ) (svc.Status , error )
176+ expectError bool
177+ }{
178+ {
179+ name : "Service already stopped" ,
180+ queryFuncs : []func () (svc.Status , error ){
181+ func () (svc.Status , error ) {
182+ return svc.Status {State : svc .Stopped }, nil
183+ },
184+ func () (svc.Status , error ) {
185+ return svc.Status {State : svc .Stopped }, nil
186+ },
187+ },
188+ controlFunc : nil ,
189+ expectError : false ,
190+ },
191+ {
192+ name : "Service running and stops successfully" ,
193+ queryFuncs : []func () (svc.Status , error ){
194+ func () (svc.Status , error ) {
195+ return svc.Status {State : svc .Running }, nil
196+ },
197+ func () (svc.Status , error ) {
198+ return svc.Status {State : svc .Stopped }, nil
199+ },
200+ },
201+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
202+ return svc.Status {State : svc .Stopped }, nil
203+ },
204+ expectError : false ,
205+ },
206+ {
207+ name : "Service running and stops after multiple attempts" ,
208+ queryFuncs : []func () (svc.Status , error ){
209+ func () (svc.Status , error ) {
210+ return svc.Status {State : svc .Running }, nil
211+ },
212+ func () (svc.Status , error ) {
213+ return svc.Status {State : svc .Running }, nil
214+ },
215+ func () (svc.Status , error ) {
216+ return svc.Status {State : svc .Running }, nil
217+ },
218+ func () (svc.Status , error ) {
219+ return svc.Status {State : svc .Stopped }, nil
220+ },
221+ },
222+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
223+ return svc.Status {State : svc .Stopped }, nil
224+ },
225+ expectError : false ,
226+ },
227+ {
228+ name : "Service running and fails to stop" ,
229+ queryFuncs : []func () (svc.Status , error ){
230+ func () (svc.Status , error ) {
231+ return svc.Status {State : svc .Running }, nil
232+ },
233+ },
234+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
235+ return svc.Status {State : svc .Running }, errors .New ("failed to stop service" ) //nolint:err113 // test error
236+ },
237+ expectError : true ,
238+ },
239+ {
240+ name : "Service query fails" ,
241+ queryFuncs : []func () (svc.Status , error ){
242+ func () (svc.Status , error ) {
243+ return svc.Status {}, errors .New ("failed to query service status" ) //nolint:err113 // test error
244+ },
245+ },
246+ controlFunc : nil ,
247+ expectError : true ,
248+ },
249+ }
250+ for _ , tt := range tests {
251+ t .Run (tt .name , func (t * testing.T ) {
252+ service := & mockManagedService {
253+ queryFuncs : tt .queryFuncs ,
254+ controlFunc : tt .controlFunc ,
255+ }
256+ err := tryStopServiceFn (context .Background (), service )()
257+ if tt .expectError {
258+ assert .Error (t , err )
259+ return
260+ }
261+ assert .NoError (t , err )
262+ })
263+ }
264+ }
265+
266+ func TestTryStartServiceFn (t * testing.T ) {
267+ tests := []struct {
268+ name string
269+ queryFuncs []func () (svc.Status , error )
270+ startFunc func (... string ) error
271+ expectError bool
272+ }{
273+ {
274+ name : "Service already running" ,
275+ queryFuncs : []func () (svc.Status , error ){
276+ func () (svc.Status , error ) {
277+ return svc.Status {State : svc .Running }, nil
278+ },
279+ func () (svc.Status , error ) {
280+ return svc.Status {State : svc .Running }, nil
281+ },
282+ },
283+ startFunc : nil ,
284+ expectError : false ,
285+ },
286+ {
287+ name : "Service already starting" ,
288+ queryFuncs : []func () (svc.Status , error ){
289+ func () (svc.Status , error ) {
290+ return svc.Status {State : svc .StartPending }, nil
291+ },
292+ func () (svc.Status , error ) {
293+ return svc.Status {State : svc .Running }, nil
294+ },
295+ },
296+ startFunc : nil ,
297+ expectError : false ,
298+ },
299+ {
300+ name : "Service starts successfully" ,
301+ queryFuncs : []func () (svc.Status , error ){
302+ func () (svc.Status , error ) {
303+ return svc.Status {State : svc .Stopped }, nil
304+ },
305+ func () (svc.Status , error ) {
306+ return svc.Status {State : svc .Running }, nil
307+ },
308+ },
309+ startFunc : func (... string ) error {
310+ return nil
311+ },
312+ expectError : false ,
313+ },
314+ {
315+ name : "Service fails to start" ,
316+ queryFuncs : []func () (svc.Status , error ){
317+ func () (svc.Status , error ) {
318+ return svc.Status {State : svc .Stopped }, nil
319+ },
320+ },
321+ startFunc : func (... string ) error {
322+ return errors .New ("failed to start service" ) //nolint:err113 // test error
323+ },
324+ expectError : true ,
325+ },
326+ {
327+ name : "Service query fails" ,
328+ queryFuncs : []func () (svc.Status , error ){
329+ func () (svc.Status , error ) {
330+ return svc.Status {}, errors .New ("failed to query service status" ) //nolint:err113 // test error
331+ },
332+ },
333+ startFunc : nil ,
334+ expectError : true ,
335+ },
336+ }
337+ for _ , tt := range tests {
338+ t .Run (tt .name , func (t * testing.T ) {
339+ service := & mockManagedService {
340+ queryFuncs : tt .queryFuncs ,
341+ startFunc : tt .startFunc ,
342+ }
343+ err := tryStartServiceFn (context .Background (), service )()
344+ if tt .expectError {
345+ assert .Error (t , err )
346+ return
347+ }
348+ assert .NoError (t , err )
349+ })
350+ }
351+ }
0 commit comments