Skip to content

Commit f81fa90

Browse files
authored
Merge pull request #2903 from nirs/download-time
Fix races during parallel downloads
2 parents 3fdfaeb + 5071535 commit f81fa90

File tree

2 files changed

+132
-95
lines changed

2 files changed

+132
-95
lines changed

pkg/downloader/downloader.go

Lines changed: 92 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,31 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
229229
return res, nil
230230
}
231231

232+
shad := cacheDirectoryPath(o.cacheDir, remote)
233+
if err := os.MkdirAll(shad, 0o700); err != nil {
234+
return nil, err
235+
}
236+
237+
var res *Result
238+
err := lockutil.WithDirLock(shad, func() error {
239+
var err error
240+
res, err = getCached(ctx, localPath, remote, o)
241+
if err != nil {
242+
return err
243+
}
244+
if res != nil {
245+
return nil
246+
}
247+
res, err = fetch(ctx, localPath, remote, o)
248+
return err
249+
})
250+
return res, err
251+
}
252+
253+
// getCached tries to copy the file from the cache to local path. Return result,
254+
// nil if the file was copied, nil, nil if the file is not in the cache or the
255+
// cache needs update, or nil, error on fatal error.
256+
func getCached(ctx context.Context, localPath, remote string, o options) (*Result, error) {
232257
shad := cacheDirectoryPath(o.cacheDir, remote)
233258
shadData := filepath.Join(shad, "data")
234259
shadTime := filepath.Join(shad, "time")
@@ -237,53 +262,62 @@ func Download(ctx context.Context, local, remote string, opts ...Opt) (*Result,
237262
if err != nil {
238263
return nil, err
239264
}
240-
if _, err := os.Stat(shadData); err == nil {
241-
logrus.Debugf("file %q is cached as %q", localPath, shadData)
242-
useCache := true
243-
if _, err := os.Stat(shadDigest); err == nil {
244-
logrus.Debugf("Comparing digest %q with the cached digest file %q, not computing the actual digest of %q",
245-
o.expectedDigest, shadDigest, shadData)
246-
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
247-
return nil, err
248-
}
249-
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
265+
if _, err := os.Stat(shadData); err != nil {
266+
return nil, nil
267+
}
268+
ext := path.Ext(remote)
269+
logrus.Debugf("file %q is cached as %q", localPath, shadData)
270+
if _, err := os.Stat(shadDigest); err == nil {
271+
logrus.Debugf("Comparing digest %q with the cached digest file %q, not computing the actual digest of %q",
272+
o.expectedDigest, shadDigest, shadData)
273+
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
274+
return nil, err
275+
}
276+
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, "", ""); err != nil {
277+
return nil, err
278+
}
279+
} else {
280+
if match, lmCached, lmRemote, err := matchLastModified(ctx, shadTime, remote); err != nil {
281+
logrus.WithError(err).Info("Failed to retrieve last-modified for cached digest-less image; using cached image.")
282+
} else if match {
283+
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil {
250284
return nil, err
251285
}
252286
} else {
253-
if match, lmCached, lmRemote, err := matchLastModified(ctx, shadTime, remote); err != nil {
254-
logrus.WithError(err).Info("Failed to retrieve last-modified for cached digest-less image; using cached image.")
255-
} else if match {
256-
if err := copyLocal(ctx, localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil {
257-
return nil, err
258-
}
259-
} else {
260-
logrus.Infof("Re-downloading digest-less image: last-modified mismatch (cached: %q, remote: %q)", lmCached, lmRemote)
261-
useCache = false
262-
}
263-
}
264-
if useCache {
265-
res := &Result{
266-
Status: StatusUsedCache,
267-
CachePath: shadData,
268-
LastModified: readTime(shadTime),
269-
ContentType: readFile(shadType),
270-
ValidatedDigest: o.expectedDigest != "",
271-
}
272-
return res, nil
287+
logrus.Infof("Re-downloading digest-less image: last-modified mismatch (cached: %q, remote: %q)", lmCached, lmRemote)
288+
return nil, nil
273289
}
274290
}
275-
if err := os.MkdirAll(shad, 0o700); err != nil {
291+
res := &Result{
292+
Status: StatusUsedCache,
293+
CachePath: shadData,
294+
LastModified: readTime(shadTime),
295+
ContentType: readFile(shadType),
296+
ValidatedDigest: o.expectedDigest != "",
297+
}
298+
return res, nil
299+
}
300+
301+
// fetch downloads remote to the cache and copy the cached file to local path.
302+
func fetch(ctx context.Context, localPath, remote string, o options) (*Result, error) {
303+
shad := cacheDirectoryPath(o.cacheDir, remote)
304+
shadData := filepath.Join(shad, "data")
305+
shadTime := filepath.Join(shad, "time")
306+
shadType := filepath.Join(shad, "type")
307+
shadDigest, err := cacheDigestPath(shad, o.expectedDigest)
308+
if err != nil {
276309
return nil, err
277310
}
311+
ext := path.Ext(remote)
278312
shadURL := filepath.Join(shad, "url")
279-
if err := writeFirst(shadURL, []byte(remote), 0o644); err != nil {
313+
if err := os.WriteFile(shadURL, []byte(remote), 0o644); err != nil {
280314
return nil, err
281315
}
282316
if err := downloadHTTP(ctx, shadData, shadTime, shadType, remote, o.description, o.expectedDigest); err != nil {
283317
return nil, err
284318
}
285319
if shadDigest != "" && o.expectedDigest != "" {
286-
if err := writeFirst(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil {
320+
if err := os.WriteFile(shadDigest, []byte(o.expectedDigest.String()), 0o644); err != nil {
287321
return nil, err
288322
}
289323
}
@@ -327,18 +361,33 @@ func Cached(remote string, opts ...Opt) (*Result, error) {
327361
if err != nil {
328362
return nil, err
329363
}
364+
365+
// Checking if data file exists is safe without locking.
330366
if _, err := os.Stat(shadData); err != nil {
331367
return nil, err
332368
}
333-
if _, err := os.Stat(shadDigest); err != nil {
334-
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
335-
return nil, err
336-
}
337-
} else {
338-
if err := validateLocalFileDigest(shadData, o.expectedDigest); err != nil {
339-
return nil, err
369+
370+
// But validating the digest or the data file must take the lock to avoid races
371+
// with parallel downloads.
372+
if err := os.MkdirAll(shad, 0o700); err != nil {
373+
return nil, err
374+
}
375+
err = lockutil.WithDirLock(shad, func() error {
376+
if _, err := os.Stat(shadDigest); err != nil {
377+
if err := validateCachedDigest(shadDigest, o.expectedDigest); err != nil {
378+
return err
379+
}
380+
} else {
381+
if err := validateLocalFileDigest(shadData, o.expectedDigest); err != nil {
382+
return err
383+
}
340384
}
385+
return nil
386+
})
387+
if err != nil {
388+
return nil, err
341389
}
390+
342391
res := &Result{
343392
Status: StatusUsedCache,
344393
CachePath: shadData,
@@ -612,13 +661,13 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url
612661
}
613662
if lastModified != "" {
614663
lm := resp.Header.Get("Last-Modified")
615-
if err := writeFirst(lastModified, []byte(lm), 0o644); err != nil {
664+
if err := os.WriteFile(lastModified, []byte(lm), 0o644); err != nil {
616665
return err
617666
}
618667
}
619668
if contentType != "" {
620669
ct := resp.Header.Get("Content-Type")
621-
if err := writeFirst(contentType, []byte(ct), 0o644); err != nil {
670+
if err := os.WriteFile(contentType, []byte(ct), 0o644); err != nil {
622671
return err
623672
}
624673
}
@@ -679,19 +728,7 @@ func downloadHTTP(ctx context.Context, localPath, lastModified, contentType, url
679728
return err
680729
}
681730

682-
// If localPath was created by a parallel download keep it. Replacing it
683-
// while another process is copying it to the destination may fail the
684-
// clonefile syscall. We use a lock to ensure that only one process updates
685-
// data, and when we return data file exists.
686-
687-
return lockutil.WithDirLock(filepath.Dir(localPath), func() error {
688-
if _, err := os.Stat(localPath); err == nil {
689-
return nil
690-
} else if !errors.Is(err, os.ErrNotExist) {
691-
return err
692-
}
693-
return os.Rename(localPathTmp, localPath)
694-
})
731+
return os.Rename(localPathTmp, localPath)
695732
}
696733

697734
var tempfileCount atomic.Uint64
@@ -706,18 +743,6 @@ func perProcessTempfile(path string) string {
706743
return fmt.Sprintf("%s.tmp.%d.%d", path, os.Getpid(), tempfileCount.Add(1))
707744
}
708745

709-
// writeFirst writes data to path unless path already exists.
710-
func writeFirst(path string, data []byte, perm os.FileMode) error {
711-
return lockutil.WithDirLock(filepath.Dir(path), func() error {
712-
if _, err := os.Stat(path); err == nil {
713-
return nil
714-
} else if !errors.Is(err, os.ErrNotExist) {
715-
return err
716-
}
717-
return os.WriteFile(path, data, perm)
718-
})
719-
}
720-
721746
// CacheEntries returns a map of cache entries.
722747
// The key is the SHA256 of the URL.
723748
// The value is the path to the cache entry.

pkg/downloader/downloader_test.go

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"os/exec"
99
"path/filepath"
1010
"runtime"
11-
"slices"
1211
"strings"
1312
"testing"
1413
"time"
@@ -31,11 +30,6 @@ type downloadResult struct {
3130
// races quicker. 20 parallel downloads take about 120 milliseconds on M1 Pro.
3231
const parallelDownloads = 20
3332

34-
// When downloading in parallel usually all downloads completed with
35-
// StatusDownload, but some may be delayed and find the data file when they
36-
// start. Can be reproduced locally using 100 parallel downloads.
37-
var parallelStatus = []Status{StatusDownloaded, StatusUsedCache}
38-
3933
func TestDownloadRemote(t *testing.T) {
4034
ts := httptest.NewServer(http.FileServer(http.Dir("testdata")))
4135
t.Cleanup(ts.Close)
@@ -103,15 +97,10 @@ func TestDownloadRemote(t *testing.T) {
10397
results <- downloadResult{r, err}
10498
}()
10599
}
106-
// We must process all results before cleanup.
107-
for i := 0; i < parallelDownloads; i++ {
108-
result := <-results
109-
if result.err != nil {
110-
t.Errorf("Download failed: %s", result.err)
111-
} else if !slices.Contains(parallelStatus, result.r.Status) {
112-
t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status)
113-
}
114-
}
100+
// Only one thread should download, the rest should use the cache.
101+
downloaded, cached := countResults(t, results)
102+
assert.Equal(t, downloaded, 1)
103+
assert.Equal(t, cached, parallelDownloads-1)
115104
})
116105
})
117106
t.Run("caching-only mode", func(t *testing.T) {
@@ -146,15 +135,10 @@ func TestDownloadRemote(t *testing.T) {
146135
results <- downloadResult{r, err}
147136
}()
148137
}
149-
// We must process all results before cleanup.
150-
for i := 0; i < parallelDownloads; i++ {
151-
result := <-results
152-
if result.err != nil {
153-
t.Errorf("Download failed: %s", result.err)
154-
} else if !slices.Contains(parallelStatus, result.r.Status) {
155-
t.Errorf("Expected download status %s, got %s", parallelStatus, result.r.Status)
156-
}
157-
}
138+
// Only one thread should download, the rest should use the cache.
139+
downloaded, cached := countResults(t, results)
140+
assert.Equal(t, downloaded, 1)
141+
assert.Equal(t, cached, parallelDownloads-1)
158142
})
159143
})
160144
t.Run("cached", func(t *testing.T) {
@@ -188,6 +172,26 @@ func TestDownloadRemote(t *testing.T) {
188172
})
189173
}
190174

175+
func countResults(t *testing.T, results chan downloadResult) (downloaded, cached int) {
176+
t.Helper()
177+
for i := 0; i < parallelDownloads; i++ {
178+
result := <-results
179+
if result.err != nil {
180+
t.Errorf("Download failed: %s", result.err)
181+
} else {
182+
switch result.r.Status {
183+
case StatusDownloaded:
184+
downloaded++
185+
case StatusUsedCache:
186+
cached++
187+
default:
188+
t.Errorf("Unexpected download status %q", result.r.Status)
189+
}
190+
}
191+
}
192+
return downloaded, cached
193+
}
194+
191195
func TestRedownloadRemote(t *testing.T) {
192196
remoteDir := t.TempDir()
193197
ts := httptest.NewServer(http.FileServer(http.Dir(remoteDir)))
@@ -203,18 +207,26 @@ func TestRedownloadRemote(t *testing.T) {
203207
assert.NilError(t, os.Chtimes(remoteFile, time.Now(), time.Now().Add(-time.Hour)))
204208
opt := []Opt{cacheOpt}
205209

206-
r, err := Download(context.Background(), filepath.Join(downloadDir, "digest-less1.txt"), ts.URL+"/digest-less.txt", opt...)
210+
// Download on the first call
211+
r, err := Download(context.Background(), filepath.Join(downloadDir, "1"), ts.URL+"/digest-less.txt", opt...)
207212
assert.NilError(t, err)
208213
assert.Equal(t, StatusDownloaded, r.Status)
209-
r, err = Download(context.Background(), filepath.Join(downloadDir, "digest-less2.txt"), ts.URL+"/digest-less.txt", opt...)
214+
215+
// Next download will use the cached download
216+
r, err = Download(context.Background(), filepath.Join(downloadDir, "2"), ts.URL+"/digest-less.txt", opt...)
210217
assert.NilError(t, err)
211218
assert.Equal(t, StatusUsedCache, r.Status)
212219

213-
// modifying remote file will cause redownload
220+
// Modifying remote file will cause redownload
214221
assert.NilError(t, os.Chtimes(remoteFile, time.Now(), time.Now()))
215-
r, err = Download(context.Background(), filepath.Join(downloadDir, "digest-less3.txt"), ts.URL+"/digest-less.txt", opt...)
222+
r, err = Download(context.Background(), filepath.Join(downloadDir, "3"), ts.URL+"/digest-less.txt", opt...)
216223
assert.NilError(t, err)
217224
assert.Equal(t, StatusDownloaded, r.Status)
225+
226+
// Next download will use the cached download
227+
r, err = Download(context.Background(), filepath.Join(downloadDir, "4"), ts.URL+"/digest-less.txt", opt...)
228+
assert.NilError(t, err)
229+
assert.Equal(t, StatusUsedCache, r.Status)
218230
})
219231

220232
t.Run("has-digest", func(t *testing.T) {

0 commit comments

Comments
 (0)