Skip to content

Commit cc64d1d

Browse files
authored
Merge pull request #224 from dfeigin-nv/223-fix-inheritfd
Fix InheritFd by passing fds via cmd.ExtraFiles
2 parents 4c61fef + fe310f3 commit cc64d1d

File tree

3 files changed

+232
-6
lines changed

3 files changed

+232
-6
lines changed

main.go

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,24 @@ import (
55
"fmt"
66
"os"
77
"os/exec"
8+
"sort"
89
"strconv"
910
"syscall"
1011

1112
"github.com/checkpoint-restore/go-criu/v8/rpc"
1213
"google.golang.org/protobuf/proto"
1314
)
1415

16+
// extraFilesStartFd is the first fd number assigned to cmd.ExtraFiles by os/exec.
17+
// As documented in os/exec: "entry i becomes file descriptor 3+i"
18+
const extraFilesStartFd = 3
19+
1520
// Criu struct
1621
type Criu struct {
17-
swrkCmd *exec.Cmd
18-
swrkSk *os.File
19-
swrkPath string
22+
swrkCmd *exec.Cmd
23+
swrkSk *os.File
24+
swrkPath string
25+
inheritFds map[string]*os.File
2026
}
2127

2228
// MakeCriu returns the Criu object required for most operations
@@ -32,8 +38,52 @@ func (c *Criu) SetCriuPath(path string) {
3238
c.swrkPath = path
3339
}
3440

41+
// AddInheritFd registers a file descriptor to be passed to CRIU.
42+
// If opts.InheritFd is not set for an operation, it will be populated
43+
// from these registrations using the same key order.
44+
func (c *Criu) AddInheritFd(key string, file *os.File) {
45+
if c.inheritFds == nil {
46+
c.inheritFds = make(map[string]*os.File)
47+
}
48+
c.inheritFds[key] = file
49+
}
50+
51+
func (c *Criu) inheritFdKeys() []string {
52+
if len(c.inheritFds) == 0 {
53+
return nil
54+
}
55+
keys := make([]string, 0, len(c.inheritFds))
56+
for k := range c.inheritFds {
57+
keys = append(keys, k)
58+
}
59+
sort.Strings(keys)
60+
return keys
61+
}
62+
63+
func (c *Criu) ensureInheritFd(opts *rpc.CriuOpts) {
64+
if opts == nil || len(opts.GetInheritFd()) > 0 || len(c.inheritFds) == 0 {
65+
return
66+
}
67+
keys := c.inheritFdKeys()
68+
if len(keys) == 0 {
69+
return
70+
}
71+
opts.InheritFd = make([]*rpc.InheritFd, 0, len(keys))
72+
for i, key := range keys {
73+
fd := int32(extraFilesStartFd + i)
74+
opts.InheritFd = append(opts.InheritFd, &rpc.InheritFd{
75+
Key: proto.String(key),
76+
Fd: proto.Int32(fd),
77+
})
78+
}
79+
}
80+
3581
// Prepare sets up everything for the RPC communication to CRIU
3682
func (c *Criu) Prepare() error {
83+
return c.doPrepare(nil)
84+
}
85+
86+
func (c *Criu) doPrepare(opts *rpc.CriuOpts) error {
3787
fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_SEQPACKET, 0)
3888
if err != nil {
3989
return err
@@ -48,6 +98,19 @@ func (c *Criu) Prepare() error {
4898
// #nosec G204
4999
cmd := exec.Command(c.swrkPath, args...)
50100

101+
// Collect file descriptors to pass to child
102+
inheritKeys := c.inheritFdKeys()
103+
extraFiles := make([]*os.File, 0, len(inheritKeys))
104+
105+
// Add fds from AddInheritFd (sorted for stable ordering)
106+
for _, k := range inheritKeys {
107+
extraFiles = append(extraFiles, c.inheritFds[k])
108+
}
109+
110+
c.ensureInheritFd(opts)
111+
112+
cmd.ExtraFiles = extraFiles
113+
51114
err = cmd.Start()
52115
if err != nil {
53116
_ = cln.Close()
@@ -106,6 +169,8 @@ func (c *Criu) doSwrk(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy Notify) e
106169
}
107170

108171
func (c *Criu) doSwrkWithResp(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy Notify, features *rpc.CriuFeatures) (resp *rpc.CriuResp, retErr error) {
172+
c.ensureInheritFd(opts)
173+
109174
req := rpc.CriuReq{
110175
Type: &reqType,
111176
Opts: opts,
@@ -120,7 +185,7 @@ func (c *Criu) doSwrkWithResp(reqType rpc.CriuReqType, opts *rpc.CriuOpts, nfy N
120185
}
121186

122187
if c.swrkCmd == nil {
123-
err := c.Prepare()
188+
err := c.doPrepare(opts)
124189
if err != nil {
125190
return nil, err
126191
}

test/Makefile

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ CRIU_FEATURE_PIDFD_STORE = $(shell if criu check --feature pidfd_store > /dev/nu
99
export CRIU_FEATURE_MEM_TRACK CRIU_FEATURE_LAZY_PAGES CRIU_FEATURE_PIDFD_STORE
1010

1111
TEST_PAYLOAD := piggie/piggie
12-
TEST_BINARIES := test $(TEST_PAYLOAD) phaul/phaul
13-
COVERAGE_BINARIES := test.coverage phaul/phaul.coverage crit/crit-test.coverage
12+
TEST_BINARIES := test $(TEST_PAYLOAD) phaul/phaul inheritfd/inheritfd
13+
COVERAGE_BINARIES := test.coverage phaul/phaul.coverage crit/crit-test.coverage inheritfd/inheritfd.coverage
1414

1515
all: $(TEST_BINARIES) phaul-test
1616
mkdir -p image
@@ -29,6 +29,9 @@ test: main.go
2929
phaul/phaul: phaul/main.go
3030
$(GO) build -v -o $@ $^
3131

32+
inheritfd/inheritfd: inheritfd/main.go
33+
$(GO) build -v -o $@ $^
34+
3235
phaul-test: $(TEST_BINARIES)
3336
rm -rf image
3437
PID=$$(piggie/piggie) && \
@@ -58,6 +61,11 @@ crit/crit-test.coverage: crit/*.go
5861
-cover \
5962
-o $@ crit/main.go
6063

64+
inheritfd/inheritfd.coverage: inheritfd/*.go
65+
$(GO) build \
66+
-cover \
67+
-o $@ inheritfd/main.go
68+
6169
coverage: $(COVERAGE_BINARIES) $(TEST_PAYLOAD)
6270
mkdir -p $(COVERAGE_PATH)
6371
mkdir -p image
@@ -70,6 +78,7 @@ coverage: $(COVERAGE_BINARIES) $(TEST_PAYLOAD)
7078
trap 'pkill -9 piggie' INT EXIT TERM && \
7179
GOCOVERDIR=${COVERAGE_PATH} phaul/phaul.coverage $$PID
7280
cd crit/ && GOCOVERDIR=${COVERAGE_PATH} ./crit-test.coverage
81+
cd inheritfd/ && GOCOVERDIR=${COVERAGE_PATH} ./inheritfd.coverage
7382
$(MAKE) -C ../crit/ unit-test GOFLAGS="-coverprofile=${COVERAGE_PATH}/coverprofile-crit-unit-test"
7483
$(MAKE) -C crit/ e2e-test GOCOVERDIR=${COVERAGE_PATH}
7584
$(MAKE) -C crit/ clean

test/inheritfd/main.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"os"
7+
"syscall"
8+
"time"
9+
10+
"github.com/checkpoint-restore/go-criu/v8"
11+
"github.com/checkpoint-restore/go-criu/v8/rpc"
12+
"google.golang.org/protobuf/proto"
13+
)
14+
15+
func getSwrkPid() int {
16+
myPid := os.Getpid()
17+
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/task/%d/children", myPid, myPid))
18+
if err != nil {
19+
return 0
20+
}
21+
var swrkPid int
22+
if _, err := fmt.Sscanf(string(data), "%d", &swrkPid); err != nil {
23+
return 0
24+
}
25+
return swrkPid
26+
}
27+
28+
func waitForSwrkPid() int {
29+
for i := 0; i < 50; i++ {
30+
swrkPid := getSwrkPid()
31+
if swrkPid != 0 {
32+
return swrkPid
33+
}
34+
time.Sleep(10 * time.Millisecond)
35+
}
36+
return 0
37+
}
38+
39+
func swrkHasInode(swrkPid int, ino uint64) (bool, error) {
40+
entries, err := os.ReadDir(fmt.Sprintf("/proc/%d/fd", swrkPid))
41+
if err != nil {
42+
return false, err
43+
}
44+
for _, entry := range entries {
45+
var stat syscall.Stat_t
46+
if err := syscall.Stat(fmt.Sprintf("/proc/%d/fd/%s", swrkPid, entry.Name()), &stat); err != nil {
47+
continue
48+
}
49+
if stat.Ino == ino {
50+
return true, nil
51+
}
52+
}
53+
return false, nil
54+
}
55+
56+
func testInheritFd(netnsIno uint64, netnsFile *os.File) {
57+
c := criu.MakeCriu()
58+
c.AddInheritFd("net", netnsFile)
59+
if err := c.Prepare(); err != nil {
60+
log.Fatalln(err)
61+
}
62+
63+
swrkPid := waitForSwrkPid()
64+
if swrkPid == 0 {
65+
log.Fatalln("no swrk pid found")
66+
}
67+
has, err := swrkHasInode(swrkPid, netnsIno)
68+
if err != nil {
69+
log.Fatalln(err)
70+
}
71+
if !has {
72+
log.Fatalln("fd not inherited with AddInheritFd")
73+
}
74+
75+
// Send a dummy RPC so swrk exits cleanly
76+
if _, err := c.GetCriuVersion(); err != nil {
77+
log.Fatalln(err)
78+
}
79+
80+
if err := c.Cleanup(); err != nil {
81+
log.Fatalln(err)
82+
}
83+
}
84+
85+
func testNoInheritFd(netnsIno uint64) {
86+
c := criu.MakeCriu()
87+
if err := c.Prepare(); err != nil {
88+
log.Fatalln(err)
89+
}
90+
91+
swrkPid := waitForSwrkPid()
92+
if swrkPid == 0 {
93+
log.Fatalln("no swrk pid found")
94+
}
95+
has, err := swrkHasInode(swrkPid, netnsIno)
96+
if err != nil {
97+
log.Fatalln(err)
98+
}
99+
if has {
100+
log.Fatalln("fd incorrectly inherited without AddInheritFd")
101+
}
102+
103+
// Send a dummy RPC so swrk exits cleanly
104+
if _, err := c.GetCriuVersion(); err != nil {
105+
log.Fatalln(err)
106+
}
107+
108+
if err := c.Cleanup(); err != nil {
109+
log.Fatalln(err)
110+
}
111+
}
112+
113+
func testAutoPopulateInheritFd(netnsFile *os.File) {
114+
c := criu.MakeCriu()
115+
c.AddInheritFd("testKey", netnsFile)
116+
117+
opts := &rpc.CriuOpts{
118+
ImagesDir: proto.String("/nonexistent"),
119+
}
120+
121+
// Call will fail (no images), but ensureInheritFd() runs first
122+
_ = c.PreDump(opts, nil)
123+
124+
// Verify opts.InheritFd was auto-populated
125+
inheritFds := opts.GetInheritFd()
126+
if len(inheritFds) != 1 {
127+
log.Fatalf("opts.InheritFd not auto-populated: got %d, want 1", len(inheritFds))
128+
}
129+
if inheritFds[0].GetKey() != "testKey" || inheritFds[0].GetFd() != 3 {
130+
log.Fatalf("opts.InheritFd wrong: key=%s fd=%d", inheritFds[0].GetKey(), inheritFds[0].GetFd())
131+
}
132+
}
133+
134+
// Usage: test-inheritfd
135+
func main() {
136+
netnsFile, err := os.Open("/proc/self/ns/net")
137+
if err != nil {
138+
log.Fatalln(err)
139+
}
140+
defer netnsFile.Close()
141+
142+
var stat syscall.Stat_t
143+
if err := syscall.Fstat(int(netnsFile.Fd()), &stat); err != nil {
144+
log.Fatalln(err)
145+
}
146+
netnsIno := stat.Ino
147+
148+
testNoInheritFd(netnsIno)
149+
testInheritFd(netnsIno, netnsFile)
150+
testAutoPopulateInheritFd(netnsFile)
151+
log.Println("PASS")
152+
}

0 commit comments

Comments
 (0)