diff --git a/get_file_test.go b/get_file_test.go index 94ab3c1c1..897b901b8 100644 --- a/get_file_test.go +++ b/get_file_test.go @@ -35,6 +35,33 @@ func TestFileGetter(t *testing.T) { } } +func TestFileGetter_Copy(t *testing.T) { + g := new(FileGetter) + dst := tempDir(t) + + g.Copy = true + + // With a dir that doesn't exist + if err := g.Get(dst, testModuleURL("basic")); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the destination folder is not a symlink + fi, err := os.Lstat(dst) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode()&os.ModeSymlink == 1 { + t.Fatal("destination is a symlink") + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } +} + func TestFileGetter_sourceFile(t *testing.T) { g := new(FileGetter) dst := tempDir(t) @@ -73,6 +100,22 @@ func TestFileGetter_dir(t *testing.T) { } } +func TestFileGetter_dir_Copy(t *testing.T) { + g := new(FileGetter) + dst := tempDir(t) + + g.Copy = true + + if err := os.MkdirAll(dst, 0755); err != nil { + t.Fatalf("err: %s", err) + } + + // With a dir that exists that isn't a symlink + if err := g.Get(dst, testModuleURL("basic")); err != nil { + t.Fatal("should not error") + } +} + func TestFileGetter_dirSymlink(t *testing.T) { g := new(FileGetter) dst := tempDir(t) diff --git a/get_file_unix.go b/get_file_unix.go index a14a38263..aa33badc6 100644 --- a/get_file_unix.go +++ b/get_file_unix.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package getter @@ -30,12 +31,12 @@ func (g *FileGetter) Get(dst string, u *url.URL) error { // If the destination already exists, it must be a symlink if err == nil { mode := fi.Mode() - if mode&os.ModeSymlink == 0 { + if mode&os.ModeSymlink == 0 && !g.Copy { return fmt.Errorf("destination exists and is not a symlink") } // Remove the destination - if err := os.Remove(dst); err != nil { + if err := os.RemoveAll(dst); err != nil { return err } } @@ -45,7 +46,21 @@ func (g *FileGetter) Get(dst string, u *url.URL) error { return err } - return os.Symlink(path, dst) + if !g.Copy { + return os.Symlink(path, dst) + } + + if err := os.Mkdir(dst, g.client.mode(0755)); err != nil { + return err + } + + var disableSymlinks bool + + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true + } + + return copyDir(g.Context(), dst, path, false, disableSymlinks, g.client.umask()) } func (g *FileGetter) GetFile(dst string, u *url.URL) error { diff --git a/get_file_windows.go b/get_file_windows.go index 31146f575..75e48e02e 100644 --- a/get_file_windows.go +++ b/get_file_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package getter @@ -34,12 +35,12 @@ func (g *FileGetter) Get(dst string, u *url.URL) error { // If the destination already exists, it must be a symlink if err == nil { mode := fi.Mode() - if mode&os.ModeSymlink == 0 { + if mode&os.ModeSymlink == 0 && !g.Copy { return fmt.Errorf("destination exists and is not a symlink") } // Remove the destination - if err := os.Remove(dst); err != nil { + if err := os.RemoveAll(dst); err != nil { return err } } @@ -49,15 +50,27 @@ func (g *FileGetter) Get(dst string, u *url.URL) error { return err } - sourcePath := toBackslash(path) + if !g.Copy { + sourcePath := toBackslash(path) + + // Use mklink to create a junction point + output, err := exec.CommandContext(ctx, "cmd", "/c", "mklink", "/J", dst, sourcePath).CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run mklink %v %v: %v %q", dst, sourcePath, err, output) + } + } + + if err := os.Mkdir(dst, g.client.mode(0755)); err != nil { + return err + } + + var disableSymlinks bool - // Use mklink to create a junction point - output, err := exec.CommandContext(ctx, "cmd", "/c", "mklink", "/J", dst, sourcePath).CombinedOutput() - if err != nil { - return fmt.Errorf("failed to run mklink %v %v: %v %q", dst, sourcePath, err, output) + if g.client != nil && g.client.DisableSymlinks { + disableSymlinks = true } - return nil + return copyDir(g.Context(), dst, path, false, disableSymlinks, g.client.umask()) } func (g *FileGetter) GetFile(dst string, u *url.URL) error {