Skip to content

Commit fe310f3

Browse files
committed
Fix InheritFd by passing fds via cmd.ExtraFiles
The InheritFd feature wasn't working because file descriptors registered via AddInheritFd() were not passed to the criu swrk child process. Go's exec.Command only inherits stdin/stdout/stderr by default. Additional file descriptors must be explicitly passed via cmd.ExtraFiles, which assigns them sequential fd numbers (3, 4, 5...) in the child process. This change: - Adds AddInheritFd(key, file) method to register fds for inheritance - Collects registered fds and passes them via cmd.ExtraFiles to swrk child - Auto-populates opts.InheritFd from registered fds (sorted by key for deterministic ordering) - Adds test verifying fd passing works with AddInheritFd and fails without it Example usage: c := criu.MakeCriu() c.AddInheritFd("net", netnsFile) c.Restore(opts, nfy) Fixes #223 Signed-off-by: Dan Feigin <dfeigin@nvidia.com>
1 parent 9684c5e commit fe310f3

File tree

3 files changed

+232
-6
lines changed

3 files changed

+232
-6
lines changed

main.go

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,24 @@ import (
55
"fmt"
66
"os"
77
"os/exec"
8+
"sort"
89
"strconv"
910
"syscall"
1011

1112
"github.com/checkpoint-restore/go-criu/v8/rpc"
1213
"google.golang.org/protobuf/proto"
1314
)
1415

16+
// extraFilesStartFd is the first fd number assigned to cmd.ExtraFiles by os/exec.
17+
// As documented in os/exec: "entry i becomes file descriptor 3+i"
18+
const extraFilesStartFd = 3
19+
1520
// Criu struct
1621
type Criu struct {
17-
swrkCmd *exec.Cmd
18-
swrkSk *os.File
19-
swrkPath string
22+
swrkCmd *exec.Cmd
23+
swrkSk *os.File
24+
swrkPath string
25+
inheritFds map[string]*os.File
2026
}
2127

2228
// MakeCriu returns the Criu object required for most operations
@@ -32,8 +38,52 @@ func (c *Criu) SetCriuPath(path string) {
3238
c.swrkPath = path
3339
}
3440

41+
// AddInheritFd registers a file descriptor to be passed to CRIU.
42+
// If opts.InheritFd is not set for an operation, it will be populated
43+
// from these registrations using the same key order.
44+
func (c *Criu) AddInheritFd(key string, file *os.File) {
45+
if c.inheritFds == nil {
46+
c.inheritFds = make(map[string]*os.File)
47+
}
48+
c.inheritFds[key] = file
49+
}
50+
51+
func (c *Criu) inheritFdKeys() []string {
52+
if len(c.inheritFds) == 0 {
53+
return nil
54+
}
55+
keys := make([]string, 0, len(c.inheritFds))
56+
for k := range c.inheritFds {
57+
keys = append(keys, k)
58+
}
59+
sort.Strings(keys)
60+
return keys
61+
}
62+
63+
func (c *Criu) ensureInheritFd(opts *rpc.CriuOpts) {
64+
if opts == nil || len(opts.GetInheritFd()) > 0 || len(c.inheritFds) == 0 {
65+
return
66+
}
67+
keys := c.inheritFdKeys()
68+
if len(keys) == 0 {
69+
return
70+
}
71+
opts.InheritFd = make([]*rpc.InheritFd, 0, len(keys))
72+
for i, key := range keys {
73+
fd := int32(extraFilesStartFd + i)
74+
opts.InheritFd = append(opts.InheritFd, &rpc.InheritFd{
75+
Key: proto.String(key),
76+
Fd: proto.Int32(fd),
77+
})
78+
}
79+
}
80+
3581
// Prepare sets up everything for the RPC communication to CRIU
3682
func (c *Criu) Prepare() error {
83+
return c.doPrepare(nil)
84+
}
85+
86+
func (c *Criu) doPrepare(opts *rpc.CriuOpts) error {
3787
fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_SEQPACKET, 0)
3888
if err != nil {
3989
return err
@@ -48,6 +98,19 @@ func (c *Criu) Prepare() error {
4898
// #nosec G204
4999
cmd := exec.Command(c.swrkPath, args...)
50100

101+
// Collect file descriptors to pass to child
102+
inheritKeys := c.inheritFdKeys()
103+
extraFiles := make([]*os.File, 0, len(inheritKeys))
104+
105+
// Add fds from AddInheritFd (sorted for stable ordering)
106+
for _, k := range inheritKeys {
107+
extraFiles = append(extraFiles, c.inheritFds[k])
108+
}
109+
110+
c.ensureInheritFd(opts)
111+
112+
cmd.ExtraFiles = extraFiles
113+
51114
err = cmd.Start()
52115
if err != nil {
53116
cln.Close()
@@ -106,6 +169,8 @@ func (c *Criu) doSwrk(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy Notify) e
106169
}
107170

108171
func (c *Criu) doSwrkWithResp(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy Notify, features *rpc.CriuFeatures) (resp *rpc.CriuResp, retErr error) {
172+
c.ensureInheritFd(opts)
173+
109174
req := rpc.CriuReq{
110175
Type: &reqType,
111176
Opts: opts,
@@ -120,7 +185,7 @@ func (c *Criu) doSwrkWithResp(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy N
120185
}
121186

122187
if c.swrkCmd == nil {
123-
err := c.Prepare()
188+
err := c.doPrepare(opts)
124189
if err != nil {
125190
return nil, err
126191
}

test/Makefile

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ CRIU_FEATURE_PIDFD_STORE = $(shell if criu check --feature pidfd_store > /dev/nu
99
export CRIU_FEATURE_MEM_TRACK CRIU_FEATURE_LAZY_PAGES CRIU_FEATURE_PIDFD_STORE
1010

1111
TEST_PAYLOAD := piggie/piggie
12-
TEST_BINARIES := test $(TEST_PAYLOAD) phaul/phaul
13-
COVERAGE_BINARIES := test.coverage phaul/phaul.coverage crit/crit-test.coverage
12+
TEST_BINARIES := test $(TEST_PAYLOAD) phaul/phaul inheritfd/inheritfd
13+
COVERAGE_BINARIES := test.coverage phaul/phaul.coverage crit/crit-test.coverage inheritfd/inheritfd.coverage
1414

1515
all: $(TEST_BINARIES) phaul-test
1616
mkdir -p image
@@ -29,6 +29,9 @@ test: main.go
2929
phaul/phaul: phaul/main.go
3030
$(GO) build -v -o $@ $^
3131

32+
inheritfd/inheritfd: inheritfd/main.go
33+
$(GO) build -v -o $@ $^
34+
3235
phaul-test: $(TEST_BINARIES)
3336
rm -rf image
3437
PID=$$(piggie/piggie) && \
@@ -58,6 +61,11 @@ crit/crit-test.coverage: crit/*.go
5861
-cover \
5962
-o $@ crit/main.go
6063

64+
inheritfd/inheritfd.coverage: inheritfd/*.go
65+
$(GO) build \
66+
-cover \
67+
-o $@ inheritfd/main.go
68+
6169
coverage: $(COVERAGE_BINARIES) $(TEST_PAYLOAD)
6270
mkdir -p $(COVERAGE_PATH)
6371
mkdir -p image
@@ -70,6 +78,7 @@ coverage: $(COVERAGE_BINARIES) $(TEST_PAYLOAD)
7078
trap 'pkill -9 piggie' INT EXIT TERM && \
7179
GOCOVERDIR=${COVERAGE_PATH} phaul/phaul.coverage $$PID
7280
cd crit/ && GOCOVERDIR=${COVERAGE_PATH} ./crit-test.coverage
81+
cd inheritfd/ && GOCOVERDIR=${COVERAGE_PATH} ./inheritfd.coverage
7382
$(MAKE) -C ../crit/ unit-test GOFLAGS="-coverprofile=${COVERAGE_PATH}/coverprofile-crit-unit-test"
7483
$(MAKE) -C crit/ e2e-test GOCOVERDIR=${COVERAGE_PATH}
7584
$(MAKE) -C crit/ clean

test/inheritfd/main.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"os"
7+
"syscall"
8+
"time"
9+
10+
"github.com/checkpoint-restore/go-criu/v8"
11+
"github.com/checkpoint-restore/go-criu/v8/rpc"
12+
"google.golang.org/protobuf/proto"
13+
)
14+
15+
func getSwrkPid() int {
16+
myPid := os.Getpid()
17+
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/task/%d/children", myPid, myPid))
18+
if err != nil {
19+
return 0
20+
}
21+
var swrkPid int
22+
if _, err := fmt.Sscanf(string(data), "%d", &swrkPid); err != nil {
23+
return 0
24+
}
25+
return swrkPid
26+
}
27+
28+
func waitForSwrkPid() int {
29+
for i := 0; i < 50; i++ {
30+
swrkPid := getSwrkPid()
31+
if swrkPid != 0 {
32+
return swrkPid
33+
}
34+
time.Sleep(10 * time.Millisecond)
35+
}
36+
return 0
37+
}
38+
39+
func swrkHasInode(swrkPid int, ino uint64) (bool, error) {
40+
entries, err := os.ReadDir(fmt.Sprintf("/proc/%d/fd", swrkPid))
41+
if err != nil {
42+
return false, err
43+
}
44+
for _, entry := range entries {
45+
var stat syscall.Stat_t
46+
if err := syscall.Stat(fmt.Sprintf("/proc/%d/fd/%s", swrkPid, entry.Name()), &stat); err != nil {
47+
continue
48+
}
49+
if stat.Ino == ino {
50+
return true, nil
51+
}
52+
}
53+
return false, nil
54+
}
55+
56+
func testInheritFd(netnsIno uint64, netnsFile *os.File) {
57+
c := criu.MakeCriu()
58+
c.AddInheritFd("net", netnsFile)
59+
if err := c.Prepare(); err != nil {
60+
log.Fatalln(err)
61+
}
62+
63+
swrkPid := waitForSwrkPid()
64+
if swrkPid == 0 {
65+
log.Fatalln("no swrk pid found")
66+
}
67+
has, err := swrkHasInode(swrkPid, netnsIno)
68+
if err != nil {
69+
log.Fatalln(err)
70+
}
71+
if !has {
72+
log.Fatalln("fd not inherited with AddInheritFd")
73+
}
74+
75+
// Send a dummy RPC so swrk exits cleanly
76+
if _, err := c.GetCriuVersion(); err != nil {
77+
log.Fatalln(err)
78+
}
79+
80+
if err := c.Cleanup(); err != nil {
81+
log.Fatalln(err)
82+
}
83+
}
84+
85+
func testNoInheritFd(netnsIno uint64) {
86+
c := criu.MakeCriu()
87+
if err := c.Prepare(); err != nil {
88+
log.Fatalln(err)
89+
}
90+
91+
swrkPid := waitForSwrkPid()
92+
if swrkPid == 0 {
93+
log.Fatalln("no swrk pid found")
94+
}
95+
has, err := swrkHasInode(swrkPid, netnsIno)
96+
if err != nil {
97+
log.Fatalln(err)
98+
}
99+
if has {
100+
log.Fatalln("fd incorrectly inherited without AddInheritFd")
101+
}
102+
103+
// Send a dummy RPC so swrk exits cleanly
104+
if _, err := c.GetCriuVersion(); err != nil {
105+
log.Fatalln(err)
106+
}
107+
108+
if err := c.Cleanup(); err != nil {
109+
log.Fatalln(err)
110+
}
111+
}
112+
113+
func testAutoPopulateInheritFd(netnsFile *os.File) {
114+
c := criu.MakeCriu()
115+
c.AddInheritFd("testKey", netnsFile)
116+
117+
opts := &rpc.CriuOpts{
118+
ImagesDir: proto.String("/nonexistent"),
119+
}
120+
121+
// Call will fail (no images), but ensureInheritFd() runs first
122+
_ = c.PreDump(opts, nil)
123+
124+
// Verify opts.InheritFd was auto-populated
125+
inheritFds := opts.GetInheritFd()
126+
if len(inheritFds) != 1 {
127+
log.Fatalf("opts.InheritFd not auto-populated: got %d, want 1", len(inheritFds))
128+
}
129+
if inheritFds[0].GetKey() != "testKey" || inheritFds[0].GetFd() != 3 {
130+
log.Fatalf("opts.InheritFd wrong: key=%s fd=%d", inheritFds[0].GetKey(), inheritFds[0].GetFd())
131+
}
132+
}
133+
134+
// Usage: test-inheritfd
135+
func main() {
136+
netnsFile, err := os.Open("/proc/self/ns/net")
137+
if err != nil {
138+
log.Fatalln(err)
139+
}
140+
defer netnsFile.Close()
141+
142+
var stat syscall.Stat_t
143+
if err := syscall.Fstat(int(netnsFile.Fd()), &stat); err != nil {
144+
log.Fatalln(err)
145+
}
146+
netnsIno := stat.Ino
147+
148+
testNoInheritFd(netnsIno)
149+
testInheritFd(netnsIno, netnsFile)
150+
testAutoPopulateInheritFd(netnsFile)
151+
log.Println("PASS")
152+
}

0 commit comments

Comments
 (0)