@@ -19,6 +19,7 @@ import (
1919 "archive/tar"
2020 "compress/gzip"
2121 "crypto/sha256"
22+ "errors"
2223 "fmt"
2324 "io"
2425 "net/http"
@@ -28,6 +29,8 @@ import (
2829 "time"
2930)
3031
32+ var errChecksumMismatch = errors .New ("checksum mismatch" )
33+
3134// Endpoints defines the endpoints used to access GitHub.
3235type Endpoints struct {
3336 // API defines the endpoint used to make API calls.
@@ -113,48 +116,78 @@ func TarballLink(githubDownload string, repo *Repo, sha string) string {
113116 return fmt .Sprintf ("%s/%s/%s/archive/%s.tar.gz" , githubDownload , repo .Org , repo .Repo , sha )
114117}
115118
116- // DownloadTarball downloads a tarball from the given source URL to the target path,
117- // verifying its SHA256 checksum matches expectedSha256. It retries up to 3 times
118- // with exponential backoff on failure.
119- func DownloadTarball (target , source , expectedSha256 string ) error {
119+ // DownloadTarball downloads a tarball from the given url to the target
120+ // path, verifying its SHA256 checksum matches expectedSha256. It retries up to
121+ // 3 times with exponential backoff on failure.
122+ func DownloadTarball (target , url , expectedSha256 string ) error {
120123 if fileExists (target ) {
121124 return nil
122125 }
126+ if err := os .MkdirAll (filepath .Dir (target ), 0755 ); err != nil {
127+ return err
128+ }
129+ tempFile , err := os .CreateTemp (filepath .Dir (target ), "temp-" )
130+ if err != nil {
131+ return err
132+ }
133+ tempPath := tempFile .Name ()
134+ defer func () {
135+ tempFile .Close ()
136+ cerr := os .Remove (tempPath )
137+ if err == nil && cerr != nil && ! os .IsNotExist (cerr ) {
138+ err = cerr
139+ }
140+ }()
141+
142+ if err := downloadTarball (tempPath , url ); err != nil {
143+ return err
144+ }
145+
146+ sha , err := computeSHA256 (tempPath )
147+ if err != nil {
148+ return err
149+ }
150+ if sha != expectedSha256 {
151+ return fmt .Errorf ("%w: expected=%s, got=%s" , errChecksumMismatch , expectedSha256 , sha )
152+ }
153+ if err := os .MkdirAll (filepath .Dir (target ), 0755 ); err != nil {
154+ return err
155+ }
156+ return os .Rename (tempPath , target )
157+ }
158+
159+ // downloadTarball downloads a tarball from the given source URL to the target
160+ // path. It retries up to 3 times with exponential backoff on failure.
161+ func downloadTarball (target , source string ) error {
123162 var err error
124163 backoff := 10 * time .Second
125164 for i := range 3 {
126165 if i != 0 {
127166 time .Sleep (backoff )
128167 backoff = 2 * backoff
129168 }
130- if err = downloadAttempt (target , source , expectedSha256 ); err == nil {
169+ if err = downloadAttempt (target , source ); err == nil {
131170 return nil
132171 }
133172 }
134173 return fmt .Errorf ("download failed after 3 attempts, last error=%w" , err )
135174}
136175
137- func downloadAttempt (target , source , expectedSha256 string ) (err error ) {
138- if err := os .MkdirAll (filepath .Dir (target ), 0755 ); err != nil {
139- return err
140- }
141- tempFile , err := os .CreateTemp (filepath .Dir (target ), "temp-" )
176+ func downloadAttempt (target , source string ) (err error ) {
177+ file , err := os .Create (target )
142178 if err != nil {
143179 return err
144180 }
145181 defer func () {
146- cerr := tempFile .Close ()
182+ cerr := file .Close ()
147183 if err == nil {
148184 err = cerr
149185 }
150186 if err != nil {
151- os .Remove (tempFile . Name () )
187+ os .Remove (target )
152188 }
153189 }()
154190
155- hasher := sha256 .New ()
156- writer := io .MultiWriter (tempFile , hasher )
157-
158191 client := http.Client {Timeout : 60 * time .Second }
159192 response , err := client .Get (source )
160193 if err != nil {
@@ -165,15 +198,11 @@ func downloadAttempt(target, source, expectedSha256 string) (err error) {
165198 return fmt .Errorf ("http error in download %s" , response .Status )
166199 }
167200
168- if _ , err := io .Copy (writer , response .Body ); err != nil {
201+ if _ , err := io .Copy (file , response .Body ); err != nil {
169202 return err
170203 }
171204
172- got := fmt .Sprintf ("%x" , hasher .Sum (nil ))
173- if expectedSha256 != got {
174- return fmt .Errorf ("mismatched hash on download, expected=%s, got=%s" , expectedSha256 , got )
175- }
176- return os .Rename (tempFile .Name (), target )
205+ return nil
177206}
178207
179208func fileExists (name string ) bool {
0 commit comments