Skip to content

Commit 67fbd5c

Browse files
authored
refactor(internal/fetch): reorganize download logic (#3022)
Refactor the tarball download code to split download logic into three functions: - DownloadTarball: public API that verifies SHA256 against expected value - downloadTarball: handles retries with exponential backoff, returns SHA256 - downloadAttempt: performs a single download attempt, returns SHA256 This allows us to reuse these functions more easily when implementing caching.
1 parent 39ec5fa commit 67fbd5c

File tree

2 files changed

+72
-26
lines changed

2 files changed

+72
-26
lines changed

internal/fetch/fetch.go

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"archive/tar"
2020
"compress/gzip"
2121
"crypto/sha256"
22+
"errors"
2223
"fmt"
2324
"io"
2425
"net/http"
@@ -28,6 +29,8 @@ import (
2829
"time"
2930
)
3031

32+
var errChecksumMismatch = errors.New("checksum mismatch")
33+
3134
// Endpoints defines the endpoints used to access GitHub.
3235
type Endpoints struct {
3336
// API defines the endpoint used to make API calls.
@@ -113,48 +116,78 @@ func TarballLink(githubDownload string, repo *Repo, sha string) string {
113116
return fmt.Sprintf("%s/%s/%s/archive/%s.tar.gz", githubDownload, repo.Org, repo.Repo, sha)
114117
}
115118

116-
// DownloadTarball downloads a tarball from the given source URL to the target path,
117-
// verifying its SHA256 checksum matches expectedSha256. It retries up to 3 times
118-
// with exponential backoff on failure.
119-
func DownloadTarball(target, source, expectedSha256 string) error {
119+
// DownloadTarball downloads a tarball from the given url to the target
120+
// path, verifying its SHA256 checksum matches expectedSha256. It retries up to
121+
// 3 times with exponential backoff on failure.
122+
func DownloadTarball(target, url, expectedSha256 string) error {
120123
if fileExists(target) {
121124
return nil
122125
}
126+
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
127+
return err
128+
}
129+
tempFile, err := os.CreateTemp(filepath.Dir(target), "temp-")
130+
if err != nil {
131+
return err
132+
}
133+
tempPath := tempFile.Name()
134+
defer func() {
135+
tempFile.Close()
136+
cerr := os.Remove(tempPath)
137+
if err == nil && cerr != nil && !os.IsNotExist(cerr) {
138+
err = cerr
139+
}
140+
}()
141+
142+
if err := downloadTarball(tempPath, url); err != nil {
143+
return err
144+
}
145+
146+
sha, err := computeSHA256(tempPath)
147+
if err != nil {
148+
return err
149+
}
150+
if sha != expectedSha256 {
151+
return fmt.Errorf("%w: expected=%s, got=%s", errChecksumMismatch, expectedSha256, sha)
152+
}
153+
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
154+
return err
155+
}
156+
return os.Rename(tempPath, target)
157+
}
158+
159+
// downloadTarball downloads a tarball from the given source URL to the target
160+
// path. It retries up to 3 times with exponential backoff on failure.
161+
func downloadTarball(target, source string) error {
123162
var err error
124163
backoff := 10 * time.Second
125164
for i := range 3 {
126165
if i != 0 {
127166
time.Sleep(backoff)
128167
backoff = 2 * backoff
129168
}
130-
if err = downloadAttempt(target, source, expectedSha256); err == nil {
169+
if err = downloadAttempt(target, source); err == nil {
131170
return nil
132171
}
133172
}
134173
return fmt.Errorf("download failed after 3 attempts, last error=%w", err)
135174
}
136175

137-
func downloadAttempt(target, source, expectedSha256 string) (err error) {
138-
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
139-
return err
140-
}
141-
tempFile, err := os.CreateTemp(filepath.Dir(target), "temp-")
176+
func downloadAttempt(target, source string) (err error) {
177+
file, err := os.Create(target)
142178
if err != nil {
143179
return err
144180
}
145181
defer func() {
146-
cerr := tempFile.Close()
182+
cerr := file.Close()
147183
if err == nil {
148184
err = cerr
149185
}
150186
if err != nil {
151-
os.Remove(tempFile.Name())
187+
os.Remove(target)
152188
}
153189
}()
154190

155-
hasher := sha256.New()
156-
writer := io.MultiWriter(tempFile, hasher)
157-
158191
client := http.Client{Timeout: 60 * time.Second}
159192
response, err := client.Get(source)
160193
if err != nil {
@@ -165,15 +198,11 @@ func downloadAttempt(target, source, expectedSha256 string) (err error) {
165198
return fmt.Errorf("http error in download %s", response.Status)
166199
}
167200

168-
if _, err := io.Copy(writer, response.Body); err != nil {
201+
if _, err := io.Copy(file, response.Body); err != nil {
169202
return err
170203
}
171204

172-
got := fmt.Sprintf("%x", hasher.Sum(nil))
173-
if expectedSha256 != got {
174-
return fmt.Errorf("mismatched hash on download, expected=%s, got=%s", expectedSha256, got)
175-
}
176-
return os.Rename(tempFile.Name(), target)
205+
return nil
177206
}
178207

179208
func fileExists(name string) bool {

internal/fetch/fetch_test.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"bytes"
2020
"compress/gzip"
2121
"crypto/sha256"
22+
"errors"
2223
"fmt"
2324
"net/http"
2425
"net/http/httptest"
@@ -201,24 +202,19 @@ func TestTarballLink(t *testing.T) {
201202

202203
func TestDownloadTarballTgzExists(t *testing.T) {
203204
testDir := t.TempDir()
204-
205205
tarball := makeTestContents(t)
206-
207206
target := path.Join(testDir, "existing-file")
208207
if err := os.WriteFile(target, tarball.Contents, 0644); err != nil {
209208
t.Fatal(err)
210209
}
211-
212210
if err := DownloadTarball(target, "https://unused/placeholder.tar.gz", tarball.Sha256); err != nil {
213211
t.Fatal(err)
214212
}
215213
}
216214

217215
func TestDownloadTarballNeedsDownload(t *testing.T) {
218216
testDir := t.TempDir()
219-
220217
tarball := makeTestContents(t)
221-
222218
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
223219
if r.URL.Path != "/placeholder.tar.gz" {
224220
t.Errorf("Expected to request '/placeholder.tar.gz', got: %s", r.URL.Path)
@@ -241,6 +237,27 @@ func TestDownloadTarballNeedsDownload(t *testing.T) {
241237
}
242238
}
243239

240+
func TestDownloadTarballChecksumMismatch(t *testing.T) {
241+
testDir := t.TempDir()
242+
tarball := makeTestContents(t)
243+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
244+
w.WriteHeader(http.StatusOK)
245+
w.Write(tarball.Contents)
246+
}))
247+
defer server.Close()
248+
249+
target := path.Join(testDir, "target-file")
250+
wrongSha := "0000000000000000000000000000000000000000000000000000000000000000"
251+
252+
err := DownloadTarball(target, server.URL+"/test.tar.gz", wrongSha)
253+
if !errors.Is(err, errChecksumMismatch) {
254+
t.Fatalf("expected errChecksumMismatch, got: %v", err)
255+
}
256+
if _, err := os.Stat(target); !os.IsNotExist(err) {
257+
t.Errorf("target file should not exist after checksum failure: %v", err)
258+
}
259+
}
260+
244261
type contents struct {
245262
Sha256 string
246263
Contents []byte

0 commit comments

Comments
 (0)