@@ -11,6 +11,8 @@ import (
11
11
"net/http"
12
12
"net/http/httptest"
13
13
"net/url"
14
+ "strings"
15
+ "sync"
14
16
"time"
15
17
16
18
"github.com/coder/jail/audit"
@@ -54,10 +56,10 @@ func NewProxyServer(config Config) *Server {
54
56
55
57
// Start starts both HTTP and HTTPS proxy servers
56
58
func (p * Server ) Start (ctx context.Context ) error {
57
- // Create HTTP server
59
+ // Create HTTP server with TLS termination capability for privileged mode
58
60
p .httpServer = & http.Server {
59
61
Addr : fmt .Sprintf (":%d" , p .httpPort ),
60
- Handler : http .HandlerFunc (p .handleHTTP ),
62
+ Handler : http .HandlerFunc (p .handleHTTPWithTLSTermination ),
61
63
}
62
64
63
65
// Create HTTPS server
@@ -67,12 +69,30 @@ func (p *Server) Start(ctx context.Context) error {
67
69
TLSConfig : p .tlsConfig ,
68
70
}
69
71
70
- // Start HTTP server
72
+ // Start HTTP server with custom listener for TLS detection
71
73
go func () {
72
- p .logger .Info ("Starting HTTP proxy" , "port" , p .httpPort )
73
- err := p .httpServer .ListenAndServe ()
74
- if err != nil && err != http .ErrServerClosed {
75
- p .logger .Error ("HTTP proxy server error" , "error" , err )
74
+ p .logger .Info ("Starting HTTP proxy with TLS termination" , "port" , p .httpPort )
75
+ listener , err := net .Listen ("tcp" , fmt .Sprintf (":%d" , p .httpPort ))
76
+ if err != nil {
77
+ p .logger .Error ("Failed to create HTTP listener" , "error" , err )
78
+ return
79
+ }
80
+
81
+ for {
82
+ conn , err := listener .Accept ()
83
+ if err != nil {
84
+ select {
85
+ case <- ctx .Done ():
86
+ listener .Close ()
87
+ return
88
+ default :
89
+ p .logger .Error ("Failed to accept connection" , "error" , err )
90
+ continue
91
+ }
92
+ }
93
+
94
+ // Handle connection with TLS detection
95
+ go p .handleConnectionWithTLSDetection (conn )
76
96
}
77
97
}()
78
98
@@ -452,4 +472,137 @@ func (p *Server) handleDecryptedHTTPS(w http.ResponseWriter, r *http.Request) {
452
472
453
473
// Forward the HTTPS request
454
474
p .forwardHTTPSRequest (w , r )
475
+ }
476
+
477
+ // handleConnectionWithTLSDetection detects TLS vs HTTP and handles appropriately
478
+ func (p * Server ) handleConnectionWithTLSDetection (conn net.Conn ) {
479
+ defer conn .Close ()
480
+
481
+ // Peek at the first byte to detect TLS handshake
482
+ conn .SetReadDeadline (time .Now ().Add (5 * time .Second ))
483
+ firstByte := make ([]byte , 1 )
484
+ n , err := conn .Read (firstByte )
485
+ if err != nil || n == 0 {
486
+ p .logger .Debug ("Failed to read first byte from connection" , "error" , err )
487
+ return
488
+ }
489
+ conn .SetReadDeadline (time.Time {}) // Clear deadline
490
+
491
+ // TLS handshake starts with 0x16 (TLS Content Type: Handshake)
492
+ if firstByte [0 ] == 0x16 {
493
+ p .logger .Debug ("Detected TLS handshake, performing TLS termination" )
494
+ p .handleTLSTermination (conn , firstByte )
495
+ } else {
496
+ p .logger .Debug ("Detected HTTP request, handling normally" )
497
+ p .handleHTTPConnection (conn , firstByte )
498
+ }
499
+ }
500
+
501
+ // handleTLSTermination performs TLS termination and processes decrypted HTTPS requests
502
+ func (p * Server ) handleTLSTermination (conn net.Conn , firstByte []byte ) {
503
+ // Create a connection that prepends the first byte we already read
504
+ connWithFirstByte := & connectionWithPrefix {
505
+ Conn : conn ,
506
+ prefix : firstByte ,
507
+ }
508
+
509
+ // We need to extract the SNI (Server Name Indication) from the TLS handshake
510
+ // to generate the appropriate certificate. For now, use a default hostname.
511
+ hostname := "unknown-host"
512
+
513
+ // TODO: Extract hostname from TLS ClientHello SNI extension
514
+ // For now, we'll perform TLS termination with a generic certificate
515
+
516
+ // Perform TLS handshake with our certificate
517
+ tlsConn := tls .Server (connWithFirstByte , p .tlsConfig )
518
+ err := tlsConn .Handshake ()
519
+ if err != nil {
520
+ p .logger .Debug ("TLS handshake failed" , "error" , err )
521
+ return
522
+ }
523
+
524
+ p .logger .Debug ("TLS handshake successful, processing decrypted HTTPS traffic" )
525
+
526
+ // Now handle the decrypted HTTPS requests using our existing logic
527
+ p .handleTLSConnection (tlsConn , hostname )
528
+ }
529
+
530
+ // handleHTTPConnection handles regular HTTP connections
531
+ func (p * Server ) handleHTTPConnection (conn net.Conn , firstByte []byte ) {
532
+ // Create a connection that prepends the first byte we already read
533
+ connWithFirstByte := & connectionWithPrefix {
534
+ Conn : conn ,
535
+ prefix : firstByte ,
536
+ }
537
+
538
+ // Create HTTP server to handle this connection
539
+ server := & http.Server {
540
+ Handler : http .HandlerFunc (p .handleHTTP ),
541
+ }
542
+
543
+ // Serve the HTTP request
544
+ err := server .Serve (& singleConnListener {conn : connWithFirstByte })
545
+ if err != nil && err != io .EOF && ! isConnectionClosed (err ) {
546
+ p .logger .Debug ("HTTP connection error" , "error" , err )
547
+ }
548
+ }
549
+
550
+ // handleHTTPWithTLSTermination is the main handler (currently just delegates to regular HTTP)
551
+ func (p * Server ) handleHTTPWithTLSTermination (w http.ResponseWriter , r * http.Request ) {
552
+ // This handler is not used when we do custom connection handling
553
+ // All traffic goes through handleConnectionWithTLSDetection
554
+ p .handleHTTP (w , r )
555
+ }
556
+
557
+ // connectionWithPrefix wraps a connection and prepends some data
558
+ type connectionWithPrefix struct {
559
+ net.Conn
560
+ prefix []byte
561
+ prefixRead bool
562
+ }
563
+
564
+ func (c * connectionWithPrefix ) Read (b []byte ) (n int , err error ) {
565
+ if ! c .prefixRead && len (c .prefix ) > 0 {
566
+ n = copy (b , c .prefix )
567
+ c .prefixRead = true
568
+ return n , nil
569
+ }
570
+ return c .Conn .Read (b )
571
+ }
572
+
573
+ // isConnectionClosed checks if an error indicates a closed connection
574
+ func isConnectionClosed (err error ) bool {
575
+ if err == nil {
576
+ return false
577
+ }
578
+ s := err .Error ()
579
+ return strings .Contains (s , "use of closed network connection" ) ||
580
+ strings .Contains (s , "broken pipe" ) ||
581
+ strings .Contains (s , "connection reset by peer" )
582
+ }
583
+
584
+ // singleConnListener wraps a single connection to implement net.Listener
585
+ type singleConnListener struct {
586
+ conn net.Conn
587
+ used bool
588
+ mu sync.Mutex
589
+ }
590
+
591
+ func (l * singleConnListener ) Accept () (net.Conn , error ) {
592
+ l .mu .Lock ()
593
+ defer l .mu .Unlock ()
594
+
595
+ if l .used {
596
+ return nil , io .EOF
597
+ }
598
+ l .used = true
599
+ return l .conn , nil
600
+ }
601
+
602
+ func (l * singleConnListener ) Close () error {
603
+ return nil // Don't close the underlying connection here
604
+ }
605
+
606
+ func (l * singleConnListener ) Addr () net.Addr {
607
+ return l .conn .LocalAddr ()
455
608
}
0 commit comments