@@ -2,10 +2,13 @@ package system
22
33import (
44 "fmt"
5+ "time"
56
67 "github.com/kubernetes-csi/csi-proxy/pkg/cim"
78 "github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl"
8- "github.com/kubernetes-csi/csi-proxy/pkg/utils"
9+ "github.com/microsoft/wmi/pkg/errors"
10+ wmiinst "github.com/microsoft/wmi/pkg/wmiinstance"
11+ "github.com/microsoft/wmi/server2019/root/cimv2"
912)
1013
1114// Implements the System OS API calls. All code here should be very simple
@@ -24,6 +27,28 @@ type ServiceInfo struct {
2427 Status uint32 `json:"Status"`
2528}
2629
30+ type periodicalCheckFunc func () (bool , error )
31+
32+ const (
33+ // startServiceErrorCodeAccepted indicates the request is accepted
34+ startServiceErrorCodeAccepted = 0
35+
36+ // startServiceErrorCodeAlreadyRunning indicates a service is already running
37+ startServiceErrorCodeAlreadyRunning = 10
38+
39+ // stopServiceErrorCodeAccepted indicates the request is accepted
40+ stopServiceErrorCodeAccepted = 0
41+
42+ // stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending)
43+ stopServiceErrorCodeStopPending = 5
44+
45+ // stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running
46+ stopServiceErrorCodeDependentRunning = 3
47+
48+ serviceStateRunning = "Running"
49+ serviceStateStopped = "Stopped"
50+ )
51+
2752var (
2853 startModeMappings = map [string ]uint32 {
2954 "Boot" : impl .START_TYPE_BOOT ,
@@ -33,24 +58,27 @@ var (
3358 "Disabled" : impl .START_TYPE_DISABLED ,
3459 }
3560
36- statusMappings = map [string ]uint32 {
37- "Unknown" : impl .SERVICE_STATUS_UNKNOWN ,
38- "Stopped" : impl .SERVICE_STATUS_STOPPED ,
39- "Start Pending" : impl .SERVICE_STATUS_START_PENDING ,
40- "Stop Pending" : impl .SERVICE_STATUS_STOP_PENDING ,
41- "Running" : impl .SERVICE_STATUS_RUNNING ,
42- "Continue Pending" : impl .SERVICE_STATUS_CONTINUE_PENDING ,
43- "Pause Pending" : impl .SERVICE_STATUS_PAUSE_PENDING ,
44- "Paused" : impl .SERVICE_STATUS_PAUSED ,
61+ stateMappings = map [string ]uint32 {
62+ "Unknown" : impl .SERVICE_STATUS_UNKNOWN ,
63+ serviceStateStopped : impl .SERVICE_STATUS_STOPPED ,
64+ "Start Pending" : impl .SERVICE_STATUS_START_PENDING ,
65+ "Stop Pending" : impl .SERVICE_STATUS_STOP_PENDING ,
66+ serviceStateRunning : impl .SERVICE_STATUS_RUNNING ,
67+ "Continue Pending" : impl .SERVICE_STATUS_CONTINUE_PENDING ,
68+ "Pause Pending" : impl .SERVICE_STATUS_PAUSE_PENDING ,
69+ "Paused" : impl .SERVICE_STATUS_PAUSED ,
4570 }
71+
72+ serviceStateCheckInternal = 500 * time .Millisecond
73+ serviceStateCheckTimeout = 5 * time .Second
4674)
4775
4876func serviceStartModeToStartType (startMode string ) uint32 {
4977 return startModeMappings [startMode ]
5078}
5179
5280func serviceState (status string ) uint32 {
53- return statusMappings [status ]
81+ return stateMappings [status ]
5482}
5583
5684type APIImplementor struct {}
@@ -101,23 +129,180 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) {
101129 }, nil
102130}
103131
132+ func waitForServiceState (serviceCheck periodicalCheckFunc , interval time.Duration , timeout time.Duration ) error {
133+ timeoutChan := time .After (timeout )
134+ ticker := time .NewTicker (interval )
135+ defer ticker .Stop ()
136+
137+ for {
138+ select {
139+ case <- timeoutChan :
140+ return errors .Timedout
141+ case <- ticker .C :
142+ done , err := serviceCheck ()
143+ if err != nil {
144+ return err
145+ }
146+
147+ if done {
148+ return nil
149+ }
150+ }
151+ }
152+ }
153+
154+ func getServiceState (name string ) (string , * cimv2.Win32_Service , error ) {
155+ service , err := cim .QueryServiceByName (name , nil )
156+ if err != nil {
157+ return "" , nil , err
158+ }
159+
160+ state , err := service .GetPropertyState ()
161+ if err != nil {
162+ return "" , nil , fmt .Errorf ("failed to get state property of service %s: %w" , name , err )
163+ }
164+
165+ return state , service , nil
166+ }
167+
104168func (APIImplementor ) StartService (name string ) error {
105- // Note: both StartService and StopService are not implemented by WMI
106- script := `Start-Service -Name $env:ServiceName`
107- cmdEnv := fmt .Sprintf ("ServiceName=%s" , name )
108- out , err := utils .RunPowershellCmd (script , cmdEnv )
169+ state , service , err := getServiceState (name )
109170 if err != nil {
110- return fmt .Errorf ("error starting service name=%s. cmd: %s, output: %s, error: %v" , name , script , string (out ), err )
171+ return err
172+ }
173+
174+ if state != serviceStateRunning {
175+ var retVal uint32
176+ retVal , err = service .StartService ()
177+ if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning ) {
178+ return fmt .Errorf ("error starting service name %s. return value: %d, error: %v" , name , retVal , err )
179+ }
180+
181+ err = waitForServiceState (func () (bool , error ) {
182+ state , service , err = getServiceState (name )
183+ if err != nil {
184+ return false , err
185+ }
186+
187+ return state == serviceStateRunning , nil
188+
189+ }, serviceStateCheckInternal , serviceStateCheckTimeout )
190+ if err != nil {
191+ return fmt .Errorf ("error waiting service %s become running. error: %v" , name , err )
192+ }
193+ }
194+
195+ if state != serviceStateRunning {
196+ return fmt .Errorf ("error starting service name %s. current state: %s" , name , state )
111197 }
112198
113199 return nil
114200}
115201
116202func (APIImplementor ) StopService (name string , force bool ) error {
117- script := `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))`
118- out , err := utils .RunPowershellCmd (script , fmt .Sprintf ("ServiceName=%s" , name ), fmt .Sprintf ("Force=%t" , force ))
203+ state , service , err := getServiceState (name )
119204 if err != nil {
120- return fmt .Errorf ("error stopping service name=%s. cmd: %s, output: %s, error: %v" , name , script , string (out ), err )
205+ return err
206+ }
207+
208+ if state == serviceStateStopped {
209+ return nil
210+ }
211+
212+ stopSingleService := func (name string , service * wmiinst.WmiInstance ) (bool , error ) {
213+ retVal , err := service .InvokeMethodWithReturn ("StopService" )
214+ if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending ) {
215+ if retVal == stopServiceErrorCodeDependentRunning {
216+ return true , fmt .Errorf ("error stopping service %s as dependent services are not stopped" , name )
217+ }
218+ return false , fmt .Errorf ("error stopping service %s. return value: %d, error: %v" , name , retVal , err )
219+ }
220+
221+ var serviceState string
222+ err = waitForServiceState (func () (bool , error ) {
223+ serviceState , _ , err = getServiceState (name )
224+ if err != nil {
225+ return false , err
226+ }
227+
228+ return serviceState == serviceStateStopped , nil
229+
230+ }, serviceStateCheckInternal , serviceStateCheckTimeout )
231+ if err != nil {
232+ return false , fmt .Errorf ("error waiting service %s become stopped. error: %v" , name , err )
233+ }
234+
235+ if serviceState != serviceStateStopped {
236+ return false , fmt .Errorf ("error stopping service name %s. current state: %s" , name , serviceState )
237+ }
238+
239+ return false , nil
240+ }
241+
242+ dependentRunning , err := stopSingleService (name , service .WmiInstance )
243+ if ! force || err == nil || ! dependentRunning {
244+ return err
245+ }
246+
247+ var serviceNames []string
248+ var servicesToCheck wmiinst.WmiInstanceCollection
249+ servicesByName := map [string ]* wmiinst.WmiInstance {}
250+
251+ servicesToCheck = append (servicesToCheck , service .WmiInstance )
252+ i := 0
253+ for i < len (servicesToCheck ) {
254+ current := servicesToCheck [i ]
255+ i += 1
256+
257+ currentNameVal , err := current .GetProperty ("Name" )
258+ if err != nil {
259+ return err
260+ }
261+
262+ currentName := currentNameVal .(string )
263+ if _ , ok := servicesByName [currentName ]; ok {
264+ continue
265+ }
266+
267+ currentStateVal , err := current .GetProperty ("State" )
268+ if err != nil {
269+ return err
270+ }
271+
272+ currentState := currentStateVal
273+ if currentState != serviceStateRunning {
274+ continue
275+ }
276+
277+ servicesByName [currentName ] = current
278+ serviceNames = append (serviceNames , currentName )
279+
280+ dependents , err := current .GetAssociated ("Win32_DependentService" , "Win32_Service" , "Dependent" , "Antecedent" )
281+ if err != nil {
282+ return err
283+ }
284+
285+ servicesToCheck = append (servicesToCheck , dependents ... )
286+ }
287+
288+ i = len (serviceNames ) - 1
289+ for i >= 0 {
290+ serviceName := serviceNames [i ]
291+ i -= 1
292+
293+ state , service , err := getServiceState (serviceName )
294+ if err != nil {
295+ return err
296+ }
297+
298+ if state == serviceStateStopped {
299+ continue
300+ }
301+
302+ _ , err = stopSingleService (serviceName , service .WmiInstance )
303+ if err != nil {
304+ return err
305+ }
121306 }
122307
123308 return nil
0 commit comments