@@ -2,10 +2,13 @@ package system
22
33import  (
44	"fmt" 
5+ 	"time" 
56
67	"github.com/kubernetes-csi/csi-proxy/pkg/cim" 
78	"github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl" 
8- 	"github.com/kubernetes-csi/csi-proxy/pkg/utils" 
9+ 	"github.com/microsoft/wmi/pkg/errors" 
10+ 	wmiinst "github.com/microsoft/wmi/pkg/wmiinstance" 
11+ 	"github.com/microsoft/wmi/server2019/root/cimv2" 
912)
1013
1114// Implements the System OS API calls. All code here should be very simple 
@@ -24,6 +27,28 @@ type ServiceInfo struct {
2427	Status  uint32  `json:"Status"` 
2528}
2629
30+ type  periodicalCheckFunc  func () (bool , error )
31+ 
32+ const  (
33+ 	// startServiceErrorCodeAccepted indicates the request is accepted 
34+ 	startServiceErrorCodeAccepted  =  0 
35+ 
36+ 	// startServiceErrorCodeAlreadyRunning indicates a service is already running 
37+ 	startServiceErrorCodeAlreadyRunning  =  10 
38+ 
39+ 	// stopServiceErrorCodeAccepted indicates the request is accepted 
40+ 	stopServiceErrorCodeAccepted  =  0 
41+ 
42+ 	// stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending) 
43+ 	stopServiceErrorCodeStopPending  =  5 
44+ 
45+ 	// stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running 
46+ 	stopServiceErrorCodeDependentRunning  =  3 
47+ 
48+ 	serviceStateRunning  =  "Running" 
49+ 	serviceStateStopped  =  "Stopped" 
50+ )
51+ 
2752var  (
2853	startModeMappings  =  map [string ]uint32 {
2954		"Boot" :     impl .START_TYPE_BOOT ,
@@ -33,24 +58,27 @@ var (
3358		"Disabled" : impl .START_TYPE_DISABLED ,
3459	}
3560
36- 	statusMappings  =  map [string ]uint32 {
37- 		"Unknown" :          impl .SERVICE_STATUS_UNKNOWN ,
38- 		"Stopped" :           impl .SERVICE_STATUS_STOPPED ,
39- 		"Start Pending" :    impl .SERVICE_STATUS_START_PENDING ,
40- 		"Stop Pending" :     impl .SERVICE_STATUS_STOP_PENDING ,
41- 		"Running" :           impl .SERVICE_STATUS_RUNNING ,
42- 		"Continue Pending" : impl .SERVICE_STATUS_CONTINUE_PENDING ,
43- 		"Pause Pending" :    impl .SERVICE_STATUS_PAUSE_PENDING ,
44- 		"Paused" :           impl .SERVICE_STATUS_PAUSED ,
61+ 	stateMappings  =  map [string ]uint32 {
62+ 		"Unknown" :            impl .SERVICE_STATUS_UNKNOWN ,
63+ 		serviceStateStopped :  impl .SERVICE_STATUS_STOPPED ,
64+ 		"Start Pending" :      impl .SERVICE_STATUS_START_PENDING ,
65+ 		"Stop Pending" :       impl .SERVICE_STATUS_STOP_PENDING ,
66+ 		serviceStateRunning :  impl .SERVICE_STATUS_RUNNING ,
67+ 		"Continue Pending" :   impl .SERVICE_STATUS_CONTINUE_PENDING ,
68+ 		"Pause Pending" :      impl .SERVICE_STATUS_PAUSE_PENDING ,
69+ 		"Paused" :             impl .SERVICE_STATUS_PAUSED ,
4570	}
71+ 
72+ 	serviceStateCheckInternal  =  500  *  time .Millisecond 
73+ 	serviceStateCheckTimeout   =  5  *  time .Second 
4674)
4775
4876func  serviceStartModeToStartType (startMode  string ) uint32  {
4977	return  startModeMappings [startMode ]
5078}
5179
5280func  serviceState (status  string ) uint32  {
53- 	return  statusMappings [status ]
81+ 	return  stateMappings [status ]
5482}
5583
5684type  APIImplementor  struct {}
@@ -101,23 +129,180 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) {
101129	}, nil 
102130}
103131
132+ func  waitForServiceState (serviceCheck  periodicalCheckFunc , interval  time.Duration , timeout  time.Duration ) error  {
133+ 	timeoutChan  :=  time .After (timeout )
134+ 	ticker  :=  time .NewTicker (interval )
135+ 	defer  ticker .Stop ()
136+ 
137+ 	for  {
138+ 		select  {
139+ 		case  <- timeoutChan :
140+ 			return  errors .Timedout 
141+ 		case  <- ticker .C :
142+ 			done , err  :=  serviceCheck ()
143+ 			if  err  !=  nil  {
144+ 				return  err 
145+ 			}
146+ 
147+ 			if  done  {
148+ 				return  nil 
149+ 			}
150+ 		}
151+ 	}
152+ }
153+ 
154+ func  getServiceState (name  string ) (string , * cimv2.Win32_Service , error ) {
155+ 	service , err  :=  cim .QueryServiceByName (name , nil )
156+ 	if  err  !=  nil  {
157+ 		return  "" , nil , err 
158+ 	}
159+ 
160+ 	state , err  :=  service .GetPropertyState ()
161+ 	if  err  !=  nil  {
162+ 		return  "" , nil , fmt .Errorf ("failed to get state property of service %s: %w" , name , err )
163+ 	}
164+ 
165+ 	return  state , service , nil 
166+ }
167+ 
104168func  (APIImplementor ) StartService (name  string ) error  {
105- 	// Note: both StartService and StopService are not implemented by WMI 
106- 	script  :=  `Start-Service -Name $env:ServiceName` 
107- 	cmdEnv  :=  fmt .Sprintf ("ServiceName=%s" , name )
108- 	out , err  :=  utils .RunPowershellCmd (script , cmdEnv )
169+ 	state , service , err  :=  getServiceState (name )
109170	if  err  !=  nil  {
110- 		return  fmt .Errorf ("error starting service name=%s. cmd: %s, output: %s, error: %v" , name , script , string (out ), err )
171+ 		return  err 
172+ 	}
173+ 
174+ 	if  state  !=  serviceStateRunning  {
175+ 		var  retVal  uint32 
176+ 		retVal , err  =  service .StartService ()
177+ 		if  err  !=  nil  ||  (retVal  !=  startServiceErrorCodeAccepted  &&  retVal  !=  startServiceErrorCodeAlreadyRunning ) {
178+ 			return  fmt .Errorf ("error starting service name %s. return value: %d, error: %v" , name , retVal , err )
179+ 		}
180+ 
181+ 		err  =  waitForServiceState (func () (bool , error ) {
182+ 			state , service , err  =  getServiceState (name )
183+ 			if  err  !=  nil  {
184+ 				return  false , err 
185+ 			}
186+ 
187+ 			return  state  ==  serviceStateRunning , nil 
188+ 
189+ 		}, serviceStateCheckInternal , serviceStateCheckTimeout )
190+ 		if  err  !=  nil  {
191+ 			return  fmt .Errorf ("error waiting service %s become running. error: %v" , name , err )
192+ 		}
193+ 	}
194+ 
195+ 	if  state  !=  serviceStateRunning  {
196+ 		return  fmt .Errorf ("error starting service name %s. current state: %s" , name , state )
111197	}
112198
113199	return  nil 
114200}
115201
116202func  (APIImplementor ) StopService (name  string , force  bool ) error  {
117- 	script  :=  `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))` 
118- 	out , err  :=  utils .RunPowershellCmd (script , fmt .Sprintf ("ServiceName=%s" , name ), fmt .Sprintf ("Force=%t" , force ))
203+ 	state , service , err  :=  getServiceState (name )
119204	if  err  !=  nil  {
120- 		return  fmt .Errorf ("error stopping service name=%s. cmd: %s, output: %s, error: %v" , name , script , string (out ), err )
205+ 		return  err 
206+ 	}
207+ 
208+ 	if  state  ==  serviceStateStopped  {
209+ 		return  nil 
210+ 	}
211+ 
212+ 	stopSingleService  :=  func (name  string , service  * wmiinst.WmiInstance ) (bool , error ) {
213+ 		retVal , err  :=  service .InvokeMethodWithReturn ("StopService" )
214+ 		if  err  !=  nil  ||  (retVal  !=  stopServiceErrorCodeAccepted  &&  retVal  !=  stopServiceErrorCodeStopPending ) {
215+ 			if  retVal  ==  stopServiceErrorCodeDependentRunning  {
216+ 				return  true , fmt .Errorf ("error stopping service %s as dependent services are not stopped" , name )
217+ 			}
218+ 			return  false , fmt .Errorf ("error stopping service %s. return value: %d, error: %v" , name , retVal , err )
219+ 		}
220+ 
221+ 		var  serviceState  string 
222+ 		err  =  waitForServiceState (func () (bool , error ) {
223+ 			serviceState , _ , err  =  getServiceState (name )
224+ 			if  err  !=  nil  {
225+ 				return  false , err 
226+ 			}
227+ 
228+ 			return  serviceState  ==  serviceStateStopped , nil 
229+ 
230+ 		}, serviceStateCheckInternal , serviceStateCheckTimeout )
231+ 		if  err  !=  nil  {
232+ 			return  false , fmt .Errorf ("error waiting service %s become stopped. error: %v" , name , err )
233+ 		}
234+ 
235+ 		if  serviceState  !=  serviceStateStopped  {
236+ 			return  false , fmt .Errorf ("error stopping service name %s. current state: %s" , name , serviceState )
237+ 		}
238+ 
239+ 		return  false , nil 
240+ 	}
241+ 
242+ 	dependentRunning , err  :=  stopSingleService (name , service .WmiInstance )
243+ 	if  ! force  ||  err  ==  nil  ||  ! dependentRunning  {
244+ 		return  err 
245+ 	}
246+ 
247+ 	var  serviceNames  []string 
248+ 	var  servicesToCheck  wmiinst.WmiInstanceCollection 
249+ 	servicesByName  :=  map [string ]* wmiinst.WmiInstance {}
250+ 
251+ 	servicesToCheck  =  append (servicesToCheck , service .WmiInstance )
252+ 	i  :=  0 
253+ 	for  i  <  len (servicesToCheck ) {
254+ 		current  :=  servicesToCheck [i ]
255+ 		i  +=  1 
256+ 
257+ 		currentNameVal , err  :=  current .GetProperty ("Name" )
258+ 		if  err  !=  nil  {
259+ 			return  err 
260+ 		}
261+ 
262+ 		currentName  :=  currentNameVal .(string )
263+ 		if  _ , ok  :=  servicesByName [currentName ]; ok  {
264+ 			continue 
265+ 		}
266+ 
267+ 		currentStateVal , err  :=  current .GetProperty ("State" )
268+ 		if  err  !=  nil  {
269+ 			return  err 
270+ 		}
271+ 
272+ 		currentState  :=  currentStateVal 
273+ 		if  currentState  !=  serviceStateRunning  {
274+ 			continue 
275+ 		}
276+ 
277+ 		servicesByName [currentName ] =  current 
278+ 		serviceNames  =  append (serviceNames , currentName )
279+ 
280+ 		dependents , err  :=  current .GetAssociated ("Win32_DependentService" , "Win32_Service" , "Dependent" , "Antecedent" )
281+ 		if  err  !=  nil  {
282+ 			return  err 
283+ 		}
284+ 
285+ 		servicesToCheck  =  append (servicesToCheck , dependents ... )
286+ 	}
287+ 
288+ 	i  =  len (serviceNames ) -  1 
289+ 	for  i  >=  0  {
290+ 		serviceName  :=  serviceNames [i ]
291+ 		i  -=  1 
292+ 
293+ 		state , service , err  :=  getServiceState (serviceName )
294+ 		if  err  !=  nil  {
295+ 			return  err 
296+ 		}
297+ 
298+ 		if  state  ==  serviceStateStopped  {
299+ 			continue 
300+ 		}
301+ 
302+ 		_ , err  =  stopSingleService (serviceName , service .WmiInstance )
303+ 		if  err  !=  nil  {
304+ 			return  err 
305+ 		}
121306	}
122307
123308	return  nil 
0 commit comments