Skip to content

Commit 9dfc00a

Browse files
authored
fix(internal/fetch): preserve last error in download retry loop (#3188)
The downloadTarball function was using := instead of = when capturing the error from downloadAttempt. This created a new variable scoped to the if statement, shadowing the function-level err variable. When all retry attempts failed, the error message showed `last error=%!w(<nil>)` because the outer err was never assigned. A test is added to verify the error message includes the actual failure. TestDownloadTarballRetry is also refactor to remove the unnecessary t.Run.
1 parent ec24497 commit 9dfc00a

File tree

2 files changed

+60
-43
lines changed

2 files changed

+60
-43
lines changed

internal/fetch/fetch.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ var (
3838
defaultBackoff = 10 * time.Second
3939
)
4040

41+
const maxDownloadRetries = 3
42+
4143
// Endpoints defines the endpoints used to access GitHub.
4244
type Endpoints struct {
4345
// API defines the endpoint used to make API calls.
@@ -142,7 +144,7 @@ func TarballLink(githubDownload string, repo *Repo, sha string) string {
142144

143145
// DownloadTarball downloads a tarball from the given url to the target
144146
// path, verifying its SHA256 checksum matches expectedSha256. It retries up to
145-
// 3 times with exponential backoff on failure.
147+
// maxDownloadRetries times with exponential backoff on failure.
146148
func DownloadTarball(ctx context.Context, target, url, expectedSha256 string) error {
147149
if fileExists(target) {
148150
return nil
@@ -169,7 +171,6 @@ func DownloadTarball(ctx context.Context, target, url, expectedSha256 string) er
169171
if err := downloadTarball(ctx, tempPath, url); err != nil {
170172
return err
171173
}
172-
173174
sha, err := computeSHA256(tempPath)
174175
if err != nil {
175176
return err
@@ -184,10 +185,10 @@ func DownloadTarball(ctx context.Context, target, url, expectedSha256 string) er
184185
}
185186

186187
// downloadTarball downloads a tarball from the given source URL to the target
187-
// path. It retries up to 3 times with exponential backoff on failure.
188+
// path. It retries up to maxDownloadRetries times with exponential backoff on failure.
188189
func downloadTarball(ctx context.Context, target, source string) error {
189190
var err error
190-
for i := range 3 {
191+
for i := range maxDownloadRetries {
191192
if i > 0 {
192193
select {
193194
case <-time.After(defaultBackoff):
@@ -197,16 +198,15 @@ func downloadTarball(ctx context.Context, target, source string) error {
197198
}
198199
}
199200

200-
if err := downloadAttempt(ctx, target, source); err != nil {
201+
if err = downloadAttempt(ctx, target, source); err != nil {
201202
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
202203
return err
203204
}
204205
continue
205206
}
206207
return nil
207-
208208
}
209-
return fmt.Errorf("download failed after 3 attempts, last error=%w", err)
209+
return fmt.Errorf("download failed after %d attempts, last error=%w", maxDownloadRetries, err)
210210
}
211211

212212
func downloadAttempt(ctx context.Context, target, source string) (err error) {
@@ -237,11 +237,9 @@ func downloadAttempt(ctx context.Context, target, source string) (err error) {
237237
if response.StatusCode >= 300 {
238238
return fmt.Errorf("http error in download %s", response.Status)
239239
}
240-
241240
if _, err := io.Copy(file, response.Body); err != nil {
242241
return err
243242
}
244-
245243
return nil
246244
}
247245

internal/fetch/fetch_test.go

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -590,43 +590,62 @@ func TestLatestCommitAndChecksum(t *testing.T) {
590590
}
591591
}
592592

593-
func TestDownloadTarballRetry(t *testing.T) {
594-
t.Run("succeeds after a few retries", func(t *testing.T) {
595-
// Set a short backoff for this test to speed up retries.
596-
defaultBackoff = time.Millisecond
597-
t.Cleanup(func() {
598-
defaultBackoff = 10 * time.Second
599-
})
600-
testDir := t.TempDir()
601-
tarball := makeTestContents(t)
602-
var requestCount int
603-
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
604-
requestCount++
605-
if requestCount < 3 {
606-
w.WriteHeader(http.StatusInternalServerError)
607-
return
608-
}
609-
w.WriteHeader(http.StatusOK)
610-
w.Write(tarball.Contents)
611-
}))
612-
defer server.Close()
593+
func TestDownloadTarballRetryErrorIncludesLastFailure(t *testing.T) {
594+
defaultBackoff = time.Millisecond
595+
t.Cleanup(func() {
596+
defaultBackoff = 10 * time.Second
597+
})
598+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
599+
w.WriteHeader(http.StatusInternalServerError)
600+
}))
601+
defer server.Close()
613602

614-
target := path.Join(testDir, "target-file")
615-
if err := DownloadTarball(t.Context(), target, server.URL+"/test.tar.gz", tarball.Sha256); err != nil {
616-
t.Fatal(err)
617-
}
603+
target := path.Join(t.TempDir(), "target-file")
604+
err := DownloadTarball(t.Context(), target, server.URL+"/test.tar.gz", "any-sha")
605+
if err == nil {
606+
t.Fatal("expected an error")
607+
}
608+
if strings.Contains(err.Error(), "<nil>") {
609+
t.Errorf("error should contain the last failure, not <nil>: %v", err)
610+
}
611+
if !strings.Contains(err.Error(), "500") {
612+
t.Errorf("error should mention the HTTP status code: %v", err)
613+
}
614+
}
618615

619-
if requestCount != 3 {
620-
t.Errorf("expected 3 requests, got %d", requestCount)
621-
}
622-
got, err := os.ReadFile(target)
623-
if err != nil {
624-
t.Fatal(err)
625-
}
626-
if diff := cmp.Diff(tarball.Contents, got); diff != "" {
627-
t.Errorf("mismatch (-want +got):\n%s", diff)
628-
}
616+
func TestDownloadTarballRetrySucceeds(t *testing.T) {
617+
defaultBackoff = time.Millisecond
618+
t.Cleanup(func() {
619+
defaultBackoff = 10 * time.Second
629620
})
621+
tarball := makeTestContents(t)
622+
var requestCount int
623+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
624+
requestCount++
625+
if requestCount < 3 {
626+
w.WriteHeader(http.StatusInternalServerError)
627+
return
628+
}
629+
w.WriteHeader(http.StatusOK)
630+
w.Write(tarball.Contents)
631+
}))
632+
defer server.Close()
633+
634+
target := path.Join(t.TempDir(), "target-file")
635+
if err := DownloadTarball(t.Context(), target, server.URL+"/test.tar.gz", tarball.Sha256); err != nil {
636+
t.Fatal(err)
637+
}
638+
639+
if requestCount != 3 {
640+
t.Errorf("expected 3 requests, got %d", requestCount)
641+
}
642+
got, err := os.ReadFile(target)
643+
if err != nil {
644+
t.Fatal(err)
645+
}
646+
if diff := cmp.Diff(tarball.Contents, got); diff != "" {
647+
t.Errorf("mismatch (-want +got):\n%s", diff)
648+
}
630649
}
631650

632651
func TestLatestCommitAndChecksumFailure(t *testing.T) {

0 commit comments

Comments
 (0)