Skip to content

Commit b9a418b

Browse files
refactor: start/stop methods
1 parent 80604e3 commit b9a418b

File tree

2 files changed

+33
-34
lines changed

2 files changed

+33
-34
lines changed

proxy/proxy.go

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package proxy
22

33
import (
44
"bufio"
5-
"context"
65
"crypto/tls"
6+
"errors"
77
"fmt"
88
"io"
99
"log/slog"
@@ -12,6 +12,7 @@ import (
1212
"net/url"
1313
"strings"
1414
"sync"
15+
"sync/atomic"
1516

1617
"github.com/coder/boundary/audit"
1718
"github.com/coder/boundary/rules"
@@ -24,6 +25,7 @@ type Server struct {
2425
logger *slog.Logger
2526
tlsConfig *tls.Config
2627
httpPort int
28+
started atomic.Bool
2729

2830
listener net.Listener
2931
}
@@ -49,7 +51,12 @@ func NewProxyServer(config Config) *Server {
4951
}
5052

5153
// Start starts the HTTP proxy server with TLS termination capability
52-
func (p *Server) Start(ctx context.Context) error {
54+
func (p *Server) Start() error {
55+
if p.isStarted() {
56+
return nil
57+
}
58+
p.started.Store(true)
59+
5360
// Start HTTP server with custom listener for TLS detection
5461
go func() {
5562
p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort)
@@ -62,46 +69,49 @@ func (p *Server) Start(ctx context.Context) error {
6269

6370
for {
6471
conn, err := p.listener.Accept()
72+
if err != nil && errors.Is(err, net.ErrClosed) && p.isStopped() {
73+
return
74+
}
6575
if err != nil {
66-
select {
67-
case <-ctx.Done():
68-
err = p.listener.Close()
69-
if err != nil {
70-
p.logger.Error("Failed to close listener", "error", err)
71-
}
72-
return
73-
default:
74-
p.logger.Error("Failed to accept connection", "error", err)
75-
continue
76-
}
76+
p.logger.Error("Failed to accept connection", "error", err)
77+
continue
7778
}
7879

7980
// Handle connection with TLS detection
8081
go p.handleConnectionWithTLSDetection(conn)
8182
}
8283
}()
8384

84-
// Wait for context cancellation
85-
<-ctx.Done()
86-
return p.Stop()
85+
return nil
8786
}
8887

8988
// Stops proxy server
9089
func (p *Server) Stop() error {
91-
if p.listener == nil {
90+
if p.isStopped() {
9291
return nil
9392
}
93+
p.started.Store(false)
94+
95+
if p.listener == nil {
96+
return errors.New("listener is nil; server was not started")
97+
}
9498

9599
err := p.listener.Close()
96100
if err != nil {
97101
p.logger.Error("Failed to close listener", "error", err)
98102
}
99103

100-
fmt.Printf("STOP is finished\n")
101-
102104
return nil
103105
}
104106

107+
func (p *Server) isStarted() bool {
108+
return p.started.Load()
109+
}
110+
111+
func (p *Server) isStopped() bool {
112+
return !p.started.Load()
113+
}
114+
105115
// handleHTTP handles regular HTTP requests and CONNECT tunneling
106116
func (p *Server) handleHTTP(w http.ResponseWriter, r *http.Request) {
107117
p.logger.Debug("handleHTTP called", "method", r.Method, "url", r.URL.String(), "host", r.Host)

proxy/proxy_test.go

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package proxy
22

33
import (
4-
"context"
54
"crypto/tls"
65
"io"
76
"log"
@@ -31,7 +30,7 @@ func (m *mockAuditor) AuditRequest(req audit.Request) {
3130
func TestProxyServerBasicHTTP(t *testing.T) {
3231
// Create test logger
3332
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
34-
Level: slog.LevelDebug,
33+
Level: slog.LevelError,
3534
}))
3635

3736
// Create test rules (allow all for testing)
@@ -60,14 +59,10 @@ func TestProxyServerBasicHTTP(t *testing.T) {
6059
TLSConfig: tlsConfig,
6160
})
6261

63-
// Create context with timeout
64-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
65-
defer cancel()
66-
6762
// Start server in goroutine
6863
serverDone := make(chan error, 1)
6964
go func() {
70-
serverDone <- server.Start(ctx)
65+
serverDone <- server.Start()
7166
}()
7267

7368
// Give server time to start
@@ -112,7 +107,6 @@ func TestProxyServerBasicHTTP(t *testing.T) {
112107

113108
err = server.Stop()
114109
require.NoError(t, err)
115-
cancel()
116110
err = <-serverDone
117111
require.NoError(t, err)
118112
}
@@ -121,7 +115,7 @@ func TestProxyServerBasicHTTP(t *testing.T) {
121115
func TestProxyServerBasicHTTPS(t *testing.T) {
122116
// Create test logger
123117
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
124-
Level: slog.LevelDebug,
118+
Level: slog.LevelError,
125119
}))
126120

127121
// Create test rules (allow all for testing)
@@ -167,14 +161,10 @@ func TestProxyServerBasicHTTPS(t *testing.T) {
167161
TLSConfig: tlsConfig,
168162
})
169163

170-
// Create context with timeout
171-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
172-
defer cancel()
173-
174164
// Start server in goroutine
175165
serverDone := make(chan error, 1)
176166
go func() {
177-
serverDone <- server.Start(ctx)
167+
serverDone <- server.Start()
178168
}()
179169

180170
// Give server time to start
@@ -215,7 +205,6 @@ func TestProxyServerBasicHTTPS(t *testing.T) {
215205

216206
err = server.Stop()
217207
require.NoError(t, err)
218-
cancel()
219208
err = <-serverDone
220209
require.NoError(t, err)
221210
}

0 commit comments

Comments
 (0)