Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions cmd/image/image.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package image

import (
"crypto/sha512"
"fmt"
"hash"
"log/slog"

pb "github.com/cheggaaa/pb/v3"
Expand All @@ -17,6 +19,11 @@ import (
"time"
)

const (
HashMD5 = "md5"
HashSHA512 = "sha512"
)

type Image struct {
log *slog.Logger
}
Expand All @@ -27,22 +34,38 @@ func NewImage(log *slog.Logger) *Image {

// Pull a image from s3
func (i *Image) Pull(image, destination string) error {
var (
sha512destination = destination + ".sha512sum"
sha512file = image + ".sha512sum"
md5destination = destination + ".md5"
md5file = image + ".md5"
)

i.log.Info("pull image", "image", image)
md5destination := destination + ".md5"
md5file := image + ".md5"
err := i.download(image, destination)
if err != nil {
return fmt.Errorf("unable to pull image %s %w", image, err)
}
err = i.download(md5file, md5destination)
defer os.Remove(md5destination)

err = i.download(sha512file, sha512destination)
defer os.Remove(sha512destination)
if err != nil {
return fmt.Errorf("unable to pull md5 %s %w", md5file, err)
}
i.log.Info("check md5")
matches, err := i.checkMD5(destination, md5destination)
if err != nil || !matches {
return fmt.Errorf("md5sum mismatch")
i.log.Info("unable to process sha512 file, trying with md5", "error", err)
err = i.download(md5file, md5destination)
defer os.Remove(md5destination)
if err != nil {
return fmt.Errorf("unable to pull hash file %s %w", md5file, err)
}
matches, err := i.checkHash(destination, md5destination, HashMD5)
if err != nil || !matches {
return fmt.Errorf("md5 mismatch, matches: %v with error: %w", matches, err)
}
} else {
i.log.Info("check sha512")
matches, err := i.checkHash(destination, sha512destination, HashSHA512)
if err != nil || !matches {
return fmt.Errorf("sha512 mismatch, matches: %v with error: %w", matches, err)
}
}

i.log.Info("pull image done", "image", image)
Expand Down Expand Up @@ -102,32 +125,40 @@ func (i *Image) Burn(prefix, image, source string) error {
return nil
}

// checkMD5 check the md5 signature of file with the md5sum given in the md5file.
// the content of the md5file must be in the form:
// <md5sum> filename
// this is the same format as create by the "md5sum" unix command
func (i *Image) checkMD5(file, md5file string) (bool, error) {
md5fileContent, err := os.ReadFile(md5file)
// checkHash check the sha512 or md5 signature of file with the sha512sum or md5sum given in the file.
// the content of the file must be in the form:
// <sha512sum | md5sum> filename
// this is the same format as create by the "sha512 | md5sum" unix command
func (i *Image) checkHash(file, hashfile, hashType string) (bool, error) {
hashfileContent, err := os.ReadFile(hashfile)
if err != nil {
return false, fmt.Errorf("unable to read md5sum file %s %w", md5file, err)
return false, fmt.Errorf("unable to read hash file %s %w", hashfile, err)
}
expectedMD5 := strings.Split(string(md5fileContent), " ")[0]
expectedHash := strings.Split(string(hashfileContent), " ")[0]

f, err := os.Open(file)
if err != nil {
return false, fmt.Errorf("unable to read file: %s %w", file, err)
}
defer f.Close()

//nolint:gosec
h := md5.New()
var h hash.Hash
switch hashType {
case HashSHA512:
h = sha512.New()
case HashMD5:
h = md5.New()
default:
return false, fmt.Errorf("unsupported hash type: %s", hashType)
}

if _, err := io.Copy(h, f); err != nil {
return false, fmt.Errorf("unable to calculate md5sum of file: %s %w", file, err)
return false, fmt.Errorf("unable to calculate %s of file: %s %w", hashType, file, err)
}
sourceMD5 := fmt.Sprintf("%x", h.Sum(nil))
i.log.Info("check md5", "source md5", sourceMD5, "expected md5", expectedMD5)
if sourceMD5 != expectedMD5 {
return false, fmt.Errorf("source md5:%s expected md5:%s", sourceMD5, expectedMD5)
sourceHash := fmt.Sprintf("%x", h.Sum(nil))
i.log.Info("check hash", "source hash", sourceHash, "expected hash", expectedHash)
if sourceHash != expectedHash {
return false, fmt.Errorf("source %s:%s expected %s:%s", hashType, sourceHash, hashType, expectedHash)
}
return true, nil
}
Expand Down
35 changes: 34 additions & 1 deletion cmd/image/image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestCheckMD5(t *testing.T) {
defer os.Remove(testfile)
defer os.Remove(testfileMD5)

matches, err := NewImage(slog.Default()).checkMD5(testfile, testfileMD5)
matches, err := NewImage(slog.Default()).checkHash(testfile, testfileMD5, HashMD5)
if err != nil {
t.Error(err)
}
Expand All @@ -43,3 +43,36 @@ func TestCheckMD5(t *testing.T) {
}

}

func TestCheckSHA512(t *testing.T) {
testfile := "/tmp/testsha512"
testfileSHA512 := "/tmp/testsha512.sha512sum"
content := []byte("This is testcontent")
err := os.WriteFile(testfile, content, os.ModePerm) // nolint:gosec
if err != nil {
t.Error(err)
}
cmd := exec.Command("sha512sum", testfile)
sha512Content, err := cmd.Output()
if err != nil {
t.Error(err)
}
sha512, err := os.Create(testfileSHA512)
if err != nil {
t.Error(err)
}
_, err = sha512.Write(sha512Content)
if err != nil {
t.Error(err)
}
sha512.Close()
defer os.Remove(testfile)
defer os.Remove(testfileSHA512)
matches, err := NewImage(slog.Default()).checkHash(testfile, testfileSHA512, HashSHA512)
if err != nil {
t.Error(err)
}
if !matches {
t.Error("expected sha512 matches, but didn't")
}
}
Loading