diff --git a/go.mod b/go.mod index 4d634e8..44da418 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,11 @@ module github.com/coder/boundary -go 1.24 +go 1.24.0 require ( github.com/coder/serpent v0.10.0 github.com/stretchr/testify v1.8.4 + golang.org/x/sync v0.17.0 ) require ( diff --git a/go.sum b/go.sum index d751167..5784f70 100644 --- a/go.sum +++ b/go.sum @@ -101,6 +101,8 @@ golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/proxy/proxy.go b/proxy/proxy.go index e2aa537..46f24df 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -2,6 +2,7 @@ package proxy import ( "bufio" + "context" "crypto/tls" "errors" "fmt" @@ -16,6 +17,7 @@ import ( "github.com/coder/boundary/audit" "github.com/coder/boundary/rules" + "golang.org/x/sync/errgroup" ) // Server handles HTTP and HTTPS requests with rule-based filtering @@ -680,20 +682,40 @@ func (p *Server) streamRequestToTarget(clientConn *tls.Conn, bufReader *bufio.Re return fmt.Errorf("failed to write headers to target: %v", err) } - // Stream request body and response bidirectionally - go func() { - // Stream request body: client -> target + // Use errgroup to manage bidirectional streaming and ensure cleanup + g, ctx := errgroup.WithContext(context.Background()) + + // Stream request body: client -> target + g.Go(func() error { _, err := io.Copy(targetConn, bufReader) - if err != nil { - p.logger.Error("Error copying request body to target", "error", err) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + p.logger.Debug("Error copying request body to target", "error", err) } - }() + // Close write side to signal EOF to target + _ = targetConn.CloseWrite() + return nil + }) // Stream response: target -> client - _, err = io.Copy(clientConn, targetConn) - if err != nil { - p.logger.Error("Error copying response from target to client", "error", err) - } + g.Go(func() error { + _, err := io.Copy(clientConn, targetConn) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + p.logger.Debug("Error copying response from target to client", "error", err) + } + return nil + }) + + // Monitor context cancellation to ensure both goroutines exit + g.Go(func() error { + <-ctx.Done() + // Force close connections to unblock any hanging io.Copy + _ = clientConn.Close() + _ = targetConn.Close() + return nil + }) + + // Wait for all goroutines to complete + _ = g.Wait() return nil } @@ -729,16 +751,41 @@ func (p *Server) handleConnectStreaming(tlsConn *tls.Conn, req *http.Request, ho } defer func() { _ = targetConn.Close() }() - // Bidirectional copy - go func() { + // Use errgroup for bidirectional copy with proper cleanup + g, ctx := errgroup.WithContext(context.Background()) + + // Client to target + g.Go(func() error { _, err := io.Copy(targetConn, tlsConn) - if err != nil { - p.logger.Error("Error copying from client to target", "error", err) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + p.logger.Debug("Error copying from client to target", "error", err) } - }() - _, err = io.Copy(tlsConn, targetConn) - if err != nil { - p.logger.Error("Error copying from target to client", "error", err) - } + // Close write side to signal EOF + if tc, ok := targetConn.(*net.TCPConn); ok { + _ = tc.CloseWrite() + } + return nil + }) + + // Target to client + g.Go(func() error { + _, err := io.Copy(tlsConn, targetConn) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + p.logger.Debug("Error copying from target to client", "error", err) + } + return nil + }) + + // Monitor context cancellation to ensure cleanup + g.Go(func() error { + <-ctx.Done() + // Force close connections to unblock any hanging io.Copy + _ = tlsConn.Close() + _ = targetConn.Close() + return nil + }) + + // Wait for all goroutines to complete + _ = g.Wait() p.logger.Debug("CONNECT tunnel closed", "hostname", hostname) }