Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 55 additions & 101 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ import (
"log/slog"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"

"github.com/coder/jail"
"github.com/coder/jail/audit"
"github.com/coder/jail/network"
"github.com/coder/jail/namespace"
"github.com/coder/jail/proxy"
"github.com/coder/jail/rules"
"github.com/coder/jail/tls"
Expand All @@ -25,7 +24,6 @@ type Config struct {
AllowStrings []string
NoTLSIntercept bool
LogLevel string
NoJailCleanup bool
}

// NewCommand creates and returns the root serpent command
Expand Down Expand Up @@ -70,14 +68,6 @@ Examples:
Default: "warn",
Value: serpent.StringOf(&config.LogLevel),
},
{
Name: "no-jail-cleanup",
Flag: "no-jail-cleanup",
Env: "JAIL_NO_JAIL_CLEANUP",
Description: "Skip jail cleanup (hidden flag for testing).",
Value: serpent.BoolOf(&config.NoJailCleanup),
Hidden: true,
},
},
Handler: func(inv *serpent.Invocation) error {
return Run(config, inv.Args)
Expand Down Expand Up @@ -123,90 +113,85 @@ func Run(config Config, args []string) error {
logger.Warn("No allow rules specified; all network traffic will be denied by default")
}

// Parse allow rules
allowRules, err := rules.ParseAllowSpecs(config.AllowStrings)
if err != nil {
logger.Error("Failed to parse allow rules", "error", err)
return fmt.Errorf("failed to parse allow rules: %v", err)
}

// Implicit final deny-all is handled by the RuleEngine default behavior when no rules match.
// Build final rules slice in order: user allows only.
ruleList := allowRules

// Create rule engine
ruleEngine := rules.NewRuleEngine(ruleList, logger)
ruleEngine := rules.NewRuleEngine(allowRules, logger)

// Create auditor
auditor := audit.NewLoggingAuditor(logger)

// Get configuration directory
configDir, err := tls.GetConfigDir()
// Create network namespace configuration
nsConfig := namespace.Config{
HTTPPort: 8040,
HTTPSPort: 8043,
}

// Create commander
commander, err := namespace.New(nsConfig, logger)
if err != nil {
logger.Error("Failed to get config directory", "error", err)
return fmt.Errorf("failed to get config directory: %v", err)
logger.Error("Failed to create network namespace", "error", err)
return fmt.Errorf("failed to create network namespace: %v", err)
}

// Create certificate manager (if TLS interception is enabled)
var certManager *tls.CertificateManager
var tlsConfig *cryptotls.Config
var extraEnv map[string]string = make(map[string]string)

if !config.NoTLSIntercept {
certManager, err = tls.NewCertificateManager(configDir, logger)
certManager, err := tls.NewCertificateManager("", logger) // Empty configDir since it will be determined internally
if err != nil {
logger.Error("Failed to create certificate manager", "error", err)
return fmt.Errorf("failed to create certificate manager: %v", err)
}

tlsConfig = certManager.GetTLSConfig()

// Get CA certificate for environment
caCertPEM, err := certManager.GetCACertPEM()
if err != nil {
logger.Error("Failed to get CA certificate", "error", err)
return fmt.Errorf("failed to get CA certificate: %v", err)
}

// Write CA certificate to a temporary file for tools that need a file path
caCertPath := filepath.Join(configDir, "ca-cert.pem")
err = os.WriteFile(caCertPath, caCertPEM, 0644)
// Setup TLS config and write CA certificate to file
var caCertPath, configDir string
tlsConfig, caCertPath, configDir, err = certManager.SetupTLSAndWriteCACert()
if err != nil {
logger.Error("Failed to write CA certificate file", "error", err)
return fmt.Errorf("failed to write CA certificate file: %v", err)
logger.Error("Failed to setup TLS and CA certificate", "error", err)
return fmt.Errorf("failed to setup TLS and CA certificate: %v", err)
}

// Set standard CA certificate environment variables for common tools
// This makes tools like curl, git, etc. trust our dynamically generated CA
extraEnv["SSL_CERT_FILE"] = caCertPath // OpenSSL/LibreSSL-based tools
extraEnv["SSL_CERT_DIR"] = configDir // OpenSSL certificate directory
extraEnv["CURL_CA_BUNDLE"] = caCertPath // curl
extraEnv["GIT_SSL_CAINFO"] = caCertPath // Git
extraEnv["REQUESTS_CA_BUNDLE"] = caCertPath // Python requests
extraEnv["NODE_EXTRA_CA_CERTS"] = caCertPath // Node.js
extraEnv["JAIL_CA_CERT"] = string(caCertPEM) // Keep for backward compatibility
commander.SetEnv("SSL_CERT_FILE", caCertPath) // OpenSSL/LibreSSL-based tools
commander.SetEnv("SSL_CERT_DIR", configDir) // OpenSSL certificate directory
commander.SetEnv("CURL_CA_BUNDLE", caCertPath) // curl
commander.SetEnv("GIT_SSL_CAINFO", caCertPath) // Git
commander.SetEnv("REQUESTS_CA_BUNDLE", caCertPath) // Python requests
commander.SetEnv("NODE_EXTRA_CA_CERTS", caCertPath) // Node.js
}

// Create network jail configuration
networkConfig := network.JailConfig{
HTTPPort: 8040,
HTTPSPort: 8043,
NetJailName: "jail",
SkipCleanup: config.NoJailCleanup,
}
// Create proxy server
proxyServer := proxy.NewProxyServer(proxy.Config{
HTTPPort: 8040,
HTTPSPort: 8043,
RuleEngine: ruleEngine,
Auditor: auditor,
Logger: logger,
TLSConfig: tlsConfig,
})

// Create network jail
networkInstance, err := network.NewJail(networkConfig, logger)
if err != nil {
logger.Error("Failed to create network jail", "error", err)
return fmt.Errorf("failed to create network jail: %v", err)
}
// Create jail instance
jailInstance := jail.New(jail.Config{
Commander: commander,
ProxyServer: proxyServer,
Logger: logger,
})

// Setup signal handling BEFORE any network setup
// Setup signal handling BEFORE any setup
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

// Handle signals immediately in background
go func() {
sig := <-sigChan
logger.Info("Received signal during setup, cleaning up...", "signal", sig)
err := networkInstance.Cleanup()
err := jailInstance.Close()
if err != nil {
logger.Error("Emergency cleanup failed", "error", err)
}
Expand All @@ -216,55 +201,29 @@ func Run(config Config, args []string) error {
// Ensure cleanup happens no matter what
defer func() {
logger.Debug("Starting cleanup process")
err := networkInstance.Cleanup()
err := jailInstance.Close()
if err != nil {
logger.Error("Failed to cleanup network jail", "error", err)
logger.Error("Failed to cleanup jail", "error", err)
} else {
logger.Debug("Cleanup completed successfully")
}
}()

// Setup network jail
err = networkInstance.Setup(networkConfig.HTTPPort, networkConfig.HTTPSPort)
// Open jail (starts network namespace and proxy server)
err = jailInstance.Open()
if err != nil {
logger.Error("Failed to setup network jail", "error", err)
return fmt.Errorf("failed to setup network jail: %v", err)
}

// Create auditor
auditor := audit.NewLoggingAuditor(logger)

// Create proxy server
proxyConfig := proxy.Config{
HTTPPort: networkConfig.HTTPPort,
HTTPSPort: networkConfig.HTTPSPort,
RuleEngine: ruleEngine,
Auditor: auditor,
Logger: logger,
TLSConfig: tlsConfig,
logger.Error("Failed to open jail", "error", err)
return fmt.Errorf("failed to open jail: %v", err)
}

proxyServer := proxy.NewProxyServer(proxyConfig)

// Create context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Start proxy server in background
go func() {
err := proxyServer.Start(ctx)
if err != nil {
logger.Error("Proxy server error", "error", err)
}
}()

// Give proxy time to start
time.Sleep(100 * time.Millisecond)

// Execute command in network jail
// Execute command in jail
go func() {
defer cancel()
err := networkInstance.Execute(args, extraEnv)
err := jailInstance.Command(args).Run()
if err != nil {
logger.Error("Command execution failed", "error", err)
}
Expand All @@ -277,12 +236,7 @@ func Run(config Config, args []string) error {
cancel()
case <-ctx.Done():
// Context cancelled by command completion
}

// Stop proxy server
err = proxyServer.Stop()
if err != nil {
logger.Error("Failed to stop proxy server", "error", err)
logger.Info("Command completed, shutting down...")
}

return nil
Expand Down
87 changes: 87 additions & 0 deletions jail.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package jail

import (
"context"
"fmt"
"log/slog"
"os/exec"
"time"

"github.com/coder/jail/proxy"
)

type Commander interface {
Open() error
SetEnv(key string, value string)
Command(command []string) *exec.Cmd
Close() error
}

type Config struct {
Commander Commander
ProxyServer *proxy.ProxyServer
Logger *slog.Logger
}

type Jail struct {
commandExecutor Commander
proxyServer *proxy.ProxyServer
logger *slog.Logger
cancel context.CancelFunc
ctx context.Context
}

func New(config Config) *Jail {
ctx, cancel := context.WithCancel(context.Background())

return &Jail{
commandExecutor: config.Commander,
proxyServer: config.ProxyServer,
logger: config.Logger,
ctx: ctx,
cancel: cancel,
}
}

func (j *Jail) Open() error {
// Open the command executor (network namespace)
err := j.commandExecutor.Open()
if err != nil {
return fmt.Errorf("failed to open command executor: %v", err)
}

// Start proxy server in background
go func() {
err := j.proxyServer.Start(j.ctx)
if err != nil {
j.logger.Error("Proxy server error", "error", err)
}
}()

// Give proxy time to start
time.Sleep(100 * time.Millisecond)

return nil
}

func (j *Jail) Command(command []string) *exec.Cmd {
return j.commandExecutor.Command(command)
}

func (j *Jail) Close() error {
// Cancel context to stop proxy server
if j.cancel != nil {
j.cancel()
}

// Stop proxy server
if j.proxyServer != nil {
err := j.proxyServer.Stop()
if err != nil {
j.logger.Error("Failed to stop proxy server", "error", err)
}
}

// Close command executor
return j.commandExecutor.Close()
}
Loading
Loading