diff --git a/libvirt/uri/ssh.go b/libvirt/uri/ssh.go index 342f4c029..58e91b0a2 100644 --- a/libvirt/uri/ssh.go +++ b/libvirt/uri/ssh.go @@ -118,7 +118,7 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi // construct the whole ssh connection, which can consist of multiple hops if using proxy jumps, // the ssh configuration file is loaded once and passed along to each host connection. func (u *ConnectionURI) dialSSH() (net.Conn, error) { - var sshcfg* ssh_config.Config = nil + var sshcfg *ssh_config.Config = nil sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile)) if err != nil { @@ -132,7 +132,7 @@ func (u *ConnectionURI) dialSSH() (net.Conn, error) { } // configuration loaded, build tunnel - sshClient, err := u.dialHost(u.Host, sshcfg, 0) + sshClient, err := u.dialHost(parsedTarget{hostName: u.Host}, sshcfg, 0) if err != nil { return nil, err } @@ -152,7 +152,12 @@ func (u *ConnectionURI) dialSSH() (net.Conn, error) { return c, nil } -func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth int) (*ssh.Client, error) { +type parsedTarget struct { + hostName string + user string +} + +func (u *ConnectionURI) dialHost(target parsedTarget, sshcfg *ssh_config.Config, depth int) (*ssh.Client, error) { if depth > maxHostHops { return nil, fmt.Errorf("[ERROR] dialHost failed: max tunnel depth of 10 reached") @@ -169,9 +174,9 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth log.Printf("[DEBUG] ssh Port is overridden to: '%s'", port) } - hostName := target + hostName := target.hostName if sshcfg != nil { - host, err := sshcfg.Get(target, "HostName") + host, err := sshcfg.Get(target.hostName, "HostName") if err == nil && host != "" { hostName = host log.Printf("[DEBUG] HostName is overridden to: '%s'", hostName) @@ -188,7 +193,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth skipVerify = true } else { if sshcfg != nil { - strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking") + strictCheck, err := sshcfg.Get(target.hostName, "StrictHostKeyChecking") if err != nil && strictCheck == "yes" { skipVerify = false } @@ -199,7 +204,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth knownHostsPath = defaultSSHKnownHostsPath if sshcfg != nil { - knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile") + knownHosts, err := sshcfg.Get(target.hostName, "UserKnownHostsFile") if err == nil && knownHosts != "" { knownHostsPath = knownHosts } @@ -236,7 +241,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth } if sshcfg != nil { - keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms") + keyAlgs, err := sshcfg.Get(target.hostName, "HostKeyAlgorithms") if err == nil && keyAlgs != "" { log.Printf("[DEBUG] HostKeyAlgorithms is overridden to '%s'", keyAlgs) hostKeyAlgorithms = strings.Split(keyAlgs, ",") @@ -251,22 +256,25 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth HostKeyAlgorithms: hostKeyAlgorithms, Timeout: dialTimeout, } + if target.user != "" { + cfg.User = target.user + } var bastion *ssh.Client = nil var bastion_proxy string = "" if sshcfg != nil { - command, err := sshcfg.Get(target, "ProxyCommand") + command, err := sshcfg.Get(target.hostName, "ProxyCommand") if err == nil && command != "" { log.Printf("[WARNING] unsupported ssh ProxyCommand '%v' - ignoring", command) } } if sshcfg != nil { - proxy, err := sshcfg.Get(target, "ProxyJump") - if err == nil && proxy != "" { + proxy, err := sshcfg.Get(target.hostName, "ProxyJump") + if err == nil && (proxy != "" && proxy != "none") { log.Printf("[DEBUG] found ProxyJump '%v'", proxy) // this is a proxy jump: we recurse into that proxy - bastion, err = u.dialHost(proxy, sshcfg, depth+1) + bastion, err = u.dialHost(proxyJumpStringToParsedTarget(proxy), sshcfg, depth+1) bastion_proxy = proxy if err != nil { return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err) @@ -276,15 +284,14 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth // cfg.User value defaults to u.User.Username() if sshcfg != nil { - sshu, err := sshcfg.Get(target, "User") + sshu, err := sshcfg.Get(target.hostName, "User") if err != nil { log.Printf("[DEBUG] ssh user for target '%v' is overridden to '%v'", target, sshu) cfg.User = sshu } } - - cfg.Auth = u.parseAuthMethods(target, sshcfg) + cfg.Auth = u.parseAuthMethods(target.hostName, sshcfg) if len(cfg.Auth) < 1 { return nil, fmt.Errorf("could not configure SSH authentication methods") } @@ -298,7 +305,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth return nil, fmt.Errorf("failed to connect to remote host '%v': %w", target, err) } - ncc, chans, reqs, err := ssh.NewClientConn(conn, target, &cfg) + ncc, chans, reqs, err := ssh.NewClientConn(conn, target.hostName, &cfg) if err != nil { return nil, fmt.Errorf("failed to connect to remote host '%v': %w", target, err) } @@ -317,3 +324,17 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth return conn, nil } } + +func proxyJumpStringToParsedTarget(s string) parsedTarget { + atIdx := strings.Index(s, "@") + if atIdx < 0 { + return parsedTarget{ + hostName: s, + } + } + + return parsedTarget{ + hostName: s[atIdx+1:], + user: s[:atIdx], + } +} diff --git a/libvirt/uri/ssh_test.go b/libvirt/uri/ssh_test.go new file mode 100644 index 000000000..a85ea394a --- /dev/null +++ b/libvirt/uri/ssh_test.go @@ -0,0 +1,29 @@ +package uri + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProxyJumpStringToParsedTarget(t *testing.T) { + in := []string{ + "host.enterprise.com", + "user@host.enterprise.com", + } + expectedOut := []parsedTarget{ + { + hostName: "host.enterprise.com", + }, + { + hostName: "host.enterprise.com", + user: "user", + }, + } + + out := []parsedTarget{} + for _, proxyJumpStr := range in { + out = append(out, proxyJumpStringToParsedTarget(proxyJumpStr)) + } + assert.Equal(t, expectedOut, out) +}