Skip to content

Commit 0d0817d

Browse files
authored
retry start, check stop pending
Signed-off-by: Evan Baker <[email protected]>
1 parent f11b9a0 commit 0d0817d

File tree

2 files changed

+138
-12
lines changed

2 files changed

+138
-12
lines changed

platform/os_windows.go

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -311,48 +311,88 @@ func restartHNS(ctx context.Context) error {
311311
)
312312
// Start the service again
313313
log.Printf("Starting HNS service")
314-
if err := service.Start(); err != nil {
315-
return errors.Wrap(err, "could not start service")
316-
}
314+
_ = retry.Do(
315+
tryStartServiceFn(ctx, service),
316+
retry.UntilSucceeded(),
317+
retry.Context(ctx),
318+
)
317319
log.Printf("HNS service started")
318320
return nil
319321
}
320322

321323
type managedService interface {
322324
Control(control svc.Cmd) (svc.Status, error)
323325
Query() (svc.Status, error)
326+
Start(args ...string) error
324327
}
325328

326-
func tryStopServiceFn(ctx context.Context, service managedService) func() error {
329+
func tryStartServiceFn(ctx context.Context, service managedService) func() error {
330+
shouldStart := func(state svc.State) bool {
331+
return !(state == svc.Running || state == svc.StartPending)
332+
}
327333
return func() error {
328334
status, err := service.Query()
329335
if err != nil {
330336
return errors.Wrap(err, "could not query service status")
331337
}
332-
// If the service is already stopped, no need to stop it again
333-
if status.State == svc.Stopped {
334-
return nil
338+
if shouldStart(status.State) {
339+
err = service.Start()
340+
if err != nil {
341+
return errors.Wrap(err, "could not start service")
342+
}
335343
}
336-
_, err = service.Control(svc.Stop)
344+
// Wait for the service to start
345+
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
346+
defer ticker.Stop()
347+
for {
348+
status, err := service.Query()
349+
if err != nil {
350+
return errors.Wrap(err, "could not query service status")
351+
}
352+
if status.State == svc.Running {
353+
log.Printf("service started")
354+
break
355+
}
356+
select {
357+
case <-ctx.Done():
358+
return errors.Wrap(ctx.Err(), "context cancelled")
359+
case <-ticker.C:
360+
}
361+
}
362+
return nil
363+
}
364+
}
365+
366+
func tryStopServiceFn(ctx context.Context, service managedService) func() error {
367+
shouldStop := func(state svc.State) bool {
368+
return !(state == svc.Stopped || state == svc.StopPending)
369+
}
370+
return func() error {
371+
status, err := service.Query()
337372
if err != nil {
338-
return errors.Wrap(err, "could not stop service")
373+
return errors.Wrap(err, "could not query service status")
374+
}
375+
if shouldStop(status.State) {
376+
_, err = service.Control(svc.Stop)
377+
if err != nil {
378+
return errors.Wrap(err, "could not stop service")
379+
}
339380
}
340381
// Wait for the service to stop
341382
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
342383
defer ticker.Stop()
343384
for {
344-
log.Printf("Waiting for HNS service to stop")
345385
status, err := service.Query()
346386
if err != nil {
347387
return errors.Wrap(err, "could not query service status")
348388
}
349389
if status.State == svc.Stopped {
350-
log.Printf("HNS service stopped")
390+
log.Printf("service stopped")
351391
break
352392
}
353393
select {
354394
case <-ctx.Done():
355-
return errors.New("context cancelled")
395+
return errors.Wrap(ctx.Err(), "context cancelled")
356396
case <-ticker.C:
357397
}
358398
}

platform/os_windows_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ func TestExecuteCommandTimeout(t *testing.T) {
151151
type mockManagedService struct {
152152
queryFuncs []func() (svc.Status, error)
153153
controlFunc func(svc.Cmd) (svc.Status, error)
154+
startFunc func(args ...string) error
154155
}
155156

156157
func (m *mockManagedService) Query() (svc.Status, error) {
@@ -163,6 +164,10 @@ func (m *mockManagedService) Control(cmd svc.Cmd) (svc.Status, error) {
163164
return m.controlFunc(cmd)
164165
}
165166

167+
func (m *mockManagedService) Start(args ...string) error {
168+
return m.startFunc(args...)
169+
}
170+
166171
func TestTryStopServiceFn(t *testing.T) {
167172
tests := []struct {
168173
name string
@@ -254,3 +259,84 @@ func TestTryStopServiceFn(t *testing.T) {
254259
})
255260
}
256261
}
262+
263+
func TestTryStartServiceFn(t *testing.T) {
264+
tests := []struct {
265+
name string
266+
queryFuncs []func() (svc.Status, error)
267+
startFunc func(...string) error
268+
expectError bool
269+
}{
270+
{
271+
name: "Service already running",
272+
queryFuncs: []func() (svc.Status, error){
273+
func() (svc.Status, error) {
274+
return svc.Status{State: svc.Running}, nil
275+
},
276+
},
277+
startFunc: nil,
278+
expectError: false,
279+
},
280+
{
281+
name: "Service already starting",
282+
queryFuncs: []func() (svc.Status, error){
283+
func() (svc.Status, error) {
284+
return svc.Status{State: svc.StartPending}, nil
285+
},
286+
},
287+
startFunc: nil,
288+
expectError: false,
289+
},
290+
{
291+
name: "Service starts successfully",
292+
queryFuncs: []func() (svc.Status, error){
293+
func() (svc.Status, error) {
294+
return svc.Status{State: svc.Stopped}, nil
295+
},
296+
func() (svc.Status, error) {
297+
return svc.Status{State: svc.Running}, nil
298+
},
299+
},
300+
startFunc: func(...string) error {
301+
return nil
302+
},
303+
expectError: false,
304+
},
305+
{
306+
name: "Service fails to start",
307+
queryFuncs: []func() (svc.Status, error){
308+
func() (svc.Status, error) {
309+
return svc.Status{State: svc.Stopped}, nil
310+
},
311+
},
312+
startFunc: func(...string) error {
313+
return errors.New("failed to start service") //nolint:err113 // test error
314+
},
315+
expectError: true,
316+
},
317+
{
318+
name: "Service query fails",
319+
queryFuncs: []func() (svc.Status, error){
320+
func() (svc.Status, error) {
321+
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
322+
},
323+
},
324+
startFunc: nil,
325+
expectError: true,
326+
},
327+
}
328+
for _, tt := range tests {
329+
t.Run(tt.name, func(t *testing.T) {
330+
service := &mockManagedService{
331+
queryFuncs: tt.queryFuncs,
332+
startFunc: tt.startFunc,
333+
}
334+
err := tryStartServiceFn(context.Background(), service)()
335+
if tt.expectError {
336+
assert.Error(t, err)
337+
return
338+
}
339+
assert.NoError(t, err)
340+
})
341+
}
342+
}

0 commit comments

Comments
 (0)