Skip to content

Commit f11b9a0

Browse files
authored
fix: carefully retry restarting HNS if it hangs
Signed-off-by: Evan Baker <[email protected]>
1 parent 082ab85 commit f11b9a0

File tree

2 files changed

+155
-17
lines changed

2 files changed

+155
-17
lines changed

platform/os_windows.go

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/Azure/azure-container-networking/log"
1818
"github.com/Azure/azure-container-networking/platform/windows/adapter"
1919
"github.com/Azure/azure-container-networking/platform/windows/adapter/mellanox"
20+
"github.com/avast/retry-go/v4"
2021
"github.com/pkg/errors"
2122
"go.uber.org/zap"
2223
"golang.org/x/sys/windows"
@@ -302,32 +303,61 @@ func restartHNS(ctx context.Context) error {
302303
}
303304
defer service.Close()
304305
// Stop the service
305-
_, err = service.Control(svc.Stop)
306-
if err != nil {
307-
return errors.Wrap(err, "could not stop service")
306+
log.Printf("Stopping HNS service")
307+
_ = retry.Do(
308+
tryStopServiceFn(ctx, service),
309+
retry.UntilSucceeded(),
310+
retry.Context(ctx),
311+
)
312+
// Start the service again
313+
log.Printf("Starting HNS service")
314+
if err := service.Start(); err != nil {
315+
return errors.Wrap(err, "could not start service")
308316
}
309-
// Wait for the service to stop
310-
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
311-
defer ticker.Stop()
312-
for { // hacky cancellable do-while
317+
log.Printf("HNS service started")
318+
return nil
319+
}
320+
321+
type managedService interface {
322+
Control(control svc.Cmd) (svc.Status, error)
323+
Query() (svc.Status, error)
324+
}
325+
326+
func tryStopServiceFn(ctx context.Context, service managedService) func() error {
327+
return func() error {
313328
status, err := service.Query()
314329
if err != nil {
315330
return errors.Wrap(err, "could not query service status")
316331
}
332+
// If the service is already stopped, no need to stop it again
317333
if status.State == svc.Stopped {
318-
break
334+
return nil
319335
}
320-
select {
321-
case <-ctx.Done():
322-
return errors.New("context cancelled")
323-
case <-ticker.C:
336+
_, err = service.Control(svc.Stop)
337+
if err != nil {
338+
return errors.Wrap(err, "could not stop service")
324339
}
340+
// Wait for the service to stop
341+
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
342+
defer ticker.Stop()
343+
for {
344+
log.Printf("Waiting for HNS service to stop")
345+
status, err := service.Query()
346+
if err != nil {
347+
return errors.Wrap(err, "could not query service status")
348+
}
349+
if status.State == svc.Stopped {
350+
log.Printf("HNS service stopped")
351+
break
352+
}
353+
select {
354+
case <-ctx.Done():
355+
return errors.New("context cancelled")
356+
case <-ticker.C:
357+
}
358+
}
359+
return nil
325360
}
326-
// Start the service again
327-
if err := service.Start(); err != nil {
328-
return errors.Wrap(err, "could not start service")
329-
}
330-
return nil
331361
}
332362

333363
func HasMellanoxAdapter() bool {

platform/os_windows_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/golang/mock/gomock"
1212
"github.com/stretchr/testify/assert"
1313
"github.com/stretchr/testify/require"
14+
"golang.org/x/sys/windows/svc"
1415
)
1516

1617
var errTestFailure = errors.New("test failure")
@@ -146,3 +147,110 @@ func TestExecuteCommandTimeout(t *testing.T) {
146147
_, err := client.ExecuteCommand(context.Background(), "ping", "-t", "localhost")
147148
require.Error(t, err)
148149
}
150+
151+
type mockManagedService struct {
152+
queryFuncs []func() (svc.Status, error)
153+
controlFunc func(svc.Cmd) (svc.Status, error)
154+
}
155+
156+
func (m *mockManagedService) Query() (svc.Status, error) {
157+
queryFunc := m.queryFuncs[0]
158+
m.queryFuncs = m.queryFuncs[1:]
159+
return queryFunc()
160+
}
161+
162+
func (m *mockManagedService) Control(cmd svc.Cmd) (svc.Status, error) {
163+
return m.controlFunc(cmd)
164+
}
165+
166+
func TestTryStopServiceFn(t *testing.T) {
167+
tests := []struct {
168+
name string
169+
queryFuncs []func() (svc.Status, error)
170+
controlFunc func(svc.Cmd) (svc.Status, error)
171+
expectError bool
172+
}{
173+
{
174+
name: "Service already stopped",
175+
queryFuncs: []func() (svc.Status, error){
176+
func() (svc.Status, error) {
177+
return svc.Status{State: svc.Stopped}, nil
178+
},
179+
},
180+
controlFunc: nil,
181+
expectError: false,
182+
},
183+
{
184+
name: "Service running and stops successfully",
185+
queryFuncs: []func() (svc.Status, error){
186+
func() (svc.Status, error) {
187+
return svc.Status{State: svc.Running}, nil
188+
},
189+
func() (svc.Status, error) {
190+
return svc.Status{State: svc.Stopped}, nil
191+
},
192+
},
193+
controlFunc: func(svc.Cmd) (svc.Status, error) {
194+
return svc.Status{State: svc.Stopped}, nil
195+
},
196+
expectError: false,
197+
},
198+
{
199+
name: "Service running and stops after multiple attempts",
200+
queryFuncs: []func() (svc.Status, error){
201+
func() (svc.Status, error) {
202+
return svc.Status{State: svc.Running}, nil
203+
},
204+
func() (svc.Status, error) {
205+
return svc.Status{State: svc.Running}, nil
206+
},
207+
func() (svc.Status, error) {
208+
return svc.Status{State: svc.Running}, nil
209+
},
210+
func() (svc.Status, error) {
211+
return svc.Status{State: svc.Stopped}, nil
212+
},
213+
},
214+
controlFunc: func(svc.Cmd) (svc.Status, error) {
215+
return svc.Status{State: svc.Stopped}, nil
216+
},
217+
expectError: false,
218+
},
219+
{
220+
name: "Service running and fails to stop",
221+
queryFuncs: []func() (svc.Status, error){
222+
func() (svc.Status, error) {
223+
return svc.Status{State: svc.Running}, nil
224+
},
225+
},
226+
controlFunc: func(svc.Cmd) (svc.Status, error) {
227+
return svc.Status{State: svc.Running}, errors.New("failed to stop service") //nolint:err113 // test error
228+
},
229+
expectError: true,
230+
},
231+
{
232+
name: "Service query fails",
233+
queryFuncs: []func() (svc.Status, error){
234+
func() (svc.Status, error) {
235+
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
236+
},
237+
},
238+
controlFunc: nil,
239+
expectError: true,
240+
},
241+
}
242+
for _, tt := range tests {
243+
t.Run(tt.name, func(t *testing.T) {
244+
service := &mockManagedService{
245+
queryFuncs: tt.queryFuncs,
246+
controlFunc: tt.controlFunc,
247+
}
248+
err := tryStopServiceFn(context.Background(), service)()
249+
if tt.expectError {
250+
assert.Error(t, err)
251+
return
252+
}
253+
assert.NoError(t, err)
254+
})
255+
}
256+
}

0 commit comments

Comments
 (0)