Skip to content

Commit 2bdcc43

Browse files
committed
fix: harden image manager and split CI tests by platform
Amp-Thread-ID: https://ampcode.com/threads/T-019c7d89-b268-73f7-9375-68d0c1c83103
1 parent 1c7f948 commit 2bdcc43

File tree

7 files changed

+334
-15
lines changed

7 files changed

+334
-15
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ jobs:
1414
strategy:
1515
fail-fast: false
1616
matrix:
17-
os:
18-
- ubuntu-latest
19-
- macos-latest
17+
include:
18+
- os: ubuntu-latest
19+
test_task: test-full
20+
- os: macos-latest
21+
test_task: test
2022
steps:
2123
- name: Checkout
2224
uses: actions/checkout@v4
@@ -30,4 +32,4 @@ jobs:
3032
run: mise run build
3133

3234
- name: Test
33-
run: mise run test
35+
run: mise run ${{ matrix.test_task }}

.mise.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ go = "1.23"
55
GOTOOLCHAIN = "local"
66

77
[tasks.test]
8-
description = "Run Go test suite"
8+
description = "Run cross-platform Go test suite"
9+
run = "go test ./..."
10+
11+
[tasks.test-full]
12+
description = "Run full Go test suite"
913
run = "go test ./..."
1014

1115
[tasks.build]

internal/backend/firecracker/backend.go

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,28 @@ import (
2525
fcvsock "github.com/firecracker-microvm/firecracker-go-sdk/vsock"
2626
)
2727

28-
type Adapter struct{}
28+
type imageEnsurer interface {
29+
Ensure(context.Context, string) (imagemgr.EnsureResult, error)
30+
}
31+
32+
type imageManagerFactory func() (imageEnsurer, error)
33+
34+
type Adapter struct {
35+
imageManagerOnce sync.Once
36+
imageManager imageEnsurer
37+
imageManagerErr error
38+
newImageManager imageManagerFactory
39+
}
2940

3041
const runObservabilityFile = "run-observability.json"
3142
const vsockDialRetryInterval = 50 * time.Millisecond
3243

3344
func New() *Adapter {
34-
return &Adapter{}
45+
return &Adapter{newImageManager: defaultImageManagerFactory}
46+
}
47+
48+
func defaultImageManagerFactory() (imageEnsurer, error) {
49+
return imagemgr.New(imagemgr.Options{})
3550
}
3651

3752
func (a *Adapter) Name() string {
@@ -280,7 +295,7 @@ func (a *Adapter) run(ctx context.Context, req backend.RunRequest, stream backen
280295
return nil, fmt.Errorf("kernel image %s: %w", kernelPath, err)
281296
}
282297

283-
imageArtifact, err := ensureImageArtifact(ctx, req.Policy.ImageRef)
298+
imageArtifact, err := a.ensureImageArtifact(ctx, req.Policy.ImageRef)
284299
if err != nil {
285300
return nil, err
286301
}
@@ -570,13 +585,13 @@ type imageArtifact struct {
570585
CacheHit bool
571586
}
572587

573-
func ensureImageArtifact(ctx context.Context, imageRef string) (imageArtifact, error) {
588+
func (a *Adapter) ensureImageArtifact(ctx context.Context, imageRef string) (imageArtifact, error) {
574589
trimmedRef := strings.TrimSpace(imageRef)
575590
if trimmedRef == "" {
576591
return imageArtifact{}, errors.New("sandbox.image.ref is required for launched execution")
577592
}
578593

579-
mgr, err := imagemgr.New(imagemgr.Options{})
594+
mgr, err := a.getImageManager()
580595
if err != nil {
581596
return imageArtifact{}, fmt.Errorf("initialise image manager: %w", err)
582597
}
@@ -594,6 +609,22 @@ func ensureImageArtifact(ctx context.Context, imageRef string) (imageArtifact, e
594609
}, nil
595610
}
596611

612+
func (a *Adapter) getImageManager() (imageEnsurer, error) {
613+
if a.newImageManager == nil {
614+
a.newImageManager = defaultImageManagerFactory
615+
}
616+
a.imageManagerOnce.Do(func() {
617+
a.imageManager, a.imageManagerErr = a.newImageManager()
618+
})
619+
if a.imageManagerErr != nil {
620+
return nil, a.imageManagerErr
621+
}
622+
if a.imageManager == nil {
623+
return nil, errors.New("image manager factory returned nil manager")
624+
}
625+
return a.imageManager, nil
626+
}
627+
597628
func runResultMessage(base string) string {
598629
return base + "; rootfs writes discarded after run"
599630
}

internal/imagemgr/imagemgr.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,18 +328,22 @@ func (m *Manager) persistFromTarStream(ctx context.Context, req persistFromTarRe
328328
}
329329

330330
outputPath := filepath.Join(m.cacheDir, strings.TrimPrefix(req.Digest, "sha256:")+".ext4")
331-
tmpPath := outputPath + ".tmp"
332-
if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) {
333-
return Record{}, fmt.Errorf("remove stale image temp file %q: %w", tmpPath, err)
331+
tmpFile, err := os.CreateTemp(m.cacheDir, strings.TrimPrefix(req.Digest, "sha256:")+".tmp-*.ext4")
332+
if err != nil {
333+
return Record{}, fmt.Errorf("create temporary image artifact for %q: %w", req.Digest, err)
334+
}
335+
tmpPath := tmpFile.Name()
336+
if err := tmpFile.Close(); err != nil {
337+
_ = os.Remove(tmpPath)
338+
return Record{}, fmt.Errorf("close temporary image artifact file %q: %w", tmpPath, err)
334339
}
340+
defer os.Remove(tmpPath)
335341

336342
sizeBytes, err := m.materialize(ctx, req.TarStream, tmpPath)
337343
if err != nil {
338-
_ = os.Remove(tmpPath)
339344
return Record{}, err
340345
}
341346
if err := os.Rename(tmpPath, outputPath); err != nil {
342-
_ = os.Remove(tmpPath)
343347
return Record{}, fmt.Errorf("move image artifact to cache %q: %w", outputPath, err)
344348
}
345349

@@ -354,6 +358,7 @@ func (m *Manager) persistFromTarStream(ctx context.Context, req persistFromTarRe
354358
OCIConfig: req.OCIConfig,
355359
}
356360
if err := m.upsertRecord(ctx, record); err != nil {
361+
_ = os.Remove(outputPath)
357362
return Record{}, err
358363
}
359364

internal/imagemgr/imagemgr_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"archive/tar"
55
"bytes"
66
"context"
7+
"errors"
78
"io"
89
"os"
910
"path/filepath"
@@ -119,6 +120,110 @@ func TestRemoveByRefSelector(t *testing.T) {
119120
}
120121
}
121122

123+
func TestPersistFromTarStreamUsesUniqueTempPaths(t *testing.T) {
124+
t.Parallel()
125+
126+
cacheDir := filepath.Join(t.TempDir(), "cache")
127+
dbPath := filepath.Join(t.TempDir(), "state", "metadata.db")
128+
if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil {
129+
t.Fatalf("create state dir: %v", err)
130+
}
131+
132+
var outputPaths []string
133+
manager, err := New(Options{
134+
CacheDir: cacheDir,
135+
MetadataDBPath: dbPath,
136+
MaterializeRootFS: func(_ context.Context, _ io.Reader, outputPath string) (int64, error) {
137+
outputPaths = append(outputPaths, outputPath)
138+
return 0, errors.New("materialise fail")
139+
},
140+
})
141+
if err != nil {
142+
t.Fatalf("create manager: %v", err)
143+
}
144+
145+
now := time.Unix(1_700_000_001, 0).UTC()
146+
req := persistFromTarRequest{
147+
Ref: testImageRef,
148+
Digest: "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
149+
TarStream: bytes.NewReader(testRootFSTar(t)),
150+
Source: "import",
151+
CreatedAt: now,
152+
LastUsedAt: now,
153+
}
154+
155+
if _, err := manager.persistFromTarStream(context.Background(), req); err == nil {
156+
t.Fatal("expected first persistFromTarStream to fail")
157+
}
158+
req.TarStream = bytes.NewReader(testRootFSTar(t))
159+
if _, err := manager.persistFromTarStream(context.Background(), req); err == nil {
160+
t.Fatal("expected second persistFromTarStream to fail")
161+
}
162+
163+
if len(outputPaths) != 2 {
164+
t.Fatalf("expected two materialise attempts, got %d", len(outputPaths))
165+
}
166+
if outputPaths[0] == outputPaths[1] {
167+
t.Fatalf("expected unique temporary output paths, got %q twice", outputPaths[0])
168+
}
169+
}
170+
171+
func TestPersistFromTarStreamRemovesArtifactWhenMetadataWriteFails(t *testing.T) {
172+
t.Parallel()
173+
174+
cacheDir := filepath.Join(t.TempDir(), "cache")
175+
stateDir := filepath.Join(t.TempDir(), "state")
176+
dbPath := filepath.Join(stateDir, "metadata.db")
177+
if err := os.MkdirAll(stateDir, 0o755); err != nil {
178+
t.Fatalf("create state dir: %v", err)
179+
}
180+
181+
manager, err := New(Options{
182+
CacheDir: cacheDir,
183+
MetadataDBPath: dbPath,
184+
MaterializeRootFS: func(_ context.Context, _ io.Reader, outputPath string) (int64, error) {
185+
if err := os.WriteFile(outputPath, []byte("fake-ext4"), 0o644); err != nil {
186+
return 0, err
187+
}
188+
if err := os.Chmod(dbPath, 0o444); err != nil {
189+
return 0, err
190+
}
191+
if err := os.Chmod(stateDir, 0o555); err != nil {
192+
return 0, err
193+
}
194+
return int64(len("fake-ext4")), nil
195+
},
196+
})
197+
if err != nil {
198+
t.Fatalf("create manager: %v", err)
199+
}
200+
t.Cleanup(func() {
201+
_ = os.Chmod(stateDir, 0o755)
202+
_ = os.Chmod(dbPath, 0o644)
203+
})
204+
205+
digest := "sha256:bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
206+
now := time.Unix(1_700_000_002, 0).UTC()
207+
req := persistFromTarRequest{
208+
Ref: "ghcr.io/buildkite/cleanroom-base/alpine@" + digest,
209+
Digest: digest,
210+
TarStream: bytes.NewReader(testRootFSTar(t)),
211+
Source: "import",
212+
CreatedAt: now,
213+
LastUsedAt: now,
214+
}
215+
216+
_, err = manager.persistFromTarStream(context.Background(), req)
217+
if err == nil {
218+
t.Fatal("expected persistFromTarStream to fail when metadata write fails")
219+
}
220+
221+
artifactPath := filepath.Join(cacheDir, strings.TrimPrefix(digest, "sha256:")+".ext4")
222+
if _, statErr := os.Stat(artifactPath); !os.IsNotExist(statErr) {
223+
t.Fatalf("expected artifact %s to be removed after metadata failure, stat err=%v", artifactPath, statErr)
224+
}
225+
}
226+
122227
func newTestManager(t *testing.T, pullFn func(context.Context, string) (io.ReadCloser, OCIConfig, error)) *Manager {
123228
t.Helper()
124229

internal/imagemgr/materialize.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package imagemgr
33
import (
44
"archive/tar"
55
"context"
6+
"errors"
67
"fmt"
78
"io"
89
"os"
@@ -125,13 +126,22 @@ func extractTar(root string, stream io.Reader) error {
125126

126127
switch hdr.Typeflag {
127128
case tar.TypeDir:
129+
if err := ensureNoSymlinkPath(root, targetPath, true); err != nil {
130+
return err
131+
}
128132
if err := os.MkdirAll(targetPath, os.FileMode(hdr.Mode)); err != nil {
129133
return fmt.Errorf("create directory %q from tar stream: %w", targetPath, err)
130134
}
131135
case tar.TypeReg, tar.TypeRegA:
136+
if err := ensureNoSymlinkPath(root, filepath.Dir(targetPath), true); err != nil {
137+
return err
138+
}
132139
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
133140
return fmt.Errorf("create parent directory for %q: %w", targetPath, err)
134141
}
142+
if err := ensureNoSymlinkPath(root, targetPath, true); err != nil {
143+
return err
144+
}
135145
f, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, os.FileMode(hdr.Mode))
136146
if err != nil {
137147
return fmt.Errorf("create file %q from tar stream: %w", targetPath, err)
@@ -144,9 +154,18 @@ func extractTar(root string, stream io.Reader) error {
144154
return fmt.Errorf("close file %q from tar stream: %w", targetPath, err)
145155
}
146156
case tar.TypeSymlink:
157+
if err := ensureNoSymlinkPath(root, filepath.Dir(targetPath), true); err != nil {
158+
return err
159+
}
147160
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
148161
return fmt.Errorf("create parent directory for symlink %q: %w", targetPath, err)
149162
}
163+
if err := validateSymlinkTarget(root, targetPath, hdr.Linkname); err != nil {
164+
return err
165+
}
166+
if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) {
167+
return fmt.Errorf("remove existing symlink path %q: %w", targetPath, err)
168+
}
150169
if err := os.Symlink(hdr.Linkname, targetPath); err != nil && !os.IsExist(err) {
151170
return fmt.Errorf("create symlink %q -> %q from tar stream: %w", targetPath, hdr.Linkname, err)
152171
}
@@ -155,9 +174,21 @@ func extractTar(root string, stream io.Reader) error {
155174
if err != nil {
156175
return err
157176
}
177+
if err := ensureNoSymlinkPath(root, linkTarget, false); err != nil {
178+
return err
179+
}
180+
if err := ensureNoSymlinkPath(root, filepath.Dir(targetPath), true); err != nil {
181+
return err
182+
}
158183
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
159184
return fmt.Errorf("create parent directory for hard link %q: %w", targetPath, err)
160185
}
186+
if err := ensureNoSymlinkPath(root, targetPath, true); err != nil {
187+
return err
188+
}
189+
if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) {
190+
return fmt.Errorf("remove existing hardlink path %q: %w", targetPath, err)
191+
}
161192
if err := os.Link(linkTarget, targetPath); err != nil {
162193
return fmt.Errorf("create hard link %q -> %q from tar stream: %w", targetPath, linkTarget, err)
163194
}
@@ -182,3 +213,54 @@ func safeJoin(root, name string) (string, error) {
182213
}
183214
return joined, nil
184215
}
216+
217+
func ensureNoSymlinkPath(root, target string, allowMissingLeaf bool) error {
218+
rel, err := filepath.Rel(root, target)
219+
if err != nil {
220+
return fmt.Errorf("resolve relative path for %q: %w", target, err)
221+
}
222+
if rel == "." {
223+
return nil
224+
}
225+
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
226+
return fmt.Errorf("refusing path outside extraction root %q", target)
227+
}
228+
229+
parts := strings.Split(rel, string(filepath.Separator))
230+
current := root
231+
for i, part := range parts {
232+
if part == "" || part == "." {
233+
continue
234+
}
235+
current = filepath.Join(current, part)
236+
info, statErr := os.Lstat(current)
237+
if statErr != nil {
238+
if errors.Is(statErr, os.ErrNotExist) {
239+
if i == len(parts)-1 && allowMissingLeaf {
240+
return nil
241+
}
242+
continue
243+
}
244+
return fmt.Errorf("inspect path %q: %w", current, statErr)
245+
}
246+
if info.Mode()&os.ModeSymlink != 0 {
247+
return fmt.Errorf("refusing archive entry that traverses symlink path component %q", current)
248+
}
249+
}
250+
return nil
251+
}
252+
253+
func validateSymlinkTarget(root, targetPath, linkName string) error {
254+
if strings.TrimSpace(linkName) == "" {
255+
return fmt.Errorf("refusing symlink %q with empty target", targetPath)
256+
}
257+
resolved := filepath.Clean(filepath.Join(filepath.Dir(targetPath), linkName))
258+
rel, err := filepath.Rel(root, resolved)
259+
if err != nil {
260+
return fmt.Errorf("resolve symlink target for %q: %w", targetPath, err)
261+
}
262+
if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) {
263+
return fmt.Errorf("refusing symlink %q -> %q that resolves outside extraction root", targetPath, linkName)
264+
}
265+
return nil
266+
}

0 commit comments

Comments
 (0)