Skip to content

Commit 07deeeb

Browse files
Merge pull request #57 from coder/yevhenii/proxy-tests
refactor: add tests for proxy package
2 parents 2ed7c35 + fa02d24 commit 07deeeb

File tree

4 files changed

+254
-57
lines changed

4 files changed

+254
-57
lines changed

boundary.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,11 @@ func (b *Boundary) Start() error {
6262
}
6363

6464
// Start proxy server in background
65-
go func() {
66-
err := b.proxyServer.Start(b.ctx)
67-
if err != nil {
68-
b.logger.Error("Proxy server error", "error", err)
69-
}
70-
}()
65+
err = b.proxyServer.Start()
66+
if err != nil {
67+
b.logger.Error("Proxy server error", "error", err)
68+
return err
69+
}
7170

7271
// Give proxy time to start
7372
time.Sleep(100 * time.Millisecond)
@@ -90,4 +89,4 @@ func (b *Boundary) Close() error {
9089

9190
// Close jailer
9291
return b.jailer.Close()
93-
}
92+
}

go.mod

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ module github.com/coder/boundary
22

33
go 1.24
44

5-
require github.com/coder/serpent v0.10.0
5+
require (
6+
github.com/coder/serpent v0.10.0
7+
github.com/stretchr/testify v1.8.4
8+
)
69

710
require (
811
cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 // indirect
912
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
1013
github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 // indirect
14+
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
1115
github.com/hashicorp/errwrap v1.1.0 // indirect
1216
github.com/hashicorp/go-multierror v1.1.1 // indirect
1317
github.com/kr/text v0.2.0 // indirect
@@ -18,6 +22,7 @@ require (
1822
github.com/muesli/termenv v0.15.2 // indirect
1923
github.com/pion/transport/v2 v2.0.0 // indirect
2024
github.com/pion/udp v0.1.4 // indirect
25+
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
2126
github.com/rivo/uniseg v0.4.4 // indirect
2227
github.com/spf13/pflag v1.0.5 // indirect
2328
go.opentelemetry.io/otel v1.19.0 // indirect

proxy/proxy.go

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package proxy
22

33
import (
44
"bufio"
5-
"context"
65
"crypto/tls"
6+
"errors"
77
"fmt"
88
"io"
99
"log/slog"
@@ -12,7 +12,7 @@ import (
1212
"net/url"
1313
"strings"
1414
"sync"
15-
"time"
15+
"sync/atomic"
1616

1717
"github.com/coder/boundary/audit"
1818
"github.com/coder/boundary/rules"
@@ -25,8 +25,9 @@ type Server struct {
2525
logger *slog.Logger
2626
tlsConfig *tls.Config
2727
httpPort int
28+
started atomic.Bool
2829

29-
httpServer *http.Server
30+
listener net.Listener
3031
}
3132

3233
// Config holds configuration for the proxy server
@@ -50,64 +51,70 @@ func NewProxyServer(config Config) *Server {
5051
}
5152

5253
// Start starts the HTTP proxy server with TLS termination capability
53-
func (p *Server) Start(ctx context.Context) error {
54-
// Create HTTP server with TLS termination capability
55-
p.httpServer = &http.Server{
56-
Addr: fmt.Sprintf(":%d", p.httpPort),
57-
Handler: http.HandlerFunc(p.handleHTTPWithTLSTermination),
54+
func (p *Server) Start() error {
55+
if p.isStarted() {
56+
return nil
57+
}
58+
59+
p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort)
60+
var err error
61+
p.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort))
62+
if err != nil {
63+
p.logger.Error("Failed to create HTTP listener", "error", err)
64+
return err
5865
}
5966

67+
p.started.Store(true)
68+
6069
// Start HTTP server with custom listener for TLS detection
6170
go func() {
62-
p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort)
63-
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort))
64-
if err != nil {
65-
p.logger.Error("Failed to create HTTP listener", "error", err)
66-
return
67-
}
68-
6971
for {
70-
conn, err := listener.Accept()
72+
conn, err := p.listener.Accept()
73+
if err != nil && errors.Is(err, net.ErrClosed) && p.isStopped() {
74+
return
75+
}
7176
if err != nil {
72-
select {
73-
case <-ctx.Done():
74-
err = listener.Close()
75-
if err != nil {
76-
p.logger.Error("Failed to close listener", "error", err)
77-
}
78-
return
79-
default:
80-
p.logger.Error("Failed to accept connection", "error", err)
81-
continue
82-
}
77+
p.logger.Error("Failed to accept connection", "error", err)
78+
continue
8379
}
8480

8581
// Handle connection with TLS detection
8682
go p.handleConnectionWithTLSDetection(conn)
8783
}
8884
}()
8985

90-
// Wait for context cancellation
91-
<-ctx.Done()
92-
return p.Stop()
86+
return nil
9387
}
9488

9589
// Stops proxy server
9690
func (p *Server) Stop() error {
97-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
98-
defer cancel()
91+
if p.isStopped() {
92+
return nil
93+
}
94+
p.started.Store(false)
9995

100-
var httpErr error
101-
if p.httpServer != nil {
102-
httpErr = p.httpServer.Shutdown(ctx)
96+
if p.listener == nil {
97+
p.logger.Error("unexpected nil listener")
98+
return errors.New("unexpected nil listener")
10399
}
104100

105-
if httpErr != nil {
106-
return httpErr
101+
err := p.listener.Close()
102+
if err != nil {
103+
p.logger.Error("Failed to close listener", "error", err)
104+
return err
107105
}
106+
108107
return nil
109108
}
110109

110+
func (p *Server) isStarted() bool {
111+
return p.started.Load()
112+
}
113+
114+
func (p *Server) isStopped() bool {
115+
return !p.started.Load()
116+
}
117+
111118
// handleHTTP handles regular HTTP requests and CONNECT tunneling
112119
func (p *Server) handleHTTP(w http.ResponseWriter, r *http.Request) {
113120
p.logger.Debug("handleHTTP called", "method", r.Method, "url", r.URL.String(), "host", r.Host)
@@ -479,13 +486,6 @@ func (p *Server) handleConnectionWithTLSDetection(conn net.Conn) {
479486
}
480487
}
481488

482-
// handleHTTPWithTLSTermination is the main handler (currently just delegates to regular HTTP)
483-
func (p *Server) handleHTTPWithTLSTermination(w http.ResponseWriter, r *http.Request) {
484-
// This handler is not used when we do custom connection handling
485-
// All traffic goes through handleConnectionWithTLSDetection
486-
p.handleHTTP(w, r)
487-
}
488-
489489
// connectionWrapper lets us "unread" the peeked byte
490490
type connectionWrapper struct {
491491
net.Conn

0 commit comments

Comments
 (0)