Skip to content

Commit 62f7b10

Browse files
authored
fix: register defers immediately in scan.go (#1354)
* fix: register defers immediately in scan.go Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * add FD leak tests Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --------- Signed-off-by: egibs <20933572+egibs@users.noreply.github.com>
1 parent 303b44a commit 62f7b10

File tree

4 files changed

+177
-13
lines changed

4 files changed

+177
-13
lines changed

pkg/action/archive_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ func TestExtractionMultiple(t *testing.T) {
8383
if err != nil {
8484
t.Fatal(err)
8585
}
86+
defer os.RemoveAll(dir)
8687
dirFiles, err := os.ReadDir(dir)
8788
if err != nil {
8889
t.Fatal(err)
@@ -110,6 +111,7 @@ func TestExtractTar(t *testing.T) {
110111
if err != nil {
111112
t.Fatal(err)
112113
}
114+
defer os.RemoveAll(dir)
113115
want := []string{
114116
"apko_0.13.2_linux_arm64",
115117
}
@@ -138,6 +140,7 @@ func TestExtractGzip(t *testing.T) {
138140
if err != nil {
139141
t.Fatal(err)
140142
}
143+
defer os.RemoveAll(dir)
141144
want := []string{
142145
"apko",
143146
}
@@ -166,6 +169,7 @@ func TestExtractZip(t *testing.T) {
166169
if err != nil {
167170
t.Fatal(err)
168171
}
172+
defer os.RemoveAll(dir)
169173
want := []string{
170174
"apko_0.13.2_linux_arm64",
171175
}
@@ -194,6 +198,7 @@ func TestExtractNestedArchive(t *testing.T) {
194198
if err != nil {
195199
t.Fatal(err)
196200
}
201+
defer os.RemoveAll(dir)
197202
want := []string{
198203
"apko_0.13.2_linux_arm64",
199204
}

pkg/action/scan.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
6262

6363
isArchive := archiveRoot != ""
6464

65-
f, err := os.Open(path)
66-
if err != nil {
67-
return nil, err
68-
}
69-
70-
fi, err := f.Stat()
65+
fi, err := os.Stat(path)
7166
if err != nil {
7267
return nil, err
7368
}
@@ -128,6 +123,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
128123
scannerPool = pool.NewScannerPool(yrs, getMaxConcurrency(runtime.GOMAXPROCS(0)))
129124
})
130125
scanner := scannerPool.Get(yrs)
126+
defer scannerPool.Put(scanner)
131127

132128
mrs, err := scanner.ScanFile(path)
133129
if err != nil {
@@ -150,8 +146,14 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
150146
// create a buffer sized to the minimum of the file's size or the default ReadBuffer
151147
// only do so if we actually need to retrieve the file's contents
152148
buf := readPool.Get(min(size, file.ReadBuffer)) //nolint:nilaway // the buffer pool is initialized in init()
149+
defer readPool.Put(buf)
150+
151+
f, err := os.Open(path)
152+
if err != nil {
153+
return nil, err
154+
}
155+
defer f.Close()
153156

154-
// Only retrieve the file's contents and calculate its checksum if we need to generate a report
155157
fc, err := file.GetContents(f, buf)
156158
if err != nil {
157159
return nil, err
@@ -169,12 +171,6 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
169171
return nil, NewFileReportError(err, path, TypeGenerateError)
170172
}
171173

172-
defer func() {
173-
f.Close()
174-
readPool.Put(buf)
175-
scannerPool.Put(scanner)
176-
}()
177-
178174
// Clean up the path if scanning an archive
179175
var clean string
180176
if isArchive || c.OCI {

pkg/action/scan_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Copyright 2026 Chainguard, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package action
5+
6+
import (
7+
"context"
8+
"io/fs"
9+
"os"
10+
"path/filepath"
11+
"runtime"
12+
"testing"
13+
14+
"github.com/chainguard-dev/malcontent/pkg/malcontent"
15+
"github.com/chainguard-dev/malcontent/rules"
16+
thirdparty "github.com/chainguard-dev/malcontent/third_party"
17+
)
18+
19+
// countOpenFDs returns the number of open file descriptors for the current process.
20+
// Returns -1 if unable to count (e.g., on unsupported platforms).
21+
func countOpenFDs(t *testing.T) int {
22+
t.Helper()
23+
24+
// Linux: count entries in /proc/self/fd
25+
if entries, err := os.ReadDir("/proc/self/fd"); err == nil {
26+
return len(entries)
27+
}
28+
29+
// macOS: count entries in /dev/fd
30+
if entries, err := os.ReadDir("/dev/fd"); err == nil {
31+
return len(entries)
32+
}
33+
34+
return -1
35+
}
36+
37+
// TestScanSinglePathNoFDLeak verifies that early return paths in scanSinglePath
38+
// properly close file handles and don't leak file descriptors.
39+
func TestScanSinglePathNoFDLeak(t *testing.T) {
40+
ctx := context.Background()
41+
42+
fdsBefore := countOpenFDs(t)
43+
if fdsBefore == -1 {
44+
t.Skip("cannot count file descriptors on this platform")
45+
}
46+
47+
rfs := []fs.FS{rules.FS, thirdparty.FS}
48+
yrs, err := CachedRules(ctx, rfs)
49+
if err != nil {
50+
t.Fatalf("rules: %v", err)
51+
}
52+
53+
cfg := malcontent.Config{
54+
Concurrency: runtime.NumCPU(),
55+
IgnoreSelf: false,
56+
IncludeDataFiles: false,
57+
MinFileRisk: 0,
58+
MinRisk: 0,
59+
Rules: yrs,
60+
RuleFS: rfs,
61+
}
62+
63+
testFiles := []string{
64+
filepath.Join("testdata", "empty"),
65+
filepath.Join("testdata", "rando"),
66+
filepath.Join("testdata", "short"),
67+
}
68+
69+
iterations := runtime.GOMAXPROCS(0) * 10
70+
for range iterations {
71+
for _, tf := range testFiles {
72+
_, _ = scanSinglePath(ctx, cfg, tf, rfs, tf, "", nil)
73+
}
74+
}
75+
76+
runtime.GC()
77+
78+
fdsAfter := countOpenFDs(t)
79+
80+
maxAllowedGrowth := 0
81+
leaked := fdsAfter - fdsBefore
82+
if leaked > maxAllowedGrowth {
83+
t.Errorf("file descriptor leak detected: before=%d after=%d leaked=%d (ran %d iterations)",
84+
fdsBefore, fdsAfter, leaked, iterations*len(testFiles))
85+
}
86+
}
87+
88+
// TestScanSinglePathNonExistentFile verifies that scanning a non-existent file
89+
// returns an error without leaking resources.
90+
func TestScanSinglePathNonExistentFile(t *testing.T) {
91+
ctx := context.Background()
92+
93+
fdsBefore := countOpenFDs(t)
94+
if fdsBefore == -1 {
95+
t.Skip("cannot count file descriptors on this platform")
96+
}
97+
98+
rfs := []fs.FS{rules.FS, thirdparty.FS}
99+
yrs, err := CachedRules(ctx, rfs)
100+
if err != nil {
101+
t.Fatalf("rules: %v", err)
102+
}
103+
104+
cfg := malcontent.Config{
105+
Rules: yrs,
106+
RuleFS: rfs,
107+
}
108+
109+
iterations := runtime.GOMAXPROCS(0) * 10
110+
for range iterations {
111+
_, err := scanSinglePath(ctx, cfg, "/nonexistent/path/to/file", rfs, "", "", nil)
112+
if err == nil {
113+
t.Error("expected error for non-existent file")
114+
}
115+
}
116+
117+
runtime.GC()
118+
119+
fdsAfter := countOpenFDs(t)
120+
maxAllowedGrowth := 0
121+
leaked := fdsAfter - fdsBefore
122+
if leaked > maxAllowedGrowth {
123+
t.Errorf("file descriptor leak on error path: before=%d after=%d leaked=%d",
124+
fdsBefore, fdsAfter, leaked)
125+
}
126+
}
127+
128+
// TestScanRepeatedScansNoResourceExhaustion verifies that repeated scans
129+
// don't exhaust scanner pool or buffer pool resources.
130+
func TestScanRepeatedScansNoResourceExhaustion(t *testing.T) {
131+
ctx := context.Background()
132+
133+
rfs := []fs.FS{rules.FS, thirdparty.FS}
134+
yrs, err := CachedRules(ctx, rfs)
135+
if err != nil {
136+
t.Fatalf("rules: %v", err)
137+
}
138+
139+
cfg := malcontent.Config{
140+
Concurrency: runtime.NumCPU(),
141+
IgnoreSelf: false,
142+
IncludeDataFiles: false,
143+
MinFileRisk: 0,
144+
MinRisk: 0,
145+
Rules: yrs,
146+
RuleFS: rfs,
147+
}
148+
149+
testFiles := []string{
150+
filepath.Join("testdata", "empty"), // zero-sized, early return before scanner
151+
filepath.Join("testdata", "rando"), // data file, early return before scanner
152+
filepath.Join("testdata", "shell"), // actual script, full scan path
153+
}
154+
155+
iterations := runtime.GOMAXPROCS(0) * 10
156+
157+
for range iterations {
158+
for _, tf := range testFiles {
159+
_, _ = scanSinglePath(ctx, cfg, tf, rfs, tf, "", nil)
160+
}
161+
}
162+
}

pkg/archive/rpm.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ func ExtractRPM(ctx context.Context, d, f string) error {
126126
if err != nil {
127127
return fmt.Errorf("failed to create zstd reader: %w", err)
128128
}
129+
defer zstdStream.Close()
129130
cr = cpio.NewReader(zstdStream)
130131
default:
131132
return fmt.Errorf("unsupported compression format: %s", compression)

0 commit comments

Comments
 (0)