Skip to content

Commit a56acdf

Browse files
authored
Merge pull request #166 from michaeldwan/disable-signals-option
Configurable signal handling, fix signal cleanup
2 parents 6b73a26 + bdba5b4 commit a56acdf

File tree

3 files changed

+163
-16
lines changed

3 files changed

+163
-16
lines changed

machine.go

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ type Config struct {
117117
// NetNS represents the path to a network namespace handle. If present, the
118118
// application will use this to join the associated network namespace
119119
NetNS string
120+
121+
// ForwardSignals is an optional list of signals to catch and forward to
122+
// firecracker. If not provided, the default signals will be used.
123+
ForwardSignals []os.Signal
120124
}
121125

122126
// Validate will ensure that the required fields are set and that
@@ -303,6 +307,16 @@ func NewMachine(ctx context.Context, cfg Config, opts ...Opt) (*Machine, error)
303307
cfg.VMID = randomID.String()
304308
}
305309

310+
if cfg.ForwardSignals == nil {
311+
cfg.ForwardSignals = []os.Signal{
312+
os.Interrupt,
313+
syscall.SIGQUIT,
314+
syscall.SIGTERM,
315+
syscall.SIGHUP,
316+
syscall.SIGABRT,
317+
}
318+
}
319+
306320
m.machineConfig = cfg.MachineCfg
307321
m.Cfg = cfg
308322

@@ -481,20 +495,7 @@ func (m *Machine) startVMM(ctx context.Context) error {
481495
close(errCh)
482496
}()
483497

484-
// Set up a signal handler and pass INT, QUIT, and TERM through to firecracker
485-
sigchan := make(chan os.Signal)
486-
signal.Notify(sigchan, os.Interrupt,
487-
syscall.SIGQUIT,
488-
syscall.SIGTERM,
489-
syscall.SIGHUP,
490-
syscall.SIGABRT)
491-
m.logger.Debugf("Setting up signal handler")
492-
go func() {
493-
if sig, ok := <-sigchan; ok {
494-
m.logger.Printf("Caught signal %s", sig)
495-
m.cmd.Process.Signal(sig)
496-
}
497-
}()
498+
m.setupSignals()
498499

499500
// Wait for firecracker to initialize:
500501
err = m.waitForSocket(time.Duration(m.client.firecrackerInitTimeout)*time.Second, errCh)
@@ -513,8 +514,6 @@ func (m *Machine) startVMM(ctx context.Context) error {
513514
m.fatalErr = err
514515
}
515516

516-
signal.Stop(sigchan)
517-
close(sigchan)
518517
close(m.exitCh)
519518
}()
520519

@@ -885,3 +884,31 @@ func (m *Machine) waitForSocket(timeout time.Duration, exitchan chan error) erro
885884
}
886885
}
887886
}
887+
888+
// Set up a signal handler to pass through to firecracker
889+
func (m *Machine) setupSignals() {
890+
signals := m.Cfg.ForwardSignals
891+
892+
if len(signals) == 0 {
893+
return
894+
}
895+
896+
m.logger.Debugf("Setting up signal handler: %v", signals)
897+
sigchan := make(chan os.Signal, len(signals))
898+
signal.Notify(sigchan, signals...)
899+
900+
go func() {
901+
for {
902+
select {
903+
case sig := <-sigchan:
904+
m.logger.Printf("Caught signal %s", sig)
905+
m.cmd.Process.Signal(sig)
906+
case <-m.exitCh:
907+
break
908+
}
909+
}
910+
911+
signal.Stop(sigchan)
912+
close(sigchan)
913+
}()
914+
}

machine_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"net"
2424
"os"
2525
"os/exec"
26+
"os/signal"
2627
"path/filepath"
2728
"strconv"
2829
"strings"
@@ -1199,3 +1200,106 @@ func createValidConfig(t *testing.T, socketPath string) Config {
11991200
},
12001201
}
12011202
}
1203+
1204+
func TestSignalForwarding(t *testing.T) {
1205+
forwardedSignals := []os.Signal{
1206+
syscall.SIGUSR1,
1207+
syscall.SIGUSR2,
1208+
syscall.SIGINT,
1209+
syscall.SIGTERM,
1210+
}
1211+
ignoredSignals := []os.Signal{
1212+
syscall.SIGHUP,
1213+
syscall.SIGQUIT,
1214+
}
1215+
1216+
cfg := Config{
1217+
Debug: true,
1218+
KernelImagePath: filepath.Join(testDataPath, "vmlinux"),
1219+
SocketPath: "/tmp/TestSignalForwarding.sock",
1220+
Drives: []models.Drive{
1221+
{
1222+
DriveID: String("0"),
1223+
IsRootDevice: Bool(true),
1224+
IsReadOnly: Bool(false),
1225+
PathOnHost: String(testRootfs),
1226+
},
1227+
},
1228+
DisableValidation: true,
1229+
ForwardSignals: forwardedSignals,
1230+
}
1231+
defer os.RemoveAll("/tmp/TestSignalForwarding.sock")
1232+
1233+
opClient := fctesting.MockClient{}
1234+
1235+
ctx := context.Background()
1236+
client := NewClient(cfg.SocketPath, fctesting.NewLogEntry(t), true, WithOpsClient(&opClient))
1237+
1238+
fd, err := net.Listen("unix", cfg.SocketPath)
1239+
if err != nil {
1240+
t.Fatalf("unexpected error during creation of unix socket: %v", err)
1241+
}
1242+
defer fd.Close()
1243+
1244+
stdout := &bytes.Buffer{}
1245+
stderr := &bytes.Buffer{}
1246+
cmd := exec.Command(filepath.Join(testDataPath, "sigprint.sh"))
1247+
cmd.Stdout = stdout
1248+
cmd.Stderr = stderr
1249+
stdin, err := cmd.StdinPipe()
1250+
assert.NoError(t, err)
1251+
1252+
m, err := NewMachine(
1253+
ctx,
1254+
cfg,
1255+
WithClient(client),
1256+
WithProcessRunner(cmd),
1257+
WithLogger(fctesting.NewLogEntry(t)),
1258+
)
1259+
if err != nil {
1260+
t.Fatalf("failed to create new machine: %v", err)
1261+
}
1262+
1263+
if err := m.startVMM(ctx); err != nil {
1264+
t.Fatalf("error startVMM: %v", err)
1265+
}
1266+
defer m.StopVMM()
1267+
1268+
sigChan := make(chan os.Signal)
1269+
signal.Notify(sigChan, ignoredSignals...)
1270+
defer func() {
1271+
signal.Stop(sigChan)
1272+
close(sigChan)
1273+
}()
1274+
1275+
go func() {
1276+
for sig := range sigChan {
1277+
t.Logf("received signal %v, ignoring", sig)
1278+
}
1279+
}()
1280+
1281+
go func() {
1282+
for _, sig := range append(forwardedSignals, ignoredSignals...) {
1283+
t.Logf("sending signal %v to self", sig)
1284+
syscall.Kill(syscall.Getpid(), sig.(syscall.Signal))
1285+
}
1286+
1287+
// give the child process time to receive signals and flush pipes
1288+
time.Sleep(1 * time.Second)
1289+
1290+
// terminate the signal printing process
1291+
stdin.Write([]byte("q"))
1292+
}()
1293+
1294+
err = m.Wait(ctx)
1295+
require.NoError(t, err, "wait returned an error")
1296+
1297+
receivedSignals := []os.Signal{}
1298+
for _, sigStr := range strings.Split(strings.TrimSpace(stdout.String()), "\n") {
1299+
i, err := strconv.Atoi(sigStr)
1300+
require.NoError(t, err, "expected numeric output")
1301+
receivedSignals = append(receivedSignals, syscall.Signal(i))
1302+
}
1303+
1304+
assert.ElementsMatch(t, forwardedSignals, receivedSignals)
1305+
}

testdata/sigprint.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
typeset -i sig=1
4+
while (( sig < 65 )); do
5+
trap "echo '$sig'" $sig 2>/dev/null
6+
let sig=sig+1
7+
done
8+
9+
>&2 echo "Send signals to PID $$ and type [q] when done."
10+
11+
while :
12+
do
13+
read -n1 input
14+
[ "$input" == "q" ] && break
15+
sleep .1
16+
done

0 commit comments

Comments
 (0)