diff --git a/libvirt/uri/ssh.go b/libvirt/uri/ssh.go index 342f4c029..2a42ba540 100644 --- a/libvirt/uri/ssh.go +++ b/libvirt/uri/ssh.go @@ -5,13 +5,14 @@ import ( "log" "net" "os" - "path/filepath" "strings" "github.com/kevinburke/ssh_config" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/knownhosts" + + "github.com/dmacvicar/terraform-provider-libvirt/libvirt/util" ) const ( @@ -80,13 +81,7 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi case "privkey": for _, keypath := range sshKeyPaths { log.Printf("[DEBUG] Reading ssh key '%s'", keypath) - path := os.ExpandEnv(keypath) - if strings.HasPrefix(path, "~/") { - home, err := os.UserHomeDir() - if err == nil { - path = filepath.Join(home, path[2:]) - } - } + path := util.ExpandPath(keypath) sshKey, err := os.ReadFile(path) if err != nil { log.Printf("[ERROR] Failed to read ssh key '%s': %v", keypath, err) @@ -119,8 +114,8 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi // the ssh configuration file is loaded once and passed along to each host connection. func (u *ConnectionURI) dialSSH() (net.Conn, error) { var sshcfg* ssh_config.Config = nil + sshConfigFile, err := os.Open(util.ExpandPath(defaultSSHConfigFile)) - sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile)) if err != nil { log.Printf("[WARN] Failed to open ssh config file: %v", err) } else { @@ -221,11 +216,11 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth ssh.KeyAlgoECDSA521, } if !skipVerify { - kh, err := knownhosts.New(os.ExpandEnv(knownHostsPath)) + kh, err := knownhosts.New(util.ExpandPath(knownHostsPath)) if err != nil { return nil, fmt.Errorf("failed to read ssh known hosts: %w", err) } - log.Printf("[DEBUG] Using known hosts file '%s' for target '%s'", os.ExpandEnv(knownHostsPath), target) + log.Printf("[DEBUG] Using known hosts file '%s' for target '%s'", util.ExpandPath(knownHostsPath), target) hostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { err := kh(net.JoinHostPort(hostName, port), remote, key) diff --git a/libvirt/util/expandenv.go b/libvirt/util/expandenv.go new file mode 100644 index 000000000..ab5450232 --- /dev/null +++ b/libvirt/util/expandenv.go @@ -0,0 +1,34 @@ +package util + +import ( + "os" + "path/filepath" + "strings" +) + +var ( + userHomeDir = os.UserHomeDir + expandEnv = os.ExpandEnv +) + +// ExpandPath expands environment variables and resolves ~ to the home directory +// this is a drop-in replacement for os.ExpandEnv but is additionally '~' aware. +func ExpandPath(path string) string { + path = filepath.Clean(expandEnv(path)) + tilde := filepath.FromSlash("~/") + + // note to maintainers: tilde without a following slash character is simply + // interpreted as part of the filename (e.g. ~foo/bar != ~/foo/bar). However, + // when running on windows, the filepath will be represented by backslashes ('\'), + // therefore we need to convert "~/" to the platform specific format to test for + // it, otherwise on windows systems the prefix test will always fail. + if strings.HasPrefix(path, tilde) { + home, err := userHomeDir() + if err != nil { + return path // return path as-is if unable to resolve home directory + } + // Replace ~ with home directory + path = filepath.Join(home, strings.TrimPrefix(path, tilde)) + } + return path +} diff --git a/libvirt/util/expandenv_test.go b/libvirt/util/expandenv_test.go new file mode 100644 index 000000000..016f4957f --- /dev/null +++ b/libvirt/util/expandenv_test.go @@ -0,0 +1,32 @@ +package util + +import ( + "fmt" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExpandPath(t *testing.T) { + userHomeDir = func() (string, error) { + return "/home/mock", nil + } + expandEnv = func(s string) string { + return strings.Replace(s, "${HOME}", "/home/mock", 1) + } + + + assert.Equal(t, filepath.FromSlash("foo/bar/baz"), ExpandPath("foo/bar/baz")) + assert.Equal(t, filepath.FromSlash("/home/mock/foo/bar/baz"), ExpandPath("~/foo/bar/baz")) + assert.Equal(t, filepath.FromSlash("/home/mock/foo/bar/baz"), ExpandPath("${HOME}/foo/bar/baz")) + assert.Equal(t, filepath.FromSlash("~foo/bar/baz"), ExpandPath("~foo/bar/baz")) + + userHomeDir = func() (string, error) { + return "", fmt.Errorf("some failure") + } + + // failure to get home expansion should leave string unchanged + assert.Equal(t, filepath.FromSlash("~/foo/bar/baz"), ExpandPath("~/foo/bar/baz")) +}