diff --git a/get_git.go b/get_git.go index db89edef8..ff7a6ba10 100644 --- a/get_git.go +++ b/get_git.go @@ -155,8 +155,15 @@ func (g *GitGetter) GetFile(dst string, u *url.URL) error { // Get the filename, and strip the filename from the URL so we can // just get the repository directly. - filename := filepath.Base(u.Path) - u.Path = filepath.Dir(u.Path) + var filename string + if u.Host == "github.com" { + tokens := strings.SplitN(u.Path[1:], "/", 3) + u.Path = "/" + tokens[0] + "/" + tokens[1] + filename = tokens[2] + } else { + filename = filepath.Base(u.Path) + u.Path = filepath.Dir(u.Path) + } // Get the full repository if err := g.Get(td, u); err != nil { diff --git a/get_git_test.go b/get_git_test.go index df6ad0390..1dbeb8989 100644 --- a/get_git_test.go +++ b/get_git_test.go @@ -410,6 +410,49 @@ func TestGitGetter_GetFile(t *testing.T) { assertContents(t, dst, "hello") } +func TestGitGetter_githubGetWithFileMode(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + dst := tempTestFile(t) + defer os.RemoveAll(filepath.Dir(dst)) + + c := Client{ + Src: "git::https://github.com/arikkfir/go-getter/testdata/basic/foo/main.tf?ref=master", + Dst: dst, + Mode: ClientModeFile, + } + if err := c.Get(); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + if _, err := os.Stat(dst); err != nil { + t.Fatalf("err: %s", err) + } + assertContents(t, dst, "# Hello\n") +} + +func TestGitGetter_githubGetFile(t *testing.T) { + if !testHasGit { + t.Skip("git not found, skipping") + } + + dst := tempTestFile(t) + defer os.RemoveAll(filepath.Dir(dst)) + + if err := GetFile(dst, "git::https://github.com/arikkfir/go-getter/testdata/basic/foo/main.tf?ref=master"); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + if _, err := os.Stat(dst); err != nil { + t.Fatalf("err: %s", err) + } + assertContents(t, dst, "# Hello\n") +} + func TestGitGetter_gitVersion(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping")