Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions internal/client/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io/fs"
"net"
"os"
"regexp"
"strings"
"sync"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/deevus/terraform-provider-truenas/internal/api"
"github.com/hashicorp/terraform-plugin-log/tflog"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

// ansiRegex matches ANSI escape sequences.
Expand All @@ -28,6 +30,9 @@ type SSHConfig struct {
Port int
User string
PrivateKey string
UseAgent bool // Use SSH agent (SSH_AUTH_SOCK) instead of private_key
AgentSocket string // Override SSH_AUTH_SOCK path (optional, defaults to env)
UseSudo *bool // Prefix commands with sudo (defaults to true)
HostKeyFingerprint string
MaxSessions int // Maximum concurrent SSH sessions (0 = default of 5)
}
Expand All @@ -37,8 +42,11 @@ func (c *SSHConfig) Validate() error {
if c.Host == "" {
return errors.New("host is required")
}
if c.PrivateKey == "" {
return errors.New("private_key is required")
if !c.UseAgent && c.PrivateKey == "" {
return errors.New("private_key is required when use_agent is false")
}
if c.UseAgent && c.PrivateKey != "" {
return errors.New("private_key and use_agent are mutually exclusive")
}
if c.HostKeyFingerprint == "" {
return errors.New("host_key_fingerprint is required")
Expand All @@ -51,6 +59,12 @@ func (c *SSHConfig) Validate() error {
if c.User == "" {
c.User = "root"
}
if c.UseAgent && c.AgentSocket == "" {
c.AgentSocket = os.Getenv("SSH_AUTH_SOCK")
if c.AgentSocket == "" {
return errors.New("SSH_AUTH_SOCK is not set and agent_socket was not provided")
}
}

return nil
}
Expand Down Expand Up @@ -149,6 +163,14 @@ func NewSSHClient(config *SSHConfig) (*SSHClient, error) {
}, nil
}

// sudoPrefix returns "sudo " if use_sudo is enabled, or "" if disabled.
func (c *SSHClient) sudoPrefix() string {
if c.config.UseSudo != nil && !*c.config.UseSudo {
return ""
}
return "sudo "
}

// acquireSession blocks until a session slot is available and returns a release function.
func (c *SSHClient) acquireSession() func() {
c.sessionSem <- struct{}{}
Expand All @@ -167,16 +189,26 @@ func (c *SSHClient) connect() error {
return nil
}

signer, err := parsePrivateKey(c.config.PrivateKey)
if err != nil {
return err
var authMethods []ssh.AuthMethod

if c.config.UseAgent {
conn, err := net.Dial("unix", c.config.AgentSocket)
if err != nil {
return fmt.Errorf("failed to connect to SSH agent at %s: %w", c.config.AgentSocket, err)
}
agentClient := agent.NewClient(conn)
authMethods = []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}
} else {
signer, err := parsePrivateKey(c.config.PrivateKey)
if err != nil {
return err
}
authMethods = []ssh.AuthMethod{ssh.PublicKeys(signer)}
}

sshConfig := &ssh.ClientConfig{
User: c.config.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
User: c.config.User,
Auth: authMethods,
HostKeyCallback: verifyHostKey(c.config.HostKeyFingerprint),
}

Expand Down Expand Up @@ -236,7 +268,7 @@ func (c *SSHClient) Call(ctx context.Context, method string, params any) (json.R
}

// Build command (use sudo for non-root users with sudo access)
cmd := fmt.Sprintf("sudo midclt call %s", method)
cmd := fmt.Sprintf("%smidclt call %s", c.sudoPrefix(), method)
paramsStr, err := serializeParams(params)
if err != nil {
return nil, err
Expand Down Expand Up @@ -308,7 +340,7 @@ func (c *SSHClient) callAndWaitWithFlag(ctx context.Context, method string, para
}

// Build command with -j flag for job waiting
cmd := fmt.Sprintf("sudo midclt call -j %s", method)
cmd := fmt.Sprintf("%smidclt call -j %s", c.sudoPrefix(), method)
paramsStr, err := serializeParams(params)
if err != nil {
return nil, err
Expand Down Expand Up @@ -563,7 +595,7 @@ func (c *SSHClient) runSudo(ctx context.Context, args ...string) error {
for _, arg := range args {
escaped = append(escaped, shellescape.Quote(arg))
}
cmd := "sudo " + strings.Join(escaped, " ")
cmd := c.sudoPrefix() + strings.Join(escaped, " ")

// Create session
session, err := c.clientWrapper.NewSession()
Expand Down Expand Up @@ -598,7 +630,7 @@ func (c *SSHClient) runSudoOutput(ctx context.Context, args ...string) ([]byte,
for _, arg := range args {
escaped = append(escaped, shellescape.Quote(arg))
}
cmd := "sudo " + strings.Join(escaped, " ")
cmd := c.sudoPrefix() + strings.Join(escaped, " ")

session, err := c.clientWrapper.NewSession()
if err != nil {
Expand Down
78 changes: 76 additions & 2 deletions internal/client/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,82 @@ func TestSSHConfig_Validate_MissingPrivateKey(t *testing.T) {
t.Fatal("expected error for missing private key")
}

if err.Error() != "private_key is required" {
t.Errorf("expected 'private_key is required', got %q", err.Error())
if err.Error() != "private_key is required when use_agent is false" {
t.Errorf("expected 'private_key is required when use_agent is false', got %q", err.Error())
}
}

func TestSSHConfig_Validate_AgentAndPrivateKeyMutuallyExclusive(t *testing.T) {
config := &SSHConfig{
Host: "truenas.local",
PrivateKey: testPrivateKey,
UseAgent: true,
HostKeyFingerprint: "SHA256:test",
}

err := config.Validate()
if err == nil {
t.Fatal("expected error for both private_key and use_agent set")
}

if err.Error() != "private_key and use_agent are mutually exclusive" {
t.Errorf("expected mutual exclusivity error, got %q", err.Error())
}
}

func TestSSHConfig_Validate_AgentWithSocket(t *testing.T) {
config := &SSHConfig{
Host: "truenas.local",
UseAgent: true,
AgentSocket: "/tmp/test-agent.sock",
HostKeyFingerprint: "SHA256:test",
}

err := config.Validate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if config.AgentSocket != "/tmp/test-agent.sock" {
t.Errorf("expected agent socket /tmp/test-agent.sock, got %q", config.AgentSocket)
}
}

func TestSSHConfig_Validate_AgentWithoutSocket(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "/tmp/env-agent.sock")

config := &SSHConfig{
Host: "truenas.local",
UseAgent: true,
HostKeyFingerprint: "SHA256:test",
}

err := config.Validate()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if config.AgentSocket != "/tmp/env-agent.sock" {
t.Errorf("expected agent socket from env, got %q", config.AgentSocket)
}
}

func TestSSHConfig_Validate_AgentNoSocketAvailable(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")

config := &SSHConfig{
Host: "truenas.local",
UseAgent: true,
HostKeyFingerprint: "SHA256:test",
}

err := config.Validate()
if err == nil {
t.Fatal("expected error when no agent socket available")
}

if err.Error() != "SSH_AUTH_SOCK is not set and agent_socket was not provided" {
t.Errorf("unexpected error: %q", err.Error())
}
}

Expand Down
80 changes: 48 additions & 32 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ type SSHBlockModel struct {
Port types.Int64 `tfsdk:"port"`
User types.String `tfsdk:"user"`
PrivateKey types.String `tfsdk:"private_key"`
UseAgent types.Bool `tfsdk:"use_agent"`
AgentSocket types.String `tfsdk:"agent_socket"`
UseSudo types.Bool `tfsdk:"use_sudo"`
HostKeyFingerprint types.String `tfsdk:"host_key_fingerprint"`
MaxSessions types.Int64 `tfsdk:"max_sessions"`
}
Expand Down Expand Up @@ -101,10 +104,22 @@ func (p *TrueNASProvider) Schema(ctx context.Context, req provider.SchemaRequest
Optional: true,
},
"private_key": schema.StringAttribute{
Description: "SSH private key content.",
Required: true,
Description: "SSH private key content. Required unless use_agent is true.",
Optional: true,
Sensitive: true,
},
"use_agent": schema.BoolAttribute{
Description: "Use SSH agent (SSH_AUTH_SOCK) for authentication instead of private_key.",
Optional: true,
},
"agent_socket": schema.StringAttribute{
Description: "Path to SSH agent socket. Defaults to SSH_AUTH_SOCK environment variable.",
Optional: true,
},
"use_sudo": schema.BoolAttribute{
Description: "Prefix commands with sudo. Defaults to true. Set to false when the SSH user can run midclt directly.",
Optional: true,
},
"host_key_fingerprint": schema.StringAttribute{
Description: "SHA256 fingerprint of the TrueNAS server's SSH host key. " +
"Get it with: ssh-keyscan <host> 2>/dev/null | ssh-keygen -lf -",
Expand Down Expand Up @@ -156,6 +171,35 @@ func (p *TrueNASProvider) Schema(ctx context.Context, req provider.SchemaRequest
}
}

func buildSSHConfig(host string, ssh *SSHBlockModel) *client.SSHConfig {
cfg := &client.SSHConfig{
Host: host,
HostKeyFingerprint: ssh.HostKeyFingerprint.ValueString(),
}
if !ssh.UseAgent.IsNull() && ssh.UseAgent.ValueBool() {
cfg.UseAgent = true
if !ssh.AgentSocket.IsNull() {
cfg.AgentSocket = ssh.AgentSocket.ValueString()
}
} else {
cfg.PrivateKey = ssh.PrivateKey.ValueString()
}
if !ssh.Port.IsNull() {
cfg.Port = int(ssh.Port.ValueInt64())
}
if !ssh.User.IsNull() {
cfg.User = ssh.User.ValueString()
}
if !ssh.MaxSessions.IsNull() {
cfg.MaxSessions = int(ssh.MaxSessions.ValueInt64())
}
if !ssh.UseSudo.IsNull() {
v := ssh.UseSudo.ValueBool()
cfg.UseSudo = &v
}
return cfg
}

func (p *TrueNASProvider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
var config TrueNASProviderModel

Expand Down Expand Up @@ -210,20 +254,7 @@ func (p *TrueNASProvider) Configure(ctx context.Context, req provider.ConfigureR
}

// Create SSH client for fallback
sshConfig := &client.SSHConfig{
Host: config.Host.ValueString(),
PrivateKey: config.SSH.PrivateKey.ValueString(),
HostKeyFingerprint: config.SSH.HostKeyFingerprint.ValueString(),
}
if !config.SSH.Port.IsNull() {
sshConfig.Port = int(config.SSH.Port.ValueInt64())
}
if !config.SSH.User.IsNull() {
sshConfig.User = config.SSH.User.ValueString()
}
if !config.SSH.MaxSessions.IsNull() {
sshConfig.MaxSessions = int(config.SSH.MaxSessions.ValueInt64())
}
sshConfig := buildSSHConfig(config.Host.ValueString(), config.SSH)

sshClient, err := factory.NewSSHClient(sshConfig)
if err != nil {
Expand Down Expand Up @@ -307,22 +338,7 @@ func (p *TrueNASProvider) Configure(ctx context.Context, req provider.ConfigureR
}

// Build SSH config with values from provider configuration
sshConfig := &client.SSHConfig{
Host: config.Host.ValueString(),
PrivateKey: config.SSH.PrivateKey.ValueString(),
HostKeyFingerprint: config.SSH.HostKeyFingerprint.ValueString(),
}

// Set optional values if provided
if !config.SSH.Port.IsNull() {
sshConfig.Port = int(config.SSH.Port.ValueInt64())
}
if !config.SSH.User.IsNull() {
sshConfig.User = config.SSH.User.ValueString()
}
if !config.SSH.MaxSessions.IsNull() {
sshConfig.MaxSessions = int(config.SSH.MaxSessions.ValueInt64())
}
sshConfig := buildSSHConfig(config.Host.ValueString(), config.SSH)

// Create SSH client (validates config and applies defaults)
sshClient, err := factory.NewSSHClient(sshConfig)
Expand Down
Loading