diff --git a/client/ssh.go b/client/ssh.go index f968808..1512c0b 100644 --- a/client/ssh.go +++ b/client/ssh.go @@ -13,9 +13,12 @@ import ( "sync" "time" + "os" + "al.essio.dev/pkg/shellescape" truenas "github.com/deevus/truenas-go" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) // ansiRegex matches ANSI escape sequences. @@ -28,7 +31,10 @@ type SSHConfig struct { User string PrivateKey string HostKeyFingerprint string - MaxSessions int // Maximum concurrent SSH sessions (0 = default of 5) + MaxSessions int // Maximum concurrent SSH sessions (0 = default of 5) + UseAgent bool // Use SSH agent instead of private key + AgentSocket string // Path to agent socket (defaults to SSH_AUTH_SOCK) + NoSudo bool // Skip sudo prefix on commands (default false = use sudo) } // Validate validates the SSHConfig and sets defaults. @@ -36,9 +42,25 @@ func (c *SSHConfig) Validate() error { if c.Host == "" { return errors.New("host is required") } - if c.PrivateKey == "" { - return errors.New("private_key is required") + + if c.UseAgent && c.PrivateKey != "" { + return errors.New("use_agent and private_key are mutually exclusive") + } + + if c.UseAgent { + // Resolve agent socket + if c.AgentSocket == "" { + c.AgentSocket = os.Getenv("SSH_AUTH_SOCK") + } + if c.AgentSocket == "" { + return errors.New("agent_socket is required when use_agent is true (SSH_AUTH_SOCK not set)") + } + } else { + if c.PrivateKey == "" { + return errors.New("private_key is required") + } } + if c.HostKeyFingerprint == "" { return errors.New("host_key_fingerprint is required") } @@ -175,6 +197,14 @@ func (c *SSHClient) acquireSession() func() { } } +// sudoPrefix returns "sudo " or "" depending on the NoSudo config. +func (c *SSHClient) sudoPrefix() string { + if c.config.NoSudo { + return "" + } + return "sudo " +} + // connect establishes the SSH connection if not already connected. func (c *SSHClient) connect() error { c.mu.Lock() @@ -185,16 +215,26 @@ func (c *SSHClient) connect() error { return nil } - signer, err := parsePrivateKey(c.config.PrivateKey) - if err != nil { - return err + var authMethods []ssh.AuthMethod + + if c.config.UseAgent { + conn, err := net.Dial("unix", c.config.AgentSocket) + if err != nil { + return fmt.Errorf("failed to connect to SSH agent at %q: %w", c.config.AgentSocket, err) + } + agentClient := agent.NewClient(conn) + authMethods = append(authMethods, ssh.PublicKeysCallback(agentClient.Signers)) + } else { + signer, err := parsePrivateKey(c.config.PrivateKey) + if err != nil { + return err + } + authMethods = append(authMethods, ssh.PublicKeys(signer)) } sshConfig := &ssh.ClientConfig{ - User: c.config.User, - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, + User: c.config.User, + Auth: authMethods, HostKeyCallback: verifyHostKey(c.config.HostKeyFingerprint), } @@ -253,8 +293,8 @@ func (c *SSHClient) Call(ctx context.Context, method string, params any) (json.R } } - // Build command (use sudo for non-root users with sudo access) - cmd := fmt.Sprintf("sudo midclt call %s", method) + // Build command + cmd := fmt.Sprintf("%smidclt call %s", c.sudoPrefix(), method) paramsStr, err := serializeParams(params) if err != nil { return nil, err @@ -326,7 +366,7 @@ func (c *SSHClient) callAndWaitWithFlag(ctx context.Context, method string, para } // Build command with -j flag for job waiting - cmd := fmt.Sprintf("sudo midclt call -j %s", method) + cmd := fmt.Sprintf("%smidclt call -j %s", c.sudoPrefix(), method) paramsStr, err := serializeParams(params) if err != nil { return nil, err @@ -563,7 +603,7 @@ func (c *SSHClient) ReadFile(ctx context.Context, path string) ([]byte, error) { return output, nil } -// runSudo executes a command with sudo via SSH. +// runSudo executes a command via SSH, optionally with sudo prefix. func (c *SSHClient) runSudo(ctx context.Context, args ...string) error { // Acquire session slot (blocks if at limit) release := c.acquireSession() @@ -581,7 +621,7 @@ func (c *SSHClient) runSudo(ctx context.Context, args ...string) error { for _, arg := range args { escaped = append(escaped, shellescape.Quote(arg)) } - cmd := "sudo " + strings.Join(escaped, " ") + cmd := c.sudoPrefix() + strings.Join(escaped, " ") // Create session session, err := c.clientWrapper.NewSession() @@ -601,7 +641,7 @@ func (c *SSHClient) runSudo(ctx context.Context, args ...string) error { return nil } -// runSudoOutput executes a command with sudo via SSH and returns stdout. +// runSudoOutput executes a command via SSH, optionally with sudo prefix, and returns stdout. func (c *SSHClient) runSudoOutput(ctx context.Context, args ...string) ([]byte, error) { release := c.acquireSession() defer release() @@ -616,7 +656,7 @@ func (c *SSHClient) runSudoOutput(ctx context.Context, args ...string) ([]byte, for _, arg := range args { escaped = append(escaped, shellescape.Quote(arg)) } - cmd := "sudo " + strings.Join(escaped, " ") + cmd := c.sudoPrefix() + strings.Join(escaped, " ") session, err := c.clientWrapper.NewSession() if err != nil { diff --git a/client/ssh_test.go b/client/ssh_test.go index 3fe81a6..5e7e056 100644 --- a/client/ssh_test.go +++ b/client/ssh_test.go @@ -1612,4 +1612,206 @@ func TestSSHClient_ReadFile_RespectsSemaphore(t *testing.T) { if atomic.LoadInt32(&maxActive) > 2 { t.Errorf("max concurrent sessions = %d, want <= 2", maxActive) } +} + +// --- SSH Agent Auth and NoSudo tests --- + +func TestSSHConfig_Validate_AgentMode(t *testing.T) { + config := &SSHConfig{ + Host: "truenas.local", + HostKeyFingerprint: testHostKeyFingerprint, + UseAgent: true, + AgentSocket: "/tmp/test-agent.sock", + } + + err := config.Validate() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if config.AgentSocket != "/tmp/test-agent.sock" { + t.Errorf("expected agent socket %q, got %q", "/tmp/test-agent.sock", config.AgentSocket) + } +} + +func TestSSHConfig_Validate_AgentAndPrivateKeyMutuallyExclusive(t *testing.T) { + config := &SSHConfig{ + Host: "truenas.local", + HostKeyFingerprint: testHostKeyFingerprint, + UseAgent: true, + AgentSocket: "/tmp/test-agent.sock", + PrivateKey: testPrivateKey, + } + + err := config.Validate() + if err == nil { + t.Fatal("expected error for mutually exclusive agent and private key") + } + + if !strings.Contains(err.Error(), "mutually exclusive") { + t.Errorf("expected error about mutual exclusivity, got %q", err.Error()) + } +} + +func TestSSHConfig_Validate_AgentWithEnvSocket(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock") + + config := &SSHConfig{ + Host: "truenas.local", + HostKeyFingerprint: testHostKeyFingerprint, + UseAgent: true, + } + + err := config.Validate() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if config.AgentSocket != "/tmp/env-agent.sock" { + t.Errorf("expected agent socket from env %q, got %q", "/tmp/env-agent.sock", config.AgentSocket) + } +} + +func TestSSHConfig_Validate_AgentNoSocketAvailable(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "") + + config := &SSHConfig{ + Host: "truenas.local", + HostKeyFingerprint: testHostKeyFingerprint, + UseAgent: true, + } + + err := config.Validate() + if err == nil { + t.Fatal("expected error when no agent socket available") + } + + if !strings.Contains(err.Error(), "SSH_AUTH_SOCK") { + t.Errorf("expected error mentioning SSH_AUTH_SOCK, got %q", err.Error()) + } +} + +func TestSSHConfig_Validate_PrivateKeyRequiredWithoutAgent(t *testing.T) { + config := &SSHConfig{ + Host: "truenas.local", + HostKeyFingerprint: testHostKeyFingerprint, + UseAgent: false, + PrivateKey: "", + } + + err := config.Validate() + if err == nil { + t.Fatal("expected error for missing private key without agent") + } + + if err.Error() != "private_key is required" { + t.Errorf("expected 'private_key is required', got %q", err.Error()) + } +} + +func TestSSHClient_SudoPrefix_Default(t *testing.T) { + client := &SSHClient{ + config: &SSHConfig{NoSudo: false}, + } + + prefix := client.sudoPrefix() + if prefix != "sudo " { + t.Errorf("expected %q, got %q", "sudo ", prefix) + } +} + +func TestSSHClient_SudoPrefix_NoSudo(t *testing.T) { + client := &SSHClient{ + config: &SSHConfig{NoSudo: true}, + } + + prefix := client.sudoPrefix() + if prefix != "" { + t.Errorf("expected empty prefix, got %q", prefix) + } +} + +func TestSSHClient_Call_NoSudo(t *testing.T) { + config := &SSHConfig{ + Host: "truenas.local", + PrivateKey: testPrivateKey, + HostKeyFingerprint: testHostKeyFingerprint, + NoSudo: true, + } + + client, err := NewSSHClient(config) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + mockSess := &mockSession{ + combinedOutputFunc: func(cmd string) ([]byte, error) { + expected := `midclt call system.info` + if cmd != expected { + t.Errorf("expected command %q, got %q", expected, cmd) + } + return []byte(`{"version": "24.04"}`), nil + }, + } + + mockClient := &mockSSHClient{ + newSessionFunc: func() (sshSession, error) { + return mockSess, nil + }, + } + + client.clientWrapper = mockClient + + result, err := client.Call(context.Background(), "system.info", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if string(result) != `{"version": "24.04"}` { + t.Errorf("expected result, got %s", result) + } +} + +func TestSSHClient_CallAndWait_NoSudo(t *testing.T) { + config := &SSHConfig{ + Host: "truenas.local", + PrivateKey: testPrivateKey, + HostKeyFingerprint: testHostKeyFingerprint, + NoSudo: true, + } + + client, err := NewSSHClient(config) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Set version to 25.x so callAndWaitWithFlag is used + client.version = truenas.Version{Major: 25, Minor: 10} + client.connected = true + + mockSess := &mockSession{ + combinedOutputFunc: func(cmd string) ([]byte, error) { + expected := `midclt call -j app.create '{"app_name":"test","custom_app":true}'` + if cmd != expected { + t.Errorf("expected command %q, got %q", expected, cmd) + } + return []byte(`null`), nil + }, + } + + mockClient := &mockSSHClient{ + newSessionFunc: func() (sshSession, error) { + return mockSess, nil + }, + } + + client.clientWrapper = mockClient + + _, err = client.CallAndWait(context.Background(), "app.create", map[string]any{ + "app_name": "test", + "custom_app": true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } } \ No newline at end of file