Skip to content

Commit 56a674d

Browse files
authored
retry start, check stop pending
Signed-off-by: Evan Baker <[email protected]>
1 parent f11b9a0 commit 56a674d

File tree

2 files changed

+135
-11
lines changed

2 files changed

+135
-11
lines changed

platform/os_windows.go

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -311,43 +311,81 @@ func restartHNS(ctx context.Context) error {
311311
)
312312
// Start the service again
313313
log.Printf("Starting HNS service")
314-
if err := service.Start(); err != nil {
315-
return errors.Wrap(err, "could not start service")
316-
}
314+
_ = retry.Do(
315+
tryStartServiceFn(ctx, service),
316+
retry.UntilSucceeded(),
317+
retry.Context(ctx),
318+
)
317319
log.Printf("HNS service started")
318320
return nil
319321
}
320322

321323
type managedService interface {
322324
Control(control svc.Cmd) (svc.Status, error)
323325
Query() (svc.Status, error)
326+
Start(args ...string) error
324327
}
325328

326-
func tryStopServiceFn(ctx context.Context, service managedService) func() error {
329+
func tryStartServiceFn(ctx context.Context, service managedService) func() error {
327330
return func() error {
328331
status, err := service.Query()
329332
if err != nil {
330333
return errors.Wrap(err, "could not query service status")
331334
}
332-
// If the service is already stopped, no need to stop it again
333-
if status.State == svc.Stopped {
334-
return nil
335+
// If the service is already running or starting, no need to start it again
336+
if !(status.State == svc.Running || status.State == svc.StartPending) {
337+
err = service.Start()
338+
if err != nil {
339+
return errors.Wrap(err, "could not start service")
340+
}
341+
}
342+
// Wait for the service to start
343+
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
344+
defer ticker.Stop()
345+
for {
346+
log.Printf("Waiting for service to start")
347+
status, err := service.Query()
348+
if err != nil {
349+
return errors.Wrap(err, "could not query service status")
350+
}
351+
if status.State == svc.Running {
352+
log.Printf("service started")
353+
break
354+
}
355+
select {
356+
case <-ctx.Done():
357+
return errors.New("context cancelled")
358+
case <-ticker.C:
359+
}
335360
}
336-
_, err = service.Control(svc.Stop)
361+
return nil
362+
}
363+
}
364+
365+
func tryStopServiceFn(ctx context.Context, service managedService) func() error {
366+
return func() error {
367+
status, err := service.Query()
337368
if err != nil {
338-
return errors.Wrap(err, "could not stop service")
369+
return errors.Wrap(err, "could not query service status")
370+
}
371+
// If the service is already stopping or stopped, no need to stop it again
372+
if !(status.State == svc.Stopped || status.State == svc.StopPending) {
373+
_, err = service.Control(svc.Stop)
374+
if err != nil {
375+
return errors.Wrap(err, "could not stop service")
376+
}
339377
}
340378
// Wait for the service to stop
341379
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
342380
defer ticker.Stop()
343381
for {
344-
log.Printf("Waiting for HNS service to stop")
382+
log.Printf("Waiting for service to stop")
345383
status, err := service.Query()
346384
if err != nil {
347385
return errors.Wrap(err, "could not query service status")
348386
}
349387
if status.State == svc.Stopped {
350-
log.Printf("HNS service stopped")
388+
log.Printf("service stopped")
351389
break
352390
}
353391
select {

platform/os_windows_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ func TestExecuteCommandTimeout(t *testing.T) {
151151
type mockManagedService struct {
152152
queryFuncs []func() (svc.Status, error)
153153
controlFunc func(svc.Cmd) (svc.Status, error)
154+
startFunc func() error
154155
}
155156

156157
func (m *mockManagedService) Query() (svc.Status, error) {
@@ -163,6 +164,10 @@ func (m *mockManagedService) Control(cmd svc.Cmd) (svc.Status, error) {
163164
return m.controlFunc(cmd)
164165
}
165166

167+
func (m *mockManagedService) Start(args ...string) error {
168+
return m.startFunc()
169+
}
170+
166171
func TestTryStopServiceFn(t *testing.T) {
167172
tests := []struct {
168173
name string
@@ -254,3 +259,84 @@ func TestTryStopServiceFn(t *testing.T) {
254259
})
255260
}
256261
}
262+
263+
func TestTryStartServiceFn(t *testing.T) {
264+
tests := []struct {
265+
name string
266+
queryFuncs []func() (svc.Status, error)
267+
startFunc func() error
268+
expectError bool
269+
}{
270+
{
271+
name: "Service already running",
272+
queryFuncs: []func() (svc.Status, error){
273+
func() (svc.Status, error) {
274+
return svc.Status{State: svc.Running}, nil
275+
},
276+
},
277+
startFunc: nil,
278+
expectError: false,
279+
},
280+
{
281+
name: "Service already starting",
282+
queryFuncs: []func() (svc.Status, error){
283+
func() (svc.Status, error) {
284+
return svc.Status{State: svc.StartPending}, nil
285+
},
286+
},
287+
startFunc: nil,
288+
expectError: false,
289+
},
290+
{
291+
name: "Service starts successfully",
292+
queryFuncs: []func() (svc.Status, error){
293+
func() (svc.Status, error) {
294+
return svc.Status{State: svc.Stopped}, nil
295+
},
296+
func() (svc.Status, error) {
297+
return svc.Status{State: svc.Running}, nil
298+
},
299+
},
300+
startFunc: func() error {
301+
return nil
302+
},
303+
expectError: false,
304+
},
305+
{
306+
name: "Service fails to start",
307+
queryFuncs: []func() (svc.Status, error){
308+
func() (svc.Status, error) {
309+
return svc.Status{State: svc.Stopped}, nil
310+
},
311+
},
312+
startFunc: func() error {
313+
return errors.New("failed to start service")
314+
},
315+
expectError: true,
316+
},
317+
{
318+
name: "Service query fails",
319+
queryFuncs: []func() (svc.Status, error){
320+
func() (svc.Status, error) {
321+
return svc.Status{}, errors.New("failed to query service status")
322+
},
323+
},
324+
startFunc: nil,
325+
expectError: true,
326+
},
327+
}
328+
for _, tt := range tests {
329+
t.Run(tt.name, func(t *testing.T) {
330+
service := &mockManagedService{
331+
queryFuncs: tt.queryFuncs,
332+
startFunc: tt.startFunc,
333+
}
334+
err := tryStartServiceFn(context.Background(), service)()
335+
if tt.expectError {
336+
assert.Error(t, err)
337+
return
338+
}
339+
assert.NoError(t, err)
340+
})
341+
}
342+
}

0 commit comments

Comments
 (0)