diff --git a/.gitignore b/.gitignore index afc60e4..b2b0f28 100644 --- a/.gitignore +++ b/.gitignore @@ -47,4 +47,4 @@ Thumbs.db build/ # Jail binary -./jail +jail diff --git a/cli/cli.go b/cli/cli.go index de39881..93ce806 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -24,6 +24,7 @@ import ( type Config struct { AllowStrings []string LogLevel string + Unprivileged bool } // NewCommand creates and returns the root serpent command @@ -31,7 +32,7 @@ func NewCommand() *serpent.Command { // To make the top level jail command, we just make some minor changes to the base command cmd := BaseCommand() cmd.Use = "jail [flags] -- command [args...]" // Add the flags and args pieces to usage. - + // Add example usage to the long description. This is different from usage as a subcommand because it // may be called something different when used as a subcommand / there will be a leading binary (i.e. `coder jail` vs. `jail`). cmd.Long += `Examples: @@ -43,7 +44,7 @@ func NewCommand() *serpent.Command { # Block everything by default (implicit)` - return cmd + return cmd } // Base command returns the jail serpent command without the information involved in making it the @@ -74,6 +75,13 @@ user-defined rules.`, Default: "warn", Value: serpent.StringOf(&config.LogLevel), }, + { + Name: "unprivileged", + Flag: "unprivileged", + Env: "JAIL_UNPRIVILEGED", + Description: "Use unprivileged mode (proxy environment variables).", + Value: serpent.BoolOf(&config.Unprivileged), + }, }, Handler: func(inv *serpent.Invocation) error { return Run(inv.Context(), config, inv.Args) @@ -123,10 +131,12 @@ func Run(ctx context.Context, config Config, args []string) error { // Create jail instance jailInstance, err := jail.New(ctx, jail.Config{ - RuleEngine: ruleEngine, - Auditor: auditor, - CertManager: certManager, - Logger: logger, + RuleEngine: ruleEngine, + Auditor: auditor, + CertManager: certManager, + Logger: logger, + UserInfo: userInfo, + Unprivileged: config.Unprivileged, }) if err != nil { return fmt.Errorf("failed to create jail instance: %v", err) @@ -171,30 +181,29 @@ func Run(ctx context.Context, config Config, args []string) error { return nil } +// getUserInfo returns information about the current user, handling sudo scenarios func getUserInfo() namespace.UserInfo { - // get the user info of the original user even if we are running under sudo - sudoUser := os.Getenv("SUDO_USER") - - // If running under sudo, get original user information - if sudoUser != "" { + // Only consider SUDO_USER if we're actually running with elevated privileges + // In environments like Coder workspaces, SUDO_USER may be set to 'root' + // but we're not actually running under sudo + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" && os.Geteuid() == 0 && sudoUser != "root" { + // We're actually running under sudo with a non-root original user user, err := user.Lookup(sudoUser) if err != nil { - // Fallback to current user if lookup fails - return getCurrentUserInfo() + return getCurrentUserInfo() // Fallback to current user } - // Parse SUDO_UID and SUDO_GID - uid := 0 - gid := 0 + uid, _ := strconv.Atoi(os.Getenv("SUDO_UID")) + gid, _ := strconv.Atoi(os.Getenv("SUDO_GID")) - if sudoUID := os.Getenv("SUDO_UID"); sudoUID != "" { - if parsedUID, err := strconv.Atoi(sudoUID); err == nil { + // If we couldn't get UID/GID from env, parse from user info + if uid == 0 { + if parsedUID, err := strconv.Atoi(user.Uid); err == nil { uid = parsedUID } } - - if sudoGID := os.Getenv("SUDO_GID"); sudoGID != "" { - if parsedGID, err := strconv.Atoi(sudoGID); err == nil { + if gid == 0 { + if parsedGID, err := strconv.Atoi(user.Gid); err == nil { gid = parsedGID } } @@ -210,7 +219,7 @@ func getUserInfo() namespace.UserInfo { } } - // Not running under sudo, use current user + // Not actually running under sudo, use current user return getCurrentUserInfo() } diff --git a/jail.go b/jail.go index 0b2a4b5..881346b 100644 --- a/jail.go +++ b/jail.go @@ -16,10 +16,12 @@ import ( ) type Config struct { - RuleEngine rules.Evaluator - Auditor audit.Auditor - CertManager tls.Manager - Logger *slog.Logger + RuleEngine rules.Evaluator + Auditor audit.Auditor + CertManager tls.Manager + Logger *slog.Logger + UserInfo namespace.UserInfo + Unprivileged bool } type Jail struct { @@ -40,31 +42,22 @@ func New(ctx context.Context, config Config) (*Jail, error) { // Create proxy server proxyServer := proxy.NewProxyServer(proxy.Config{ HTTPPort: 8080, - HTTPSPort: 8443, - Auditor: config.Auditor, RuleEngine: config.RuleEngine, + Auditor: config.Auditor, Logger: config.Logger, TLSConfig: tlsConfig, }) - // Create commander + // Create namespace commander, err := newNamespaceCommander(namespace.Config{ - Logger: config.Logger, - HttpProxyPort: 8080, - HttpsProxyPort: 8443, - Env: map[string]string{ - // Set standard CA certificate environment variables for common tools - // This makes tools like curl, git, etc. trust our dynamically generated CA - "SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools - "SSL_CERT_DIR": configDir, // OpenSSL certificate directory - "CURL_CA_BUNDLE": caCertPath, // curl - "GIT_SSL_CAINFO": caCertPath, // Git - "REQUESTS_CA_BUNDLE": caCertPath, // Python requests - "NODE_EXTRA_CA_CERTS": caCertPath, // Node.js - }, - }) + Logger: config.Logger, + HttpProxyPort: 8080, + TlsConfigDir: configDir, + CACertPath: caCertPath, + UserInfo: config.UserInfo, + }, config.Unprivileged) if err != nil { - return nil, fmt.Errorf("failed to create commander: %v", err) + return nil, fmt.Errorf("failed to create namespace commander: %v", err) } // Create cancellable context for jail @@ -118,7 +111,11 @@ func (j *Jail) Close() error { } // newNamespaceCommander creates a new namespace instance for the current platform -func newNamespaceCommander(config namespace.Config) (namespace.Commander, error) { +func newNamespaceCommander(config namespace.Config, unprivledged bool) (namespace.Commander, error) { + if unprivledged { + return namespace.NewUnprivileged(config) + } + switch runtime.GOOS { case "darwin": return namespace.NewMacOS(config) diff --git a/namespace/linux.go b/namespace/linux.go index 173d536..3e9de0e 100644 --- a/namespace/linux.go +++ b/namespace/linux.go @@ -7,44 +7,33 @@ import ( "log/slog" "os" "os/exec" - "strings" - "syscall" "time" ) // Linux implements jail.Commander using Linux network namespaces type Linux struct { - namespace string - vethHost string // Host-side veth interface name for iptables rules - logger *slog.Logger - preparedEnv map[string]string - procAttr *syscall.SysProcAttr - httpProxyPort int - httpsProxyPort int - user string - homeDir string - uid int - gid int + logger *slog.Logger + namespace string + vethHost string // Host-side veth interface name for iptables rules + commandEnv []string + httpProxyPort int + tlsConfigDir string + caCertPath string + userInfo UserInfo } -// NewLinux creates a new Linux network jail instance func NewLinux(config Config) (*Linux, error) { - // Initialize preparedEnv with config environment variables - preparedEnv := make(map[string]string) - for key, value := range config.Env { - preparedEnv[key] = value - } - return &Linux{ - namespace: newNamespaceName(), - logger: config.Logger, - preparedEnv: preparedEnv, - httpProxyPort: config.HttpProxyPort, - httpsProxyPort: config.HttpsProxyPort, + logger: config.Logger, + namespace: newNamespaceName(), + httpProxyPort: config.HttpProxyPort, + tlsConfigDir: config.TlsConfigDir, + caCertPath: config.CACertPath, + userInfo: config.UserInfo, }, nil } -// Setup creates network namespace and configures iptables rules +// Start creates network namespace and configures iptables rules func (l *Linux) Start() error { l.logger.Debug("Setup called") @@ -75,30 +64,12 @@ func (l *Linux) Start() error { // Prepare environment once during setup l.logger.Debug("Preparing environment") - - // Start with current environment - for _, envVar := range os.Environ() { - if parts := strings.SplitN(envVar, "=", 2); len(parts) == 2 { - // Only set if not already set by config - if _, exists := l.preparedEnv[parts[0]]; !exists { - l.preparedEnv[parts[0]] = parts[1] - } - } - } - - // Set HOME to original user's home directory - l.preparedEnv["HOME"] = l.homeDir - // Set USER to original username - l.preparedEnv["USER"] = l.user - // Set LOGNAME to original username (some tools check this instead of USER) - l.preparedEnv["LOGNAME"] = l.user - - l.procAttr = &syscall.SysProcAttr{ - Credential: &syscall.Credential{ - Uid: uint32(l.uid), - Gid: uint32(l.gid), - }, - } + e := getEnvs(l.tlsConfigDir, l.caCertPath) + l.commandEnv = mergeEnvs(e, map[string]string{ + "HOME": l.userInfo.HomeDir, + "USER": l.userInfo.Username, + "LOGNAME": l.userInfo.Username, + }) l.logger.Debug("Setup completed successfully") return nil @@ -116,23 +87,15 @@ func (l *Linux) Command(command []string) *exec.Cmd { cmd := exec.Command("ip", cmdArgs[1:]...) - // Use prepared environment from Open method - env := make([]string, 0, len(l.preparedEnv)) - for key, value := range l.preparedEnv { - env = append(env, fmt.Sprintf("%s=%s", key, value)) - } - cmd.Env = env + cmd.Env = l.commandEnv cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Use prepared process attributes from Open method - cmd.SysProcAttr = l.procAttr - return cmd } -// Cleanup removes the network namespace and iptables rules +// Close removes the network namespace and iptables rules func (l *Linux) Close() error { // Remove iptables rules err := l.removeIptables() @@ -246,23 +209,22 @@ func (l *Linux) setupIptables() error { return fmt.Errorf("failed to add NAT rule: %v", err) } - // COMPREHENSIVE APPROACH: Intercept ALL TCP traffic from namespace - // Use PREROUTING on host to catch traffic after it exits namespace but before routing - // This ensures NO TCP traffic can bypass the proxy - cmd = exec.Command("iptables", "-t", "nat", "-A", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpsProxyPort)) + // COMPREHENSIVE APPROACH: Route ALL TCP traffic to HTTP proxy + // The HTTP proxy will intelligently handle both HTTP and TLS traffic + cmd = exec.Command("iptables", "-t", "nat", "-A", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpProxyPort)) err = cmd.Run() if err != nil { return fmt.Errorf("failed to add comprehensive TCP redirect rule: %v", err) } - l.logger.Debug("Comprehensive TCP jailing enabled", "interface", l.vethHost, "proxy_port", l.httpsProxyPort) + l.logger.Debug("Comprehensive TCP jailing enabled", "interface", l.vethHost, "proxy_port", l.httpProxyPort) return nil } // removeIptables removes iptables rules func (l *Linux) removeIptables() error { // Remove comprehensive TCP redirect rule - cmd := exec.Command("iptables", "-t", "nat", "-D", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpsProxyPort)) + cmd := exec.Command("iptables", "-t", "nat", "-D", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpProxyPort)) cmd.Run() // Ignore errors during cleanup // Remove NAT rule @@ -280,4 +242,4 @@ func (l *Linux) removeNamespace() error { return fmt.Errorf("failed to remove namespace: %v", err) } return nil -} \ No newline at end of file +} diff --git a/namespace/macos.go b/namespace/macos.go index 30d9fd8..1e21e01 100644 --- a/namespace/macos.go +++ b/namespace/macos.go @@ -19,15 +19,16 @@ const ( // MacOSNetJail implements network jail using macOS PF (Packet Filter) and group-based isolation type MacOSNetJail struct { - restrictedGid int - pfRulesPath string - mainRulesPath string - logger *slog.Logger - preparedEnv map[string]string - procAttr *syscall.SysProcAttr - httpProxyPort int - httpsProxyPort int - userInfo UserInfo + restrictedGid int + pfRulesPath string + mainRulesPath string + logger *slog.Logger + commandEnv []string + procAttr *syscall.SysProcAttr + httpProxyPort int + tlsConfigDir string + caCertPath string + userInfo UserInfo } // NewMacOS creates a new macOS network jail instance @@ -36,20 +37,14 @@ func NewMacOS(config Config) (*MacOSNetJail, error) { pfRulesPath := fmt.Sprintf("/tmp/%s.pf", ns) mainRulesPath := fmt.Sprintf("/tmp/%s_main.pf", ns) - // Initialize preparedEnv with config environment variables - preparedEnv := make(map[string]string) - for key, value := range config.Env { - preparedEnv[key] = value - } - return &MacOSNetJail{ - pfRulesPath: pfRulesPath, - mainRulesPath: mainRulesPath, - logger: config.Logger, - preparedEnv: preparedEnv, - httpProxyPort: config.HttpProxyPort, - httpsProxyPort: config.HttpsProxyPort, - userInfo: config.UserInfo, + pfRulesPath: pfRulesPath, + mainRulesPath: mainRulesPath, + logger: config.Logger, + httpProxyPort: config.HttpProxyPort, + tlsConfigDir: config.TlsConfigDir, + caCertPath: config.CACertPath, + userInfo: config.UserInfo, }, nil } @@ -74,22 +69,12 @@ func (m *MacOSNetJail) Start() error { // Prepare environment once during setup m.logger.Debug("Preparing environment") - // Start with current environment - for _, envVar := range os.Environ() { - if parts := strings.SplitN(envVar, "=", 2); len(parts) == 2 { - // Only set if not already set by config - if _, exists := m.preparedEnv[parts[0]]; !exists { - m.preparedEnv[parts[0]] = parts[1] - } - } - } - - // Set HOME to original user's home directory - m.preparedEnv["HOME"] = m.userInfo.HomeDir - // Set USER to original username - m.preparedEnv["USER"] = m.userInfo.Username - // Set LOGNAME to original username (some tools check this instead of USER) - m.preparedEnv["LOGNAME"] = m.userInfo.Username + e := getEnvs(m.tlsConfigDir, m.caCertPath) + m.commandEnv = mergeEnvs(e, map[string]string{ + "HOME": m.userInfo.HomeDir, + "USER": m.userInfo.Username, + "LOGNAME": m.userInfo.Username, + }) // Prepare process credentials once during setup m.logger.Debug("Preparing process credentials") @@ -117,12 +102,7 @@ func (m *MacOSNetJail) Command(command []string) *exec.Cmd { cmd := exec.Command(command[0], command[1:]...) m.logger.Debug("Full command args", "args", command) - // Use prepared environment from Open method - env := make([]string, 0, len(m.preparedEnv)) - for key, value := range m.preparedEnv { - env = append(env, fmt.Sprintf("%s=%s", key, value)) - } - cmd.Env = env + cmd.Env = m.commandEnv cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Stdin = os.Stdin @@ -240,8 +220,8 @@ func (m *MacOSNetJail) createPFRules() (string, error) { # COMPREHENSIVE APPROACH: Intercept ALL TCP traffic from the jailed group # This ensures NO TCP traffic can bypass the proxy by using alternative ports -# First, redirect ALL TCP traffic arriving on lo0 to our HTTPS proxy port -# The HTTPS proxy can handle both HTTP and HTTPS traffic +# First, redirect ALL TCP traffic arriving on lo0 to our HTTP proxy with TLS termination +# The HTTP proxy with TLS termination can handle both HTTP and HTTPS traffic rdr pass on lo0 inet proto tcp from any to any -> 127.0.0.1 port %d # Route ALL TCP traffic from boundary group to lo0 where it will be redirected @@ -255,13 +235,13 @@ pass on lo0 all `, m.restrictedGid, iface, - m.httpsProxyPort, // Use HTTPS proxy port for all TCP traffic + m.httpProxyPort, // Use HTTP proxy with TLS termination for all TCP traffic m.restrictedGid, iface, m.restrictedGid, ) - m.logger.Debug("Comprehensive TCP jailing enabled for macOS", "group_id", m.restrictedGid, "proxy_port", m.httpsProxyPort) + m.logger.Debug("Comprehensive TCP jailing enabled for macOS", "group_id", m.restrictedGid, "proxy_port", m.httpProxyPort) return rules, nil } @@ -283,6 +263,7 @@ func (m *MacOSNetJail) setupPFRules() error { cmd := exec.Command("pfctl", "-a", pfAnchorName, "-f", m.pfRulesPath) err = cmd.Run() if err != nil { + m.logger.Error("Failed to load PF rules", "error", err, "rules_file", m.pfRulesPath) return fmt.Errorf("failed to load PF rules: %v", err) } diff --git a/namespace/name.go b/namespace/name.go index 235ab30..8b7c997 100644 --- a/namespace/name.go +++ b/namespace/name.go @@ -2,6 +2,8 @@ package namespace import ( "fmt" + "os" + "strings" "time" ) @@ -12,3 +14,41 @@ const ( func newNamespaceName() string { return fmt.Sprintf("%s_%d", prefix, time.Now().UnixNano()%10000000) } + +func getEnvs(configDir string, caCertPath string) []string { + e := os.Environ() + + e = mergeEnvs(e, map[string]string{ + // Set standard CA certificate environment variables for common tools + // This makes tools like curl, git, etc. trust our dynamically generated CA + "SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools + "SSL_CERT_DIR": configDir, // OpenSSL certificate directory + "CURL_CA_BUNDLE": caCertPath, // curl + "GIT_SSL_CAINFO": caCertPath, // Git + "REQUESTS_CA_BUNDLE": caCertPath, // Python requests + "NODE_EXTRA_CA_CERTS": caCertPath, // Node.js + }) + + return e +} + +func mergeEnvs(base []string, extra map[string]string) []string { + envMap := make(map[string]string) + for _, env := range base { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + for key, value := range extra { + envMap[key] = value + } + + merged := make([]string, 0, len(envMap)) + for key, value := range envMap { + merged = append(merged, key+"="+value) + } + + return merged +} diff --git a/namespace/namespace.go b/namespace/namespace.go index 5d9a33a..4b7dfa9 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -12,11 +12,11 @@ type Commander interface { } type Config struct { - Logger *slog.Logger - HttpProxyPort int - HttpsProxyPort int - Env map[string]string - UserInfo UserInfo + Logger *slog.Logger + HttpProxyPort int + TlsConfigDir string + CACertPath string + UserInfo UserInfo } type UserInfo struct { diff --git a/namespace/unprivileged.go b/namespace/unprivileged.go new file mode 100644 index 0000000..083daf7 --- /dev/null +++ b/namespace/unprivileged.go @@ -0,0 +1,52 @@ +package namespace + +import ( + "fmt" + "log/slog" + "os" + "os/exec" +) + +type Unprivileged struct { + logger *slog.Logger + commandEnv []string + httpProxyPort int + tlsConfigDir string + caCertPath string + userInfo UserInfo +} + +func NewUnprivileged(config Config) (*Unprivileged, error) { + return &Unprivileged{ + logger: config.Logger, + httpProxyPort: config.HttpProxyPort, + tlsConfigDir: config.TlsConfigDir, + caCertPath: config.CACertPath, + userInfo: config.UserInfo, + }, nil +} + +func (u *Unprivileged) Start() error { + u.logger.Debug("Starting in unprivileged mode") + e := getEnvs(u.tlsConfigDir, u.caCertPath) + u.commandEnv = mergeEnvs(e, map[string]string{ + "HTTP_PROXY": fmt.Sprintf("http://127.0.0.1:%d", u.httpProxyPort), + "HTTPS_PROXY": fmt.Sprintf("http://127.0.0.1:%d", u.httpProxyPort), + "http_proxy": fmt.Sprintf("http://127.0.0.1:%d", u.httpProxyPort), + "https_proxy": fmt.Sprintf("http://127.0.0.1:%d", u.httpProxyPort), + }) + return nil +} + +func (u *Unprivileged) Command(command []string) *exec.Cmd { + cmd := exec.Command(command[0], command[1:]...) + cmd.Env = u.commandEnv + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + return cmd +} + +func (u *Unprivileged) Close() error { + return nil +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 944b8aa..dce7219 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,13 +1,18 @@ package proxy import ( + "bufio" "context" "crypto/tls" "fmt" "io" "log/slog" + "net" "net/http" + "net/http/httptest" "net/url" + "strings" + "sync" "time" "github.com/coder/jail/audit" @@ -21,16 +26,13 @@ type Server struct { logger *slog.Logger tlsConfig *tls.Config httpPort int - httpsPort int - httpServer *http.Server - httpsServer *http.Server + httpServer *http.Server } // Config holds configuration for the proxy server type Config struct { HTTPPort int - HTTPSPort int RuleEngine rules.Evaluator Auditor audit.Auditor Logger *slog.Logger @@ -45,40 +47,41 @@ func NewProxyServer(config Config) *Server { logger: config.Logger, tlsConfig: config.TLSConfig, httpPort: config.HTTPPort, - httpsPort: config.HTTPSPort, } } -// Start starts both HTTP and HTTPS proxy servers +// Start starts the HTTP proxy server with TLS termination capability func (p *Server) Start(ctx context.Context) error { - // Create HTTP server + // Create HTTP server with TLS termination capability p.httpServer = &http.Server{ Addr: fmt.Sprintf(":%d", p.httpPort), - Handler: http.HandlerFunc(p.handleHTTP), + Handler: http.HandlerFunc(p.handleHTTPWithTLSTermination), } - // Create HTTPS server - p.httpsServer = &http.Server{ - Addr: fmt.Sprintf(":%d", p.httpsPort), - Handler: http.HandlerFunc(p.handleHTTPS), - TLSConfig: p.tlsConfig, - } - - // Start HTTP server + // Start HTTP server with custom listener for TLS detection go func() { - p.logger.Info("Starting HTTP proxy", "port", p.httpPort) - err := p.httpServer.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - p.logger.Error("HTTP proxy server error", "error", err) + p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort) + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort)) + if err != nil { + p.logger.Error("Failed to create HTTP listener", "error", err) + return } - }() - // Start HTTPS server - go func() { - p.logger.Info("Starting HTTPS proxy", "port", p.httpsPort) - err := p.httpsServer.ListenAndServeTLS("", "") - if err != nil && err != http.ErrServerClosed { - p.logger.Error("HTTPS proxy server error", "error", err) + for { + conn, err := listener.Accept() + if err != nil { + select { + case <-ctx.Done(): + listener.Close() + return + default: + p.logger.Error("Failed to accept connection", "error", err) + continue + } + } + + // Handle connection with TLS detection + go p.handleConnectionWithTLSDetection(conn) } }() @@ -87,49 +90,32 @@ func (p *Server) Start(ctx context.Context) error { return p.Stop() } -// Stop stops both proxy servers +// Stops proxy server func (p *Server) Stop() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - var httpErr, httpsErr error + var httpErr error if p.httpServer != nil { httpErr = p.httpServer.Shutdown(ctx) } - if p.httpsServer != nil { - httpsErr = p.httpsServer.Shutdown(ctx) - } if httpErr != nil { return httpErr } - return httpsErr + return nil } -// handleHTTP handles regular HTTP requests +// handleHTTP handles regular HTTP requests and CONNECT tunneling func (p *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { - // Check if request should be allowed - result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) + p.logger.Debug("handleHTTP called", "method", r.Method, "url", r.URL.String(), "host", r.Host) - // Audit the request - p.auditor.AuditRequest(audit.Request{ - Method: r.Method, - URL: r.URL.String(), - Allowed: result.Allowed, - Rule: result.Rule, - }) - - if !result.Allowed { - p.writeBlockedResponse(w, r) + // Handle CONNECT method for HTTPS tunneling + if r.Method == "CONNECT" { + p.handleConnect(w, r) return } - // Forward regular HTTP request - p.forwardHTTPRequest(w, r) -} - -// handleHTTPS handles HTTPS requests (after TLS termination) -func (p *Server) handleHTTPS(w http.ResponseWriter, r *http.Request) { // Check if request should be allowed result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) @@ -146,22 +132,25 @@ func (p *Server) handleHTTPS(w http.ResponseWriter, r *http.Request) { return } - // Forward HTTPS request - p.forwardHTTPSRequest(w, r) + // Forward regular HTTP request + p.forwardHTTPRequest(w, r) } // forwardHTTPRequest forwards a regular HTTP request func (p *Server) forwardHTTPRequest(w http.ResponseWriter, r *http.Request) { + p.logger.Debug("forwardHTTPRequest called", "method", r.Method, "url", r.URL.String(), "host", r.Host) + // Create a new request to the target server - targetURL := r.URL - if targetURL.Scheme == "" { - targetURL.Scheme = "http" - } - if targetURL.Host == "" { - targetURL.Host = r.Host + targetURL := &url.URL{ + Scheme: "http", + Host: r.Host, + Path: r.URL.Path, + RawQuery: r.URL.RawQuery, } - // Create HTTP client + p.logger.Debug("Target URL constructed", "target", targetURL.String()) + + // Create HTTP client with very short timeout for debugging client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse // Don't follow redirects @@ -171,27 +160,38 @@ func (p *Server) forwardHTTPRequest(w http.ResponseWriter, r *http.Request) { // Create new request req, err := http.NewRequest(r.Method, targetURL.String(), r.Body) if err != nil { + p.logger.Error("Failed to create forward request", "error", err) http.Error(w, fmt.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) return } // Copy headers for name, values := range r.Header { + // Skip connection-specific headers + if strings.ToLower(name) == "connection" || strings.ToLower(name) == "proxy-connection" { + continue + } for _, value := range values { req.Header.Add(name, value) } } - // Make the request + p.logger.Debug("About to make HTTP request", "target", targetURL.String()) resp, err := client.Do(req) if err != nil { + p.logger.Error("Failed to make forward request", "error", err, "target", targetURL.String(), "error_type", fmt.Sprintf("%T", err)) http.Error(w, fmt.Sprintf("Failed to make request: %v", err), http.StatusBadGateway) return } defer resp.Body.Close() - // Copy response headers + p.logger.Debug("Received response", "status", resp.StatusCode, "target", targetURL.String()) + + // Copy response headers (except connection-specific ones) for name, values := range resp.Header { + if strings.ToLower(name) == "connection" || strings.ToLower(name) == "transfer-encoding" { + continue + } for _, value := range values { w.Header().Add(name, value) } @@ -201,89 +201,296 @@ func (p *Server) forwardHTTPRequest(w http.ResponseWriter, r *http.Request) { w.WriteHeader(resp.StatusCode) // Copy response body - io.Copy(w, resp.Body) + bytesWritten, copyErr := io.Copy(w, resp.Body) + if copyErr != nil { + p.logger.Error("Error copying response body", "error", copyErr, "bytes_written", bytesWritten) + http.Error(w, "Failed to copy response", http.StatusBadGateway) + } else { + p.logger.Debug("Successfully forwarded HTTP response", "bytes_written", bytesWritten, "status", resp.StatusCode) + } + + // Ensure response is flushed + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + p.logger.Debug("forwardHTTPRequest completed") } -// forwardHTTPSRequest forwards an HTTPS request -func (p *Server) forwardHTTPSRequest(w http.ResponseWriter, r *http.Request) { - // Create target URL - targetURL := &url.URL{ - Scheme: "https", - Host: r.Host, - Path: r.URL.Path, - RawQuery: r.URL.RawQuery, +// writeBlockedResponse writes a blocked response +func (p *Server) writeBlockedResponse(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusForbidden) + + // Extract host from URL for cleaner display + host := r.URL.Host + if host == "" { + host = r.Host } - // Create HTTPS client - client := &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: false, - }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, + fmt.Fprintf(w, `🚫 Request Blocked by Coder Jail + +Request: %s %s +Host: %s + +To allow this request, restart jail with: + --allow "%s" # Allow all methods to this host + --allow "%s %s" # Allow only %s requests to this host + +For more help: https://github.com/coder/jail +`, + r.Method, r.URL.Path, host, host, r.Method, host, r.Method) +} + +// handleConnect handles CONNECT requests for HTTPS tunneling with TLS termination +func (p *Server) handleConnect(w http.ResponseWriter, r *http.Request) { + // Extract hostname from the CONNECT request + hostname := r.URL.Hostname() + if hostname == "" { + // Fallback to Host header parsing + host := r.URL.Host + if host == "" { + host = r.Host + } + if h, _, err := net.SplitHostPort(host); err == nil { + hostname = h + } else { + hostname = host + } } - // Create new request - req, err := http.NewRequest(r.Method, targetURL.String(), r.Body) + if hostname == "" { + http.Error(w, "Invalid CONNECT request: no hostname", http.StatusBadRequest) + return + } + + // Allow all CONNECT requests - we'll evaluate rules on the decrypted HTTPS content + p.logger.Debug("Establishing CONNECT tunnel with TLS termination", "hostname", hostname) + + // Hijack the connection to handle TLS manually + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "Hijacking not supported", http.StatusInternalServerError) + return + } + + // Hijack the underlying connection + conn, _, err := hijacker.Hijack() if err != nil { - http.Error(w, fmt.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) + p.logger.Error("Failed to hijack connection", "error", err) return } + defer conn.Close() - // Copy headers - for name, values := range r.Header { - for _, value := range values { - req.Header.Add(name, value) + // Send 200 Connection established response manually + _, err = conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) + if err != nil { + p.logger.Error("Failed to send CONNECT response", "error", err) + return + } + + // Perform TLS handshake with the client using our certificates + p.logger.Debug("Starting TLS handshake", "hostname", hostname) + + // Create TLS config that forces HTTP/1.1 (disable HTTP/2 ALPN) + tlsConfig := p.tlsConfig.Clone() + tlsConfig.NextProtos = []string{"http/1.1"} + + tlsConn := tls.Server(conn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + p.logger.Error("TLS handshake failed", "hostname", hostname, "error", err) + return + } + p.logger.Debug("TLS handshake successful", "hostname", hostname) + + // Now we have a TLS connection - handle HTTPS requests + p.logger.Debug("Starting HTTPS request handling", "hostname", hostname) + p.handleTLSConnection(tlsConn, hostname) + p.logger.Debug("HTTPS request handling completed", "hostname", hostname) +} + +// handleTLSConnection processes decrypted HTTPS requests over the TLS connection +func (p *Server) handleTLSConnection(tlsConn *tls.Conn, hostname string) { + p.logger.Debug("Creating HTTP server for TLS connection", "hostname", hostname) + + // Use ReadRequest to manually read HTTP requests from the TLS connection + bufReader := bufio.NewReader(tlsConn) + for { + // Read HTTP request from TLS connection + req, err := http.ReadRequest(bufReader) + if err != nil { + if err == io.EOF { + p.logger.Debug("TLS connection closed by client", "hostname", hostname) + } else { + p.logger.Debug("Failed to read HTTP request", "hostname", hostname, "error", err) + } + break + } + + p.logger.Debug("Processing decrypted HTTPS request", "hostname", hostname, "method", req.Method, "path", req.URL.Path) + + // Set the hostname and scheme if not already set + if req.URL.Host == "" { + req.URL.Host = hostname + } + if req.URL.Scheme == "" { + req.URL.Scheme = "https" + } + + // Create a response recorder to capture the response + recorder := httptest.NewRecorder() + + // Process the HTTPS request + p.handleDecryptedHTTPS(recorder, req) + + // Write the response back to the TLS connection + resp := recorder.Result() + err = resp.Write(tlsConn) + if err != nil { + p.logger.Debug("Failed to write response", "hostname", hostname, "error", err) + break } } - // Make the request - resp, err := client.Do(req) + p.logger.Debug("TLS connection handling completed", "hostname", hostname) +} + +// handleDecryptedHTTPS handles decrypted HTTPS requests and applies rules +func (p *Server) handleDecryptedHTTPS(w http.ResponseWriter, r *http.Request) { + // Check if request should be allowed + result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) + + // Audit the request + p.auditor.AuditRequest(audit.Request{ + Method: r.Method, + URL: r.URL.String(), + Allowed: result.Allowed, + Rule: result.Rule, + }) + + if !result.Allowed { + p.writeBlockedResponse(w, r) + return + } + + // Forward the HTTPS request (now handled same as HTTP after TLS termination) + p.forwardHTTPRequest(w, r) +} + +// handleConnectionWithTLSDetection detects TLS vs HTTP and handles appropriately +func (p *Server) handleConnectionWithTLSDetection(conn net.Conn) { + defer conn.Close() + + // Peek at first byte to detect protocol + buf := make([]byte, 1) + _, err := conn.Read(buf) if err != nil { - http.Error(w, fmt.Sprintf("Failed to make request: %v", err), http.StatusBadGateway) + p.logger.Debug("Failed to read first byte from connection", "error", err) return } - defer resp.Body.Close() - // Copy response headers - for name, values := range resp.Header { - for _, value := range values { - w.Header().Add(name, value) + // Create connection wrapper that can "unread" the peeked byte + connWrapper := &connectionWrapper{conn, buf, false} + + // TLS handshake starts with 0x16 (TLS Content Type: Handshake) + if buf[0] == 0x16 { + p.logger.Debug("Detected TLS handshake, performing TLS termination") + // Perform TLS handshake + tlsConn := tls.Server(connWrapper, p.tlsConfig) + err := tlsConn.Handshake() + if err != nil { + p.logger.Debug("TLS handshake failed", "error", err) + return } + p.logger.Debug("TLS handshake successful") + // Use HTTP server with TLS connection + listener := newSingleConnectionListener(tlsConn) + defer listener.Close() + err = http.Serve(listener, http.HandlerFunc(p.handleDecryptedHTTPS)) + p.logger.Debug("http.Serve completed for HTTPS", "error", err) + } else { + p.logger.Debug("Detected HTTP request, handling normally") + // Use HTTP server with regular connection + p.logger.Debug("About to call http.Serve for HTTP connection") + listener := newSingleConnectionListener(connWrapper) + defer listener.Close() + err = http.Serve(listener, http.HandlerFunc(p.handleHTTP)) + p.logger.Debug("http.Serve completed", "error", err) } +} - // Copy status code - w.WriteHeader(resp.StatusCode) +// handleHTTPWithTLSTermination is the main handler (currently just delegates to regular HTTP) +func (p *Server) handleHTTPWithTLSTermination(w http.ResponseWriter, r *http.Request) { + // This handler is not used when we do custom connection handling + // All traffic goes through handleConnectionWithTLSDetection + p.handleHTTP(w, r) +} - // Copy response body - io.Copy(w, resp.Body) +// connectionWrapper lets us "unread" the peeked byte +type connectionWrapper struct { + net.Conn + buf []byte + bufUsed bool } -// writeBlockedResponse writes a blocked response -func (p *Server) writeBlockedResponse(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(http.StatusForbidden) +func (c *connectionWrapper) Read(p []byte) (int, error) { + if !c.bufUsed && len(c.buf) > 0 { + n := copy(p, c.buf) + c.bufUsed = true + return n, nil + } + return c.Conn.Read(p) +} - // Extract host from URL for cleaner display - host := r.URL.Host - if host == "" { - host = r.Host +// singleConnectionListener wraps a single connection into a net.Listener +type singleConnectionListener struct { + conn net.Conn + used bool + closed chan struct{} + mu sync.Mutex +} + +func newSingleConnectionListener(conn net.Conn) *singleConnectionListener { + return &singleConnectionListener{ + conn: conn, + closed: make(chan struct{}), } +} - fmt.Fprintf(w, `🚫 Request Blocked by Coder Jail +func (sl *singleConnectionListener) Accept() (net.Conn, error) { + sl.mu.Lock() + defer sl.mu.Unlock() -Request: %s %s -Host: %s + if sl.used || sl.conn == nil { + // Wait for close signal + <-sl.closed + return nil, io.EOF + } + sl.used = true + return sl.conn, nil +} -To allow this request, restart jail with: - --allow "%s" # Allow all methods to this host - --allow "%s %s" # Allow only %s requests to this host +func (sl *singleConnectionListener) Close() error { + sl.mu.Lock() + defer sl.mu.Unlock() -For more help: https://github.com/coder/jail -`, - r.Method, r.URL.Path, host, host, r.Method, host, r.Method) + select { + case <-sl.closed: + // Already closed + default: + close(sl.closed) + } + + if sl.conn != nil { + sl.conn.Close() + sl.conn = nil + } + return nil +} + +func (sl *singleConnectionListener) Addr() net.Addr { + if sl.conn == nil { + return nil + } + return sl.conn.LocalAddr() } diff --git a/tls/tls.go b/tls/tls.go index e36e6de..f40a9c3 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -335,4 +335,4 @@ func (cm *CertificateManager) generateServerCertificate(hostname string) (*tls.C cm.logger.Debug("Generated certificate", "hostname", hostname) return tlsCert, nil -} \ No newline at end of file +}