diff --git a/libvirt/config.go b/libvirt/config.go index a382479e1..b41290b66 100644 --- a/libvirt/config.go +++ b/libvirt/config.go @@ -3,11 +3,14 @@ package libvirt import ( "fmt" "log" + "net/url" + "strings" "sync" libvirt "github.com/digitalocean/go-libvirt" + "github.com/dmacvicar/terraform-provider-libvirt/libvirt/dialers" "github.com/dmacvicar/terraform-provider-libvirt/libvirt/helper/mutexkv" - uri "github.com/dmacvicar/terraform-provider-libvirt/libvirt/uri" + luri "github.com/dmacvicar/terraform-provider-libvirt/libvirt/uri" ) // Config struct for the libvirt-provider. @@ -26,15 +29,55 @@ type Client struct { // Client libvirt, returns a libvirt client for a config. func (c *Config) Client() (*Client, error) { - u, err := uri.Parse(c.URI) + uri, err := url.Parse(c.URI) if err != nil { return nil, err } - l := libvirt.NewWithDialer(u) + var l *libvirt.Libvirt - if err := l.ConnectToURI(libvirt.ConnectURI(u.RemoteName())); err != nil { - return nil, fmt.Errorf("failed to connect: %w", err) + // Check if we should use the command-line SSH tool + useSSHCmd := uri.Query().Has("use_ssh_cmd") + + // We only use our custom SSH command dialer if: + // 1. The use_ssh_cmd parameter is present + // 2. The URI scheme contains +ssh (like qemu+ssh) + if useSSHCmd && strings.Contains(uri.Scheme, "+ssh") { + // Remove the special param to not interfere with other URI processing + q := uri.Query() + q.Del("use_ssh_cmd") + uri.RawQuery = q.Encode() + + // Create a dialer using the SSH command-line tool + sshDialer := dialers.NewSSHCmdDialer(uri) + + // Use NewWithDialer to create a connection with the custom dialer + l = libvirt.NewWithDialer(sshDialer) + + // Connect to the remote URI + remoteName := libvirt.RemoteURI(uri) + if err := l.ConnectToURI(remoteName); err != nil { + return nil, fmt.Errorf("failed to connect to libvirt with ssh command: %w", err) + } + + log.Printf("[INFO] Connected to libvirt using SSH command-line tool") + } else if strings.Contains(uri.Scheme, "+ssh") { + u, err := luri.Parse(c.URI) + if err != nil { + return nil, err + } + + l = libvirt.NewWithDialer(u) + + if err := l.ConnectToURI(libvirt.ConnectURI(u.RemoteName())); err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } + } else { + // Use the default connection method + l, err = libvirt.ConnectToURI(uri) + if err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } } v, err := l.ConnectGetLibVersion() diff --git a/libvirt/dialers/sshcmd.go b/libvirt/dialers/sshcmd.go new file mode 100644 index 000000000..c597e80c0 --- /dev/null +++ b/libvirt/dialers/sshcmd.go @@ -0,0 +1,444 @@ +package dialers + +import ( + "bufio" + "container/ring" + "context" + "fmt" + "io" + "log" + "net" + "net/url" + "os" + "os/exec" + "strings" + "sync" + "time" +) + +const ( + defaultUnixSock = "/var/run/libvirt/libvirt-sock" + defaultNetcatBin = "nc" + defaultHelperBin = "virt-ssh-helper" + defaultSSHBin = "ssh" +) + +// https://libvirt.org/uri.html#proxy-parameter +type ProxyMode string + +const ( + ProxyAuto ProxyMode = "auto" + ProxyNative ProxyMode = "native" + ProxyNetcat ProxyMode = "netcat" +) + +// SSHCmdDialer implements socket.Dialer interface for go-libvirt +// It uses the command-line ssh tool for communication, which automatically +// respects OpenSSH config settings in ~/.ssh/config. +type SSHCmdDialer struct { + // Connection details + hostname string + port string + username string + socket string + remoteURI string // Remote URI to pass to virt-ssh-helper + + // Proxy configuration + proxyMode ProxyMode + netcatBin string // Netcat binary to use when needed + + // SSH options + sshBin string + keyFiles []string + knownHostsFile string + strictHostCheck bool + batchMode bool + forwardAgent bool + authMethods []string // Authentication methods to use + extraArgs []string // Any additional SSH arguments +} + +func NewSSHCmdDialer(uri *url.URL) *SSHCmdDialer { + hostname := uri.Hostname() + if hostname == "" { + hostname = "localhost" + } + + query := uri.Query() + + // Extract the driver part of the URI (qemu, lxc, etc.) + driver := strings.Split(uri.Scheme, "+")[0] + + // Construct the remote URI (e.g., "qemu:///system") + remoteName := driver + ":///system" + if uri.Path != "" && uri.Path != "/" { + remoteName = driver + "://" + uri.Path + } + + dialer := &SSHCmdDialer{ + // Connection details + hostname: hostname, + port: uri.Port(), + socket: defaultUnixSock, + remoteURI: remoteName, + + // Proxy configuration + proxyMode: ProxyAuto, // default like upstream libvirt + netcatBin: defaultNetcatBin, + + // SSH options + sshBin: defaultSSHBin, + strictHostCheck: true, + batchMode: true, + forwardAgent: false, + authMethods: []string{}, + keyFiles: []string{}, + } + + if uri.User != nil { + dialer.username = uri.User.Username() + } + + if socketParam := query.Get("socket"); socketParam != "" { + dialer.socket = socketParam + } + + if proxyParam := query.Get("proxy"); proxyParam != "" { + switch ProxyMode(proxyParam) { + case ProxyAuto, ProxyNative, ProxyNetcat: + dialer.proxyMode = ProxyMode(proxyParam) + default: + log.Printf("[WARN] Unknown proxy mode: %s, using 'auto'", proxyParam) + } + } + + if netcatParam := query.Get("netcat"); netcatParam != "" { + dialer.netcatBin = netcatParam + } + + dialer.applyURIOptions(uri) + + return dialer +} + +// see: https://libvirt.org/uri.html#ssh-transport +func (d *SSHCmdDialer) applyURIOptions(uri *url.URL) { + query := uri.Query() + + if keyFile := query.Get("keyfile"); keyFile != "" { + keyFile = os.ExpandEnv(keyFile) + if strings.HasPrefix(keyFile, "~") { + if home, err := os.UserHomeDir(); err == nil { + keyFile = strings.Replace(keyFile, "~", home, 1) + } + } + d.keyFiles = append(d.keyFiles, keyFile) + } + + if knownHosts := query.Get("knownhosts"); knownHosts != "" { + knownHosts = os.ExpandEnv(knownHosts) + if strings.HasPrefix(knownHosts, "~") { + if home, err := os.UserHomeDir(); err == nil { + knownHosts = strings.Replace(knownHosts, "~", home, 1) + } + } + d.knownHostsFile = knownHosts + } + + knownHostsVerify := query.Get("known_hosts_verify") + if knownHostsVerify == "ignore" || query.Has("no_verify") { + d.strictHostCheck = false + } + + if command := query.Get("command"); command != "" { + d.sshBin = command + } + + if proxy := query.Get("proxy"); proxy != "" { + switch proxy { + case string(ProxyAuto): + d.proxyMode = ProxyAuto + case string(ProxyNative): + d.proxyMode = ProxyNative + case string(ProxyNetcat): + d.proxyMode = ProxyNetcat + default: + log.Printf("[WARN] Unknown proxy mode: %s, using 'auto'", proxy) + + } + } + + if sshAuth := query.Get("sshauth"); sshAuth != "" { + authMethods := strings.Split(sshAuth, ",") + d.authMethods = authMethods + + // Set specific options based on auth methods + for _, auth := range authMethods { + switch auth { + case "agent": + d.forwardAgent = true + case "password", "keyboard-interactive": + d.batchMode = false // Disable batch mode for interactive auth + } + } + } + + // TODO mode parameter +} + +// Dial implements the socket.Dialer interface to enable using this dialer with go-libvirt. +func (d *SSHCmdDialer) Dial() (net.Conn, error) { + args := d.buildSSHArgs() + + log.Printf("[INFO] SSH command dialer connecting to %s with args: %v", d.hostname, args) + + //nolint:gosec + cmd := exec.Command(d.sshBin, args...) + + var err error + + var stdout io.ReadCloser + if stdout, err = cmd.StdoutPipe(); err != nil { + return nil, fmt.Errorf("failed to acquire stdout pipe: %w", err) + } + + var stdin io.WriteCloser + if stdin, err = cmd.StdinPipe(); err != nil { + return nil, fmt.Errorf("failed to acquire stdin pipe: %w", err) + + } + + var stderr io.ReadCloser + if stderr, err = cmd.StderrPipe(); err != nil { + return nil, fmt.Errorf("failed to acquire stdout pipe: %w", err) + + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start ssh command: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + + // custom net.Conn implementation that communicates with the ssh process + conn := &sshCmdConn{ + cmd: cmd, + stdin: stdin, + stdout: stdout, + stderr: stderr, + cancel: cancel, + hostAndPort: d.hostname, + remoteSocket: d.socket, + lastStdErrLines: ring.New(5), + } + + go func() { + err := cmd.Wait() + if err != nil { + log.Printf("[ERROR] SSH command exited unexpectedly: %v", err) + } + + cancel() // Ensure cleanup is triggered + }() + + // Monitor the process in a goroutine + go func() { + defer cancel() + <-ctx.Done() + if cmd.Process != nil { + if err := cmd.Process.Kill(); err != nil { + log.Printf("[ERROR] Failed to kill ssh command: %v", err) + } + } + }() + + // collect std err to give context to any errors later + go func() { + scanner := bufio.NewScanner(stderr) + for scanner.Scan() { + conn.appendStderrLine(scanner.Text()) + log.Printf("[WARN] ssh: %s", scanner.Text()) + } + }() + + // Wait for initial connection (give ssh some time to establish the connection) + //nolint:mnd + time.Sleep(100 * time.Millisecond) + if cmd.ProcessState != nil && cmd.ProcessState.Exited() { + return nil, fmt.Errorf("ssh command terminated prematurely with exit code %d:\n%s", cmd.ProcessState.ExitCode(), strings.Join(conn.lastStderrLines(), "\n")) + } + + return conn, nil +} + +func (d *SSHCmdDialer) buildSSHArgs() []string { + args := []string{} + + if d.port != "" { + args = append(args, "-p", d.port) + } + + // Standard arguments for libvirt connections + args = append(args, + "-T", // Disable pseudo-terminal allocation + "-o", "ControlPath=none", // Don't use multiplexing + "-e", "none", // Disable escape character + ) + + for _, keyFile := range d.keyFiles { + args = append(args, "-i", keyFile) + } + + if d.knownHostsFile != "" { + args = append(args, "-o", "UserKnownHostsFile="+d.knownHostsFile) + } + + if !d.strictHostCheck { + args = append(args, "-o", "StrictHostKeyChecking=no") + } + + if d.batchMode { + args = append(args, "-o", "BatchMode=yes") + } + + if d.forwardAgent { + args = append(args, "-o", "ForwardAgent=yes") + } + + for _, auth := range d.authMethods { + switch auth { + case "privkey": + args = append(args, "-o", "PreferredAuthentications=publickey") + case "password": + args = append(args, "-o", "PreferredAuthentications=password") + case "keyboard-interactive": + args = append(args, "-o", "PreferredAuthentications=keyboard-interactive") + } + } + + args = append(args, d.extraArgs...) + + // Build the destination string (user@host) + destination := d.hostname + if d.username != "" { + destination = d.username + "@" + destination + } + + // Use the remote URI that was constructed during initialization + var shellCmd string + + switch d.proxyMode { + case ProxyNative: + // Native mode uses virt-ssh-helper directly + shellCmd = fmt.Sprintf("sh -c 'virt-ssh-helper \"%s\"'", d.remoteURI) + log.Printf("[DEBUG] Using native virt-ssh-helper with URI: %s", d.remoteURI) + + case ProxyNetcat: + // Netcat mode - detect proper flags for netcat + //nolint:lll + shellCmd = fmt.Sprintf("sh -c 'if \"%s\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"%s\" $ARG -U %s'", + d.netcatBin, d.netcatBin, d.socket) + log.Printf("[DEBUG] Using netcat %s for socket connection to %s", d.netcatBin, d.socket) + + case ProxyAuto: + // Auto mode - try virt-ssh-helper first, then fall back to netcat + //nolint:lll + shellCmd = fmt.Sprintf("sh -c 'which virt-ssh-helper 1>/dev/null 2>&1; if test $? = 0; then virt-ssh-helper \"%s\"; else if \"%s\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"%s\" $ARG -U %s; fi'", + d.remoteURI, d.netcatBin, d.netcatBin, d.socket) + log.Printf("[DEBUG] Using auto proxy mode with URI: %s", d.remoteURI) + + default: + // This shouldn't happen, but use netcat as a safe fallback + shellCmd = fmt.Sprintf("sh -c '\"%s\" -U %s'", d.netcatBin, d.socket) + log.Printf("[WARN] Unknown proxy mode, falling back to netcat") + } + + args = append(args, "--", destination, shellCmd) + + return args +} + +// sshCmdConn implements net.Conn to communicate with the ssh process. +type sshCmdConn struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser + cancel context.CancelFunc + hostAndPort string + remoteSocket string + + lastStdErrLines *ring.Ring + stderrRingMu sync.Mutex +} + +func (c *sshCmdConn) Read(b []byte) (int, error) { + n, err := c.stdout.Read(b) + if err != nil { + log.Printf("[ERROR] ssh: read error: %s", err) + return n, fmt.Errorf("ssh: %s", strings.Join(c.lastStderrLines(), "\n")) + } + + return n, nil +} + +func (c *sshCmdConn) Write(b []byte) (int, error) { + n, err := c.stdin.Write(b) + if err != nil { + log.Printf("[ERROR] ssh: write error: %s", err) + return n, fmt.Errorf("ssh: %s", strings.Join(c.lastStderrLines(), "\n")) + } + + return n, nil +} + +func (c *sshCmdConn) Close() error { + c.cancel() + c.stdin.Close() + c.stdout.Close() + c.stderr.Close() + return nil +} + +func (c *sshCmdConn) lastStderrLines() []string { + c.stderrRingMu.Lock() + defer c.stderrRingMu.Unlock() + + var lines []string + c.lastStdErrLines.Do(func(el any) { + if el == nil { + return + } + lines = append(lines, el.(string)) + }) + + return lines +} + +func (c *sshCmdConn) appendStderrLine(line string) { + c.stderrRingMu.Lock() + defer c.stderrRingMu.Unlock() + + c.lastStdErrLines.Value = line + c.lastStdErrLines = c.lastStdErrLines.Next() +} + +func (c *sshCmdConn) LocalAddr() net.Addr { + return &net.UnixAddr{Name: "local", Net: "unix"} +} + +func (c *sshCmdConn) RemoteAddr() net.Addr { + return &net.UnixAddr{Name: c.remoteSocket, Net: "unix"} +} + +func (c *sshCmdConn) SetDeadline(t time.Time) error { + return fmt.Errorf("SetDeadline not implemented for SSH command connection") +} + +func (c *sshCmdConn) SetReadDeadline(t time.Time) error { + return fmt.Errorf("SetReadDeadline not implemented for SSH command connection") +} + +func (c *sshCmdConn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("SetWriteDeadline not implemented for SSH command connection") +} diff --git a/libvirt/dialers/sshcmd_test.go b/libvirt/dialers/sshcmd_test.go new file mode 100644 index 000000000..b88230e6e --- /dev/null +++ b/libvirt/dialers/sshcmd_test.go @@ -0,0 +1,409 @@ +package dialers + +import ( + "net/url" + "os" + "reflect" + "strings" + "testing" +) + +func TestNewSSHCmdDialer(t *testing.T) { + tests := []struct { + name string + uri string + expected *SSHCmdDialer + }{ + { + name: "basic ssh uri", + uri: "qemu+ssh://user@example.com/system", + expected: &SSHCmdDialer{ + hostname: "example.com", + port: "", + username: "user", + socket: defaultUnixSock, + remoteURI: "qemu:///system", + proxyMode: ProxyAuto, + netcatBin: defaultNetcatBin, + sshBin: defaultSSHBin, + strictHostCheck: true, + batchMode: true, + forwardAgent: false, + authMethods: []string{}, + keyFiles: []string{}, + }, + }, + { + name: "ssh uri with port", + uri: "qemu+ssh://user@example.com:2222/system", + expected: &SSHCmdDialer{ + hostname: "example.com", + port: "2222", + username: "user", + socket: defaultUnixSock, + remoteURI: "qemu:///system", + proxyMode: ProxyAuto, + netcatBin: defaultNetcatBin, + sshBin: defaultSSHBin, + strictHostCheck: true, + batchMode: true, + forwardAgent: false, + authMethods: []string{}, + keyFiles: []string{}, + }, + }, + { + name: "ssh uri with custom socket", + uri: "qemu+ssh://user@example.com/system?socket=/tmp/libvirt.sock", + expected: &SSHCmdDialer{ + hostname: "example.com", + port: "", + username: "user", + socket: "/tmp/libvirt.sock", + remoteURI: "qemu:///system", + proxyMode: ProxyAuto, + netcatBin: defaultNetcatBin, + sshBin: defaultSSHBin, + strictHostCheck: true, + batchMode: true, + forwardAgent: false, + authMethods: []string{}, + keyFiles: []string{}, + }, + }, + { + name: "ssh uri with proxy mode", + uri: "qemu+ssh://user@example.com/system?proxy=netcat", + expected: &SSHCmdDialer{ + hostname: "example.com", + port: "", + username: "user", + socket: defaultUnixSock, + remoteURI: "qemu:///system", + proxyMode: ProxyNetcat, + netcatBin: defaultNetcatBin, + sshBin: defaultSSHBin, + strictHostCheck: true, + batchMode: true, + forwardAgent: false, + authMethods: []string{}, + keyFiles: []string{}, + }, + }, + { + name: "ssh uri with keyfile", + uri: "qemu+ssh://user@example.com/system?keyfile=/home/user/.ssh/id_rsa", + expected: &SSHCmdDialer{ + hostname: "example.com", + port: "", + username: "user", + socket: defaultUnixSock, + remoteURI: "qemu:///system", + proxyMode: ProxyAuto, + netcatBin: defaultNetcatBin, + sshBin: defaultSSHBin, + strictHostCheck: true, + batchMode: true, + forwardAgent: false, + authMethods: []string{}, + keyFiles: []string{"/home/user/.ssh/id_rsa"}, + }, + }, + { + name: "ssh uri with known hosts settings", + uri: "qemu+ssh://user@example.com/system?knownhosts=/home/user/.ssh/known_hosts&known_hosts_verify=ignore", + expected: &SSHCmdDialer{ + hostname: "example.com", + port: "", + username: "user", + socket: defaultUnixSock, + remoteURI: "qemu:///system", + proxyMode: ProxyAuto, + netcatBin: defaultNetcatBin, + sshBin: defaultSSHBin, + strictHostCheck: false, + batchMode: true, + forwardAgent: false, + authMethods: []string{}, + keyFiles: []string{}, + knownHostsFile: "/home/user/.ssh/known_hosts", + }, + }, + { + name: "ssh uri with auth methods", + uri: "qemu+ssh://user@example.com/system?sshauth=agent,password", + expected: &SSHCmdDialer{ + hostname: "example.com", + port: "", + username: "user", + socket: defaultUnixSock, + remoteURI: "qemu:///system", + proxyMode: ProxyAuto, + netcatBin: defaultNetcatBin, + sshBin: defaultSSHBin, + strictHostCheck: true, + batchMode: false, // Disabled for password auth + forwardAgent: true, // Enabled for agent auth + authMethods: []string{"agent", "password"}, + keyFiles: []string{}, + }, + }, + } + + // Save original environment to restore later + origEnv := os.Environ() + defer func() { + os.Clearenv() + for _, env := range origEnv { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + os.Setenv(parts[0], parts[1]) + } + } + }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the URI + parsedURI, err := url.Parse(tt.uri) + if err != nil { + t.Fatalf("Failed to parse URI: %v", err) + } + + // Create the dialer + dialer := NewSSHCmdDialer(parsedURI) + + // Compare fields that we care about + if dialer.hostname != tt.expected.hostname { + t.Errorf("hostname = %v, want %v", dialer.hostname, tt.expected.hostname) + } + if dialer.port != tt.expected.port { + t.Errorf("port = %v, want %v", dialer.port, tt.expected.port) + } + if dialer.username != tt.expected.username { + t.Errorf("username = %v, want %v", dialer.username, tt.expected.username) + } + if dialer.socket != tt.expected.socket { + t.Errorf("socket = %v, want %v", dialer.socket, tt.expected.socket) + } + if dialer.remoteURI != tt.expected.remoteURI { + t.Errorf("remoteURI = %v, want %v", dialer.remoteURI, tt.expected.remoteURI) + } + if dialer.proxyMode != tt.expected.proxyMode { + t.Errorf("proxyMode = %v, want %v", dialer.proxyMode, tt.expected.proxyMode) + } + if dialer.netcatBin != tt.expected.netcatBin { + t.Errorf("netcatBin = %v, want %v", dialer.netcatBin, tt.expected.netcatBin) + } + if dialer.sshBin != tt.expected.sshBin { + t.Errorf("sshBin = %v, want %v", dialer.sshBin, tt.expected.sshBin) + } + if dialer.strictHostCheck != tt.expected.strictHostCheck { + t.Errorf("strictHostCheck = %v, want %v", dialer.strictHostCheck, tt.expected.strictHostCheck) + } + if dialer.batchMode != tt.expected.batchMode { + t.Errorf("batchMode = %v, want %v", dialer.batchMode, tt.expected.batchMode) + } + if dialer.forwardAgent != tt.expected.forwardAgent { + t.Errorf("forwardAgent = %v, want %v", dialer.forwardAgent, tt.expected.forwardAgent) + } + if !reflect.DeepEqual(dialer.authMethods, tt.expected.authMethods) { + t.Errorf("authMethods = %v, want %v", dialer.authMethods, tt.expected.authMethods) + } + if !reflect.DeepEqual(dialer.keyFiles, tt.expected.keyFiles) { + t.Errorf("keyFiles = %v, want %v", dialer.keyFiles, tt.expected.keyFiles) + } + }) + } +} + +func TestBuildSSHArgs(t *testing.T) { + tests := []struct { + name string + dialer *SSHCmdDialer + expected []string // Complete expected command arguments + }{ + { + name: "basic args", + dialer: &SSHCmdDialer{ + hostname: "example.com", + username: "user", + proxyMode: ProxyAuto, + remoteURI: "qemu:///system", + }, + expected: []string{ + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-o", "StrictHostKeyChecking=no", + "--", + "user@example.com", + "sh -c 'which virt-ssh-helper 1>/dev/null 2>&1; if test $? = 0; then virt-ssh-helper \"qemu:///system\"; else if \"\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"\" $ARG -U ; fi'", + }, + }, + { + name: "with port", + dialer: &SSHCmdDialer{ + hostname: "example.com", + port: "2222", + username: "user", + proxyMode: ProxyAuto, + remoteURI: "qemu:///system", + }, + expected: []string{ + "-p", "2222", + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-o", "StrictHostKeyChecking=no", + "--", + "user@example.com", + "sh -c 'which virt-ssh-helper 1>/dev/null 2>&1; if test $? = 0; then virt-ssh-helper \"qemu:///system\"; else if \"\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"\" $ARG -U ; fi'", + }, + }, + { + name: "with keyfile", + dialer: &SSHCmdDialer{ + hostname: "example.com", + username: "user", + proxyMode: ProxyAuto, + remoteURI: "qemu:///system", + keyFiles: []string{"/home/user/.ssh/id_rsa"}, + }, + expected: []string{ + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-i", "/home/user/.ssh/id_rsa", + "-o", "StrictHostKeyChecking=no", + "--", + "user@example.com", + "sh -c 'which virt-ssh-helper 1>/dev/null 2>&1; if test $? = 0; then virt-ssh-helper \"qemu:///system\"; else if \"\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"\" $ARG -U ; fi'", + }, + }, + { + name: "with known hosts settings", + dialer: &SSHCmdDialer{ + hostname: "example.com", + username: "user", + proxyMode: ProxyAuto, + remoteURI: "qemu:///system", + knownHostsFile: "/home/user/.ssh/known_hosts", + strictHostCheck: false, + }, + expected: []string{ + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-o", "UserKnownHostsFile=/home/user/.ssh/known_hosts", + "-o", "StrictHostKeyChecking=no", + "--", + "user@example.com", + "sh -c 'which virt-ssh-helper 1>/dev/null 2>&1; if test $? = 0; then virt-ssh-helper \"qemu:///system\"; else if \"\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"\" $ARG -U ; fi'", + }, + }, + { + name: "with batch mode disabled", + dialer: &SSHCmdDialer{ + hostname: "example.com", + username: "user", + proxyMode: ProxyAuto, + remoteURI: "qemu:///system", + batchMode: false, + }, + expected: []string{ + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-o", "StrictHostKeyChecking=no", + "--", + "user@example.com", + "sh -c 'which virt-ssh-helper 1>/dev/null 2>&1; if test $? = 0; then virt-ssh-helper \"qemu:///system\"; else if \"\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"\" $ARG -U ; fi'", + }, + }, + { + name: "with agent forwarding", + dialer: &SSHCmdDialer{ + hostname: "example.com", + username: "user", + proxyMode: ProxyAuto, + remoteURI: "qemu:///system", + forwardAgent: true, + }, + expected: []string{ + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-o", "StrictHostKeyChecking=no", + "-o", "ForwardAgent=yes", + "--", + "user@example.com", + "sh -c 'which virt-ssh-helper 1>/dev/null 2>&1; if test $? = 0; then virt-ssh-helper \"qemu:///system\"; else if \"\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"\" $ARG -U ; fi'", + }, + }, + { + name: "with netcat proxy mode", + dialer: &SSHCmdDialer{ + hostname: "example.com", + username: "user", + proxyMode: ProxyNetcat, + remoteURI: "qemu:///system", + socket: "/var/run/libvirt/libvirt-sock", + netcatBin: "nc", + }, + expected: []string{ + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-o", "StrictHostKeyChecking=no", + "--", + "user@example.com", + "sh -c 'if \"nc\" -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then ARG=-q0; else ARG=; fi; \"nc\" $ARG -U /var/run/libvirt/libvirt-sock'", + }, + }, + { + name: "with native proxy mode", + dialer: &SSHCmdDialer{ + hostname: "example.com", + username: "user", + proxyMode: ProxyNative, + remoteURI: "qemu:///system", + }, + expected: []string{ + "-T", + "-o", "ControlPath=none", + "-e", "none", + "-o", "StrictHostKeyChecking=no", + "--", + "user@example.com", + "sh -c 'virt-ssh-helper \"qemu:///system\"'", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := tt.dialer.buildSSHArgs() + + // Check if the number of arguments matches + if len(args) != len(tt.expected) { + t.Errorf("Expected %d arguments, got %d", len(tt.expected), len(args)) + t.Errorf("Expected: %v", tt.expected) + t.Errorf("Got: %v", args) + return + } + + // Check each argument + for i, arg := range tt.expected { + if i >= len(args) { + t.Errorf("Missing argument at position %d, expected '%s'", i, arg) + continue + } + + if args[i] != arg { + t.Errorf("Argument mismatch at position %d: expected '%s', got '%s'", i, arg, args[i]) + } + } + }) + } +} diff --git a/libvirt/uri/ssh.go b/libvirt/uri/ssh.go index 3c886c4fc..7cadd15da 100644 --- a/libvirt/uri/ssh.go +++ b/libvirt/uri/ssh.go @@ -118,7 +118,7 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi // construct the whole ssh connection, which can consist of multiple hops if using proxy jumps, // the ssh configuration file is loaded once and passed along to each host connection. func (u *ConnectionURI) dialSSH() (net.Conn, error) { - var sshcfg* ssh_config.Config = nil + var sshcfg *ssh_config.Config = nil sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile)) if err != nil { @@ -307,7 +307,6 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth } } - cfg.Auth = u.parseAuthMethods(target, sshcfg) if len(cfg.Auth) < 1 { return nil, fmt.Errorf("could not configure SSH authentication methods") diff --git a/website/docs/index.html.markdown b/website/docs/index.html.markdown index e9287a805..b76d6e5f3 100644 --- a/website/docs/index.html.markdown +++ b/website/docs/index.html.markdown @@ -20,18 +20,24 @@ working and haven't been tested. ## The connection URI -The provider understands [connection URIs](https://libvirt.org/uri.html). The supported transports are: +The provider understands [connection URIs](https://libvirt.org/uri.html). + +As the provider does not use libvirt on the client side, not all connection URI options are supported or apply. + +The supported transports are: * `tcp` (non-encrypted connection) * `unix` (UNIX domain socket) * `tls` (See [here](https://libvirt.org/kbase/tlscerts.html) for information how to setup certificates) * `ssh` (Secure shell) -Unlike the original libvirt, the `ssh` transport is not implemented using the ssh command and therefore does not require `nc` (netcat) on the server side. +### SSH + +The `ssh` transport is implemented using the `golang.org/x/crypto/ssh` library, and therefore it does not honour the `ssh` command configuration. This `ssh` transport is not implemented using the ssh command and therefore does not require it, nor `nc` (netcat) on the server side. Additionally, the `ssh` URI supports passwords using the `driver+ssh://[username:PASSWORD@][hostname][:port]/[path]?sshauth=ssh-password` syntax. -As the provider does not use libvirt on the client side, not all connection URI options are supported or apply. +An experimental `ssh` transport using the `ssh` command can be enabled with the `use_ssh_cmd=1` parameter. Example: `qemu+ssh://user@localhost/system?no_verify=1&use_ssh_cmd=1`. This transport honours the local `ssh` configuration. It requies `nc` (netcat) on the server side and emulates the options of the original libvirt [ssh transport](https://libvirt.org/uri.html#ssh-transport). ## Example Usage