From 64e455b4b65672c83697b5e450f2aaaaca0215f5 Mon Sep 17 00:00:00 2001 From: Zhongcheng Lao Date: Sat, 31 May 2025 19:03:22 +0800 Subject: [PATCH 1/5] Use WMI to implement Service related System APIs --- pkg/os/system/api.go | 223 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 204 insertions(+), 19 deletions(-) diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go index a09c83a4..b4b18223 100644 --- a/pkg/os/system/api.go +++ b/pkg/os/system/api.go @@ -2,10 +2,13 @@ package system import ( "fmt" + "time" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl" - "github.com/kubernetes-csi/csi-proxy/pkg/utils" + "github.com/microsoft/wmi/pkg/errors" + wmiinst "github.com/microsoft/wmi/pkg/wmiinstance" + "github.com/microsoft/wmi/server2019/root/cimv2" ) // Implements the System OS API calls. All code here should be very simple @@ -24,6 +27,28 @@ type ServiceInfo struct { Status uint32 `json:"Status"` } +type periodicalCheckFunc func() (bool, error) + +const ( + // startServiceErrorCodeAccepted indicates the request is accepted + startServiceErrorCodeAccepted = 0 + + // startServiceErrorCodeAlreadyRunning indicates a service is already running + startServiceErrorCodeAlreadyRunning = 10 + + // stopServiceErrorCodeAccepted indicates the request is accepted + stopServiceErrorCodeAccepted = 0 + + // stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending) + stopServiceErrorCodeStopPending = 5 + + // stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running + stopServiceErrorCodeDependentRunning = 3 + + serviceStateRunning = "Running" + serviceStateStopped = "Stopped" +) + var ( startModeMappings = map[string]uint32{ "Boot": impl.START_TYPE_BOOT, @@ -33,16 +58,19 @@ var ( "Disabled": impl.START_TYPE_DISABLED, } - statusMappings = map[string]uint32{ - "Unknown": impl.SERVICE_STATUS_UNKNOWN, - "Stopped": impl.SERVICE_STATUS_STOPPED, - "Start Pending": impl.SERVICE_STATUS_START_PENDING, - "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING, - "Running": impl.SERVICE_STATUS_RUNNING, - "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING, - "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING, - "Paused": impl.SERVICE_STATUS_PAUSED, + stateMappings = map[string]uint32{ + "Unknown": impl.SERVICE_STATUS_UNKNOWN, + serviceStateStopped: impl.SERVICE_STATUS_STOPPED, + "Start Pending": impl.SERVICE_STATUS_START_PENDING, + "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING, + serviceStateRunning: impl.SERVICE_STATUS_RUNNING, + "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING, + "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING, + "Paused": impl.SERVICE_STATUS_PAUSED, } + + serviceStateCheckInternal = 500 * time.Millisecond + serviceStateCheckTimeout = 5 * time.Second ) func serviceStartModeToStartType(startMode string) uint32 { @@ -50,7 +78,7 @@ func serviceStartModeToStartType(startMode string) uint32 { } func serviceState(status string) uint32 { - return statusMappings[status] + return stateMappings[status] } type APIImplementor struct{} @@ -101,23 +129,180 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) { }, nil } +func waitForServiceState(serviceCheck periodicalCheckFunc, interval time.Duration, timeout time.Duration) error { + timeoutChan := time.After(timeout) + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-timeoutChan: + return errors.Timedout + case <-ticker.C: + done, err := serviceCheck() + if err != nil { + return err + } + + if done { + return nil + } + } + } +} + +func getServiceState(name string) (string, *cimv2.Win32_Service, error) { + service, err := cim.QueryServiceByName(name, nil) + if err != nil { + return "", nil, err + } + + state, err := service.GetPropertyState() + if err != nil { + return "", nil, fmt.Errorf("failed to get state property of service %s: %w", name, err) + } + + return state, service, nil +} + func (APIImplementor) StartService(name string) error { - // Note: both StartService and StopService are not implemented by WMI - script := `Start-Service -Name $env:ServiceName` - cmdEnv := fmt.Sprintf("ServiceName=%s", name) - out, err := utils.RunPowershellCmd(script, cmdEnv) + state, service, err := getServiceState(name) if err != nil { - return fmt.Errorf("error starting service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err) + return err + } + + if state != serviceStateRunning { + var retVal uint32 + retVal, err = service.StartService() + if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) { + return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err) + } + + err = waitForServiceState(func() (bool, error) { + state, service, err = getServiceState(name) + if err != nil { + return false, err + } + + return state == serviceStateRunning, nil + + }, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil { + return fmt.Errorf("error waiting service %s become running. error: %v", name, err) + } + } + + if state != serviceStateRunning { + return fmt.Errorf("error starting service name %s. current state: %s", name, state) } return nil } func (APIImplementor) StopService(name string, force bool) error { - script := `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))` - out, err := utils.RunPowershellCmd(script, fmt.Sprintf("ServiceName=%s", name), fmt.Sprintf("Force=%t", force)) + state, service, err := getServiceState(name) if err != nil { - return fmt.Errorf("error stopping service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err) + return err + } + + if state == serviceStateStopped { + return nil + } + + stopSingleService := func(name string, service *wmiinst.WmiInstance) (bool, error) { + retVal, err := service.InvokeMethodWithReturn("StopService") + if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) { + if retVal == stopServiceErrorCodeDependentRunning { + return true, fmt.Errorf("error stopping service %s as dependent services are not stopped", name) + } + return false, fmt.Errorf("error stopping service %s. return value: %d, error: %v", name, retVal, err) + } + + var serviceState string + err = waitForServiceState(func() (bool, error) { + serviceState, _, err = getServiceState(name) + if err != nil { + return false, err + } + + return serviceState == serviceStateStopped, nil + + }, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil { + return false, fmt.Errorf("error waiting service %s become stopped. error: %v", name, err) + } + + if serviceState != serviceStateStopped { + return false, fmt.Errorf("error stopping service name %s. current state: %s", name, serviceState) + } + + return false, nil + } + + dependentRunning, err := stopSingleService(name, service.WmiInstance) + if !force || err == nil || !dependentRunning { + return err + } + + var serviceNames []string + var servicesToCheck wmiinst.WmiInstanceCollection + servicesByName := map[string]*wmiinst.WmiInstance{} + + servicesToCheck = append(servicesToCheck, service.WmiInstance) + i := 0 + for i < len(servicesToCheck) { + current := servicesToCheck[i] + i += 1 + + currentNameVal, err := current.GetProperty("Name") + if err != nil { + return err + } + + currentName := currentNameVal.(string) + if _, ok := servicesByName[currentName]; ok { + continue + } + + currentStateVal, err := current.GetProperty("State") + if err != nil { + return err + } + + currentState := currentStateVal + if currentState != serviceStateRunning { + continue + } + + servicesByName[currentName] = current + serviceNames = append(serviceNames, currentName) + + dependents, err := current.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent") + if err != nil { + return err + } + + servicesToCheck = append(servicesToCheck, dependents...) + } + + i = len(serviceNames) - 1 + for i >= 0 { + serviceName := serviceNames[i] + i -= 1 + + state, service, err := getServiceState(serviceName) + if err != nil { + return err + } + + if state == serviceStateStopped { + continue + } + + _, err = stopSingleService(serviceName, service.WmiInstance) + if err != nil { + return err + } } return nil From e8f5d13ed3b54f11b3525e0772bf404b0770e02f Mon Sep 17 00:00:00 2001 From: Zhongcheng Lao Date: Tue, 10 Jun 2025 20:52:48 +0800 Subject: [PATCH 2/5] Refactor and add test cases --- pkg/os/system/api.go | 323 ++++++++++++++++++++++++-------------- pkg/os/system/api_test.go | 213 +++++++++++++++++++++++++ 2 files changed, 415 insertions(+), 121 deletions(-) create mode 100644 pkg/os/system/api_test.go diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go index b4b18223..47dca94f 100644 --- a/pkg/os/system/api.go +++ b/pkg/os/system/api.go @@ -6,9 +6,9 @@ import ( "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl" - "github.com/microsoft/wmi/pkg/errors" - wmiinst "github.com/microsoft/wmi/pkg/wmiinstance" "github.com/microsoft/wmi/server2019/root/cimv2" + "github.com/pkg/errors" + "k8s.io/klog/v2" ) // Implements the System OS API calls. All code here should be very simple @@ -27,7 +27,8 @@ type ServiceInfo struct { Status uint32 `json:"Status"` } -type periodicalCheckFunc func() (bool, error) +type stateCheckFunc func(ServiceInterface, string) (bool, string, error) +type stateTransitionFunc func(ServiceInterface) error const ( // startServiceErrorCodeAccepted indicates the request is accepted @@ -69,8 +70,9 @@ var ( "Paused": impl.SERVICE_STATUS_PAUSED, } - serviceStateCheckInternal = 500 * time.Millisecond - serviceStateCheckTimeout = 5 * time.Second + serviceStateCheckInternal = 200 * time.Millisecond + serviceStateCheckTimeout = 30 * time.Second + errTimedOut = errors.New("Timed out") ) func serviceStartModeToStartType(startMode string) uint32 { @@ -81,10 +83,39 @@ func serviceState(status string) uint32 { return stateMappings[status] } -type APIImplementor struct{} +type ServiceInterface interface { + GetPropertyName() (string, error) + GetPropertyDisplayName() (string, error) + GetPropertyState() (string, error) + GetPropertyStartMode() (string, error) + GetDependents() ([]ServiceInterface, error) + StartService() (result uint32, err error) + StopService() (result uint32, err error) + Refresh() error +} + +type ServiceManager interface { + WaitUntilServiceState(ServiceInterface, stateTransitionFunc, stateCheckFunc, time.Duration, time.Duration) (string, error) + GetDependentsForService(string) ([]string, error) +} + +type ServiceFactory interface { + GetService(name string) (ServiceInterface, error) +} + +type APIImplementor struct { + serviceFactory ServiceFactory + serviceManager ServiceManager +} func New() APIImplementor { - return APIImplementor{} + serviceFactory := Win32ServiceFactory{} + return APIImplementor{ + serviceFactory: serviceFactory, + serviceManager: ServiceManagerImpl{ + serviceFactory: serviceFactory, + }, + } } func (APIImplementor) GetBIOSSerialNumber() (string, error) { @@ -129,181 +160,231 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) { }, nil } -func waitForServiceState(serviceCheck periodicalCheckFunc, interval time.Duration, timeout time.Duration) error { - timeoutChan := time.After(timeout) - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-timeoutChan: - return errors.Timedout - case <-ticker.C: - done, err := serviceCheck() - if err != nil { - return err - } - - if done { - return nil - } +func (impl APIImplementor) StartService(name string) error { + startService := func(service ServiceInterface) error { + retVal, err := service.StartService() + if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) { + return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err) } + return nil } -} + serviceRunningCheck := func(service ServiceInterface, state string) (bool, string, error) { + err := service.Refresh() + if err != nil { + return false, "", err + } -func getServiceState(name string) (string, *cimv2.Win32_Service, error) { - service, err := cim.QueryServiceByName(name, nil) - if err != nil { - return "", nil, err + newState, err := service.GetPropertyState() + if err != nil { + return false, state, err + } + + klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState) + return state == serviceStateRunning, newState, err } - state, err := service.GetPropertyState() + service, err := impl.serviceFactory.GetService(name) if err != nil { - return "", nil, fmt.Errorf("failed to get state property of service %s: %w", name, err) + return err } - return state, service, nil -} - -func (APIImplementor) StartService(name string) error { - state, service, err := getServiceState(name) - if err != nil { + state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil && !errors.Is(err, errTimedOut) { return err } if state != serviceStateRunning { - var retVal uint32 - retVal, err = service.StartService() - if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) { - return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err) - } + return fmt.Errorf("timed out waiting for service %s to become running", name) + } - err = waitForServiceState(func() (bool, error) { - state, service, err = getServiceState(name) - if err != nil { - return false, err - } + return nil +} - return state == serviceStateRunning, nil +func (impl APIImplementor) stopSingleService(name string) (bool, error) { + var dependentRunning bool + stopService := func(service ServiceInterface) error { + retVal, err := service.StopService() + if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) { + if retVal == stopServiceErrorCodeDependentRunning { + dependentRunning = true + return fmt.Errorf("error stopping service %s as dependent services are not stopped", name) + } + return fmt.Errorf("error stopping service %s. return value: %d, error: %v", name, retVal, err) + } + return nil + } + serviceStoppedCheck := func(service ServiceInterface, state string) (bool, string, error) { + err := service.Refresh() + if err != nil { + return false, "", err + } - }, serviceStateCheckInternal, serviceStateCheckTimeout) + newState, err := service.GetPropertyState() if err != nil { - return fmt.Errorf("error waiting service %s become running. error: %v", name, err) + return false, state, err } + + klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState) + return newState == serviceStateStopped, newState, err } - if state != serviceStateRunning { - return fmt.Errorf("error starting service name %s. current state: %s", name, state) + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return dependentRunning, err } - return nil + state, err := impl.serviceManager.WaitUntilServiceState(service, stopService, serviceStoppedCheck, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil && !errors.Is(err, errTimedOut) { + return dependentRunning, fmt.Errorf("error stopping service name %s. current state: %s", name, state) + } + + if state != serviceStateStopped { + return dependentRunning, fmt.Errorf("timed out waiting for service %s to stop", name) + } + + return dependentRunning, nil } -func (APIImplementor) StopService(name string, force bool) error { - state, service, err := getServiceState(name) - if err != nil { +func (impl APIImplementor) StopService(name string, force bool) error { + dependentRunning, err := impl.stopSingleService(name) + if err == nil || !dependentRunning || !force { return err } - if state == serviceStateStopped { - return nil + serviceNames, err := impl.serviceManager.GetDependentsForService(name) + if err != nil { + return fmt.Errorf("error getting dependent services for service name %s", name) } - stopSingleService := func(name string, service *wmiinst.WmiInstance) (bool, error) { - retVal, err := service.InvokeMethodWithReturn("StopService") - if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) { - if retVal == stopServiceErrorCodeDependentRunning { - return true, fmt.Errorf("error stopping service %s as dependent services are not stopped", name) - } - return false, fmt.Errorf("error stopping service %s. return value: %d, error: %v", name, retVal, err) + for _, serviceName := range serviceNames { + _, err = impl.stopSingleService(serviceName) + if err != nil { + return err } + } - var serviceState string - err = waitForServiceState(func() (bool, error) { - serviceState, _, err = getServiceState(name) - if err != nil { - return false, err - } + return nil +} - return serviceState == serviceStateStopped, nil +type Win32Service struct { + *cimv2.Win32_Service +} - }, serviceStateCheckInternal, serviceStateCheckTimeout) +func (s *Win32Service) GetDependents() ([]ServiceInterface, error) { + collection, err := s.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent") + if err != nil { + return nil, err + } + + var result []ServiceInterface + for _, coll := range collection { + service, err := cimv2.NewWin32_ServiceEx1(coll) if err != nil { - return false, fmt.Errorf("error waiting service %s become stopped. error: %v", name, err) + return nil, err } - if serviceState != serviceStateStopped { - return false, fmt.Errorf("error stopping service name %s. current state: %s", name, serviceState) - } + result = append(result, &Win32Service{ + service, + }) + } + return result, nil +} + +type Win32ServiceFactory struct { +} + +func (impl Win32ServiceFactory) GetService(name string) (ServiceInterface, error) { + service, err := cim.QueryServiceByName(name, nil) + if err != nil { + return nil, err + } + + return &Win32Service{Win32_Service: service}, nil +} - return false, nil +type ServiceManagerImpl struct { + serviceFactory ServiceFactory +} + +func (impl ServiceManagerImpl) WaitUntilServiceState(service ServiceInterface, stateTransition stateTransitionFunc, stateCheck stateCheckFunc, interval time.Duration, timeout time.Duration) (string, error) { + done, state, err := stateCheck(service, "") + if err != nil { + return state, err + } + if done { + return state, err } - dependentRunning, err := stopSingleService(name, service.WmiInstance) - if !force || err == nil || !dependentRunning { - return err + // Perform transition if not already in desired state + if err := stateTransition(service); err != nil { + return state, err } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + timeoutChan := time.After(timeout) + + for { + select { + case <-ticker.C: + klog.V(6).Infof("Checking service (%v) state...", service) + done, state, err = stateCheck(service, state) + if err != nil { + return state, fmt.Errorf("check failed: %w", err) + } + if done { + klog.V(6).Infof("service (%v) state is %s and transition done.", service, state) + return state, nil + } + case <-timeoutChan: + done, state, err = stateCheck(service, state) + return state, errTimedOut + } + } +} + +func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, error) { var serviceNames []string - var servicesToCheck wmiinst.WmiInstanceCollection - servicesByName := map[string]*wmiinst.WmiInstance{} + var servicesToCheck []ServiceInterface + servicesByName := map[string]string{} + + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return serviceNames, err + } - servicesToCheck = append(servicesToCheck, service.WmiInstance) + servicesToCheck = append(servicesToCheck, service) i := 0 for i < len(servicesToCheck) { - current := servicesToCheck[i] + service = servicesToCheck[i] i += 1 - currentNameVal, err := current.GetProperty("Name") + serviceName, err := service.GetPropertyName() if err != nil { - return err - } - - currentName := currentNameVal.(string) - if _, ok := servicesByName[currentName]; ok { - continue + return serviceNames, err } - currentStateVal, err := current.GetProperty("State") + currentState, err := service.GetPropertyState() if err != nil { - return err + return serviceNames, err } - currentState := currentStateVal if currentState != serviceStateRunning { continue } - servicesByName[currentName] = current - serviceNames = append(serviceNames, currentName) + servicesByName[serviceName] = serviceName + // prepend the current service to the front + serviceNames = append([]string{serviceName}, serviceNames...) - dependents, err := current.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent") + dependents, err := service.GetDependents() if err != nil { - return err + return serviceNames, err } servicesToCheck = append(servicesToCheck, dependents...) } - i = len(serviceNames) - 1 - for i >= 0 { - serviceName := serviceNames[i] - i -= 1 - - state, service, err := getServiceState(serviceName) - if err != nil { - return err - } - - if state == serviceStateStopped { - continue - } - - _, err = stopSingleService(serviceName, service.WmiInstance) - if err != nil { - return err - } - } - - return nil + return serviceNames, nil } diff --git a/pkg/os/system/api_test.go b/pkg/os/system/api_test.go new file mode 100644 index 00000000..4e84cd95 --- /dev/null +++ b/pkg/os/system/api_test.go @@ -0,0 +1,213 @@ +package system + +import ( + "fmt" + "testing" + "time" + + "github.com/pkg/errors" +) + +type MockService struct { + Name string + DisplayName string + State string + StartMode string + Dependents []ServiceInterface + + StartResult uint32 + StopResult uint32 + + Err error +} + +func (m *MockService) GetPropertyName() (string, error) { + return m.Name, m.Err +} + +func (m *MockService) GetPropertyDisplayName() (string, error) { + return m.DisplayName, m.Err +} + +func (m *MockService) GetPropertyState() (string, error) { + return m.State, m.Err +} + +func (m *MockService) GetPropertyStartMode() (string, error) { + return m.StartMode, m.Err +} + +func (m *MockService) GetDependents() ([]ServiceInterface, error) { + return m.Dependents, m.Err +} + +func (m *MockService) StartService() (uint32, error) { + m.State = "Running" + return m.StartResult, m.Err +} + +func (m *MockService) StopService() (uint32, error) { + m.State = "Stopped" + return m.StopResult, m.Err +} + +func (m *MockService) Refresh() error { + return nil +} + +type MockServiceFactory struct { + Services map[string]ServiceInterface + Err error +} + +func (f *MockServiceFactory) GetService(name string) (ServiceInterface, error) { + svc, ok := f.Services[name] + if !ok { + return nil, fmt.Errorf("service not found: %s", name) + } + return svc, f.Err +} + +func TestWaitUntilServiceState_Success(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateChanged := false + + stateCheck := func(s ServiceInterface, state string) (bool, string, error) { + if stateChanged { + svc.State = serviceStateRunning + return true, svc.State, nil + } + return false, svc.State, nil + } + + stateTransition := func(s ServiceInterface) error { + stateChanged = true + return nil + } + + impl := ServiceManagerImpl{} + state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 500*time.Millisecond) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state != serviceStateRunning { + t.Fatalf("expected state %q, got %q", serviceStateRunning, state) + } +} + +func TestWaitUntilServiceState_Timeout(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateCheck := func(s ServiceInterface, state string) (bool, string, error) { + return false, svc.State, nil + } + + stateTransition := func(s ServiceInterface) error { + return nil + } + + impl := ServiceManagerImpl{} + state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond) + if !errors.Is(err, errTimedOut) { + t.Fatalf("expected timeout error, got %v", err) + } + if state != svc.State { + t.Fatalf("expected state %q, got %q", svc.State, state) + } +} + +func TestWaitUntilServiceState_TransitionFails(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateCheck := func(s ServiceInterface, state string) (bool, string, error) { + return false, svc.State, nil + } + + stateTransition := func(s ServiceInterface) error { + return fmt.Errorf("transition failed") + } + + impl := ServiceManagerImpl{} + _, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond) + if err == nil || err.Error() != "transition failed" { + t.Fatalf("expected transition error, got %v", err) + } +} + +func TestGetDependentsForService(t *testing.T) { + // Construct the dependency tree + svcC := &MockService{Name: "C", State: serviceStateRunning} + svcB := &MockService{Name: "B", State: serviceStateRunning, Dependents: []ServiceInterface{svcC}} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []ServiceInterface{svcB}} + + factory := &MockServiceFactory{ + Services: map[string]ServiceInterface{ + "A": svcA, + "B": svcB, + "C": svcC, + }, + } + + impl := ServiceManagerImpl{ + serviceFactory: factory, + } + + names, err := impl.GetDependentsForService("A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"C", "B", "A"} + if len(names) != len(expected) { + t.Fatalf("expected %d services, got %d", len(expected), len(names)) + } + for i, name := range expected { + if names[i] != name { + t.Errorf("expected %s at position %d, got %s", name, i, names[i]) + } + } +} + +func TestGetDependentsForService_SkipsNonRunning(t *testing.T) { + svcB := &MockService{Name: "B", State: "Stopped"} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []ServiceInterface{svcB}} + + factory := &MockServiceFactory{ + Services: map[string]ServiceInterface{ + "A": svcA, + "B": svcB, + }, + } + + impl := ServiceManagerImpl{ + serviceFactory: factory, + } + + names, err := impl.GetDependentsForService("A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"A"} // B is skipped due to stopped state + if len(names) != len(expected) { + t.Fatalf("expected %d services, got %d", len(expected), len(names)) + } +} + +func TestGetDependenciesForService_Winmgmt(t *testing.T) { + impl := ServiceManagerImpl{ + serviceFactory: Win32ServiceFactory{}, + } + + serviceName := "Winmgmt" + names, err := impl.GetDependentsForService(serviceName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := 4 + if len(names) != expected || names[len(names)-1] != serviceName { + t.Fatalf("expected %d services, got %d", expected, len(names)) + } +} From 8a7c3a0c94bec1f3433cca4d0f6aa84f1bacdcc4 Mon Sep 17 00:00:00 2001 From: Zhongcheng Lao Date: Fri, 20 Jun 2025 15:57:10 +0800 Subject: [PATCH 3/5] Move WMI service functions to cim package --- pkg/cim/system.go | 78 +++++++++++++++++++++++++++++++++ pkg/os/system/api.go | 92 ++++++++++----------------------------- pkg/os/system/api_test.go | 33 +++++++------- 3 files changed, 117 insertions(+), 86 deletions(-) diff --git a/pkg/cim/system.go b/pkg/cim/system.go index 3ab32af6..80181794 100644 --- a/pkg/cim/system.go +++ b/pkg/cim/system.go @@ -10,6 +10,22 @@ import ( "github.com/microsoft/wmi/server2019/root/cimv2" ) +var ( + BIOSSelectorList = []string{"SerialNumber"} + ServiceSelectorList = []string{"DisplayName", "State", "StartMode"} +) + +type ServiceInterface interface { + GetPropertyName() (string, error) + GetPropertyDisplayName() (string, error) + GetPropertyState() (string, error) + GetPropertyStartMode() (string, error) + GetDependents() ([]ServiceInterface, error) + StartService() (result uint32, err error) + StopService() (result uint32, err error) + Refresh() error +} + // QueryBIOSElement retrieves the BIOS element. // // The equivalent WMI query is: @@ -33,6 +49,11 @@ func QueryBIOSElement(selectorList []string) (*cimv2.CIM_BIOSElement, error) { return bios, err } +// GetBIOSSerialNumber returns the BIOS serial number. +func GetBIOSSerialNumber(bios *cimv2.CIM_BIOSElement) (string, error) { + return bios.GetPropertySerialNumber() +} + // QueryServiceByName retrieves a specific service by its name. // // The equivalent WMI query is: @@ -55,3 +76,60 @@ func QueryServiceByName(name string, selectorList []string) (*cimv2.Win32_Servic return service, err } + +// GetServiceName returns the name of a service. +func GetServiceName(service ServiceInterface) (string, error) { + return service.GetPropertyName() +} + +// GetServiceDisplayName returns the display name of a service. +func GetServiceDisplayName(service ServiceInterface) (string, error) { + return service.GetPropertyDisplayName() +} + +// GetServiceState returns the state of a service. +func GetServiceState(service ServiceInterface) (string, error) { + return service.GetPropertyState() +} + +// GetServiceStartMode returns the start mode of a service. +func GetServiceStartMode(service ServiceInterface) (string, error) { + return service.GetPropertyStartMode() +} + +// Win32Service wraps the WMI class Win32_Service (mainly for testing) +type Win32Service struct { + *cimv2.Win32_Service +} + +func (s *Win32Service) GetDependents() ([]ServiceInterface, error) { + collection, err := s.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent") + if err != nil { + return nil, err + } + + var result []ServiceInterface + for _, coll := range collection { + service, err := cimv2.NewWin32_ServiceEx1(coll) + if err != nil { + return nil, err + } + + result = append(result, &Win32Service{ + service, + }) + } + return result, nil +} + +type Win32ServiceFactory struct { +} + +func (impl Win32ServiceFactory) GetService(name string) (ServiceInterface, error) { + service, err := QueryServiceByName(name, ServiceSelectorList) + if err != nil { + return nil, err + } + + return &Win32Service{Win32_Service: service}, nil +} diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go index 47dca94f..02283767 100644 --- a/pkg/os/system/api.go +++ b/pkg/os/system/api.go @@ -6,7 +6,6 @@ import ( "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl" - "github.com/microsoft/wmi/server2019/root/cimv2" "github.com/pkg/errors" "k8s.io/klog/v2" ) @@ -27,8 +26,8 @@ type ServiceInfo struct { Status uint32 `json:"Status"` } -type stateCheckFunc func(ServiceInterface, string) (bool, string, error) -type stateTransitionFunc func(ServiceInterface) error +type stateCheckFunc func(cim.ServiceInterface, string) (bool, string, error) +type stateTransitionFunc func(cim.ServiceInterface) error const ( // startServiceErrorCodeAccepted indicates the request is accepted @@ -83,24 +82,13 @@ func serviceState(status string) uint32 { return stateMappings[status] } -type ServiceInterface interface { - GetPropertyName() (string, error) - GetPropertyDisplayName() (string, error) - GetPropertyState() (string, error) - GetPropertyStartMode() (string, error) - GetDependents() ([]ServiceInterface, error) - StartService() (result uint32, err error) - StopService() (result uint32, err error) - Refresh() error -} - type ServiceManager interface { - WaitUntilServiceState(ServiceInterface, stateTransitionFunc, stateCheckFunc, time.Duration, time.Duration) (string, error) + WaitUntilServiceState(cim.ServiceInterface, stateTransitionFunc, stateCheckFunc, time.Duration, time.Duration) (string, error) GetDependentsForService(string) ([]string, error) } type ServiceFactory interface { - GetService(name string) (ServiceInterface, error) + GetService(name string) (cim.ServiceInterface, error) } type APIImplementor struct { @@ -109,7 +97,7 @@ type APIImplementor struct { } func New() APIImplementor { - serviceFactory := Win32ServiceFactory{} + serviceFactory := cim.Win32ServiceFactory{} return APIImplementor{ serviceFactory: serviceFactory, serviceManager: ServiceManagerImpl{ @@ -119,12 +107,12 @@ func New() APIImplementor { } func (APIImplementor) GetBIOSSerialNumber() (string, error) { - bios, err := cim.QueryBIOSElement([]string{"SerialNumber"}) + bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList) if err != nil { return "", fmt.Errorf("failed to get BIOS element: %w", err) } - sn, err := bios.GetPropertySerialNumber() + sn, err := cim.GetBIOSSerialNumber(bios) if err != nil { return "", fmt.Errorf("failed to get BIOS serial number property: %w", err) } @@ -132,23 +120,23 @@ func (APIImplementor) GetBIOSSerialNumber() (string, error) { return sn, nil } -func (APIImplementor) GetService(name string) (*ServiceInfo, error) { - service, err := cim.QueryServiceByName(name, []string{"DisplayName", "State", "StartMode"}) +func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) { + service, err := impl.serviceFactory.GetService(name) if err != nil { return nil, fmt.Errorf("failed to get service %s: %w", name, err) } - displayName, err := service.GetPropertyDisplayName() + displayName, err := cim.GetServiceDisplayName(service) if err != nil { return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err) } - state, err := service.GetPropertyState() + state, err := cim.GetServiceState(service) if err != nil { return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err) } - startMode, err := service.GetPropertyStartMode() + startMode, err := cim.GetServiceStartMode(service) if err != nil { return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err) } @@ -161,20 +149,20 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) { } func (impl APIImplementor) StartService(name string) error { - startService := func(service ServiceInterface) error { + startService := func(service cim.ServiceInterface) error { retVal, err := service.StartService() if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) { return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err) } return nil } - serviceRunningCheck := func(service ServiceInterface, state string) (bool, string, error) { + serviceRunningCheck := func(service cim.ServiceInterface, state string) (bool, string, error) { err := service.Refresh() if err != nil { return false, "", err } - newState, err := service.GetPropertyState() + newState, err := cim.GetServiceState(service) if err != nil { return false, state, err } @@ -202,7 +190,7 @@ func (impl APIImplementor) StartService(name string) error { func (impl APIImplementor) stopSingleService(name string) (bool, error) { var dependentRunning bool - stopService := func(service ServiceInterface) error { + stopService := func(service cim.ServiceInterface) error { retVal, err := service.StopService() if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) { if retVal == stopServiceErrorCodeDependentRunning { @@ -213,13 +201,13 @@ func (impl APIImplementor) stopSingleService(name string) (bool, error) { } return nil } - serviceStoppedCheck := func(service ServiceInterface, state string) (bool, string, error) { + serviceStoppedCheck := func(service cim.ServiceInterface, state string) (bool, string, error) { err := service.Refresh() if err != nil { return false, "", err } - newState, err := service.GetPropertyState() + newState, err := cim.GetServiceState(service) if err != nil { return false, state, err } @@ -266,47 +254,11 @@ func (impl APIImplementor) StopService(name string, force bool) error { return nil } -type Win32Service struct { - *cimv2.Win32_Service -} - -func (s *Win32Service) GetDependents() ([]ServiceInterface, error) { - collection, err := s.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent") - if err != nil { - return nil, err - } - - var result []ServiceInterface - for _, coll := range collection { - service, err := cimv2.NewWin32_ServiceEx1(coll) - if err != nil { - return nil, err - } - - result = append(result, &Win32Service{ - service, - }) - } - return result, nil -} - -type Win32ServiceFactory struct { -} - -func (impl Win32ServiceFactory) GetService(name string) (ServiceInterface, error) { - service, err := cim.QueryServiceByName(name, nil) - if err != nil { - return nil, err - } - - return &Win32Service{Win32_Service: service}, nil -} - type ServiceManagerImpl struct { serviceFactory ServiceFactory } -func (impl ServiceManagerImpl) WaitUntilServiceState(service ServiceInterface, stateTransition stateTransitionFunc, stateCheck stateCheckFunc, interval time.Duration, timeout time.Duration) (string, error) { +func (impl ServiceManagerImpl) WaitUntilServiceState(service cim.ServiceInterface, stateTransition stateTransitionFunc, stateCheck stateCheckFunc, interval time.Duration, timeout time.Duration) (string, error) { done, state, err := stateCheck(service, "") if err != nil { return state, err @@ -346,7 +298,7 @@ func (impl ServiceManagerImpl) WaitUntilServiceState(service ServiceInterface, s func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, error) { var serviceNames []string - var servicesToCheck []ServiceInterface + var servicesToCheck []cim.ServiceInterface servicesByName := map[string]string{} service, err := impl.serviceFactory.GetService(name) @@ -360,12 +312,12 @@ func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, e service = servicesToCheck[i] i += 1 - serviceName, err := service.GetPropertyName() + serviceName, err := cim.GetServiceName(service) if err != nil { return serviceNames, err } - currentState, err := service.GetPropertyState() + currentState, err := cim.GetServiceState(service) if err != nil { return serviceNames, err } diff --git a/pkg/os/system/api_test.go b/pkg/os/system/api_test.go index 4e84cd95..b977c74c 100644 --- a/pkg/os/system/api_test.go +++ b/pkg/os/system/api_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/pkg/errors" ) @@ -13,7 +14,7 @@ type MockService struct { DisplayName string State string StartMode string - Dependents []ServiceInterface + Dependents []cim.ServiceInterface StartResult uint32 StopResult uint32 @@ -37,7 +38,7 @@ func (m *MockService) GetPropertyStartMode() (string, error) { return m.StartMode, m.Err } -func (m *MockService) GetDependents() ([]ServiceInterface, error) { +func (m *MockService) GetDependents() ([]cim.ServiceInterface, error) { return m.Dependents, m.Err } @@ -56,11 +57,11 @@ func (m *MockService) Refresh() error { } type MockServiceFactory struct { - Services map[string]ServiceInterface + Services map[string]cim.ServiceInterface Err error } -func (f *MockServiceFactory) GetService(name string) (ServiceInterface, error) { +func (f *MockServiceFactory) GetService(name string) (cim.ServiceInterface, error) { svc, ok := f.Services[name] if !ok { return nil, fmt.Errorf("service not found: %s", name) @@ -73,7 +74,7 @@ func TestWaitUntilServiceState_Success(t *testing.T) { stateChanged := false - stateCheck := func(s ServiceInterface, state string) (bool, string, error) { + stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { if stateChanged { svc.State = serviceStateRunning return true, svc.State, nil @@ -81,7 +82,7 @@ func TestWaitUntilServiceState_Success(t *testing.T) { return false, svc.State, nil } - stateTransition := func(s ServiceInterface) error { + stateTransition := func(s cim.ServiceInterface) error { stateChanged = true return nil } @@ -99,11 +100,11 @@ func TestWaitUntilServiceState_Success(t *testing.T) { func TestWaitUntilServiceState_Timeout(t *testing.T) { svc := &MockService{Name: "svc", State: "Stopped"} - stateCheck := func(s ServiceInterface, state string) (bool, string, error) { + stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { return false, svc.State, nil } - stateTransition := func(s ServiceInterface) error { + stateTransition := func(s cim.ServiceInterface) error { return nil } @@ -120,11 +121,11 @@ func TestWaitUntilServiceState_Timeout(t *testing.T) { func TestWaitUntilServiceState_TransitionFails(t *testing.T) { svc := &MockService{Name: "svc", State: "Stopped"} - stateCheck := func(s ServiceInterface, state string) (bool, string, error) { + stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { return false, svc.State, nil } - stateTransition := func(s ServiceInterface) error { + stateTransition := func(s cim.ServiceInterface) error { return fmt.Errorf("transition failed") } @@ -138,11 +139,11 @@ func TestWaitUntilServiceState_TransitionFails(t *testing.T) { func TestGetDependentsForService(t *testing.T) { // Construct the dependency tree svcC := &MockService{Name: "C", State: serviceStateRunning} - svcB := &MockService{Name: "B", State: serviceStateRunning, Dependents: []ServiceInterface{svcC}} - svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []ServiceInterface{svcB}} + svcB := &MockService{Name: "B", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcC}} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}} factory := &MockServiceFactory{ - Services: map[string]ServiceInterface{ + Services: map[string]cim.ServiceInterface{ "A": svcA, "B": svcB, "C": svcC, @@ -171,10 +172,10 @@ func TestGetDependentsForService(t *testing.T) { func TestGetDependentsForService_SkipsNonRunning(t *testing.T) { svcB := &MockService{Name: "B", State: "Stopped"} - svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []ServiceInterface{svcB}} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}} factory := &MockServiceFactory{ - Services: map[string]ServiceInterface{ + Services: map[string]cim.ServiceInterface{ "A": svcA, "B": svcB, }, @@ -197,7 +198,7 @@ func TestGetDependentsForService_SkipsNonRunning(t *testing.T) { func TestGetDependenciesForService_Winmgmt(t *testing.T) { impl := ServiceManagerImpl{ - serviceFactory: Win32ServiceFactory{}, + serviceFactory: cim.Win32ServiceFactory{}, } serviceName := "Winmgmt" From 14a454ce80a519e9c9d225e0d472fd98ba442995 Mon Sep 17 00:00:00 2001 From: Zhongcheng Lao Date: Thu, 26 Jun 2025 20:13:01 +0800 Subject: [PATCH 4/5] Skip platform specific UT --- pkg/os/system/api_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/os/system/api_test.go b/pkg/os/system/api_test.go index b977c74c..1a4c50a7 100644 --- a/pkg/os/system/api_test.go +++ b/pkg/os/system/api_test.go @@ -2,6 +2,8 @@ package system import ( "fmt" + "os" + "strings" "testing" "time" @@ -197,6 +199,10 @@ func TestGetDependentsForService_SkipsNonRunning(t *testing.T) { } func TestGetDependenciesForService_Winmgmt(t *testing.T) { + if strings.ToLower(os.Getenv("TEST_MULTI_SERVICE_DEPENDENTS")) != "true" { + t.Skipf("Test skipped") + } + impl := ServiceManagerImpl{ serviceFactory: cim.Win32ServiceFactory{}, } From fc1bc4c8776ea16d20e9cccc5ef8d98da57087d7 Mon Sep 17 00:00:00 2001 From: Zhongcheng Lao Date: Thu, 26 Jun 2025 22:18:09 +0800 Subject: [PATCH 5/5] Wrap errors --- pkg/os/system/api.go | 43 +++++++++++++++++++++------------------ pkg/os/system/api_test.go | 18 +++++++++------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go index 02283767..2a8ddaf8 100644 --- a/pkg/os/system/api.go +++ b/pkg/os/system/api.go @@ -123,7 +123,7 @@ func (APIImplementor) GetBIOSSerialNumber() (string, error) { func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) { service, err := impl.serviceFactory.GetService(name) if err != nil { - return nil, fmt.Errorf("failed to get service %s: %w", name, err) + return nil, fmt.Errorf("failed to get service %s. error: %w", name, err) } displayName, err := cim.GetServiceDisplayName(service) @@ -152,7 +152,7 @@ func (impl APIImplementor) StartService(name string) error { startService := func(service cim.ServiceInterface) error { retVal, err := service.StartService() if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) { - return fmt.Errorf("error starting service name %s. return value: %d, error: %v", name, retVal, err) + return fmt.Errorf("error starting service name %s. return value: %d, error: %w", name, retVal, err) } return nil } @@ -173,12 +173,12 @@ func (impl APIImplementor) StartService(name string) error { service, err := impl.serviceFactory.GetService(name) if err != nil { - return err + return fmt.Errorf("failed to get service %s. error: %w", name, err) } state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout) if err != nil && !errors.Is(err, errTimedOut) { - return err + return fmt.Errorf("failed to wait for service %s state change. error: %w", name, err) } if state != serviceStateRunning { @@ -197,28 +197,28 @@ func (impl APIImplementor) stopSingleService(name string) (bool, error) { dependentRunning = true return fmt.Errorf("error stopping service %s as dependent services are not stopped", name) } - return fmt.Errorf("error stopping service %s. return value: %d, error: %v", name, retVal, err) + return fmt.Errorf("error stopping service %s. return value: %d, error: %w", name, retVal, err) } return nil } serviceStoppedCheck := func(service cim.ServiceInterface, state string) (bool, string, error) { err := service.Refresh() if err != nil { - return false, "", err + return false, "", fmt.Errorf("error refresh service %s instance. error: %w", name, err) } newState, err := cim.GetServiceState(service) if err != nil { - return false, state, err + return false, state, fmt.Errorf("error getting service %s state. error: %w", name, err) } klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState) - return newState == serviceStateStopped, newState, err + return newState == serviceStateStopped, newState, nil } service, err := impl.serviceFactory.GetService(name) if err != nil { - return dependentRunning, err + return dependentRunning, fmt.Errorf("failed to get service %s. error: %w", name, err) } state, err := impl.serviceManager.WaitUntilServiceState(service, stopService, serviceStoppedCheck, serviceStateCheckInternal, serviceStateCheckTimeout) @@ -235,8 +235,11 @@ func (impl APIImplementor) stopSingleService(name string) (bool, error) { func (impl APIImplementor) StopService(name string, force bool) error { dependentRunning, err := impl.stopSingleService(name) - if err == nil || !dependentRunning || !force { - return err + if err == nil { + return nil + } + if !dependentRunning || !force { + return fmt.Errorf("failed to stop service %s. error: %w", name, err) } serviceNames, err := impl.serviceManager.GetDependentsForService(name) @@ -247,7 +250,7 @@ func (impl APIImplementor) StopService(name string, force bool) error { for _, serviceName := range serviceNames { _, err = impl.stopSingleService(serviceName) if err != nil { - return err + return fmt.Errorf("failed to stop service %s. error: %w", name, err) } } @@ -261,15 +264,15 @@ type ServiceManagerImpl struct { func (impl ServiceManagerImpl) WaitUntilServiceState(service cim.ServiceInterface, stateTransition stateTransitionFunc, stateCheck stateCheckFunc, interval time.Duration, timeout time.Duration) (string, error) { done, state, err := stateCheck(service, "") if err != nil { - return state, err + return state, fmt.Errorf("service %v state check failed: %w", service, err) } if done { - return state, err + return state, nil } // Perform transition if not already in desired state if err := stateTransition(service); err != nil { - return state, err + return state, fmt.Errorf("service %v state transition failed: %w", service, err) } ticker := time.NewTicker(interval) @@ -283,7 +286,7 @@ func (impl ServiceManagerImpl) WaitUntilServiceState(service cim.ServiceInterfac klog.V(6).Infof("Checking service (%v) state...", service) done, state, err = stateCheck(service, state) if err != nil { - return state, fmt.Errorf("check failed: %w", err) + return state, fmt.Errorf("service %v state check failed: %w", service, err) } if done { klog.V(6).Infof("service (%v) state is %s and transition done.", service, state) @@ -303,7 +306,7 @@ func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, e service, err := impl.serviceFactory.GetService(name) if err != nil { - return serviceNames, err + return serviceNames, fmt.Errorf("failed to get service %s. error: %w", name, err) } servicesToCheck = append(servicesToCheck, service) @@ -314,12 +317,12 @@ func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, e serviceName, err := cim.GetServiceName(service) if err != nil { - return serviceNames, err + return serviceNames, fmt.Errorf("error getting service name %v. error: %w", service, err) } currentState, err := cim.GetServiceState(service) if err != nil { - return serviceNames, err + return serviceNames, fmt.Errorf("error getting service %s state. error: %w", serviceName, err) } if currentState != serviceStateRunning { @@ -332,7 +335,7 @@ func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, e dependents, err := service.GetDependents() if err != nil { - return serviceNames, err + return serviceNames, fmt.Errorf("error getting service %s dependents. error: %w", serviceName, err) } servicesToCheck = append(servicesToCheck, dependents...) diff --git a/pkg/os/system/api_test.go b/pkg/os/system/api_test.go index 1a4c50a7..7b2992bb 100644 --- a/pkg/os/system/api_test.go +++ b/pkg/os/system/api_test.go @@ -58,6 +58,8 @@ func (m *MockService) Refresh() error { return nil } +var _ cim.ServiceInterface = &MockService{} + type MockServiceFactory struct { Services map[string]cim.ServiceInterface Err error @@ -71,12 +73,14 @@ func (f *MockServiceFactory) GetService(name string) (cim.ServiceInterface, erro return svc, f.Err } +var _ ServiceFactory = &MockServiceFactory{} + func TestWaitUntilServiceState_Success(t *testing.T) { svc := &MockService{Name: "svc", State: "Stopped"} stateChanged := false - stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { + stateCheck := func(_ cim.ServiceInterface, _ string) (bool, string, error) { if stateChanged { svc.State = serviceStateRunning return true, svc.State, nil @@ -84,7 +88,7 @@ func TestWaitUntilServiceState_Success(t *testing.T) { return false, svc.State, nil } - stateTransition := func(s cim.ServiceInterface) error { + stateTransition := func(_ cim.ServiceInterface) error { stateChanged = true return nil } @@ -102,11 +106,11 @@ func TestWaitUntilServiceState_Success(t *testing.T) { func TestWaitUntilServiceState_Timeout(t *testing.T) { svc := &MockService{Name: "svc", State: "Stopped"} - stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { + stateCheck := func(_ cim.ServiceInterface, _ string) (bool, string, error) { return false, svc.State, nil } - stateTransition := func(s cim.ServiceInterface) error { + stateTransition := func(_ cim.ServiceInterface) error { return nil } @@ -123,17 +127,17 @@ func TestWaitUntilServiceState_Timeout(t *testing.T) { func TestWaitUntilServiceState_TransitionFails(t *testing.T) { svc := &MockService{Name: "svc", State: "Stopped"} - stateCheck := func(s cim.ServiceInterface, state string) (bool, string, error) { + stateCheck := func(_ cim.ServiceInterface, _ string) (bool, string, error) { return false, svc.State, nil } - stateTransition := func(s cim.ServiceInterface) error { + stateTransition := func(_ cim.ServiceInterface) error { return fmt.Errorf("transition failed") } impl := ServiceManagerImpl{} _, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond) - if err == nil || err.Error() != "transition failed" { + if err == nil || !strings.Contains(err.Error(), "transition failed") { t.Fatalf("expected transition error, got %v", err) } }