Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 57 additions & 17 deletions client/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (
"sync"
"time"

"os"

"al.essio.dev/pkg/shellescape"
truenas "github.com/deevus/truenas-go"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

// ansiRegex matches ANSI escape sequences.
Expand All @@ -28,17 +31,36 @@ type SSHConfig struct {
User string
PrivateKey string
HostKeyFingerprint string
MaxSessions int // Maximum concurrent SSH sessions (0 = default of 5)
MaxSessions int // Maximum concurrent SSH sessions (0 = default of 5)
UseAgent bool // Use SSH agent instead of private key
AgentSocket string // Path to agent socket (defaults to SSH_AUTH_SOCK)
NoSudo bool // Skip sudo prefix on commands (default false = use sudo)
}

// Validate validates the SSHConfig and sets defaults.
func (c *SSHConfig) Validate() error {
if c.Host == "" {
return errors.New("host is required")
}
if c.PrivateKey == "" {
return errors.New("private_key is required")

if c.UseAgent && c.PrivateKey != "" {
return errors.New("use_agent and private_key are mutually exclusive")
}

if c.UseAgent {
// Resolve agent socket
if c.AgentSocket == "" {
c.AgentSocket = os.Getenv("SSH_AUTH_SOCK")
}
if c.AgentSocket == "" {
return errors.New("agent_socket is required when use_agent is true (SSH_AUTH_SOCK not set)")
}
} else {
if c.PrivateKey == "" {
return errors.New("private_key is required")
}
}

if c.HostKeyFingerprint == "" {
return errors.New("host_key_fingerprint is required")
}
Expand Down Expand Up @@ -175,6 +197,14 @@ func (c *SSHClient) acquireSession() func() {
}
}

// sudoPrefix returns "sudo " or "" depending on the NoSudo config.
func (c *SSHClient) sudoPrefix() string {
if c.config.NoSudo {
return ""
}
return "sudo "
}

// connect establishes the SSH connection if not already connected.
func (c *SSHClient) connect() error {
c.mu.Lock()
Expand All @@ -185,16 +215,26 @@ func (c *SSHClient) connect() error {
return nil
}

signer, err := parsePrivateKey(c.config.PrivateKey)
if err != nil {
return err
var authMethods []ssh.AuthMethod

if c.config.UseAgent {
conn, err := net.Dial("unix", c.config.AgentSocket)
if err != nil {
return fmt.Errorf("failed to connect to SSH agent at %q: %w", c.config.AgentSocket, err)
}
agentClient := agent.NewClient(conn)
authMethods = append(authMethods, ssh.PublicKeysCallback(agentClient.Signers))
} else {
signer, err := parsePrivateKey(c.config.PrivateKey)
if err != nil {
return err
}
authMethods = append(authMethods, ssh.PublicKeys(signer))
}

sshConfig := &ssh.ClientConfig{
User: c.config.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
User: c.config.User,
Auth: authMethods,
HostKeyCallback: verifyHostKey(c.config.HostKeyFingerprint),
}

Expand Down Expand Up @@ -253,8 +293,8 @@ func (c *SSHClient) Call(ctx context.Context, method string, params any) (json.R
}
}

// Build command (use sudo for non-root users with sudo access)
cmd := fmt.Sprintf("sudo midclt call %s", method)
// Build command
cmd := fmt.Sprintf("%smidclt call %s", c.sudoPrefix(), method)
paramsStr, err := serializeParams(params)
if err != nil {
return nil, err
Expand Down Expand Up @@ -326,7 +366,7 @@ func (c *SSHClient) callAndWaitWithFlag(ctx context.Context, method string, para
}

// Build command with -j flag for job waiting
cmd := fmt.Sprintf("sudo midclt call -j %s", method)
cmd := fmt.Sprintf("%smidclt call -j %s", c.sudoPrefix(), method)
paramsStr, err := serializeParams(params)
if err != nil {
return nil, err
Expand Down Expand Up @@ -563,7 +603,7 @@ func (c *SSHClient) ReadFile(ctx context.Context, path string) ([]byte, error) {
return output, nil
}

// runSudo executes a command with sudo via SSH.
// runSudo executes a command via SSH, optionally with sudo prefix.
func (c *SSHClient) runSudo(ctx context.Context, args ...string) error {
// Acquire session slot (blocks if at limit)
release := c.acquireSession()
Expand All @@ -581,7 +621,7 @@ func (c *SSHClient) runSudo(ctx context.Context, args ...string) error {
for _, arg := range args {
escaped = append(escaped, shellescape.Quote(arg))
}
cmd := "sudo " + strings.Join(escaped, " ")
cmd := c.sudoPrefix() + strings.Join(escaped, " ")

// Create session
session, err := c.clientWrapper.NewSession()
Expand All @@ -601,7 +641,7 @@ func (c *SSHClient) runSudo(ctx context.Context, args ...string) error {
return nil
}

// runSudoOutput executes a command with sudo via SSH and returns stdout.
// runSudoOutput executes a command via SSH, optionally with sudo prefix, and returns stdout.
func (c *SSHClient) runSudoOutput(ctx context.Context, args ...string) ([]byte, error) {
release := c.acquireSession()
defer release()
Expand All @@ -616,7 +656,7 @@ func (c *SSHClient) runSudoOutput(ctx context.Context, args ...string) ([]byte,
for _, arg := range args {
escaped = append(escaped, shellescape.Quote(arg))
}
cmd := "sudo " + strings.Join(escaped, " ")
cmd := c.sudoPrefix() + strings.Join(escaped, " ")

session, err := c.clientWrapper.NewSession()
if err != nil {
Expand Down
202 changes: 202 additions & 0 deletions client/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1612,4 +1612,206 @@ func TestSSHClient_ReadFile_RespectsSemaphore(t *testing.T) {
if atomic.LoadInt32(&maxActive) > 2 {
t.Errorf("max concurrent sessions = %d, want <= 2", maxActive)
}
}

// --- SSH Agent Auth and NoSudo tests ---

func TestSSHConfig_Validate_AgentMode(t *testing.T) {
config := &SSHConfig{
Host: "truenas.local",
HostKeyFingerprint: testHostKeyFingerprint,
UseAgent: true,
AgentSocket: "/tmp/test-agent.sock",
}

err := config.Validate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if config.AgentSocket != "/tmp/test-agent.sock" {
t.Errorf("expected agent socket %q, got %q", "/tmp/test-agent.sock", config.AgentSocket)
}
}

func TestSSHConfig_Validate_AgentAndPrivateKeyMutuallyExclusive(t *testing.T) {
config := &SSHConfig{
Host: "truenas.local",
HostKeyFingerprint: testHostKeyFingerprint,
UseAgent: true,
AgentSocket: "/tmp/test-agent.sock",
PrivateKey: testPrivateKey,
}

err := config.Validate()
if err == nil {
t.Fatal("expected error for mutually exclusive agent and private key")
}

if !strings.Contains(err.Error(), "mutually exclusive") {
t.Errorf("expected error about mutual exclusivity, got %q", err.Error())
}
}

func TestSSHConfig_Validate_AgentWithEnvSocket(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")

config := &SSHConfig{
Host: "truenas.local",
HostKeyFingerprint: testHostKeyFingerprint,
UseAgent: true,
}

err := config.Validate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if config.AgentSocket != "/tmp/env-agent.sock" {
t.Errorf("expected agent socket from env %q, got %q", "/tmp/env-agent.sock", config.AgentSocket)
}
}

func TestSSHConfig_Validate_AgentNoSocketAvailable(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")

config := &SSHConfig{
Host: "truenas.local",
HostKeyFingerprint: testHostKeyFingerprint,
UseAgent: true,
}

err := config.Validate()
if err == nil {
t.Fatal("expected error when no agent socket available")
}

if !strings.Contains(err.Error(), "SSH_AUTH_SOCK") {
t.Errorf("expected error mentioning SSH_AUTH_SOCK, got %q", err.Error())
}
}

func TestSSHConfig_Validate_PrivateKeyRequiredWithoutAgent(t *testing.T) {
config := &SSHConfig{
Host: "truenas.local",
HostKeyFingerprint: testHostKeyFingerprint,
UseAgent: false,
PrivateKey: "",
}

err := config.Validate()
if err == nil {
t.Fatal("expected error for missing private key without agent")
}

if err.Error() != "private_key is required" {
t.Errorf("expected 'private_key is required', got %q", err.Error())
}
}

func TestSSHClient_SudoPrefix_Default(t *testing.T) {
client := &SSHClient{
config: &SSHConfig{NoSudo: false},
}

prefix := client.sudoPrefix()
if prefix != "sudo " {
t.Errorf("expected %q, got %q", "sudo ", prefix)
}
}

func TestSSHClient_SudoPrefix_NoSudo(t *testing.T) {
client := &SSHClient{
config: &SSHConfig{NoSudo: true},
}

prefix := client.sudoPrefix()
if prefix != "" {
t.Errorf("expected empty prefix, got %q", prefix)
}
}

func TestSSHClient_Call_NoSudo(t *testing.T) {
config := &SSHConfig{
Host: "truenas.local",
PrivateKey: testPrivateKey,
HostKeyFingerprint: testHostKeyFingerprint,
NoSudo: true,
}

client, err := NewSSHClient(config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

mockSess := &mockSession{
combinedOutputFunc: func(cmd string) ([]byte, error) {
expected := `midclt call system.info`
if cmd != expected {
t.Errorf("expected command %q, got %q", expected, cmd)
}
return []byte(`{"version": "24.04"}`), nil
},
}

mockClient := &mockSSHClient{
newSessionFunc: func() (sshSession, error) {
return mockSess, nil
},
}

client.clientWrapper = mockClient

result, err := client.Call(context.Background(), "system.info", nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if string(result) != `{"version": "24.04"}` {
t.Errorf("expected result, got %s", result)
}
}

func TestSSHClient_CallAndWait_NoSudo(t *testing.T) {
config := &SSHConfig{
Host: "truenas.local",
PrivateKey: testPrivateKey,
HostKeyFingerprint: testHostKeyFingerprint,
NoSudo: true,
}

client, err := NewSSHClient(config)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Set version to 25.x so callAndWaitWithFlag is used
client.version = truenas.Version{Major: 25, Minor: 10}
client.connected = true

mockSess := &mockSession{
combinedOutputFunc: func(cmd string) ([]byte, error) {
expected := `midclt call -j app.create '{"app_name":"test","custom_app":true}'`
if cmd != expected {
t.Errorf("expected command %q, got %q", expected, cmd)
}
return []byte(`null`), nil
},
}

mockClient := &mockSSHClient{
newSessionFunc: func() (sshSession, error) {
return mockSess, nil
},
}

client.clientWrapper = mockClient

_, err = client.CallAndWait(context.Background(), "app.create", map[string]any{
"app_name": "test",
"custom_app": true,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}