diff --git a/boundary.go b/boundary.go index 9599820..d3e98a6 100644 --- a/boundary.go +++ b/boundary.go @@ -62,12 +62,11 @@ func (b *Boundary) Start() error { } // Start proxy server in background - go func() { - err := b.proxyServer.Start(b.ctx) - if err != nil { - b.logger.Error("Proxy server error", "error", err) - } - }() + err = b.proxyServer.Start() + if err != nil { + b.logger.Error("Proxy server error", "error", err) + return err + } // Give proxy time to start time.Sleep(100 * time.Millisecond) @@ -90,4 +89,4 @@ func (b *Boundary) Close() error { // Close jailer return b.jailer.Close() -} \ No newline at end of file +} diff --git a/go.mod b/go.mod index 638ddb7..4d634e8 100644 --- a/go.mod +++ b/go.mod @@ -2,12 +2,16 @@ module github.com/coder/boundary go 1.24 -require github.com/coder/serpent v0.10.0 +require ( + github.com/coder/serpent v0.10.0 + github.com/stretchr/testify v1.8.4 +) require ( cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/kr/text v0.2.0 // indirect @@ -18,6 +22,7 @@ require ( github.com/muesli/termenv v0.15.2 // indirect github.com/pion/transport/v2 v2.0.0 // indirect github.com/pion/udp v0.1.4 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/spf13/pflag v1.0.5 // indirect go.opentelemetry.io/otel v1.19.0 // indirect diff --git a/proxy/proxy.go b/proxy/proxy.go index 6328a38..e2aa537 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -2,8 +2,8 @@ package proxy import ( "bufio" - "context" "crypto/tls" + "errors" "fmt" "io" "log/slog" @@ -12,7 +12,7 @@ import ( "net/url" "strings" "sync" - "time" + "sync/atomic" "github.com/coder/boundary/audit" "github.com/coder/boundary/rules" @@ -25,8 +25,9 @@ type Server struct { logger *slog.Logger tlsConfig *tls.Config httpPort int + started atomic.Bool - httpServer *http.Server + listener net.Listener } // Config holds configuration for the proxy server @@ -50,36 +51,31 @@ func NewProxyServer(config Config) *Server { } // Start starts the HTTP proxy server with TLS termination capability -func (p *Server) Start(ctx context.Context) error { - // Create HTTP server with TLS termination capability - p.httpServer = &http.Server{ - Addr: fmt.Sprintf(":%d", p.httpPort), - Handler: http.HandlerFunc(p.handleHTTPWithTLSTermination), +func (p *Server) Start() error { + if p.isStarted() { + return nil + } + + p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort) + var err error + p.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort)) + if err != nil { + p.logger.Error("Failed to create HTTP listener", "error", err) + return err } + p.started.Store(true) + // Start HTTP server with custom listener for TLS detection go func() { - p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort) - listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort)) - if err != nil { - p.logger.Error("Failed to create HTTP listener", "error", err) - return - } - for { - conn, err := listener.Accept() + conn, err := p.listener.Accept() + if err != nil && errors.Is(err, net.ErrClosed) && p.isStopped() { + return + } if err != nil { - select { - case <-ctx.Done(): - err = listener.Close() - if err != nil { - p.logger.Error("Failed to close listener", "error", err) - } - return - default: - p.logger.Error("Failed to accept connection", "error", err) - continue - } + p.logger.Error("Failed to accept connection", "error", err) + continue } // Handle connection with TLS detection @@ -87,27 +83,38 @@ func (p *Server) Start(ctx context.Context) error { } }() - // Wait for context cancellation - <-ctx.Done() - return p.Stop() + return nil } // Stops proxy server func (p *Server) Stop() error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + if p.isStopped() { + return nil + } + p.started.Store(false) - var httpErr error - if p.httpServer != nil { - httpErr = p.httpServer.Shutdown(ctx) + if p.listener == nil { + p.logger.Error("unexpected nil listener") + return errors.New("unexpected nil listener") } - if httpErr != nil { - return httpErr + err := p.listener.Close() + if err != nil { + p.logger.Error("Failed to close listener", "error", err) + return err } + return nil } +func (p *Server) isStarted() bool { + return p.started.Load() +} + +func (p *Server) isStopped() bool { + return !p.started.Load() +} + // handleHTTP handles regular HTTP requests and CONNECT tunneling func (p *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { p.logger.Debug("handleHTTP called", "method", r.Method, "url", r.URL.String(), "host", r.Host) @@ -479,13 +486,6 @@ func (p *Server) handleConnectionWithTLSDetection(conn net.Conn) { } } -// handleHTTPWithTLSTermination is the main handler (currently just delegates to regular HTTP) -func (p *Server) handleHTTPWithTLSTermination(w http.ResponseWriter, r *http.Request) { - // This handler is not used when we do custom connection handling - // All traffic goes through handleConnectionWithTLSDetection - p.handleHTTP(w, r) -} - // connectionWrapper lets us "unread" the peeked byte type connectionWrapper struct { net.Conn diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index dae1ad6..a70c9e7 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,9 +1,202 @@ package proxy -import "testing" +import ( + "crypto/tls" + "io" + "log" + "log/slog" + "net/http" + "os" + "os/user" + "strconv" + "testing" + "time" -// Stub test file - tests removed -func TestStub(t *testing.T) { - // This is a stub test - t.Skip("stub test file") + boundary_tls "github.com/coder/boundary/tls" + "github.com/stretchr/testify/require" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/rules" +) + +// mockAuditor is a simple mock auditor for testing +type mockAuditor struct{} + +func (m *mockAuditor) AuditRequest(req audit.Request) { + // No-op for testing +} + +// TestProxyServerBasicHTTP tests basic HTTP request handling +func TestProxyServerBasicHTTP(t *testing.T) { + // Create test logger + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelError, + })) + + // Create test rules (allow all for testing) + testRules, err := rules.ParseAllowSpecs([]string{"*"}) + if err != nil { + t.Fatalf("Failed to parse test rules: %v", err) + } + + // Create rule engine + ruleEngine := rules.NewRuleEngine(testRules, logger) + + // Create mock auditor + auditor := &mockAuditor{} + + // Create TLS config (minimal for testing) + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + // Create proxy server + server := NewProxyServer(Config{ + HTTPPort: 8080, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + }) + + // Start server + err = server.Start() + require.NoError(t, err) + + // Give server time to start + time.Sleep(100 * time.Millisecond) + + // Test basic HTTP request + t.Run("BasicHTTPRequest", func(t *testing.T) { + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, // Skip cert verification for testing + }, + }, + Timeout: 5 * time.Second, + } + + // Make request to proxy + req, err := http.NewRequest("GET", "http://localhost:8080/todos/1", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + // Override the Host header + req.Host = "jsonplaceholder.typicode.com" + + // Make the request + resp, err := client.Do(req) + require.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + expectedResponse := `{ + "userId": 1, + "id": 1, + "title": "delectus aut autem", + "completed": false +}` + require.Equal(t, expectedResponse, string(body)) + }) + + err = server.Stop() + require.NoError(t, err) +} + +// TestProxyServerBasicHTTPS tests basic HTTPS request handling +func TestProxyServerBasicHTTPS(t *testing.T) { + // Create test logger + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelError, + })) + + // Create test rules (allow all for testing) + testRules, err := rules.ParseAllowSpecs([]string{"*"}) + if err != nil { + t.Fatalf("Failed to parse test rules: %v", err) + } + + // Create rule engine + ruleEngine := rules.NewRuleEngine(testRules, logger) + + // Create mock auditor + auditor := &mockAuditor{} + + currentUser, err := user.Current() + if err != nil { + log.Fatal(err) + } + + uid, _ := strconv.Atoi(currentUser.Uid) + gid, _ := strconv.Atoi(currentUser.Gid) + + // Create TLS certificate manager + certManager, err := boundary_tls.NewCertificateManager(boundary_tls.Config{ + Logger: logger, + ConfigDir: "/tmp/boundary", + Uid: uid, + Gid: gid, + }) + require.NoError(t, err) + + // Setup TLS to get cert path for jailer + tlsConfig, caCertPath, configDir, err := certManager.SetupTLSAndWriteCACert() + require.NoError(t, err) + _, _ = caCertPath, configDir + + // Create proxy server + server := NewProxyServer(Config{ + HTTPPort: 8080, + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + }) + + // Start server + err = server.Start() + require.NoError(t, err) + + // Give server time to start + time.Sleep(100 * time.Millisecond) + + // Test basic HTTPS request + t.Run("BasicHTTPSRequest", func(t *testing.T) { + // Create HTTP client + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, // Skip cert verification for testing + }, + }, + Timeout: 5 * time.Second, + } + + // Make request to proxy + req, err := http.NewRequest("GET", "https://localhost:8080/api/v2", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + // Override the Host header + req.Host = "dev.coder.com" + + // Make the request + resp, err := client.Do(req) + require.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + expectedResponse := `{"message":"👋"} +` + require.Equal(t, expectedResponse, string(body)) + }) + + err = server.Stop() + require.NoError(t, err) }