Skip to content

Commit a39345e

Browse files
committed
pkg/downloader: refactor choosing of decompressor
Signed-off-by: Oleksandr Redko <[email protected]>
1 parent 37b74b6 commit a39345e

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

pkg/downloader/downloader.go

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
164164

165165
ext := path.Ext(remote)
166166
if IsLocal(remote) {
167-
if err := copyLocal(localPath, remote, ext, o.decompress, o.description, o.expectedDigest); err != nil {
167+
if err := copyLocal(ctx, localPath, remote, ext, o.decompress, o.description, o.expectedDigest); err != nil {
168168
return nil, err
169169
}
170170
res := &Result{
@@ -199,11 +199,11 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
199199
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
200200
return nil, err
201201
}
202-
if err := copyLocal(localPath, shadData, ext, o.decompress, "", ""); err != nil {
202+
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
203203
return nil, err
204204
}
205205
} else {
206-
if err := copyLocal(localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil {
206+
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil {
207207
return nil, err
208208
}
209209
}
@@ -228,7 +228,7 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
228228
return nil, err
229229
}
230230
// no need to pass the digest to copyLocal(), as we already verified the digest
231-
if err := copyLocal(localPath, shadData, ext, o.decompress, "", ""); err != nil {
231+
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
232232
return nil, err
233233
}
234234
if shadDigest != "" && o.expectedDigest != "" {
@@ -336,7 +336,7 @@ func canonicalLocalPath(s string) (string, error) {
336336
return localpathutil.Expand(s)
337337
}
338338

339-
func copyLocal(dst, src, ext string, decompress bool, description string, expectedDigest digest.Digest) error {
339+
func copyLocal(ctx context.Context, dst, src, ext string, decompress bool, description string, expectedDigest digest.Digest) error {
340340
srcPath, err := canonicalLocalPath(src)
341341
if err != nil {
342342
return err
@@ -357,37 +357,33 @@ func copyLocal(dst, src, ext string, decompress bool, description string, expect
357357
if err != nil {
358358
return err
359359
}
360-
if _, ok := Decompressor(ext); ok && decompress {
361-
return decompressLocal(dstPath, srcPath, ext, description)
360+
if decompress {
361+
command := decompressor(ext)
362+
if command != "" {
363+
return decompressLocal(ctx, command, dstPath, srcPath, ext, description)
364+
}
362365
}
363366
// TODO: progress bar for copy
364367
return fs.CopyFile(dstPath, srcPath)
365368
}
366369

367-
func Decompressor(ext string) ([]string, bool) {
368-
var program string
370+
func decompressor(ext string) string {
369371
switch ext {
370372
case ".gz":
371-
program = "gzip"
373+
return "gzip"
372374
case ".bz2":
373-
program = "bzip2"
375+
return "bzip2"
374376
case ".xz":
375-
program = "xz"
377+
return "xz"
376378
case ".zst":
377-
program = "zstd"
379+
return "zstd"
378380
default:
379-
return nil, false
381+
return ""
380382
}
381-
// -d --decompress
382-
return []string{program, "-d"}, true
383383
}
384384

385-
func decompressLocal(dst, src, ext, description string) error {
386-
command, found := Decompressor(ext)
387-
if !found {
388-
return fmt.Errorf("decompressLocal: unknown extension %s", ext)
389-
}
390-
logrus.Infof("decompressing %s with %v", ext, command)
385+
func decompressLocal(ctx context.Context, decompressCmd, dst, src, ext, description string) error {
386+
logrus.Infof("decompressing %s with %v", ext, decompressCmd)
391387

392388
st, err := os.Stat(src)
393389
if err != nil {
@@ -412,7 +408,7 @@ func decompressLocal(dst, src, ext, description string) error {
412408
}
413409
defer out.Close()
414410
buf := new(bytes.Buffer)
415-
cmd := exec.Command(command[0], command[1:]...)
411+
cmd := exec.CommandContext(ctx, decompressCmd, "-d") // -d --decompress
416412
cmd.Stdin = bar.NewProxyReader(in)
417413
cmd.Stdout = out
418414
cmd.Stderr = buf

pkg/downloader/downloader_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,38 @@ func TestDownloadCompressed(t *testing.T) {
179179
assert.NilError(t, err)
180180
assert.Equal(t, string(got), string(testDownloadCompressedContents))
181181
})
182+
183+
t.Run("bzip2", func(t *testing.T) {
184+
localPath := filepath.Join(t.TempDir(), t.Name())
185+
localFile := filepath.Join(t.TempDir(), "test-file")
186+
testDownloadCompressedContents := []byte("TestDownloadCompressed")
187+
assert.NilError(t, os.WriteFile(localFile, testDownloadCompressedContents, 0o644))
188+
assert.NilError(t, exec.Command("bzip2", localFile).Run())
189+
localFile += ".bz2"
190+
testLocalFileURL := "file://" + localFile
191+
192+
r, err := Download(context.Background(), localPath, testLocalFileURL, WithDecompress(true))
193+
assert.NilError(t, err)
194+
assert.Equal(t, StatusDownloaded, r.Status)
195+
196+
got, err := os.ReadFile(localPath)
197+
assert.NilError(t, err)
198+
assert.Equal(t, string(got), string(testDownloadCompressedContents))
199+
})
200+
201+
t.Run("unknown decompressor", func(t *testing.T) {
202+
localPath := filepath.Join(t.TempDir(), t.Name())
203+
localFile := filepath.Join(t.TempDir(), "test-file.rar")
204+
testDownloadCompressedContents := []byte("TestDownloadCompressed")
205+
assert.NilError(t, os.WriteFile(localFile, testDownloadCompressedContents, 0o644))
206+
testLocalFileURL := "file://" + localFile
207+
208+
r, err := Download(context.Background(), localPath, testLocalFileURL, WithDecompress(true))
209+
assert.NilError(t, err)
210+
assert.Equal(t, StatusDownloaded, r.Status)
211+
212+
got, err := os.ReadFile(localPath)
213+
assert.NilError(t, err)
214+
assert.Equal(t, string(got), string(testDownloadCompressedContents))
215+
})
182216
}

0 commit comments

Comments
 (0)