@@ -2,8 +2,8 @@ package proxy
2
2
3
3
import (
4
4
"bufio"
5
- "context"
6
5
"crypto/tls"
6
+ "errors"
7
7
"fmt"
8
8
"io"
9
9
"log/slog"
@@ -12,7 +12,7 @@ import (
12
12
"net/url"
13
13
"strings"
14
14
"sync"
15
- "time "
15
+ "sync/atomic "
16
16
17
17
"github.com/coder/boundary/audit"
18
18
"github.com/coder/boundary/rules"
@@ -25,8 +25,9 @@ type Server struct {
25
25
logger * slog.Logger
26
26
tlsConfig * tls.Config
27
27
httpPort int
28
+ started atomic.Bool
28
29
29
- httpServer * http. Server
30
+ listener net. Listener
30
31
}
31
32
32
33
// Config holds configuration for the proxy server
@@ -50,64 +51,70 @@ func NewProxyServer(config Config) *Server {
50
51
}
51
52
52
53
// Start starts the HTTP proxy server with TLS termination capability
53
- func (p * Server ) Start (ctx context.Context ) error {
54
- // Create HTTP server with TLS termination capability
55
- p .httpServer = & http.Server {
56
- Addr : fmt .Sprintf (":%d" , p .httpPort ),
57
- Handler : http .HandlerFunc (p .handleHTTPWithTLSTermination ),
54
+ func (p * Server ) Start () error {
55
+ if p .isStarted () {
56
+ return nil
57
+ }
58
+
59
+ p .logger .Info ("Starting HTTP proxy with TLS termination" , "port" , p .httpPort )
60
+ var err error
61
+ p .listener , err = net .Listen ("tcp" , fmt .Sprintf (":%d" , p .httpPort ))
62
+ if err != nil {
63
+ p .logger .Error ("Failed to create HTTP listener" , "error" , err )
64
+ return err
58
65
}
59
66
67
+ p .started .Store (true )
68
+
60
69
// Start HTTP server with custom listener for TLS detection
61
70
go func () {
62
- p .logger .Info ("Starting HTTP proxy with TLS termination" , "port" , p .httpPort )
63
- listener , err := net .Listen ("tcp" , fmt .Sprintf (":%d" , p .httpPort ))
64
- if err != nil {
65
- p .logger .Error ("Failed to create HTTP listener" , "error" , err )
66
- return
67
- }
68
-
69
71
for {
70
- conn , err := listener .Accept ()
72
+ conn , err := p .listener .Accept ()
73
+ if err != nil && errors .Is (err , net .ErrClosed ) && p .isStopped () {
74
+ return
75
+ }
71
76
if err != nil {
72
- select {
73
- case <- ctx .Done ():
74
- err = listener .Close ()
75
- if err != nil {
76
- p .logger .Error ("Failed to close listener" , "error" , err )
77
- }
78
- return
79
- default :
80
- p .logger .Error ("Failed to accept connection" , "error" , err )
81
- continue
82
- }
77
+ p .logger .Error ("Failed to accept connection" , "error" , err )
78
+ continue
83
79
}
84
80
85
81
// Handle connection with TLS detection
86
82
go p .handleConnectionWithTLSDetection (conn )
87
83
}
88
84
}()
89
85
90
- // Wait for context cancellation
91
- <- ctx .Done ()
92
- return p .Stop ()
86
+ return nil
93
87
}
94
88
95
89
// Stops proxy server
96
90
func (p * Server ) Stop () error {
97
- ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
98
- defer cancel ()
91
+ if p .isStopped () {
92
+ return nil
93
+ }
94
+ p .started .Store (false )
99
95
100
- var httpErr error
101
- if p . httpServer != nil {
102
- httpErr = p . httpServer . Shutdown ( ctx )
96
+ if p . listener == nil {
97
+ p . logger . Error ( "unexpected nil listener" )
98
+ return errors . New ( "unexpected nil listener" )
103
99
}
104
100
105
- if httpErr != nil {
106
- return httpErr
101
+ err := p .listener .Close ()
102
+ if err != nil {
103
+ p .logger .Error ("Failed to close listener" , "error" , err )
104
+ return err
107
105
}
106
+
108
107
return nil
109
108
}
110
109
110
+ func (p * Server ) isStarted () bool {
111
+ return p .started .Load ()
112
+ }
113
+
114
+ func (p * Server ) isStopped () bool {
115
+ return ! p .started .Load ()
116
+ }
117
+
111
118
// handleHTTP handles regular HTTP requests and CONNECT tunneling
112
119
func (p * Server ) handleHTTP (w http.ResponseWriter , r * http.Request ) {
113
120
p .logger .Debug ("handleHTTP called" , "method" , r .Method , "url" , r .URL .String (), "host" , r .Host )
@@ -479,13 +486,6 @@ func (p *Server) handleConnectionWithTLSDetection(conn net.Conn) {
479
486
}
480
487
}
481
488
482
- // handleHTTPWithTLSTermination is the main handler (currently just delegates to regular HTTP)
483
- func (p * Server ) handleHTTPWithTLSTermination (w http.ResponseWriter , r * http.Request ) {
484
- // This handler is not used when we do custom connection handling
485
- // All traffic goes through handleConnectionWithTLSDetection
486
- p .handleHTTP (w , r )
487
- }
488
-
489
489
// connectionWrapper lets us "unread" the peeked byte
490
490
type connectionWrapper struct {
491
491
net.Conn
0 commit comments