@@ -7,86 +7,109 @@ import (
77 "io"
88 "os"
99 "path/filepath"
10+ "runtime"
1011 "strings"
1112
1213 "github.com/chainguard-dev/clog"
14+ "golang.org/x/sync/errgroup"
1315)
1416
15- // extractZip extracts .jar and .zip archives.
17+ // ExtractZip extracts .jar and .zip archives.
1618func ExtractZip (ctx context.Context , d string , f string ) error {
1719 logger := clog .FromContext (ctx ).With ("dir" , d , "file" , f )
1820 logger .Debug ("extracting zip" )
1921
20- // Check if the file is valid
21- _ , err := os .Stat (f )
22+ fi , err := os .Stat (f )
2223 if err != nil {
2324 return fmt .Errorf ("failed to stat file %s: %w" , f , err )
2425 }
26+ if fi .Size () == 0 {
27+ return fmt .Errorf ("empty zip file: %s" , f )
28+ }
2529
2630 read , err := zip .OpenReader (f )
2731 if err != nil {
2832 return fmt .Errorf ("failed to open zip file %s: %w" , f , err )
2933 }
3034 defer read .Close ()
3135
36+ if err := os .MkdirAll (d , 0o700 ); err != nil {
37+ return fmt .Errorf ("failed to create extraction directory: %w" , err )
38+ }
39+
40+ g , gCtx := errgroup .WithContext (ctx )
41+ g .SetLimit (runtime .GOMAXPROCS (0 ))
42+
3243 for _ , file := range read .File {
33- clean := filepath .Clean (filepath .ToSlash (file .Name ))
34- if strings .Contains (clean , ".." ) {
35- logger .Warnf ("skipping potentially unsafe file path: %s" , file .Name )
36- continue
37- }
44+ g .Go (func () error {
45+ return extractFile (gCtx , file , d , logger )
46+ })
47+ }
3848
39- target := filepath . Join ( d , clean )
40- if ! IsValidPath ( target , d ) {
41- logger . Warnf ( "skipping file path outside extraction directory: %s" , target )
42- continue
43- }
49+ if err := g . Wait (); err != nil {
50+ return fmt . Errorf ( "extraction failed: %w" , err )
51+ }
52+ return nil
53+ }
4454
45- // Check if a directory with the same name exists
46- if info , err := os .Stat (target ); err == nil && info .IsDir () {
47- continue
48- }
55+ func extractFile (ctx context.Context , file * zip.File , destDir string , logger * clog.Logger ) error {
56+ buf , ok := bufferPool .Get ().(* []byte )
57+ if ! ok {
58+ return fmt .Errorf ("failed to retrieve buffer" )
59+ }
60+ defer bufferPool .Put (buf )
4961
50- if file .Mode ().IsDir () {
51- err := os .MkdirAll (target , 0o700 )
52- if err != nil {
53- return fmt .Errorf ("failed to create directory: %w" , err )
54- }
55- continue
56- }
62+ clean := filepath .Clean (filepath .ToSlash (file .Name ))
63+ if strings .Contains (clean , ".." ) {
64+ logger .Warnf ("skipping potentially unsafe file path: %s" , file .Name )
65+ return nil
66+ }
5767
58- zf , err := file .Open ()
59- if err != nil {
60- return fmt .Errorf ("failed to open file in zip: %w" , err )
61- }
68+ target := filepath .Join (destDir , clean )
69+ if ! IsValidPath (target , destDir ) {
70+ logger .Warnf ("skipping file path outside extraction directory: %s" , target )
71+ return nil
72+ }
6273
63- err = os . MkdirAll ( filepath . Dir ( target ), 0o700 )
64- if err != nil {
65- zf . Close ()
66- return fmt . Errorf ( "failed to create directory: %w" , err )
67- }
74+ select {
75+ case <- ctx . Done ():
76+ return ctx . Err ()
77+ default :
78+ }
6879
69- out , err := os .OpenFile (target , os .O_WRONLY | os .O_CREATE | os .O_TRUNC , 0o600 )
70- if err != nil {
71- out .Close ()
72- return fmt .Errorf ("failed to create file: %w" , err )
73- }
80+ if file .Mode ().IsDir () {
81+ return os .MkdirAll (target , 0o700 )
82+ }
7483
75- written , err := io .Copy (out , io .LimitReader (zf , maxBytes ))
76- if err != nil {
77- return fmt .Errorf ("failed to copy file: %w" , err )
78- }
79- if written >= maxBytes {
80- return fmt .Errorf ("file exceeds maximum allowed size (%d bytes): %s" , maxBytes , target )
81- }
84+ if err := os .MkdirAll (filepath .Dir (target ), 0o700 ); err != nil {
85+ return fmt .Errorf ("failed to create directory structure: %w" , err )
86+ }
8287
83- if err := out .Close (); err != nil {
84- return fmt .Errorf ("failed to close file: %w" , err )
85- }
88+ src , err := file .Open ()
89+ if err != nil {
90+ return fmt .Errorf ("failed to open archived file: %w" , err )
91+ }
92+ defer src .Close ()
8693
87- if err := zf .Close (); err != nil {
88- return fmt .Errorf ("failed to close file: %w" , err )
94+ dst , err := os .OpenFile (target , os .O_WRONLY | os .O_CREATE | os .O_TRUNC , 0o600 )
95+ if err != nil {
96+ return fmt .Errorf ("failed to create destination file: %w" , err )
97+ }
98+
99+ var closeErr error
100+ defer func () {
101+ if cerr := dst .Close (); cerr != nil && closeErr == nil {
102+ closeErr = cerr
89103 }
104+ }()
105+
106+ written , err := io .CopyBuffer (dst , io .LimitReader (src , maxBytes ), * buf )
107+ if err != nil {
108+ return fmt .Errorf ("failed to copy file contents: %w" , err )
90109 }
91- return nil
110+ if written >= maxBytes {
111+ return fmt .Errorf ("file exceeds maximum allowed size (%d bytes): %s" , maxBytes , target )
112+ }
113+
114+ return closeErr
92115}
0 commit comments