Skip to content

Commit d4f3a62

Browse files
authored
Remove signal handling and add tests to shutdown (osquery#58)
Update the shutdown mechanism per discussion in osquery/osquery#4577
1 parent c674f1b commit d4f3a62

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

server.go

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@ package osquery
33
import (
44
"context"
55
"fmt"
6-
"os"
7-
"os/signal"
86
"sync"
9-
"syscall"
107
"time"
118

129
"git.apache.org/thrift.git/lib/go/thrift"
@@ -52,6 +49,7 @@ type ExtensionManagerServer struct {
5249
timeout time.Duration
5350
pingInterval time.Duration // How often to ping osquery server
5451
mutex sync.Mutex
52+
started bool // Used to ensure tests wait until the server is actually started
5553
}
5654

5755
// validRegistryNames contains the allowable RegistryName() values. If a plugin
@@ -167,6 +165,9 @@ func (s *ExtensionManagerServer) Start() error {
167165

168166
s.server = thrift.NewTSimpleServer2(processor, s.transport)
169167
server = s.server
168+
169+
s.started = true
170+
170171
return nil
171172
}()
172173

@@ -177,23 +178,14 @@ func (s *ExtensionManagerServer) Start() error {
177178
return server.Serve()
178179
}
179180

180-
// Run starts the extension manager and runs until an an interrupt
181-
// signal is received.
182-
// Run will call Shutdown before exiting.
181+
// Run starts the extension manager and runs until osquery calls for a shutdown
182+
// or the osquery instance goes away.
183183
func (s *ExtensionManagerServer) Run() error {
184184
errc := make(chan error)
185185
go func() {
186186
errc <- s.Start()
187187
}()
188188

189-
// Interrupt handler.
190-
go func() {
191-
sig := make(chan os.Signal)
192-
signal.Notify(sig, os.Interrupt, os.Kill, syscall.SIGTERM)
193-
<-sig
194-
errc <- nil
195-
}()
196-
197189
// Watch for the osquery process going away. If so, initiate shutdown.
198190
go func() {
199191
for {
@@ -265,5 +257,19 @@ func (s *ExtensionManagerServer) Shutdown() error {
265257
server.Stop()
266258
}()
267259
}
260+
268261
return nil
269262
}
263+
264+
// Useful for testing
265+
func (s *ExtensionManagerServer) waitStarted() {
266+
for {
267+
s.mutex.Lock()
268+
started := s.started
269+
s.mutex.Unlock()
270+
if started {
271+
time.Sleep(10 * time.Millisecond)
272+
break
273+
}
274+
}
275+
}

server_test.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,17 @@ func testShutdownDeadlock(t *testing.T) {
106106
},
107107
}
108108
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
109+
110+
wait := sync.WaitGroup{}
111+
112+
wait.Add(1)
109113
go func() {
110114
err := server.Start()
111115
require.Nil(t, err)
116+
wait.Done()
112117
}()
113-
// Sleep long enough for server to start listening on socket
114-
time.Sleep(500 * time.Millisecond)
118+
// Wait for server to be set up
119+
server.waitStarted()
115120

116121
// Create a raw client to access the shutdown method that is not
117122
// usually exposed.
@@ -127,7 +132,6 @@ func testShutdownDeadlock(t *testing.T) {
127132

128133
// Simultaneously call shutdown through a request from the client and
129134
// directly on the server object.
130-
wait := sync.WaitGroup{}
131135
wait.Add(1)
132136
go func() {
133137
defer wait.Done()
@@ -148,6 +152,40 @@ func testShutdownDeadlock(t *testing.T) {
148152
close(completed)
149153
}()
150154

155+
// either indicate successful shutdown, or fatal the test because it
156+
// hung
157+
select {
158+
case <-completed:
159+
// Success. Do nothing.
160+
case <-time.After(5 * time.Second):
161+
t.Fatal("hung on shutdown")
162+
}
163+
}
164+
165+
func TestShutdownBasic(t *testing.T) {
166+
tempPath, err := ioutil.TempFile("", "")
167+
require.Nil(t, err)
168+
defer os.Remove(tempPath.Name())
169+
170+
retUUID := osquery.ExtensionRouteUUID(0)
171+
mock := &MockExtensionManager{
172+
RegisterExtensionFunc: func(info *osquery.InternalExtensionInfo, registry osquery.ExtensionRegistry) (*osquery.ExtensionStatus, error) {
173+
return &osquery.ExtensionStatus{Code: 0, UUID: retUUID}, nil
174+
},
175+
}
176+
server := ExtensionManagerServer{serverClient: mock, sockPath: tempPath.Name()}
177+
178+
completed := make(chan struct{})
179+
go func() {
180+
err := server.Start()
181+
require.NoError(t, err)
182+
close(completed)
183+
}()
184+
185+
server.waitStarted()
186+
err = server.Shutdown()
187+
require.NoError(t, err)
188+
151189
// Either indicate successful shutdown, or fatal the test because it
152190
// hung
153191
select {

0 commit comments

Comments
 (0)