@@ -14,7 +14,16 @@ import (
1414const pollingInterval = 250 * time .Millisecond
1515
1616type windowsService struct {
17- handle * mgr.Service
17+ handle windowsServiceHandle
18+ }
19+
20+ // Replaces mgr.Service
21+ type windowsServiceHandle interface {
22+ Close () error
23+ Start (args ... string ) error
24+ Control (c svc.Cmd ) (svc.Status , error )
25+ Query () (svc.Status , error )
26+ Delete () error
1827}
1928
2029func (winSvc * windowsService ) Close () error {
@@ -59,8 +68,25 @@ func (winSvc *windowsService) IsActive() bool {
5968 return status .State == svc .Running
6069}
6170
62- func Create (params AgentParams ) (Service , error ) {
63- svcMgr , err := mgr .Connect ()
71+ // Substitute for mgr.Mgr
72+ type windowsServiceManager interface {
73+ Disconnect () error
74+ CreateService (name string , exepath string , c mgr.Config , args ... string ) (* mgr.Service , error )
75+ OpenService (name string ) (* mgr.Service , error )
76+ }
77+
78+ type windowsServiceManagerFactory interface {
79+ Connect () (windowsServiceManager , error )
80+ }
81+
82+ type defaultWindowsServiceManagerFactory struct {}
83+
84+ func (f * defaultWindowsServiceManagerFactory ) Connect () (windowsServiceManager , error ) {
85+ return mgr .Connect ()
86+ }
87+
88+ func createWithFactory (params AgentParams , factory windowsServiceManagerFactory ) (Service , error ) {
89+ svcMgr , err := factory .Connect ()
6490 if err != nil {
6591 return nil , err
6692 }
@@ -80,8 +106,8 @@ func Create(params AgentParams) (Service, error) {
80106 }, nil
81107}
82108
83- func Open (name string ) (Service , error ) {
84- svcMgr , err := mgr .Connect ()
109+ func openWithFactory (name string , factory windowsServiceManagerFactory ) (Service , error ) {
110+ svcMgr , err := factory .Connect ()
85111 if err != nil {
86112 return nil , err
87113 }
@@ -97,6 +123,14 @@ func Open(name string) (Service, error) {
97123 }, nil
98124}
99125
126+ func Create (params AgentParams ) (Service , error ) {
127+ return createWithFactory (params , & defaultWindowsServiceManagerFactory {})
128+ }
129+
130+ func Open (name string ) (Service , error ) {
131+ return openWithFactory (name , & defaultWindowsServiceManagerFactory {})
132+ }
133+
100134type windowsRunner struct {
101135 runner Runner
102136 exitCode int
@@ -106,7 +140,7 @@ func (host *windowsRunner) Execute(args []string, request <-chan svc.ChangeReque
106140 response <- svc.Status {State : svc .StartPending }
107141
108142 // Make the channels
109- stop := make (chan struct {})
143+ stop := make (chan struct {}, 1 )
110144 running := make (chan struct {})
111145
112146 // Make go routines for the channels
@@ -148,22 +182,41 @@ func (host *windowsRunner) Execute(args []string, request <-chan svc.ChangeReque
148182 return host .exitCode == 0 , uint32 (host .exitCode )
149183}
150184
185+ type windowsServiceFactory interface {
186+ IsWindowsService () (bool , error )
187+ Run (name string , handler svc.Handler ) error
188+ }
189+
190+ type defaultWindowsServiceFactory struct {}
191+
192+ func (f * defaultWindowsServiceFactory ) IsWindowsService () (bool , error ) {
193+ return svc .IsWindowsService ()
194+ }
195+
196+ func (f * defaultWindowsServiceFactory ) Run (name string , handler svc.Handler ) error {
197+ return svc .Run (name , handler )
198+ }
199+
151200func Run (runner Runner ) (int , error ) {
201+ return runWithFactory (runner , & defaultWindowsServiceFactory {})
202+ }
203+
204+ func runWithFactory (runner Runner , factory windowsServiceFactory ) (int , error ) {
152205 // Check if this is running as a service
153- isWinSvc , err := svc .IsWindowsService ()
206+ isWinSvc , err := factory .IsWindowsService ()
154207 if err != nil {
155- return 1 , err
208+ return int ( GenericError ) , err
156209 }
157210
158211 if ! isWinSvc {
159- return 1 , fmt .Errorf ("executable should be run as a service" )
212+ return int ( GenericError ) , fmt .Errorf ("executable should be run as a service" )
160213 }
161214
162215 // Start the windows service
163216 host := & windowsRunner {
164217 runner : runner ,
165218 }
166- err = svc .Run (runner .Name (), host )
219+ err = factory .Run (runner .Name (), host )
167220 if err != nil {
168221 return host .exitCode , err
169222 }
0 commit comments