Skip to content

Commit c569328

Browse files
authored
Merge pull request #182 from kzys/wait-and-cancel
Make sure Machine#Wait() gurantee that the underlying Firecracker process is stopped
2 parents 7dce86a + 25ecff7 commit c569328

File tree

2 files changed

+98
-29
lines changed

2 files changed

+98
-29
lines changed

machine.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,10 @@ func NewMachine(ctx context.Context, cfg Config, opts ...Opt) (*Machine, error)
338338
return m, nil
339339
}
340340

341-
// Start will iterate through the handler list and call each handler. If an
341+
// Start actually start a Firecracker microVM.
342+
// The context must not be cancelled while the microVM is running.
343+
//
344+
// It will iterate through the handler list and call each handler. If an
342345
// error occurred during handler execution, that error will be returned. If the
343346
// handlers succeed, then this will start the VMM instance.
344347
// Start may only be called once per Machine. Subsequent calls will return
@@ -516,14 +519,22 @@ func (m *Machine) startVMM(ctx context.Context) error {
516519

517520
return err
518521
}
522+
523+
// This goroutine is used to kill the process by context cancelletion,
524+
// but doesn't tell anyone about that.
519525
go func() {
520-
select {
521-
case <-ctx.Done():
522-
m.fatalErr = ctx.Err()
523-
case err := <-errCh:
524-
m.fatalErr = err
526+
<-ctx.Done()
527+
err := m.stopVMM()
528+
if err != nil {
529+
m.logger.WithError(err).Errorf("failed to stop vm %q", m.Cfg.VMID)
525530
}
531+
}()
526532

533+
// This goroutine is used to tell clients that the process is stopped
534+
// (gracefully or not).
535+
go func() {
536+
m.fatalErr = <-errCh
537+
m.logger.Debugf("closing the exitCh %v", m.fatalErr)
527538
close(m.exitCh)
528539
}()
529540

machine_test.go

Lines changed: 81 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,37 +1081,95 @@ func TestCaptureFifoToFile_leak(t *testing.T) {
10811081
assert.Contains(t, loggerBuffer.String(), `file already closed`, "log")
10821082
}
10831083

1084-
func TestWaitWithKill(t *testing.T) {
1084+
// Replace filesystem-unsafe characters (such as /) which are often seen in Go's test names
1085+
var fsSafeTestName = strings.NewReplacer("/", "_")
1086+
1087+
func TestWait(t *testing.T) {
10851088
fctesting.RequiresRoot(t)
1086-
ctx := context.Background()
10871089

1088-
socketPath := filepath.Join(testDataPath, t.Name())
1089-
defer os.Remove(socketPath)
1090+
cases := []struct {
1091+
name string
1092+
stop func(m *Machine, cancel context.CancelFunc)
1093+
}{
1094+
{
1095+
name: "StopVMM",
1096+
stop: func(m *Machine, _ context.CancelFunc) {
1097+
err := m.StopVMM()
1098+
require.NoError(t, err)
1099+
},
1100+
},
1101+
{
1102+
name: "Kill",
1103+
stop: func(m *Machine, cancel context.CancelFunc) {
1104+
pid, err := m.PID()
1105+
require.NoError(t, err)
1106+
1107+
process, err := os.FindProcess(pid)
1108+
err = process.Kill()
1109+
require.NoError(t, err)
1110+
},
1111+
},
1112+
{
1113+
name: "Context Cancel",
1114+
stop: func(m *Machine, cancel context.CancelFunc) {
1115+
cancel()
1116+
},
1117+
},
1118+
{
1119+
name: "StopVMM + Context Cancel",
1120+
stop: func(m *Machine, cancel context.CancelFunc) {
1121+
m.StopVMM()
1122+
time.Sleep(1 * time.Second)
1123+
cancel()
1124+
},
1125+
},
1126+
}
10901127

1091-
cfg := createValidConfig(t, socketPath)
1092-
cmd := VMCommandBuilder{}.
1093-
WithSocketPath(cfg.SocketPath).
1094-
WithBin(getFirecrackerBinaryPath()).
1095-
Build(ctx)
1096-
m, err := NewMachine(ctx, cfg, WithProcessRunner(cmd))
1097-
require.NoError(t, err)
1128+
for _, c := range cases {
1129+
t.Run(c.name, func(t *testing.T) {
1130+
ctx := context.Background()
1131+
vmContext, vmCancel := context.WithCancel(context.Background())
10981132

1099-
err = m.Start(ctx)
1100-
require.NoError(t, err)
1133+
socketPath := filepath.Join(testDataPath, fsSafeTestName.Replace(t.Name()))
1134+
defer os.Remove(socketPath)
11011135

1102-
go func() {
1103-
pid, err := m.PID()
1104-
require.NoError(t, err)
1136+
cfg := createValidConfig(t, socketPath)
1137+
m, err := NewMachine(ctx, cfg, func(m *Machine) {
1138+
// Rewriting m.cmd partially wouldn't work since Cmd has
1139+
// some unexported members
1140+
args := m.cmd.Args[1:]
1141+
m.cmd = exec.Command(getFirecrackerBinaryPath(), args...)
1142+
})
1143+
require.NoError(t, err)
11051144

1106-
process, err := os.FindProcess(pid)
1107-
require.NoError(t, err)
1145+
err = m.Start(vmContext)
1146+
require.NoError(t, err)
11081147

1109-
err = process.Kill()
1110-
require.NoError(t, err)
1111-
}()
1148+
pid, err := m.PID()
1149+
require.NoError(t, err)
11121150

1113-
err = m.Wait(ctx)
1114-
require.Error(t, err, "Firecracker was killed and it must be reported")
1151+
var wg sync.WaitGroup
1152+
wg.Add(1)
1153+
go func() {
1154+
defer wg.Done()
1155+
c.stop(m, vmCancel)
1156+
}()
1157+
1158+
err = m.Wait(ctx)
1159+
require.Error(t, err, "Firecracker was killed and it must be reported")
1160+
t.Logf("err = %v", err)
1161+
1162+
proc, err := os.FindProcess(pid)
1163+
// Having an error here doesn't mean the process is not there.
1164+
// In fact it won't be non-nil on Unix systems
1165+
require.NoError(t, err)
1166+
1167+
err = proc.Signal(syscall.Signal(0))
1168+
require.Equal(t, "os: process already finished", err.Error())
1169+
1170+
wg.Wait()
1171+
})
1172+
}
11151173
}
11161174

11171175
func TestWaitWithInvalidBinary(t *testing.T) {

0 commit comments

Comments
 (0)