Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added boundary
Binary file not shown.
210 changes: 180 additions & 30 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"log/slog"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
Expand Down Expand Up @@ -325,55 +324,55 @@ func (p *Server) handleConnect(w http.ResponseWriter, r *http.Request) {
p.logger.Debug("HTTPS request handling completed", "hostname", hostname)
}

// handleTLSConnection processes decrypted HTTPS requests over the TLS connection
// handleTLSConnection processes decrypted HTTPS requests over the TLS connection with streaming support
func (p *Server) handleTLSConnection(tlsConn *tls.Conn, hostname string) {
p.logger.Debug("Creating HTTP server for TLS connection", "hostname", hostname)
p.logger.Debug("Creating streaming HTTP handler for TLS connection", "hostname", hostname)

// Set read timeout to detect hanging connections
tlsConn.SetReadDeadline(time.Now().Add(5 * time.Second))

// Use ReadRequest to manually read HTTP requests from the TLS connection
// Use streaming HTTP parsing instead of ReadRequest
bufReader := bufio.NewReader(tlsConn)
for {
// Read HTTP request from TLS connection
req, err := http.ReadRequest(bufReader)
// Parse HTTP request headers incrementally
req, err := p.parseHTTPRequestHeaders(bufReader, hostname)
if err != nil {
if err == io.EOF {
p.logger.Debug("TLS connection closed by client", "hostname", hostname)
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
p.logger.Debug("TLS connection read timeout - client not sending HTTP requests", "hostname", hostname)
} else {
p.logger.Debug("Failed to read HTTP request", "hostname", hostname, "error", err)
p.logger.Debug("Failed to parse HTTP request headers", "hostname", hostname, "error", err)
}
break
}

p.logger.Debug("Processing decrypted HTTPS request", "hostname", hostname, "method", req.Method, "path", req.URL.Path)
p.logger.Debug("Processing streaming HTTPS request", "hostname", hostname, "method", req.Method, "path", req.URL.Path)

// Set the hostname and scheme if not already set
if req.URL.Host == "" {
req.URL.Host = hostname
}
if req.URL.Scheme == "" {
req.URL.Scheme = "https"
// Handle CONNECT method for HTTPS tunneling
if req.Method == "CONNECT" {
p.handleConnectStreaming(tlsConn, req, hostname)
return // CONNECT takes over the entire connection
}

// Create a response recorder to capture the response
recorder := httptest.NewRecorder()
// Check if request should be allowed (based on headers only)
fullURL := p.constructFullURL(req, hostname)
result := p.ruleEngine.Evaluate(req.Method, fullURL)

// Process the HTTPS request
p.handleDecryptedHTTPS(recorder, req)
// Audit the request
p.auditor.AuditRequest(audit.Request{
Method: req.Method,
URL: fullURL,
Allowed: result.Allowed,
Rule: result.Rule,
})

// Write the response back to the TLS connection
resp := recorder.Result()
err = resp.Write(tlsConn)
if !result.Allowed {
p.writeBlockedResponseStreaming(tlsConn, req)
continue
}

// Stream the request to target server
err = p.streamRequestToTarget(tlsConn, bufReader, req, hostname)
if err != nil {
p.logger.Debug("Failed to write response", "hostname", hostname, "error", err)
p.logger.Debug("Error streaming request", "hostname", hostname, "error", err)
break
}

// Reset read deadline for next request
tlsConn.SetReadDeadline(time.Now().Add(5 * time.Second))
}

p.logger.Debug("TLS connection handling completed", "hostname", hostname)
Expand Down Expand Up @@ -532,3 +531,154 @@ func (sl *singleConnectionListener) Addr() net.Addr {
}
return sl.conn.LocalAddr()
}

// parseHTTPRequestHeaders parses HTTP request headers incrementally without reading the body
func (p *Server) parseHTTPRequestHeaders(bufReader *bufio.Reader, hostname string) (*http.Request, error) {
// Read the request line (e.g., "GET /path HTTP/1.1")
requestLine, _, err := bufReader.ReadLine()
if err != nil {
return nil, err
}

// Parse request line
parts := strings.Fields(string(requestLine))
if len(parts) != 3 {
return nil, fmt.Errorf("invalid request line: %s", requestLine)
}

method := parts[0]
requestURI := parts[1]
proto := parts[2]

// Parse URL
var url *url.URL
if strings.HasPrefix(requestURI, "http://") || strings.HasPrefix(requestURI, "https://") {
url, err = url.Parse(requestURI)
} else {
// Relative URL, construct with hostname
url, err = url.Parse("https://" + hostname + requestURI)
}
if err != nil {
return nil, fmt.Errorf("invalid request URI: %s", requestURI)
}

// Read headers
headers := make(http.Header)
for {
headerLine, _, err := bufReader.ReadLine()
if err != nil {
return nil, err
}

// Empty line indicates end of headers
if len(headerLine) == 0 {
break
}

// Parse header
headerStr := string(headerLine)
colonIdx := strings.Index(headerStr, ":")
if colonIdx == -1 {
continue // Skip malformed headers
}

headerName := strings.TrimSpace(headerStr[:colonIdx])
headerValue := strings.TrimSpace(headerStr[colonIdx+1:])
headers.Add(headerName, headerValue)
}

// Create request object (without body)
req := &http.Request{
Method: method,
URL: url,
Proto: proto,
Header: headers,
Host: url.Host,
// Note: Body is intentionally nil - we'll stream it separately
}

return req, nil
}

// constructFullURL builds the full URL from request and hostname
func (p *Server) constructFullURL(req *http.Request, hostname string) string {
if req.URL.Host == "" {
req.URL.Host = hostname
}
if req.URL.Scheme == "" {
req.URL.Scheme = "https"
}
return req.URL.String()
}

// writeBlockedResponseStreaming writes a blocked response directly to the TLS connection
func (p *Server) writeBlockedResponseStreaming(tlsConn *tls.Conn, req *http.Request) {
response := fmt.Sprintf("HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n🚫 Request Blocked by Boundary\n\nRequest: %s %s\nHost: %s\n\nTo allow this request, restart boundary with:\n --allow \"%s\"\n",
req.Method, req.URL.Path, req.Host, req.Host)
tlsConn.Write([]byte(response))
}

// streamRequestToTarget streams the HTTP request (including body) to the target server
func (p *Server) streamRequestToTarget(clientConn *tls.Conn, bufReader *bufio.Reader, req *http.Request, hostname string) error {
// Connect to target server
targetConn, err := tls.Dial("tcp", hostname+":443", &tls.Config{ServerName: hostname})
if err != nil {
return fmt.Errorf("failed to connect to target %s: %v", hostname, err)
}
defer targetConn.Close()

// Send HTTP request headers to target
reqLine := fmt.Sprintf("%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
targetConn.Write([]byte(reqLine))

// Send headers
for name, values := range req.Header {
for _, value := range values {
headerLine := fmt.Sprintf("%s: %s\r\n", name, value)
targetConn.Write([]byte(headerLine))
}
}
targetConn.Write([]byte("\r\n")) // End of headers

// Stream request body and response bidirectionally
go func() {
// Stream request body: client -> target
io.Copy(targetConn, bufReader)
}()

// Stream response: target -> client
io.Copy(clientConn, targetConn)
return nil
}

// handleConnectStreaming handles CONNECT requests with streaming TLS termination
func (p *Server) handleConnectStreaming(tlsConn *tls.Conn, req *http.Request, hostname string) {
p.logger.Debug("Handling CONNECT request with streaming", "hostname", hostname)

// For CONNECT, we need to establish a tunnel but still maintain TLS termination
// This is the tricky part - we're already inside a TLS connection from the client
// The client is asking us to CONNECT to another server, but we want to intercept that too

// Send CONNECT response
response := "HTTP/1.1 200 Connection established\r\n\r\n"
tlsConn.Write([]byte(response))

// Now the client will try to do TLS handshake for the target server
// But we want to intercept and terminate it
// This means we need to do another level of TLS termination

// For now, let's create a simple tunnel and log that we're not inspecting
p.logger.Warn("CONNECT tunnel established - content not inspected", "hostname", hostname)

// Create connection to real target
targetConn, err := net.Dial("tcp", req.Host)
if err != nil {
p.logger.Error("Failed to connect to CONNECT target", "target", req.Host, "error", err)
return
}
defer targetConn.Close()

// Bidirectional copy
go io.Copy(targetConn, tlsConn)
io.Copy(tlsConn, targetConn)
}
Loading