1515package testserver
1616
1717import (
18+ "errors"
1819 "fmt"
1920 "io"
2021 "log"
@@ -28,22 +29,42 @@ import (
2829 "regexp"
2930 "runtime"
3031 "time"
32+
33+ "github.com/gofrs/flock"
3134)
3235
3336const (
3437 latestSuffix = "LATEST"
3538 finishedFileMode = 0555
39+ writingFileMode = 0600 // Allow reads so that another process can check if there's a flock.
3640)
3741
38- func downloadFile (response * http.Response , filePath string ) error {
39- output , err := os .OpenFile (filePath , os .O_WRONLY | os .O_CREATE | os .O_EXCL , 0200 )
42+ func downloadFile (response * http.Response , filePath string , tc * TestConfig ) error {
43+ output , err := os .OpenFile (filePath , os .O_WRONLY | os .O_CREATE | os .O_EXCL , writingFileMode )
4044 if err != nil {
4145 return fmt .Errorf ("error creating %s: %s" , filePath , err )
4246 }
4347 defer func () { _ = output .Close () }()
4448
4549 log .Printf ("saving %s to %s, this may take some time" , response .Request .URL , filePath )
4650
51+ // Assign a flock to the local file.
52+ // If the downloading process is killed in the middle,
53+ // the lock will be automatically dropped.
54+ localFileLock := flock .New (filePath )
55+
56+ if _ , err := localFileLock .TryLock (); err != nil {
57+ return err
58+ }
59+
60+ defer func () { _ = localFileLock .Unlock () }()
61+
62+ if tc .IsTest && tc .StopDownloadInMiddle {
63+ log .Printf ("download process killed" )
64+ output .Close ()
65+ return errStoppedInMiddle
66+ }
67+
4768 if _ , err := io .Copy (output , response .Body ); err != nil {
4869 return fmt .Errorf ("problem saving %s to %s: %s" , response .Request .URL , filePath , err )
4970 }
@@ -53,14 +74,21 @@ func downloadFile(response *http.Response, filePath string) error {
5374 return err
5475 }
5576
77+ if err := localFileLock .Unlock (); err != nil {
78+ return err
79+ }
80+
5681 // We explicitly close here to ensure the error is checked; the deferred
5782 // close above will likely error in this case, but that's harmless.
5883 return output .Close ()
5984}
6085
6186var muslRE = regexp .MustCompile (`(?i)\bmusl\b` )
6287
63- func downloadLatestBinary () (string , error ) {
88+ // GetDownloadResponse return the http response of a CRDB download.
89+ // It creates the url for downloading a CRDB binary for current runtime OS,
90+ // makes a request to this url, and return the response.
91+ func GetDownloadResponse () (* http.Response , error ) {
6492 goos := runtime .GOOS
6593 if goos == "linux" {
6694 goos += func () string {
@@ -88,14 +116,17 @@ func downloadLatestBinary() (string, error) {
88116 log .Printf ("GET %s" , url )
89117 response , err := http .Get (url .String ())
90118 if err != nil {
91- return "" , err
119+ return nil , err
92120 }
93- defer func () { _ = response .Body .Close () }()
94121
95122 if response .StatusCode != 200 {
96- return "" , fmt .Errorf ("error downloading %s: %d (%s)" , url , response .StatusCode , response .Status )
123+ return nil , fmt .Errorf ("error downloading %s: %d (%s)" , url ,
124+ response .StatusCode , response .Status )
97125 }
126+ return response , nil
127+ }
98128
129+ func GetDownloadFilename (response * http.Response ) (string , error ) {
99130 const contentDisposition = "Content-Disposition"
100131 _ , disposition , err := mime .ParseMediaType (response .Header .Get (contentDisposition ))
101132 if err != nil {
@@ -106,6 +137,20 @@ func downloadLatestBinary() (string, error) {
106137 if ! ok {
107138 return "" , fmt .Errorf ("content disposition header %s did not contain filename" , disposition )
108139 }
140+ return filename , nil
141+ }
142+
143+ func downloadLatestBinary (tc * TestConfig ) (string , error ) {
144+ response , err := GetDownloadResponse ()
145+ if err != nil {
146+ return "" , err
147+ }
148+ defer func () { _ = response .Body .Close () }()
149+
150+ filename , err := GetDownloadFilename (response )
151+ if err != nil {
152+ return "" , err
153+ }
109154 localFile := filepath .Join (os .TempDir (), filename )
110155 for {
111156 info , err := os .Stat (localFile )
@@ -120,13 +165,35 @@ func downloadLatestBinary() (string, error) {
120165 if info .Mode ().Perm () == finishedFileMode {
121166 return localFile , nil
122167 }
168+
169+ localFileLock := flock .New (localFile )
170+ // If there's a process downloading the binary, local file cannot be flocked.
171+ locked , err := localFileLock .TryLock ()
172+ if err != nil {
173+ return "" , err
174+ }
175+
176+ if locked {
177+ // If local file can be locked, it means the previous download was
178+ // killed in the middle. Delete local file and re-download.
179+ log .Printf ("previous download failed in the middle, deleting and re-downloading" )
180+ if err := os .Remove (localFile ); err != nil {
181+ log .Printf ("failed to remove partial download %s: %v" , localFile , err )
182+ return "" , err
183+ }
184+ break
185+ }
186+
123187 log .Printf ("waiting for download of %s" , localFile )
124188 time .Sleep (time .Millisecond * 10 )
125189 }
126190
127- if err := downloadFile (response , localFile ); err != nil {
128- if err := os .Remove (localFile ); err != nil {
129- log .Printf ("failed to remove %s: %s" , localFile , err )
191+ err = downloadFile (response , localFile , tc )
192+ if err != nil {
193+ if ! errors .Is (err , errStoppedInMiddle ) {
194+ if err := os .Remove (localFile ); err != nil {
195+ log .Printf ("failed to remove %s: %s" , localFile , err )
196+ }
130197 }
131198 return "" , err
132199 }
0 commit comments