55 "bufio"
66 "errors"
77 "fmt"
8+ "io"
89 stdioutil "io/ioutil"
910 "os"
1011 "strings"
@@ -242,7 +243,39 @@ func (d *DotGit) Object(h plumbing.Hash) (billy.File, error) {
242243 return d .fs .Open (file )
243244}
244245
245- func (d * DotGit ) SetRef (r * plumbing.Reference ) error {
246+ func (d * DotGit ) readReferenceFrom (rd io.Reader , name string ) (ref * plumbing.Reference , err error ) {
247+ b , err := stdioutil .ReadAll (rd )
248+ if err != nil {
249+ return nil , err
250+ }
251+
252+ line := strings .TrimSpace (string (b ))
253+ return plumbing .NewReferenceFromStrings (name , line ), nil
254+ }
255+
256+ func (d * DotGit ) checkReferenceAndTruncate (f billy.File , old * plumbing.Reference ) error {
257+ if old == nil {
258+ return nil
259+ }
260+ ref , err := d .readReferenceFrom (f , old .Name ().String ())
261+ if err != nil {
262+ return err
263+ }
264+ if ref .Hash () != old .Hash () {
265+ return fmt .Errorf ("reference has changed concurrently" )
266+ }
267+ _ , err = f .Seek (0 , io .SeekStart )
268+ if err != nil {
269+ return err
270+ }
271+ err = f .Truncate (0 )
272+ if err != nil {
273+ return err
274+ }
275+ return nil
276+ }
277+
278+ func (d * DotGit ) SetRef (r , old * plumbing.Reference ) error {
246279 var content string
247280 switch r .Type () {
248281 case plumbing .SymbolicReference :
@@ -251,13 +284,34 @@ func (d *DotGit) SetRef(r *plumbing.Reference) error {
251284 content = fmt .Sprintln (r .Hash ().String ())
252285 }
253286
254- f , err := d .fs .Create (r .Name ().String ())
287+ // If we are not checking an old ref, just truncate the file.
288+ mode := os .O_RDWR | os .O_CREATE
289+ if old == nil {
290+ mode |= os .O_TRUNC
291+ }
292+
293+ f , err := d .fs .OpenFile (r .Name ().String (), mode , 0666 )
255294 if err != nil {
256295 return err
257296 }
258297
259298 defer ioutil .CheckClose (f , & err )
260299
300+ // Lock is unlocked by the deferred Close above. This is because Unlock
301+ // does not imply a fsync and thus there would be a race between
302+ // Unlock+Close and other concurrent writers. Adding Sync to go-billy
303+ // could work, but this is better (and avoids superfluous syncs).
304+ err = f .Lock ()
305+ if err != nil {
306+ return err
307+ }
308+
309+ // this is a no-op to call even when old is nil.
310+ err = d .checkReferenceAndTruncate (f , old )
311+ if err != nil {
312+ return err
313+ }
314+
261315 _ , err = f .Write ([]byte (content ))
262316 return err
263317}
@@ -512,13 +566,7 @@ func (d *DotGit) readReferenceFile(path, name string) (ref *plumbing.Reference,
512566 }
513567 defer ioutil .CheckClose (f , & err )
514568
515- b , err := stdioutil .ReadAll (f )
516- if err != nil {
517- return nil , err
518- }
519-
520- line := strings .TrimSpace (string (b ))
521- return plumbing .NewReferenceFromStrings (name , line ), nil
569+ return d .readReferenceFrom (f , name )
522570}
523571
524572// Module return a billy.Filesystem poiting to the module folder
0 commit comments