Skip to content

Commit 85e9c6c

Browse files
authored
chore: consolidate common buffer and size values into a single package (#1245)
* chore: consolidate common buffer values into a single package Signed-off-by: egibs <[email protected]> * move the pool values as well Signed-off-by: egibs <[email protected]> * use int64 for all of the consts Signed-off-by: egibs <[email protected]> * fix comment Signed-off-by: egibs <[email protected]> * use more appropriate file variable names Signed-off-by: egibs <[email protected]> * update remaining file variable names to avoid shadowing in the future Signed-off-by: egibs <[email protected]> --------- Signed-off-by: egibs <[email protected]>
1 parent 360493b commit 85e9c6c

File tree

13 files changed

+87
-79
lines changed

13 files changed

+87
-79
lines changed

pkg/action/scan.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/chainguard-dev/clog"
2323
"github.com/chainguard-dev/malcontent/pkg/archive"
2424
"github.com/chainguard-dev/malcontent/pkg/compile"
25+
"github.com/chainguard-dev/malcontent/pkg/file"
2526
"github.com/chainguard-dev/malcontent/pkg/malcontent"
2627
"github.com/chainguard-dev/malcontent/pkg/pool"
2728
"github.com/chainguard-dev/malcontent/pkg/programkind"
@@ -41,10 +42,8 @@ var (
4142
compiledRuleCache atomic.Pointer[yarax.Rules] // compiledRuleCache are a cache of previously compiled rules.
4243
compileOnce sync.Once // compileOnce ensures that we compile rules only once even across threads.
4344
ErrMatchedCondition = errors.New("matched exit criteria")
44-
initReadPool sync.Once // initReadPool ensures that the bytes read pool is only initialized once.
45-
initScannerPool sync.Once // initScannerPool ensures that the scanner pool is only initialized once.
46-
maxBytes int64 = 1 << 32 // 4GB
47-
readBuffer int64 = 64 * 1024 // 64KB
45+
initReadPool sync.Once // initReadPool ensures that the bytes read pool is only initialized once.
46+
initScannerPool sync.Once // initScannerPool ensures that the scanner pool is only initialized once.
4847
readPool *pool.BufferPool
4948
scannerPool *pool.ScannerPool
5049
)
@@ -82,7 +81,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
8281
initReadPool.Do(func() {
8382
readPool = pool.NewBufferPool(runtime.GOMAXPROCS(0))
8483
})
85-
buf := readPool.Get(readBuffer) //nolint:nilaway // the buffer pool is created above
84+
buf := readPool.Get(file.ReadBuffer) //nolint:nilaway // the buffer pool is created above
8685

8786
mime := "<unknown>"
8887
kind, err := programkind.File(ctx, path)
@@ -141,7 +140,7 @@ func scanSinglePath(ctx context.Context, c malcontent.Config, path string, ruleF
141140

142141
// Only retrieve the file's contents and calculate its checksum if we need to generate a report
143142
var fc bytes.Buffer
144-
_, err = io.CopyBuffer(&fc, io.LimitReader(f, maxBytes), buf)
143+
_, err = io.CopyBuffer(&fc, io.LimitReader(f, file.MaxBytes), buf)
145144
if err != nil {
146145
return nil, err
147146
}

pkg/action/scan_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ func TestCleanPath(t *testing.T) {
7777
}
7878

7979
filePath := filepath.Join(nestedDir, "test.txt")
80-
file, err := os.Create(filePath)
80+
f, err := os.Create(filePath)
8181
if err != nil {
8282
t.Fatalf("failed to create file: %v", err)
8383
}
84-
file.Close()
84+
defer f.Close()
8585

8686
fullPath := filepath.Join(tempDir, tt.path)
8787
fullPrefix := filepath.Join(tempDir, tt.prefix)

pkg/archive/archive.go

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,12 @@ import (
1313
"time"
1414

1515
"github.com/chainguard-dev/clog"
16+
"github.com/chainguard-dev/malcontent/pkg/file"
1617
"github.com/chainguard-dev/malcontent/pkg/malcontent"
1718
"github.com/chainguard-dev/malcontent/pkg/pool"
1819
"github.com/chainguard-dev/malcontent/pkg/programkind"
1920
)
2021

21-
const (
22-
extractBuffer = 64 * 1024 // 64KB
23-
maxBytes = 1 << 31 // 2048MB
24-
zipBuffer = 2 * 1024 // 2KB
25-
)
26-
2722
var (
2823
archivePool, tarPool, zipPool *pool.BufferPool
2924
initializeOnce sync.Once
@@ -132,16 +127,16 @@ func extractNestedArchive(ctx context.Context, c malcontent.Config, d string, f
132127
return fmt.Errorf("failed to remove archive file: %w", err)
133128
}
134129

135-
files, err := os.ReadDir(d)
130+
entries, err := os.ReadDir(d)
136131
if err != nil {
137132
return fmt.Errorf("failed to read directory after extraction: %w", err)
138133
}
139134

140-
for _, file := range files {
135+
for _, entry := range entries {
141136
if ctx.Err() != nil {
142137
return ctx.Err()
143138
}
144-
rel := file.Name()
139+
rel := entry.Name()
145140
if _, alreadyProcessed := extracted.Load(rel); !alreadyProcessed {
146141
if err := extractNestedArchive(ctx, c, d, rel, extracted, logger); err != nil {
147142
return fmt.Errorf("process nested file %s: %w", rel, err)
@@ -263,7 +258,7 @@ func handleDirectory(target string) error {
263258

264259
// handleFile extracts valid files within .deb or .tar archives.
265260
func handleFile(target string, tr *tar.Reader) error {
266-
buf := tarPool.Get(extractBuffer) //nolint:nilaway // the buffer pool is created above
261+
buf := tarPool.Get(file.ExtractBuffer) //nolint:nilaway // the buffer pool is created above
267262
defer tarPool.Put(buf)
268263

269264
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
@@ -276,15 +271,15 @@ func handleFile(target string, tr *tar.Reader) error {
276271
}
277272
defer out.Close()
278273

279-
written, err := io.CopyBuffer(out, io.LimitReader(tr, maxBytes), buf)
274+
written, err := io.CopyBuffer(out, io.LimitReader(tr, file.MaxBytes), buf)
280275
if err != nil {
281276
if (strings.Contains(err.Error(), "unexpected EOF") && written == 0) ||
282277
!strings.Contains(err.Error(), "unexpected EOF") {
283278
return fmt.Errorf("failed to copy file: %w", err)
284279
}
285280
}
286-
if written >= maxBytes {
287-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
281+
if written >= file.MaxBytes {
282+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", file.MaxBytes, target)
288283
}
289284

290285
return nil

pkg/archive/bz2.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111

1212
"github.com/chainguard-dev/clog"
13+
"github.com/chainguard-dev/malcontent/pkg/file"
1314
bzip2 "github.com/cosnicolaou/pbzip2"
1415
)
1516

@@ -31,7 +32,7 @@ func ExtractBz2(ctx context.Context, d, f string) error {
3132
return nil
3233
}
3334

34-
buf := archivePool.Get(extractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
35+
buf := archivePool.Get(file.ExtractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
3536

3637
tf, err := os.Open(f)
3738
if err != nil {
@@ -69,15 +70,15 @@ func ExtractBz2(ctx context.Context, d, f string) error {
6970

7071
var written int64
7172
for {
72-
if written > 0 && written%extractBuffer == 0 && ctx.Err() != nil {
73+
if written > 0 && written%file.ExtractBuffer == 0 && ctx.Err() != nil {
7374
return ctx.Err()
7475
}
7576

7677
n, err := br.Read(buf)
7778
if n > 0 {
7879
written += int64(n)
79-
if written > maxBytes {
80-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
80+
if written > file.MaxBytes {
81+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", file.MaxBytes, target)
8182
}
8283
if _, writeErr := out.Write(buf[:n]); writeErr != nil {
8384
return fmt.Errorf("failed to write file contents: %w", writeErr)

pkg/archive/gzip.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"path/filepath"
1010

1111
"github.com/chainguard-dev/clog"
12+
"github.com/chainguard-dev/malcontent/pkg/file"
1213
"github.com/chainguard-dev/malcontent/pkg/programkind"
1314
gzip "github.com/klauspost/pgzip"
1415
)
@@ -53,7 +54,7 @@ func ExtractGzip(ctx context.Context, d string, f string) error {
5354
return nil
5455
}
5556

56-
buf := archivePool.Get(extractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
57+
buf := archivePool.Get(file.ExtractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
5758

5859
gf, err := os.Open(f)
5960
if err != nil {
@@ -85,15 +86,15 @@ func ExtractGzip(ctx context.Context, d string, f string) error {
8586

8687
var written int64
8788
for {
88-
if written > 0 && written%extractBuffer == 0 && ctx.Err() != nil {
89+
if written > 0 && written%file.ExtractBuffer == 0 && ctx.Err() != nil {
8990
return ctx.Err()
9091
}
9192

9293
n, err := gr.Read(buf)
9394
if n > 0 {
9495
written += int64(n)
95-
if written > maxBytes {
96-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
96+
if written > file.MaxBytes {
97+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", file.MaxBytes, target)
9798
}
9899

99100
if _, writeErr := out.Write(buf[:n]); writeErr != nil {

pkg/archive/rpm.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/cavaliergopher/cpio"
1414
"github.com/cavaliergopher/rpm"
1515
"github.com/chainguard-dev/clog"
16+
"github.com/chainguard-dev/malcontent/pkg/file"
1617
"github.com/klauspost/compress/zstd"
1718
"github.com/ulikunitz/xz"
1819
)
@@ -40,7 +41,7 @@ func ExtractRPM(ctx context.Context, d, f string) error {
4041
return nil
4142
}
4243

43-
buf := archivePool.Get(extractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
44+
buf := archivePool.Get(file.ExtractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
4445
defer archivePool.Put(buf)
4546

4647
pkg, err := rpm.Read(rpmFile)
@@ -123,15 +124,15 @@ func ExtractRPM(ctx context.Context, d, f string) error {
123124

124125
var written int64
125126
for {
126-
if written > 0 && written%extractBuffer == 0 && ctx.Err() != nil {
127+
if written > 0 && written%file.ExtractBuffer == 0 && ctx.Err() != nil {
127128
return ctx.Err()
128129
}
129130

130131
n, err := cr.Read(buf)
131132
if n > 0 {
132133
written += int64(n)
133-
if written > maxBytes {
134-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
134+
if written > file.MaxBytes {
135+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", file.MaxBytes, target)
135136
}
136137
if _, writeErr := out.Write(buf[:n]); writeErr != nil {
137138
return fmt.Errorf("failed to write file contents: %w", writeErr)

pkg/archive/tar.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"sync"
1414

1515
"github.com/chainguard-dev/clog"
16+
"github.com/chainguard-dev/malcontent/pkg/file"
1617
"github.com/chainguard-dev/malcontent/pkg/pool"
1718
"github.com/chainguard-dev/malcontent/pkg/programkind"
1819
bzip2 "github.com/cosnicolaou/pbzip2"
@@ -48,7 +49,7 @@ func ExtractTar(ctx context.Context, d string, f string) error {
4849
return nil
4950
}
5051

51-
buf := tarPool.Get(extractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
52+
buf := tarPool.Get(file.ExtractBuffer) //nolint:nilaway // the buffer pool is created in archive.go
5253

5354
filename := filepath.Base(f)
5455
tf, err := os.Open(f)
@@ -109,15 +110,15 @@ func ExtractTar(ctx context.Context, d string, f string) error {
109110

110111
var written int64
111112
for {
112-
if written > 0 && written%extractBuffer == 0 && ctx.Err() != nil {
113+
if written > 0 && written%file.ExtractBuffer == 0 && ctx.Err() != nil {
113114
return ctx.Err()
114115
}
115116

116117
n, err := xzStream.Read(buf)
117118
if n > 0 {
118119
written += int64(n)
119-
if written > maxBytes {
120-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
120+
if written > file.MaxBytes {
121+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", file.MaxBytes, target)
121122
}
122123
if _, writeErr := out.Write(buf[:n]); writeErr != nil {
123124
return fmt.Errorf("failed to write file contents: %w", writeErr)
@@ -146,15 +147,15 @@ func ExtractTar(ctx context.Context, d string, f string) error {
146147
}
147148
var written int64
148149
for {
149-
if written > 0 && written%extractBuffer == 0 && ctx.Err() != nil {
150+
if written > 0 && written%file.ExtractBuffer == 0 && ctx.Err() != nil {
150151
return ctx.Err()
151152
}
152153

153154
n, err := br.Read(buf)
154155
if n > 0 {
155156
written += int64(n)
156-
if written > maxBytes {
157-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
157+
if written > file.MaxBytes {
158+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", file.MaxBytes, target)
158159
}
159160

160161
if _, writeErr := out.Write(buf[:n]); writeErr != nil {

pkg/archive/zip.go

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
"github.com/chainguard-dev/clog"
16+
"github.com/chainguard-dev/malcontent/pkg/file"
1617
"github.com/chainguard-dev/malcontent/pkg/pool"
1718
"github.com/chainguard-dev/malcontent/pkg/programkind"
1819
zip "github.com/klauspost/compress/zip"
@@ -72,11 +73,11 @@ func ExtractZip(ctx context.Context, d string, f string) error {
7273
return fmt.Errorf("failed to create extraction directory: %w", err)
7374
}
7475

75-
for _, file := range read.File {
76-
if file.Mode().IsDir() {
77-
clean := filepath.Clean(filepath.ToSlash(file.Name))
76+
for _, zf := range read.File {
77+
if zf.Mode().IsDir() {
78+
clean := filepath.Clean(filepath.ToSlash(zf.Name))
7879
if strings.Contains(clean, "..") {
79-
logger.Warnf("skipping potentially unsafe directory path: %s", file.Name)
80+
logger.Warnf("skipping potentially unsafe directory path: %s", zf.Name)
8081
continue
8182
}
8283

@@ -95,12 +96,12 @@ func ExtractZip(ctx context.Context, d string, f string) error {
9596
g, gCtx := errgroup.WithContext(ctx)
9697
g.SetLimit(runtime.GOMAXPROCS(0))
9798

98-
for _, file := range read.File {
99-
if file.Mode().IsDir() {
99+
for _, zf := range read.File {
100+
if zf.Mode().IsDir() {
100101
continue
101102
}
102103
g.Go(func() error {
103-
return extractFile(gCtx, file, d, logger)
104+
return extractFile(gCtx, zf, d, logger)
104105
})
105106
}
106107

@@ -110,24 +111,24 @@ func ExtractZip(ctx context.Context, d string, f string) error {
110111
return nil
111112
}
112113

113-
func extractFile(ctx context.Context, file *zip.File, destDir string, logger *clog.Logger) error {
114+
func extractFile(ctx context.Context, zf *zip.File, destDir string, logger *clog.Logger) error {
114115
if ctx.Err() != nil {
115116
return ctx.Err()
116117
}
117118

118119
// macOS will encounter issues with paths like META-INF/LICENSE and META-INF/license/foo
119120
// this case insensitivity will break scans, so rename files that collide with existing directories
120121
if runtime.GOOS == "darwin" {
121-
if _, err := os.Stat(filepath.Join(destDir, file.Name)); err == nil {
122-
file.Name = fmt.Sprintf("%s%d", file.Name, time.Now().UnixNano())
122+
if _, err := os.Stat(filepath.Join(destDir, zf.Name)); err == nil {
123+
zf.Name = fmt.Sprintf("%s%d", zf.Name, time.Now().UnixNano())
123124
}
124125
}
125126

126-
buf := zipPool.Get(zipBuffer) //nolint:nilaway // the buffer pool is created in archive.go
127+
buf := zipPool.Get(file.ZipBuffer) //nolint:nilaway // the buffer pool is created in archive.go
127128

128-
clean := filepath.Clean(filepath.ToSlash(file.Name))
129+
clean := filepath.Clean(filepath.ToSlash(zf.Name))
129130
if strings.Contains(clean, "..") {
130-
logger.Warnf("skipping potentially unsafe file path: %s", file.Name)
131+
logger.Warnf("skipping potentially unsafe file path: %s", zf.Name)
131132
return nil
132133
}
133134

@@ -141,7 +142,7 @@ func extractFile(ctx context.Context, file *zip.File, destDir string, logger *cl
141142
return fmt.Errorf("failed to create directory structure: %w", err)
142143
}
143144

144-
src, err := file.Open()
145+
src, err := zf.Open()
145146
if err != nil {
146147
return fmt.Errorf("failed to open archived file: %w", err)
147148
}
@@ -159,15 +160,15 @@ func extractFile(ctx context.Context, file *zip.File, destDir string, logger *cl
159160

160161
var written int64
161162
for {
162-
if written > 0 && written%zipBuffer == 0 && ctx.Err() != nil {
163+
if written > 0 && written%file.ZipBuffer == 0 && ctx.Err() != nil {
163164
return ctx.Err()
164165
}
165166

166167
n, err := src.Read(buf)
167168
if n > 0 {
168169
written += int64(n)
169-
if written > maxBytes {
170-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
170+
if written > file.MaxBytes {
171+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", file.MaxBytes, target)
171172
}
172173

173174
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {

0 commit comments

Comments
 (0)