diff --git a/audit/logging_auditor.go b/audit/logging_auditor.go new file mode 100644 index 0000000..31de3d2 --- /dev/null +++ b/audit/logging_auditor.go @@ -0,0 +1,29 @@ +package audit + +import "log/slog" + +// LoggingAuditor implements Auditor by logging to slog +type LoggingAuditor struct { + logger *slog.Logger +} + +// NewLoggingAuditor creates a new LoggingAuditor +func NewLoggingAuditor(logger *slog.Logger) *LoggingAuditor { + return &LoggingAuditor{ + logger: logger, + } +} + +// AuditRequest logs the request using structured logging +func (a *LoggingAuditor) AuditRequest(req *Request) { + if req.Allowed { + a.logger.Info("ALLOW", + "method", req.Method, + "url", req.URL, + "rule", req.Rule) + } else { + a.logger.Warn("DENY", + "method", req.Method, + "url", req.URL) + } +} diff --git a/audit/logging_auditor_test.go b/audit/logging_auditor_test.go new file mode 100644 index 0000000..3dac8ca --- /dev/null +++ b/audit/logging_auditor_test.go @@ -0,0 +1,376 @@ +package audit + +import ( + "bytes" + "io" + "log/slog" + "strings" + "testing" +) + +func TestLoggingAuditor(t *testing.T) { + tests := []struct { + name string + request *Request + expectedLevel string + expectedFields []string + }{ + { + name: "allow request", + request: &Request{ + Method: "GET", + URL: "https://github.com", + Allowed: true, + Rule: "allow github.com", + }, + expectedLevel: "INFO", + expectedFields: []string{"ALLOW", "GET", "https://github.com", "allow github.com"}, + }, + { + name: "deny request", + request: &Request{ + Method: "POST", + URL: "https://example.com", + Allowed: false, + }, + expectedLevel: "WARN", + expectedFields: []string{"DENY", "POST", "https://example.com"}, + }, + { + name: "allow with empty rule", + request: &Request{ + Method: "PUT", + URL: "https://api.github.com/repos", + Allowed: true, + Rule: "", + }, + expectedLevel: "INFO", + expectedFields: []string{"ALLOW", "PUT", "https://api.github.com/repos"}, + }, + { + name: "deny HTTPS request", + request: &Request{ + Method: "GET", + URL: "https://malware.bad.com/payload", + Allowed: false, + }, + expectedLevel: "WARN", + expectedFields: []string{"DENY", "GET", "https://malware.bad.com/payload"}, + }, + { + name: "allow with wildcard rule", + request: &Request{ + Method: "POST", + URL: "https://api.github.com/graphql", + Allowed: true, + Rule: "allow api.github.com/*", + }, + expectedLevel: "INFO", + expectedFields: []string{"ALLOW", "POST", "https://api.github.com/graphql", "allow api.github.com/*"}, + }, + { + name: "deny HTTP request", + request: &Request{ + Method: "GET", + URL: "http://insecure.example.com", + Allowed: false, + }, + expectedLevel: "WARN", + expectedFields: []string{"DENY", "GET", "http://insecure.example.com"}, + }, + { + name: "allow HEAD request", + request: &Request{ + Method: "HEAD", + URL: "https://cdn.jsdelivr.net/health", + Allowed: true, + Rule: "allow HEAD cdn.jsdelivr.net", + }, + expectedLevel: "INFO", + expectedFields: []string{"ALLOW", "HEAD", "https://cdn.jsdelivr.net/health", "allow HEAD cdn.jsdelivr.net"}, + }, + { + name: "deny OPTIONS request", + request: &Request{ + Method: "OPTIONS", + URL: "https://restricted.api.com/cors", + Allowed: false, + }, + expectedLevel: "WARN", + expectedFields: []string{"DENY", "OPTIONS", "https://restricted.api.com/cors"}, + }, + { + name: "allow with port number", + request: &Request{ + Method: "GET", + URL: "https://localhost:3000/api/health", + Allowed: true, + Rule: "allow localhost:3000", + }, + expectedLevel: "INFO", + expectedFields: []string{"ALLOW", "GET", "https://localhost:3000/api/health", "allow localhost:3000"}, + }, + { + name: "deny DELETE request", + request: &Request{ + Method: "DELETE", + URL: "https://api.production.com/users/admin", + Allowed: false, + }, + expectedLevel: "WARN", + expectedFields: []string{"DENY", "DELETE", "https://api.production.com/users/admin"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + auditor := NewLoggingAuditor(logger) + auditor.AuditRequest(tt.request) + + logOutput := buf.String() + if logOutput == "" { + t.Fatalf("expected log output, got empty string") + } + + if !strings.Contains(logOutput, tt.expectedLevel) { + t.Errorf("expected log level %s, got: %s", tt.expectedLevel, logOutput) + } + + for _, field := range tt.expectedFields { + if !strings.Contains(logOutput, field) { + t.Errorf("expected log to contain %q, got: %s", field, logOutput) + } + } + }) + } +} + +func TestLoggingAuditor_EdgeCases(t *testing.T) { + tests := []struct { + name string + request *Request + expectedLevel string + expectedFields []string + }{ + { + name: "empty fields", + request: &Request{ + Method: "", + URL: "", + Allowed: true, + Rule: "", + }, + expectedLevel: "INFO", + expectedFields: []string{"ALLOW"}, + }, + { + name: "special characters in URL", + request: &Request{ + Method: "POST", + URL: "https://api.example.com/users?name=John%20Doe&id=123", + Allowed: true, + Rule: "allow api.example.com/*", + }, + expectedLevel: "INFO", + expectedFields: []string{"ALLOW", "POST", "https://api.example.com/users?name=John%20Doe&id=123", "allow api.example.com/*"}, + }, + { + name: "very long URL", + request: &Request{ + Method: "GET", + URL: "https://example.com/" + strings.Repeat("a", 1000), + Allowed: false, + }, + expectedLevel: "WARN", + expectedFields: []string{"DENY", "GET"}, + }, + { + name: "deny with custom URL", + request: &Request{ + Method: "DELETE", + URL: "https://malicious.com", + Allowed: false, + }, + expectedLevel: "WARN", + expectedFields: []string{"DENY", "DELETE", "https://malicious.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + auditor := NewLoggingAuditor(logger) + auditor.AuditRequest(tt.request) + + logOutput := buf.String() + if logOutput == "" { + t.Fatalf("expected log output, got empty string") + } + + if !strings.Contains(logOutput, tt.expectedLevel) { + t.Errorf("expected log level %s, got: %s", tt.expectedLevel, logOutput) + } + + for _, field := range tt.expectedFields { + if !strings.Contains(logOutput, field) { + t.Errorf("expected log to contain %q, got: %s", field, logOutput) + } + } + }) + } +} + +func TestLoggingAuditor_DifferentLogLevels(t *testing.T) { + tests := []struct { + name string + logLevel slog.Level + request *Request + expectOutput bool + }{ + { + name: "info level allows info logs", + logLevel: slog.LevelInfo, + request: &Request{ + Method: "GET", + URL: "https://github.com", + Allowed: true, + Rule: "allow github.com", + }, + expectOutput: true, + }, + { + name: "warn level blocks info logs", + logLevel: slog.LevelWarn, + request: &Request{ + Method: "GET", + URL: "https://github.com", + Allowed: true, + Rule: "allow github.com", + }, + expectOutput: false, + }, + { + name: "warn level allows warn logs", + logLevel: slog.LevelWarn, + request: &Request{ + Method: "POST", + URL: "https://example.com", + Allowed: false, + }, + expectOutput: true, + }, + { + name: "error level blocks warn logs", + logLevel: slog.LevelError, + request: &Request{ + Method: "POST", + URL: "https://example.com", + Allowed: false, + }, + expectOutput: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{ + Level: tt.logLevel, + })) + + auditor := NewLoggingAuditor(logger) + auditor.AuditRequest(tt.request) + + logOutput := buf.String() + hasOutput := logOutput != "" + + if hasOutput != tt.expectOutput { + t.Errorf("expected output: %v, got output: %v (log: %q)", tt.expectOutput, hasOutput, logOutput) + } + }) + } +} + +func TestLoggingAuditor_NilLogger(t *testing.T) { + // This test ensures we handle edge cases gracefully + // In practice, NewLoggingAuditor should never receive a nil logger, + // but we test defensive programming + defer func() { + if r := recover(); r != nil { + // If it panics, that's also acceptable behavior + t.Logf("AuditRequest panicked with nil logger: %v", r) + } + }() + + auditor := &LoggingAuditor{logger: nil} + req := &Request{ + Method: "GET", + URL: "https://example.com", + Allowed: true, + Rule: "test", + } + + // This should either handle gracefully or panic - both are acceptable + auditor.AuditRequest(req) +} + +func TestLoggingAuditor_JSONHandler(t *testing.T) { + // Test with JSON handler instead of text handler + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + auditor := NewLoggingAuditor(logger) + req := &Request{ + Method: "GET", + URL: "https://github.com", + Allowed: true, + Rule: "allow github.com", + } + + auditor.AuditRequest(req) + + logOutput := buf.String() + if logOutput == "" { + t.Fatal("expected log output") + } + + // Verify it contains JSON structure + if !strings.Contains(logOutput, "{") || !strings.Contains(logOutput, "}") { + t.Error("expected JSON format in log output") + } + + // Verify expected fields are present in JSON + expectedFields := []string{"\"msg\":\"ALLOW\"", "\"method\":\"GET\"", "\"url\":\"https://github.com\"", "\"rule\":\"allow github.com\""} + for _, field := range expectedFields { + if !strings.Contains(logOutput, field) { + t.Errorf("expected JSON log to contain %q, got: %s", field, logOutput) + } + } +} + +func TestLoggingAuditor_DiscardHandler(t *testing.T) { + // Test with discard handler (no output) + logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + + auditor := NewLoggingAuditor(logger) + req := &Request{ + Method: "GET", + URL: "https://example.com", + Allowed: true, + Rule: "allow example.com", + } + + // This should not panic even with discard handler + auditor.AuditRequest(req) +} diff --git a/audit/request.go b/audit/request.go new file mode 100644 index 0000000..b0e7bae --- /dev/null +++ b/audit/request.go @@ -0,0 +1,21 @@ +package audit + +import ( + "net/http" +) + +// Request represents information about an HTTP request for auditing +type Request struct { + Method string + URL string + Allowed bool + Rule string // The rule that matched (if any) +} + +// HTTPRequestToAuditRequest converts an http.Request to an audit.Request +func HTTPRequestToAuditRequest(httpReq *http.Request) *Request { + return &Request{ + Method: httpReq.Method, + URL: httpReq.URL.String(), + } +} diff --git a/audit/request_test.go b/audit/request_test.go new file mode 100644 index 0000000..b8b6a5a --- /dev/null +++ b/audit/request_test.go @@ -0,0 +1,117 @@ +package audit + +import ( + "net/http" + "net/url" + "strings" + "testing" +) + +func TestHTTPRequestToAuditRequest(t *testing.T) { + tests := []struct { + name string + request *http.Request + expectedMethod string + expectedURL string + }{ + { + name: "basic GET request", + request: func() *http.Request { + req, _ := http.NewRequest("GET", "https://example.com/path?query=value", nil) + return req + }(), + expectedMethod: "GET", + expectedURL: "https://example.com/path?query=value", + }, + { + name: "POST request with body", + request: func() *http.Request { + req, _ := http.NewRequest("POST", "https://api.example.com/users", strings.NewReader("data")) + return req + }(), + expectedMethod: "POST", + expectedURL: "https://api.example.com/users", + }, + { + name: "request with port", + request: func() *http.Request { + req, _ := http.NewRequest("GET", "https://example.com:8443/api", nil) + return req + }(), + expectedMethod: "GET", + expectedURL: "https://example.com:8443/api", + }, + { + name: "request with complex query parameters", + request: func() *http.Request { + req, _ := http.NewRequest("GET", "https://search.example.com/api?q=hello%20world&limit=10&offset=0", nil) + return req + }(), + expectedMethod: "GET", + expectedURL: "https://search.example.com/api?q=hello%20world&limit=10&offset=0", + }, + { + name: "request with fragment (should be ignored)", + request: func() *http.Request { + u, _ := url.Parse("https://example.com/page#section") + req := &http.Request{ + Method: "GET", + URL: u, + } + return req + }(), + expectedMethod: "GET", + expectedURL: "https://example.com/page#section", + }, + { + name: "HTTP request (not HTTPS)", + request: func() *http.Request { + req, _ := http.NewRequest("GET", "http://insecure.example.com/data", nil) + return req + }(), + expectedMethod: "GET", + expectedURL: "http://insecure.example.com/data", + }, + { + name: "PUT request", + request: func() *http.Request { + req, _ := http.NewRequest("PUT", "https://api.example.com/users/123", strings.NewReader("updated data")) + return req + }(), + expectedMethod: "PUT", + expectedURL: "https://api.example.com/users/123", + }, + { + name: "DELETE request", + request: func() *http.Request { + req, _ := http.NewRequest("DELETE", "https://api.example.com/users/123", nil) + return req + }(), + expectedMethod: "DELETE", + expectedURL: "https://api.example.com/users/123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auditReq := HTTPRequestToAuditRequest(tt.request) + + if auditReq.Method != tt.expectedMethod { + t.Errorf("expected method %s, got %s", tt.expectedMethod, auditReq.Method) + } + + if auditReq.URL != tt.expectedURL { + t.Errorf("expected URL %s, got %s", tt.expectedURL, auditReq.URL) + } + + // Verify that fields not set by HTTPRequestToAuditRequest have zero values + if auditReq.Allowed != false { + t.Errorf("expected Allowed to be false (zero value), got %v", auditReq.Allowed) + } + + if auditReq.Rule != "" { + t.Errorf("expected Rule to be empty (zero value), got %q", auditReq.Rule) + } + }) + } +} diff --git a/cli/cli.go b/cli/cli.go new file mode 100644 index 0000000..f495da2 --- /dev/null +++ b/cli/cli.go @@ -0,0 +1,282 @@ +package cli + +import ( + "context" + cryptotls "crypto/tls" + "fmt" + "log/slog" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/coder/jail/audit" + "github.com/coder/jail/network" + "github.com/coder/jail/proxy" + "github.com/coder/jail/rules" + "github.com/coder/jail/tls" + "github.com/coder/serpent" +) + +// Config holds all configuration for the CLI +type Config struct { + AllowStrings []string + NoTLSIntercept bool + LogLevel string + NoJailCleanup bool +} + +// NewCommand creates and returns the root serpent command +func NewCommand() *serpent.Command { + var config Config + + return &serpent.Command{ + Use: "jail [flags] -- command [args...]", + Short: "Monitor and restrict HTTP/HTTPS requests from processes", + Long: `jail creates an isolated network environment for the target process, +intercepting all HTTP/HTTPS traffic through a transparent proxy that enforces +user-defined rules. + +Examples: + # Allow only requests to github.com + jail --allow "github.com" -- curl https://github.com + + # Monitor all requests to specific domains (allow only those) + jail --allow "github.com/api/issues/*" --allow "GET,HEAD github.com" -- npm install + + # Block everything by default (implicit)`, + Options: serpent.OptionSet{ + { + Name: "allow", + Flag: "allow", + Env: "JAIL_ALLOW", + Description: "Allow rule (can be specified multiple times). Format: 'pattern' or 'METHOD[,METHOD] pattern'.", + Value: serpent.StringArrayOf(&config.AllowStrings), + }, + { + Name: "no-tls-intercept", + Flag: "no-tls-intercept", + Env: "JAIL_NO_TLS_INTERCEPT", + Description: "Disable HTTPS interception.", + Value: serpent.BoolOf(&config.NoTLSIntercept), + }, + { + Name: "log-level", + Flag: "log-level", + Env: "JAIL_LOG_LEVEL", + Description: "Set log level (error, warn, info, debug).", + Default: "warn", + Value: serpent.StringOf(&config.LogLevel), + }, + { + Name: "no-jail-cleanup", + Flag: "no-jail-cleanup", + Env: "JAIL_NO_JAIL_CLEANUP", + Description: "Skip jail cleanup (hidden flag for testing).", + Value: serpent.BoolOf(&config.NoJailCleanup), + Hidden: true, + }, + }, + Handler: func(inv *serpent.Invocation) error { + return Run(config, inv.Args) + }, + } +} + +// setupLogging creates a slog logger with the specified level +func setupLogging(logLevel string) *slog.Logger { + var level slog.Level + switch strings.ToLower(logLevel) { + case "error": + level = slog.LevelError + case "warn": + level = slog.LevelWarn + case "info": + level = slog.LevelInfo + case "debug": + level = slog.LevelDebug + default: + level = slog.LevelWarn // Default to warn if invalid level + } + + // Create a standard slog logger with the appropriate level + handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + }) + + return slog.New(handler) +} + +// Run executes the jail command with the given configuration and arguments +func Run(config Config, args []string) error { + logger := setupLogging(config.LogLevel) + + // Get command arguments + if len(args) == 0 { + return fmt.Errorf("no command specified") + } + + // Parse allow list; default to deny-all if none provided + if len(config.AllowStrings) == 0 { + logger.Warn("No allow rules specified; all network traffic will be denied by default") + } + + allowRules, err := rules.ParseAllowSpecs(config.AllowStrings) + if err != nil { + logger.Error("Failed to parse allow rules", "error", err) + return fmt.Errorf("failed to parse allow rules: %v", err) + } + + // Implicit final deny-all is handled by the RuleEngine default behavior when no rules match. + // Build final rules slice in order: user allows only. + ruleList := allowRules + + // Create rule engine + ruleEngine := rules.NewRuleEngine(ruleList, logger) + + // Get configuration directory + configDir, err := tls.GetConfigDir() + if err != nil { + logger.Error("Failed to get config directory", "error", err) + return fmt.Errorf("failed to get config directory: %v", err) + } + + // Create certificate manager (if TLS interception is enabled) + var certManager *tls.CertificateManager + var tlsConfig *cryptotls.Config + var extraEnv map[string]string = make(map[string]string) + + if !config.NoTLSIntercept { + certManager, err = tls.NewCertificateManager(configDir, logger) + if err != nil { + logger.Error("Failed to create certificate manager", "error", err) + return fmt.Errorf("failed to create certificate manager: %v", err) + } + + tlsConfig = certManager.GetTLSConfig() + + // Get CA certificate for environment + caCertPEM, err := certManager.GetCACertPEM() + if err != nil { + logger.Error("Failed to get CA certificate", "error", err) + return fmt.Errorf("failed to get CA certificate: %v", err) + } + + // Write CA certificate to a temporary file for tools that need a file path + caCertPath := filepath.Join(configDir, "ca-cert.pem") + if err := os.WriteFile(caCertPath, caCertPEM, 0644); err != nil { + logger.Error("Failed to write CA certificate file", "error", err) + return fmt.Errorf("failed to write CA certificate file: %v", err) + } + + // Set standard CA certificate environment variables for common tools + // This makes tools like curl, git, etc. trust our dynamically generated CA + extraEnv["SSL_CERT_FILE"] = caCertPath // OpenSSL/LibreSSL-based tools + extraEnv["SSL_CERT_DIR"] = configDir // OpenSSL certificate directory + extraEnv["CURL_CA_BUNDLE"] = caCertPath // curl + extraEnv["GIT_SSL_CAINFO"] = caCertPath // Git + extraEnv["REQUESTS_CA_BUNDLE"] = caCertPath // Python requests + extraEnv["NODE_EXTRA_CA_CERTS"] = caCertPath // Node.js + extraEnv["JAIL_CA_CERT"] = string(caCertPEM) // Keep for backward compatibility + } + + // Create network jail configuration + networkConfig := network.JailConfig{ + HTTPPort: 8040, + HTTPSPort: 8043, + NetJailName: "jail", + SkipCleanup: config.NoJailCleanup, + } + + // Create network jail + networkInstance, err := network.NewJail(networkConfig, logger) + if err != nil { + logger.Error("Failed to create network jail", "error", err) + return fmt.Errorf("failed to create network jail: %v", err) + } + + // Setup signal handling BEFORE any network setup + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Handle signals immediately in background + go func() { + sig := <-sigChan + logger.Info("Received signal during setup, cleaning up...", "signal", sig) + if err := networkInstance.Cleanup(); err != nil { + logger.Error("Emergency cleanup failed", "error", err) + } + os.Exit(1) + }() + + // Ensure cleanup happens no matter what + defer func() { + logger.Debug("Starting cleanup process") + if err := networkInstance.Cleanup(); err != nil { + logger.Error("Failed to cleanup network jail", "error", err) + } else { + logger.Debug("Cleanup completed successfully") + } + }() + + // Setup network jail + if err := networkInstance.Setup(networkConfig.HTTPPort, networkConfig.HTTPSPort); err != nil { + logger.Error("Failed to setup network jail", "error", err) + return fmt.Errorf("failed to setup network jail: %v", err) + } + + // Create auditor + auditor := audit.NewLoggingAuditor(logger) + + // Create proxy server + proxyConfig := proxy.Config{ + HTTPPort: networkConfig.HTTPPort, + HTTPSPort: networkConfig.HTTPSPort, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + } + + proxyServer := proxy.NewProxyServer(proxyConfig) + + // Create context for graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start proxy server in background + go func() { + if err := proxyServer.Start(ctx); err != nil { + logger.Error("Proxy server error", "error", err) + } + }() + + // Give proxy time to start + time.Sleep(100 * time.Millisecond) + + // Execute command in network jail + go func() { + defer cancel() + if err := networkInstance.Execute(args, extraEnv); err != nil { + logger.Error("Command execution failed", "error", err) + } + }() + + // Wait for signal or context cancellation + select { + case sig := <-sigChan: + logger.Info("Received signal, shutting down...", "signal", sig) + cancel() + case <-ctx.Done(): + // Context cancelled by command completion + } + + // Stop proxy server + if err := proxyServer.Stop(); err != nil { + logger.Error("Failed to stop proxy server", "error", err) + } + + return nil +} \ No newline at end of file diff --git a/go.mod b/go.mod index 273588b..1e0f109 100644 --- a/go.mod +++ b/go.mod @@ -22,10 +22,10 @@ require ( github.com/spf13/pflag v1.0.5 // indirect go.opentelemetry.io/otel v1.19.0 // indirect go.opentelemetry.io/otel/trace v1.19.0 // indirect - golang.org/x/crypto v0.35.0 // indirect + golang.org/x/crypto v0.19.0 // indirect golang.org/x/exp v0.0.0-20240213143201-ec583247a57a // indirect - golang.org/x/sys v0.30.0 // indirect - golang.org/x/term v0.29.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/term v0.17.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 501a518..d751167 100644 --- a/go.sum +++ b/go.sum @@ -88,8 +88,8 @@ go.opentelemetry.io/otel/trace v1.19.0 h1:DFVQmlVbfVeOuBRrwdtaehRrWiL1JoVs9CPIQ1 go.opentelemetry.io/otel/trace v1.19.0/go.mod h1:mfaSyvGyEJEI0nyV2I4qhNQnbBOUUmYZpYojqMnX2vo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= -golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE= golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -109,19 +109,19 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= -golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s= +golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/main.go b/main.go index 2572b3d..6a7bb59 100644 --- a/main.go +++ b/main.go @@ -1,81 +1,14 @@ package main import ( - "context" - cryptotls "crypto/tls" "fmt" - "log/slog" "os" - "os/signal" - "path/filepath" - "strings" - "syscall" - "time" - "github.com/coder/jail/network" - "github.com/coder/jail/proxy" - "github.com/coder/jail/rules" - "github.com/coder/jail/tls" - "github.com/coder/serpent" -) - -var ( - allowStrings []string - noTLSIntercept bool - logLevel string - noJailCleanup bool + "github.com/coder/jail/cli" ) func main() { - cmd := &serpent.Command{ - Use: "jail [flags] -- command [args...]", - Short: "Monitor and restrict HTTP/HTTPS requests from processes", - Long: `jail creates an isolated network environment for the target process, -intercepting all HTTP/HTTPS traffic through a transparent proxy that enforces -user-defined rules. - -Examples: - # Allow only requests to github.com - jail --allow "github.com" -- curl https://github.com - - # Monitor all requests to specific domains (allow only those) - jail --allow "github.com/api/issues/*" --allow "GET,HEAD github.com" -- npm install - - # Block everything by default (implicit)`, - Options: serpent.OptionSet{ - { - Name: "allow", - Flag: "allow", - Env: "JAIL_ALLOW", - Description: "Allow rule (can be specified multiple times). Format: 'pattern' or 'METHOD[,METHOD] pattern'.", - Value: serpent.StringArrayOf(&allowStrings), - }, - { - Name: "no-tls-intercept", - Flag: "no-tls-intercept", - Env: "JAIL_NO_TLS_INTERCEPT", - Description: "Disable HTTPS interception.", - Value: serpent.BoolOf(&noTLSIntercept), - }, - { - Name: "log-level", - Flag: "log-level", - Env: "JAIL_LOG_LEVEL", - Description: "Set log level (error, warn, info, debug).", - Default: "warn", - Value: serpent.StringOf(&logLevel), - }, - { - Name: "no-jail-cleanup", - Flag: "no-jail-cleanup", - Env: "JAIL_NO_JAIL_CLEANUP", - Description: "Skip jail cleanup (hidden flag for testing).", - Value: serpent.BoolOf(&noJailCleanup), - Hidden: true, - }, - }, - Handler: runJail, - } + cmd := cli.NewCommand() err := cmd.Invoke().WithOS().Run() if err != nil { @@ -83,194 +16,3 @@ Examples: os.Exit(1) } } - -func setupLogging(logLevel string) *slog.Logger { - var level slog.Level - switch strings.ToLower(logLevel) { - case "error": - level = slog.LevelError - case "warn": - level = slog.LevelWarn - case "info": - level = slog.LevelInfo - case "debug": - level = slog.LevelDebug - default: - level = slog.LevelWarn // Default to warn if invalid level - } - - // Create a standard slog logger with the appropriate level - handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - }) - - return slog.New(handler) -} - -func runJail(inv *serpent.Invocation) error { - logger := setupLogging(logLevel) - - // Get command arguments - args := inv.Args - if len(args) == 0 { - return fmt.Errorf("no command specified") - } - - // Parse allow list; default to deny-all if none provided - if len(allowStrings) == 0 { - logger.Warn("No allow rules specified; all network traffic will be denied by default") - } - - allowRules, err := rules.ParseAllowSpecs(allowStrings) - if err != nil { - logger.Error("Failed to parse allow rules", "error", err) - return fmt.Errorf("failed to parse allow rules: %v", err) - } - - // Implicit final deny-all is handled by the RuleEngine default behavior when no rules match. - // Build final rules slice in order: user allows only. - ruleList := allowRules - - // Create rule engine - ruleEngine := rules.NewRuleEngine(ruleList, logger) - - // Get configuration directory - configDir, err := tls.GetConfigDir() - if err != nil { - logger.Error("Failed to get config directory", "error", err) - return fmt.Errorf("failed to get config directory: %v", err) - } - - // Create certificate manager (if TLS interception is enabled) - var certManager *tls.CertificateManager - var tlsConfig *cryptotls.Config - var extraEnv map[string]string = make(map[string]string) - - if !noTLSIntercept { - certManager, err = tls.NewCertificateManager(configDir, logger) - if err != nil { - logger.Error("Failed to create certificate manager", "error", err) - return fmt.Errorf("failed to create certificate manager: %v", err) - } - - tlsConfig = certManager.GetTLSConfig() - - // Get CA certificate for environment - caCertPEM, err := certManager.GetCACertPEM() - if err != nil { - logger.Error("Failed to get CA certificate", "error", err) - return fmt.Errorf("failed to get CA certificate: %v", err) - } - - // Write CA certificate to a temporary file for tools that need a file path - caCertPath := filepath.Join(configDir, "ca-cert.pem") - if err := os.WriteFile(caCertPath, caCertPEM, 0644); err != nil { - logger.Error("Failed to write CA certificate file", "error", err) - return fmt.Errorf("failed to write CA certificate file: %v", err) - } - - // Set standard CA certificate environment variables for common tools - // This makes tools like curl, git, etc. trust our dynamically generated CA - extraEnv["SSL_CERT_FILE"] = caCertPath // OpenSSL/LibreSSL-based tools - extraEnv["SSL_CERT_DIR"] = configDir // OpenSSL certificate directory - extraEnv["CURL_CA_BUNDLE"] = caCertPath // curl - extraEnv["GIT_SSL_CAINFO"] = caCertPath // Git - extraEnv["REQUESTS_CA_BUNDLE"] = caCertPath // Python requests - extraEnv["NODE_EXTRA_CA_CERTS"] = caCertPath // Node.js - extraEnv["JAIL_CA_CERT"] = string(caCertPEM) // Keep for backward compatibility - } - - // Create network jail configuration - networkConfig := network.JailConfig{ - HTTPPort: 8040, - HTTPSPort: 8043, - NetJailName: "jail", - SkipCleanup: noJailCleanup, - } - - // Create network jail - networkInstance, err := network.NewJail(networkConfig, logger) - if err != nil { - logger.Error("Failed to create network jail", "error", err) - return fmt.Errorf("failed to create network jail: %v", err) - } - - // Setup signal handling BEFORE any network setup - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - // Handle signals immediately in background - go func() { - sig := <-sigChan - logger.Info("Received signal during setup, cleaning up...", "signal", sig) - if err := networkInstance.Cleanup(); err != nil { - logger.Error("Emergency cleanup failed", "error", err) - } - os.Exit(1) - }() - - // Ensure cleanup happens no matter what - defer func() { - logger.Debug("Starting cleanup process") - if err := networkInstance.Cleanup(); err != nil { - logger.Error("Failed to cleanup network jail", "error", err) - } else { - logger.Debug("Cleanup completed successfully") - } - }() - - // Setup network jail - if err := networkInstance.Setup(networkConfig.HTTPPort, networkConfig.HTTPSPort); err != nil { - logger.Error("Failed to setup network jail", "error", err) - return fmt.Errorf("failed to setup network jail: %v", err) - } - - // Create proxy server - proxyConfig := proxy.Config{ - HTTPPort: networkConfig.HTTPPort, - HTTPSPort: networkConfig.HTTPSPort, - RuleEngine: ruleEngine, - Logger: logger, - TLSConfig: tlsConfig, - } - - proxyServer := proxy.NewProxyServer(proxyConfig) - - // Create context for graceful shutdown - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Start proxy server in background - go func() { - if err := proxyServer.Start(ctx); err != nil { - logger.Error("Proxy server error", "error", err) - } - }() - - // Give proxy time to start - time.Sleep(100 * time.Millisecond) - - // Execute command in network jail - go func() { - defer cancel() - if err := networkInstance.Execute(args, extraEnv); err != nil { - logger.Error("Command execution failed", "error", err) - } - }() - - // Wait for signal or context cancellation - select { - case sig := <-sigChan: - logger.Info("Received signal, shutting down...", "signal", sig) - cancel() - case <-ctx.Done(): - // Context cancelled by command completion - } - - // Stop proxy server - if err := proxyServer.Stop(); err != nil { - logger.Error("Failed to stop proxy server", "error", err) - } - - return nil -} diff --git a/network/linux.go b/network/linux.go index 82eacc2..e05e04d 100644 --- a/network/linux.go +++ b/network/linux.go @@ -7,6 +7,8 @@ import ( "log/slog" "os" "os/exec" + "os/user" + "strconv" "syscall" "time" ) @@ -97,14 +99,53 @@ func (l *LinuxJail) Execute(command []string, extraEnv map[string]string) error env = append(env, fmt.Sprintf("%s=%s", key, value)) } + // When running under sudo, restore essential user environment variables + sudoUser := os.Getenv("SUDO_USER") + if sudoUser != "" { + user, err := user.Lookup(sudoUser) + if err == nil { + // Set HOME to original user's home directory + env = append(env, fmt.Sprintf("HOME=%s", user.HomeDir)) + // Set USER to original username + env = append(env, fmt.Sprintf("USER=%s", sudoUser)) + // Set LOGNAME to original username (some tools check this instead of USER) + env = append(env, fmt.Sprintf("LOGNAME=%s", sudoUser)) + l.logger.Debug("Restored user environment", "home", user.HomeDir, "user", sudoUser) + } + } + cmd.Env = env cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + // Drop privileges to original user if running under sudo + var gid, uid int + var err error + sudoUID := os.Getenv("SUDO_UID") + if sudoUID != "" { + uid, err = strconv.Atoi(sudoUID) + if err != nil { + l.logger.Warn("Invalid SUDO_UID, subprocess will run as root", "sudo_uid", sudoUID, "error", err) + } + } + sudoGID := os.Getenv("SUDO_GID") + if sudoGID != "" { + gid, err = strconv.Atoi(sudoGID) + if err != nil { + l.logger.Warn("Invalid SUDO_GID, subprocess will run as root", "sudo_gid", sudoGID, "error", err) + } + } + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: uint32(uid), + Gid: uint32(gid), + }, + } + // Start command l.logger.Debug("Starting command", "path", cmd.Path, "args", cmd.Args) - err := cmd.Start() + err = cmd.Start() if err != nil { return fmt.Errorf("failed to start command: %v", err) } diff --git a/network/macos.go b/network/macos.go index 79f8097..e3b9809 100644 --- a/network/macos.go +++ b/network/macos.go @@ -7,6 +7,7 @@ import ( "log/slog" "os" "os/exec" + "os/user" "strconv" "strings" "syscall" @@ -84,18 +85,51 @@ func (m *MacOSNetJail) Execute(command []string, extraEnv map[string]string) err env = append(env, fmt.Sprintf("%s=%s", key, value)) } + // When running under sudo, restore essential user environment variables + sudoUser := os.Getenv("SUDO_USER") + if sudoUser != "" { + user, err := user.Lookup(sudoUser) + if err == nil { + // Set HOME to original user's home directory + env = append(env, fmt.Sprintf("HOME=%s", user.HomeDir)) + // Set USER to original username + env = append(env, fmt.Sprintf("USER=%s", sudoUser)) + // Set LOGNAME to original username (some tools check this instead of USER) + env = append(env, fmt.Sprintf("LOGNAME=%s", sudoUser)) + m.logger.Debug("Restored user environment", "home", user.HomeDir, "user", sudoUser) + } + } + cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Stdin = os.Stdin - // Set group ID using syscall (like httpjail does) + // Set group ID using syscall cmd.SysProcAttr = &syscall.SysProcAttr{ Credential: &syscall.Credential{ Gid: uint32(m.groupID), }, } + // Drop privileges to original user if running under sudo + sudoUID := os.Getenv("SUDO_UID") + if sudoUID != "" { + uid, err := strconv.Atoi(sudoUID) + if err != nil { + m.logger.Warn("Invalid SUDO_UID, subprocess will run as root", "sudo_uid", sudoUID, "error", err) + } else { + // Use original user ID but KEEP the jail group for network isolation + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: uint32(uid), + Gid: uint32(m.groupID), // Keep jail group, not original user's group + }, + } + m.logger.Debug("Dropping privileges to original user with jail group", "uid", uid, "jail_gid", m.groupID) + } + } + // Start and wait for command to complete m.logger.Debug("Starting command", "path", cmd.Path, "args", cmd.Args) err := cmd.Start() diff --git a/proxy/proxy.go b/proxy/proxy.go index 8f59077..1fab9cc 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -10,6 +10,7 @@ import ( "net/url" "time" + "github.com/coder/jail/audit" "github.com/coder/jail/rules" ) @@ -18,6 +19,7 @@ type ProxyServer struct { httpServer *http.Server httpsServer *http.Server ruleEngine *rules.RuleEngine + auditor *audit.LoggingAuditor logger *slog.Logger tlsConfig *tls.Config httpPort int @@ -29,6 +31,7 @@ type Config struct { HTTPPort int HTTPSPort int RuleEngine *rules.RuleEngine + Auditor *audit.LoggingAuditor Logger *slog.Logger TLSConfig *tls.Config } @@ -37,6 +40,7 @@ type Config struct { func NewProxyServer(config Config) *ProxyServer { return &ProxyServer{ ruleEngine: config.RuleEngine, + auditor: config.Auditor, logger: config.Logger, tlsConfig: config.TLSConfig, httpPort: config.HTTPPort, @@ -102,8 +106,15 @@ func (p *ProxyServer) Stop() error { // handleHTTP handles regular HTTP requests func (p *ProxyServer) handleHTTP(w http.ResponseWriter, r *http.Request) { // Check if request should be allowed - action := p.ruleEngine.Evaluate(r.Method, r.URL.String()) - if action == rules.Deny { + result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) + + // Audit the request + auditReq := audit.HTTPRequestToAuditRequest(r) + auditReq.Allowed = result.Allowed + auditReq.Rule = result.Rule + p.auditRequest(auditReq) + + if !result.Allowed { p.writeBlockedResponse(w, r) return } @@ -121,8 +132,18 @@ func (p *ProxyServer) handleHTTPS(w http.ResponseWriter, r *http.Request) { } // Check if request should be allowed - action := p.ruleEngine.Evaluate(r.Method, fullURL) - if action == rules.Deny { + result := p.ruleEngine.Evaluate(r.Method, fullURL) + + // Audit the request + auditReq := &audit.Request{ + Method: r.Method, + URL: fullURL, + Allowed: result.Allowed, + Rule: result.Rule, + } + p.auditRequest(auditReq) + + if !result.Allowed { p.writeBlockedResponse(w, r) return } @@ -268,3 +289,8 @@ For more help: https://github.com/coder/jail `, r.Method, r.URL.Path, host, host, r.Method, host, r.Method) } + +// auditRequest handles auditing of requests +func (p *ProxyServer) auditRequest(req *audit.Request) { + p.auditor.AuditRequest(req) +} diff --git a/rules/rules.go b/rules/rules.go index 2f4810d..6ad64fe 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -6,23 +6,6 @@ import ( "strings" ) -// Action represents whether to allow a request -type Action int - -const ( - Allow Action = iota - Deny // Default deny when no allow rules match -) - -func (a Action) String() string { - switch a { - case Allow: - return "ALLOW" - default: - return "DENY" - } -} - // Rule represents an allow rule with optional HTTP method restrictions type Rule struct { Pattern string // wildcard pattern for matching @@ -30,7 +13,6 @@ type Rule struct { Raw string // rule string for logging } - // Matches checks if the rule matches the given method and URL using wildcard patterns func (r *Rule) Matches(method, url string) bool { // Check method if specified @@ -137,19 +119,29 @@ func NewRuleEngine(rules []*Rule, logger *slog.Logger) *RuleEngine { } } -// Evaluate evaluates a request against all allow rules and returns the action to take -func (re *RuleEngine) Evaluate(method, url string) Action { +// EvaluationResult contains the result of rule evaluation +type EvaluationResult struct { + Allowed bool + Rule string // The rule that matched (if any) +} + +// Evaluate evaluates a request and returns both result and matching rule +func (re *RuleEngine) Evaluate(method, url string) EvaluationResult { // Check if any allow rule matches for _, rule := range re.rules { if rule.Matches(method, url) { - re.logger.Info("ALLOW", "method", method, "url", url, "rule", rule.Raw) - return Allow + return EvaluationResult{ + Allowed: true, + Rule: rule.Raw, + } } } // Default deny if no allow rules match - re.logger.Warn("DENY", "method", method, "url", url, "reason", "no matching allow rules") - return Deny + return EvaluationResult{ + Allowed: false, + Rule: "", + } } // newAllowRule creates an allow Rule from a spec string used by --allow. diff --git a/rules/rules_test.go b/rules/rules_test.go index fdcd997..5fbe009 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -157,7 +157,7 @@ func TestWildcardMatch(t *testing.T) { // Basic exact matches {"exact match", "github.com", "github.com", true}, {"no match", "github.com", "gitlab.com", false}, - + // Wildcard * tests {"star matches all", "*", "anything.com", true}, {"star matches empty", "*", "", true}, @@ -240,19 +240,19 @@ func TestRuleEngine(t *testing.T) { name string method string url string - expected Action + expected bool }{ - {"allow github", "GET", "https://github.com/user/repo", Allow}, - {"allow api GET", "GET", "https://api.example.com", Allow}, - {"deny api POST", "POST", "https://api.example.com", Deny}, - {"deny other", "GET", "https://example.com", Deny}, + {"allow github", "GET", "https://github.com/user/repo", true}, + {"allow api GET", "GET", "https://api.example.com", true}, + {"deny api POST", "POST", "https://api.example.com", false}, + {"deny other", "GET", "https://example.com", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := engine.Evaluate(tt.method, tt.url) - if result != tt.expected { - t.Errorf("expected %v, got %v", tt.expected, result) + if result.Allowed != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result.Allowed) } }) } @@ -275,21 +275,21 @@ func TestRuleEngineWildcardRules(t *testing.T) { name string method string url string - expected Action + expected bool }{ - {"allow github", "GET", "https://github.com", Allow}, - {"allow github subdomain", "POST", "https://github.io", Allow}, - {"allow api GET", "GET", "https://api.example.com", Allow}, - {"deny api POST", "POST", "https://api.example.com", Deny}, - {"deny unmatched", "GET", "https://example.org", Deny}, + {"allow github", "GET", "https://github.com", true}, + {"allow github subdomain", "POST", "https://github.io", true}, + {"allow api GET", "GET", "https://api.example.com", true}, + {"deny api POST", "POST", "https://api.example.com", false}, + {"deny unmatched", "GET", "https://example.org", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := engine.Evaluate(tt.method, tt.url) - if result != tt.expected { - t.Errorf("expected %v, got %v", tt.expected, result) + if result.Allowed != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result.Allowed) } }) } -} \ No newline at end of file +} diff --git a/tls/tls.go b/tls/tls.go index a9e51b0..07ef6ad 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -12,7 +12,9 @@ import ( "math/big" "net" "os" + "os/user" "path/filepath" + "strconv" "sync" "time" ) @@ -141,6 +143,22 @@ func (cm *CertificateManager) generateCA(keyPath, certPath string) error { return fmt.Errorf("failed to create config directory: %v", err) } + // When running under sudo, ensure the directory is owned by the original user + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + if sudoUID := os.Getenv("SUDO_UID"); sudoUID != "" { + if sudoGID := os.Getenv("SUDO_GID"); sudoGID != "" { + uid, err1 := strconv.Atoi(sudoUID) + gid, err2 := strconv.Atoi(sudoGID) + if err1 == nil && err2 == nil { + // Change ownership of the config directory to the original user + if err := os.Chown(cm.configDir, uid, gid); err != nil { + cm.logger.Warn("Failed to change config directory ownership", "error", err) + } + } + } + } + } + // Generate private key privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -295,9 +313,28 @@ func (cm *CertificateManager) generateServerCertificate(hostname string) (*tls.C // GetConfigDir returns the configuration directory path func GetConfigDir() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get user home directory: %v", err) + // When running under sudo, use the original user's home directory + // so the subprocess can access the CA certificate files + var homeDir string + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + // Get original user's home directory + if user, err := user.Lookup(sudoUser); err == nil { + homeDir = user.HomeDir + } else { + // Fallback to current user if lookup fails + var err2 error + homeDir, err2 = os.UserHomeDir() + if err2 != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err2) + } + } + } else { + // Normal case - use current user's home + var err error + homeDir, err = os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err) + } } // Use platform-specific config directory