@@ -24,7 +24,7 @@ import (
2424// Verify that an error in server.Start will return an error instead of deadlock.
2525func TestNoDeadlockOnError (t * testing.T ) {
2626 registry := make (map [string ](map [string ]OsqueryPlugin ))
27- for reg , _ := range validRegistryNames {
27+ for reg := range validRegistryNames {
2828 registry [reg ] = make (map [string ]OsqueryPlugin )
2929 }
3030 mut := sync.Mutex {}
@@ -43,8 +43,9 @@ func TestNoDeadlockOnError(t *testing.T) {
4343 CloseFunc : func () {},
4444 }
4545 server := & ExtensionManagerServer {
46- serverClient : mock ,
47- registry : registry ,
46+ serverClient : mock ,
47+ registry : registry ,
48+ serverClientShouldShutdown : true ,
4849 }
4950
5051 log := func (ctx context.Context , typ logger.LogType , logText string ) error {
@@ -63,8 +64,12 @@ func TestNoDeadlockOnError(t *testing.T) {
6364// Ensure that the extension server will shutdown and return if the osquery
6465// instance it is talking to stops responding to pings.
6566func TestShutdownWhenPingFails (t * testing.T ) {
67+ tempPath , err := ioutil .TempFile ("" , "" )
68+ require .Nil (t , err )
69+ defer os .Remove (tempPath .Name ())
70+
6671 registry := make (map [string ](map [string ]OsqueryPlugin ))
67- for reg , _ := range validRegistryNames {
72+ for reg := range validRegistryNames {
6873 registry [reg ] = make (map [string ]OsqueryPlugin )
6974 }
7075 mock := & MockExtensionManager {
@@ -81,11 +86,14 @@ func TestShutdownWhenPingFails(t *testing.T) {
8186 CloseFunc : func () {},
8287 }
8388 server := & ExtensionManagerServer {
84- serverClient : mock ,
85- registry : registry ,
89+ serverClient : mock ,
90+ registry : registry ,
91+ serverClientShouldShutdown : true ,
92+ pingInterval : 1 * time .Second ,
93+ sockPath : tempPath .Name (),
8694 }
8795
88- err : = server .Run ()
96+ err = server .Run ()
8997 assert .Error (t , err )
9098 assert .Contains (t , err .Error (), "broken pipe" )
9199 assert .True (t , mock .DeRegisterExtensionFuncInvoked )
@@ -106,6 +114,7 @@ func TestShutdownDeadlock(t *testing.T) {
106114 })
107115 }
108116}
117+
109118func testShutdownDeadlock (t * testing.T , uuid int ) {
110119 tempPath , err := ioutil .TempFile ("" , "" )
111120 require .Nil (t , err )
@@ -122,9 +131,10 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
122131 CloseFunc : func () {},
123132 }
124133 server := ExtensionManagerServer {
125- serverClient : mock ,
126- sockPath : tempPath .Name (),
127- timeout : defaultTimeout ,
134+ serverClient : mock ,
135+ sockPath : tempPath .Name (),
136+ timeout : defaultTimeout ,
137+ serverClientShouldShutdown : true ,
128138 }
129139
130140 var wait sync.WaitGroup
@@ -152,8 +162,12 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
152162 for ! opened && attempt < 10 {
153163 transport = thrift .NewTSocketFromAddrTimeout (addr , timeout , timeout )
154164 err = transport .Open ()
155- opened = err == nil
156165 attempt ++
166+ if err != nil {
167+ time .Sleep (1 * time .Second )
168+ } else {
169+ opened = true
170+ }
157171 }
158172 require .NoError (t , err )
159173 client := osquery .NewExtensionManagerClientFactory (transport ,
@@ -193,9 +207,13 @@ func testShutdownDeadlock(t *testing.T, uuid int) {
193207}
194208
195209func TestShutdownBasic (t * testing.T ) {
196- tempPath , err := ioutil .TempFile ("" , "" )
197- require .Nil (t , err )
198- defer os .Remove (tempPath .Name ())
210+ dir := t .TempDir ()
211+
212+ tempPath := func () string {
213+ tmp , err := os .CreateTemp (dir , "" )
214+ require .NoError (t , err )
215+ return tmp .Name ()
216+ }
199217
200218 retUUID := osquery .ExtensionRouteUUID (0 )
201219 mock := & MockExtensionManager {
@@ -207,26 +225,38 @@ func TestShutdownBasic(t *testing.T) {
207225 },
208226 CloseFunc : func () {},
209227 }
210- server := ExtensionManagerServer {serverClient : mock , sockPath : tempPath .Name ()}
211228
212- completed := make (chan struct {})
213- go func () {
214- err := server .Start ()
229+ for _ , server := range []* ExtensionManagerServer {
230+ // Create the extension manager without using NewExtensionManagerServer.
231+ {serverClient : mock , sockPath : tempPath ()},
232+ // Create the extension manager using ExtensionManagerServer.
233+ {serverClient : mock , sockPath : tempPath (), serverClientShouldShutdown : true },
234+ } {
235+ completed := make (chan struct {})
236+ go func () {
237+ err := server .Start ()
238+ require .NoError (t , err )
239+ close (completed )
240+ }()
241+
242+ server .waitStarted ()
243+
244+ err := server .Shutdown (context .Background ())
215245 require .NoError (t , err )
216- close (completed )
217- }()
218246
219- server .waitStarted ()
220- err = server .Shutdown (context .Background ())
221- require .NoError (t , err )
247+ // Test that server.Shutdown is idempotent.
248+ err = server .Shutdown (context .Background ())
249+ require .NoError (t , err )
250+
251+ // Either indicate successful shutdown, or fatal the test because it
252+ // hung
253+ select {
254+ case <- completed :
255+ // Success. Do nothing.
256+ case <- time .After (5 * time .Second ):
257+ t .Fatal ("hung on shutdown" )
258+ }
222259
223- // Either indicate successful shutdown, or fatal the test because it
224- // hung
225- select {
226- case <- completed :
227- // Success. Do nothing.
228- case <- time .After (5 * time .Second ):
229- t .Fatal ("hung on shutdown" )
230260 }
231261}
232262
0 commit comments