Skip to content

Commit e7d91da

Browse files
authored
Extract .jar and .zip files concurrently, use buffer for all io.Copy operations (#779)
* Improve efficiency and performance of zip extractions Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Use buffer for all io.Copy operations; add limit reader for .xz files Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Consolidate consts Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> * Keep parameters on a single line Signed-off-by: egibs <20933572+egibs@users.noreply.github.com> --------- Signed-off-by: egibs <20933572+egibs@users.noreply.github.com>
1 parent 3147dd1 commit e7d91da

File tree

9 files changed

+148
-65
lines changed

9 files changed

+148
-65
lines changed

pkg/archive/archive.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,27 @@ import (
1515
"github.com/chainguard-dev/malcontent/pkg/programkind"
1616
)
1717

18+
const (
19+
// 32KB buffer.
20+
bufferSize = 32 * 1024
21+
// 512MB file limit.
22+
maxBytes = 1 << 29
23+
)
24+
25+
// Shared buffer pool for io.CopyBuffer operations.
26+
var bufferPool = sync.Pool{
27+
New: func() interface{} {
28+
b := make([]byte, bufferSize)
29+
return &b
30+
},
31+
}
32+
1833
// isValidPath checks if the target file is within the given directory.
1934
func IsValidPath(target, dir string) bool {
2035
return strings.HasPrefix(filepath.Clean(target), filepath.Clean(dir))
2136
}
2237

23-
const maxBytes = 1 << 29 // 512MB
24-
25-
func extractNestedArchive(
26-
ctx context.Context,
27-
d string,
28-
f string,
29-
extracted *sync.Map,
30-
) error {
38+
func extractNestedArchive(ctx context.Context, d string, f string, extracted *sync.Map) error {
3139
isArchive := false
3240
// zlib-compressed files are also archives
3341
ft, err := programkind.File(f)
@@ -223,6 +231,12 @@ func handleDirectory(target string) error {
223231

224232
// handleFile extracts valid files within .deb or .tar archives.
225233
func handleFile(target string, tr *tar.Reader) error {
234+
buf, ok := bufferPool.Get().(*[]byte)
235+
if !ok {
236+
return fmt.Errorf("failed to retrieve buffer")
237+
}
238+
defer bufferPool.Put(buf)
239+
226240
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
227241
return fmt.Errorf("failed to create parent directory: %w", err)
228242
}
@@ -233,7 +247,7 @@ func handleFile(target string, tr *tar.Reader) error {
233247
}
234248
defer out.Close()
235249

236-
written, err := io.Copy(out, io.LimitReader(tr, maxBytes))
250+
written, err := io.CopyBuffer(out, io.LimitReader(tr, maxBytes), *buf)
237251
if err != nil {
238252
return fmt.Errorf("failed to copy file: %w", err)
239253
}

pkg/archive/bz2.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ func ExtractBz2(ctx context.Context, d, f string) error {
1717
logger := clog.FromContext(ctx).With("dir", d, "file", f)
1818
logger.Debug("extracting bzip2 file")
1919

20+
buf, ok := bufferPool.Get().(*[]byte)
21+
if !ok {
22+
return fmt.Errorf("failed to retrieve buffer")
23+
}
24+
defer bufferPool.Put(buf)
25+
2026
// Check if the file is valid
2127
_, err := os.Stat(f)
2228
if err != nil {
@@ -53,7 +59,7 @@ func ExtractBz2(ctx context.Context, d, f string) error {
5359
}
5460
defer out.Close()
5561

56-
written, err := io.Copy(out, io.LimitReader(br, maxBytes))
62+
written, err := io.CopyBuffer(out, io.LimitReader(br, maxBytes), *buf)
5763
if err != nil {
5864
return fmt.Errorf("failed to copy file: %w", err)
5965
}

pkg/archive/deb.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ func ExtractDeb(ctx context.Context, d, f string) error {
1919
logger := clog.FromContext(ctx).With("dir", d, "file", f)
2020
logger.Debug("extracting deb")
2121

22+
buf, ok := bufferPool.Get().(*[]byte)
23+
if !ok {
24+
return fmt.Errorf("failed to retrieve buffer")
25+
}
26+
defer bufferPool.Put(buf)
27+
2228
fd, err := os.Open(f)
2329
if err != nil {
2430
panic(err)

pkg/archive/gzip.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ func ExtractGzip(ctx context.Context, d string, f string) error {
2929
logger := clog.FromContext(ctx).With("dir", d, "file", f)
3030
logger.Debug("extracting gzip")
3131

32+
buf, ok := bufferPool.Get().(*[]byte)
33+
if !ok {
34+
return fmt.Errorf("failed to retrieve buffer")
35+
}
36+
defer bufferPool.Put(buf)
37+
3238
// Check if the file is valid
3339
_, err := os.Stat(f)
3440
if err != nil {
@@ -59,7 +65,7 @@ func ExtractGzip(ctx context.Context, d string, f string) error {
5965
}
6066
defer out.Close()
6167

62-
written, err := io.Copy(out, io.LimitReader(gr, maxBytes))
68+
written, err := io.CopyBuffer(out, io.LimitReader(gr, maxBytes), *buf)
6369
if err != nil {
6470
return fmt.Errorf("failed to copy file: %w", err)
6571
}

pkg/archive/rpm.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ func ExtractRPM(ctx context.Context, d, f string) error {
2222
logger := clog.FromContext(ctx).With("dir", d, "file", f)
2323
logger.Debug("extracting rpm")
2424

25+
buf, ok := bufferPool.Get().(*[]byte)
26+
if !ok {
27+
return fmt.Errorf("failed to retrieve buffer")
28+
}
29+
defer bufferPool.Put(buf)
30+
2531
rpmFile, err := os.Open(f)
2632
if err != nil {
2733
return fmt.Errorf("failed to open RPM file: %w", err)
@@ -106,7 +112,7 @@ func ExtractRPM(ctx context.Context, d, f string) error {
106112
return fmt.Errorf("failed to create file: %w", err)
107113
}
108114

109-
written, err := io.Copy(out, io.LimitReader(cr, maxBytes))
115+
written, err := io.CopyBuffer(out, io.LimitReader(cr, maxBytes), *buf)
110116
if err != nil {
111117
return fmt.Errorf("failed to copy file: %w", err)
112118
}

pkg/archive/tar.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ func ExtractTar(ctx context.Context, d string, f string) error {
2222
logger := clog.FromContext(ctx).With("dir", d, "file", f)
2323
logger.Debug("extracting tar")
2424

25+
buf, ok := bufferPool.Get().(*[]byte)
26+
if !ok {
27+
return fmt.Errorf("failed to retrieve buffer")
28+
}
29+
defer bufferPool.Put(buf)
30+
2531
// Check if the file is valid
2632
_, err := os.Stat(f)
2733
if err != nil {
@@ -83,9 +89,13 @@ func ExtractTar(ctx context.Context, d string, f string) error {
8389
}
8490
defer out.Close()
8591

86-
if _, err = io.Copy(out, xzStream); err != nil {
92+
written, err := io.CopyBuffer(out, io.LimitReader(xzStream, maxBytes), *buf)
93+
if err != nil {
8794
return fmt.Errorf("failed to write decompressed xz output: %w", err)
8895
}
96+
if written >= maxBytes {
97+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
98+
}
8999
return nil
90100
case strings.Contains(filename, ".tar.bz2") || strings.Contains(filename, ".tbz"):
91101
br := bzip2.NewReader(tf)

pkg/archive/upx.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ func ExtractUPX(ctx context.Context, d, f string) error {
2020
logger := clog.FromContext(ctx).With("dir", d, "file", f)
2121
logger.Debug("extracting upx")
2222

23+
buf, ok := bufferPool.Get().(*[]byte)
24+
if !ok {
25+
return fmt.Errorf("failed to retrieve buffer")
26+
}
27+
defer bufferPool.Put(buf)
28+
2329
// Check if the file is valid
2430
_, err := os.Stat(f)
2531
if err != nil {

pkg/archive/zip.go

Lines changed: 74 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,86 +7,109 @@ import (
77
"io"
88
"os"
99
"path/filepath"
10+
"runtime"
1011
"strings"
1112

1213
"github.com/chainguard-dev/clog"
14+
"golang.org/x/sync/errgroup"
1315
)
1416

15-
// extractZip extracts .jar and .zip archives.
17+
// ExtractZip extracts .jar and .zip archives.
1618
func ExtractZip(ctx context.Context, d string, f string) error {
1719
logger := clog.FromContext(ctx).With("dir", d, "file", f)
1820
logger.Debug("extracting zip")
1921

20-
// Check if the file is valid
21-
_, err := os.Stat(f)
22+
fi, err := os.Stat(f)
2223
if err != nil {
2324
return fmt.Errorf("failed to stat file %s: %w", f, err)
2425
}
26+
if fi.Size() == 0 {
27+
return fmt.Errorf("empty zip file: %s", f)
28+
}
2529

2630
read, err := zip.OpenReader(f)
2731
if err != nil {
2832
return fmt.Errorf("failed to open zip file %s: %w", f, err)
2933
}
3034
defer read.Close()
3135

36+
if err := os.MkdirAll(d, 0o700); err != nil {
37+
return fmt.Errorf("failed to create extraction directory: %w", err)
38+
}
39+
40+
g, gCtx := errgroup.WithContext(ctx)
41+
g.SetLimit(runtime.GOMAXPROCS(0))
42+
3243
for _, file := range read.File {
33-
clean := filepath.Clean(filepath.ToSlash(file.Name))
34-
if strings.Contains(clean, "..") {
35-
logger.Warnf("skipping potentially unsafe file path: %s", file.Name)
36-
continue
37-
}
44+
g.Go(func() error {
45+
return extractFile(gCtx, file, d, logger)
46+
})
47+
}
3848

39-
target := filepath.Join(d, clean)
40-
if !IsValidPath(target, d) {
41-
logger.Warnf("skipping file path outside extraction directory: %s", target)
42-
continue
43-
}
49+
if err := g.Wait(); err != nil {
50+
return fmt.Errorf("extraction failed: %w", err)
51+
}
52+
return nil
53+
}
4454

45-
// Check if a directory with the same name exists
46-
if info, err := os.Stat(target); err == nil && info.IsDir() {
47-
continue
48-
}
55+
func extractFile(ctx context.Context, file *zip.File, destDir string, logger *clog.Logger) error {
56+
buf, ok := bufferPool.Get().(*[]byte)
57+
if !ok {
58+
return fmt.Errorf("failed to retrieve buffer")
59+
}
60+
defer bufferPool.Put(buf)
4961

50-
if file.Mode().IsDir() {
51-
err := os.MkdirAll(target, 0o700)
52-
if err != nil {
53-
return fmt.Errorf("failed to create directory: %w", err)
54-
}
55-
continue
56-
}
62+
clean := filepath.Clean(filepath.ToSlash(file.Name))
63+
if strings.Contains(clean, "..") {
64+
logger.Warnf("skipping potentially unsafe file path: %s", file.Name)
65+
return nil
66+
}
5767

58-
zf, err := file.Open()
59-
if err != nil {
60-
return fmt.Errorf("failed to open file in zip: %w", err)
61-
}
68+
target := filepath.Join(destDir, clean)
69+
if !IsValidPath(target, destDir) {
70+
logger.Warnf("skipping file path outside extraction directory: %s", target)
71+
return nil
72+
}
6273

63-
err = os.MkdirAll(filepath.Dir(target), 0o700)
64-
if err != nil {
65-
zf.Close()
66-
return fmt.Errorf("failed to create directory: %w", err)
67-
}
74+
select {
75+
case <-ctx.Done():
76+
return ctx.Err()
77+
default:
78+
}
6879

69-
out, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
70-
if err != nil {
71-
out.Close()
72-
return fmt.Errorf("failed to create file: %w", err)
73-
}
80+
if file.Mode().IsDir() {
81+
return os.MkdirAll(target, 0o700)
82+
}
7483

75-
written, err := io.Copy(out, io.LimitReader(zf, maxBytes))
76-
if err != nil {
77-
return fmt.Errorf("failed to copy file: %w", err)
78-
}
79-
if written >= maxBytes {
80-
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
81-
}
84+
if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil {
85+
return fmt.Errorf("failed to create directory structure: %w", err)
86+
}
8287

83-
if err := out.Close(); err != nil {
84-
return fmt.Errorf("failed to close file: %w", err)
85-
}
88+
src, err := file.Open()
89+
if err != nil {
90+
return fmt.Errorf("failed to open archived file: %w", err)
91+
}
92+
defer src.Close()
8693

87-
if err := zf.Close(); err != nil {
88-
return fmt.Errorf("failed to close file: %w", err)
94+
dst, err := os.OpenFile(target, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600)
95+
if err != nil {
96+
return fmt.Errorf("failed to create destination file: %w", err)
97+
}
98+
99+
var closeErr error
100+
defer func() {
101+
if cerr := dst.Close(); cerr != nil && closeErr == nil {
102+
closeErr = cerr
89103
}
104+
}()
105+
106+
written, err := io.CopyBuffer(dst, io.LimitReader(src, maxBytes), *buf)
107+
if err != nil {
108+
return fmt.Errorf("failed to copy file contents: %w", err)
90109
}
91-
return nil
110+
if written >= maxBytes {
111+
return fmt.Errorf("file exceeds maximum allowed size (%d bytes): %s", maxBytes, target)
112+
}
113+
114+
return closeErr
92115
}

pkg/archive/zlib.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ func ExtractZlib(ctx context.Context, d string, f string) error {
1616
logger := clog.FromContext(ctx).With("dir", d, "file", f)
1717
logger.Debugf("extracting zlib")
1818

19+
buf, ok := bufferPool.Get().(*[]byte)
20+
if !ok {
21+
return fmt.Errorf("failed to retrieve buffer")
22+
}
23+
defer bufferPool.Put(buf)
24+
1925
// Check if the file is valid
2026
_, err := os.Stat(f)
2127
if err != nil {
@@ -43,7 +49,7 @@ func ExtractZlib(ctx context.Context, d string, f string) error {
4349
}
4450
defer out.Close()
4551

46-
written, err := io.Copy(out, io.LimitReader(zr, maxBytes))
52+
written, err := io.CopyBuffer(out, io.LimitReader(zr, maxBytes), *buf)
4753
if err != nil {
4854
return fmt.Errorf("failed to copy file: %w", err)
4955
}

0 commit comments

Comments
 (0)