diff --git a/pkg/server/server.go b/pkg/server/server.go index 8ca3f0cef..ffec433c3 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -108,6 +108,53 @@ const ( destHostKey key = iota ) +// mapDialErrorToHTTPStatus maps common TCP/network error strings to appropriate HTTP status codes +func mapDialErrorToHTTPStatus(errStr string) int { + // Convert to lowercase for case-insensitive matching + errLower := strings.ToLower(errStr) + + // Check each error pattern and return appropriate status code + switch { + // Timeouts - backend didn't respond in time -> 504 Gateway Timeout + case strings.Contains(errLower, "i/o timeout"), + strings.Contains(errLower, "deadline exceeded"), + strings.Contains(errLower, "context deadline exceeded"), + strings.Contains(errLower, "timeout"): + return 504 + + // Resource exhaustion errors -> 503 Service Unavailable + case strings.Contains(errLower, "too many open files"), + strings.Contains(errLower, "socket: too many open files"): + return 503 + + // Connection errors -> 502 Bad Gateway + case strings.Contains(errLower, "connection refused"), + strings.Contains(errLower, "connection reset by peer"), + strings.Contains(errLower, "broken pipe"), + strings.Contains(errLower, "network is unreachable"), + strings.Contains(errLower, "no route to host"), + strings.Contains(errLower, "host is unreachable"), + strings.Contains(errLower, "network is down"): + return 502 + + // DNS resolution failures -> 502 Bad Gateway + case strings.Contains(errLower, "no such host"), + strings.Contains(errLower, "name resolution"), + strings.Contains(errLower, "lookup") && strings.Contains(errLower, "no such host"): + return 502 + + // TLS/SSL errors -> 502 Bad Gateway + case strings.Contains(errLower, "tls"), + strings.Contains(errLower, "ssl"), + strings.Contains(errLower, "certificate"): + return 502 + + // Default to 502 Bad Gateway for unknown proxy errors + default: + return 502 + } +} + func (c *ProxyClientConnection) send(pkt *client.Packet) error { defer func(start time.Time) { metrics.Metrics.ObserveFrontendWriteLatency(time.Since(start)) }(time.Now()) if c.Mode == ModeGRPC { @@ -122,11 +169,21 @@ func (c *ProxyClientConnection) send(pkt *client.Packet) error { _, err := c.HTTP.Write(pkt.GetData().Data) return err } else if pkt.Type == client.PacketType_DIAL_RSP { - if pkt.GetDialResponse().Error != "" { - body := bytes.NewBufferString(pkt.GetDialResponse().Error) + dialErr := pkt.GetDialResponse().Error + if dialErr != "" { + // // Map the error to appropriate HTTP status code + statusCode := mapDialErrorToHTTPStatus(dialErr) + statusText := http.StatusText(statusCode) + body := bytes.NewBufferString(dialErr) t := http.Response{ - StatusCode: 503, + StatusCode: statusCode, + Status: fmt.Sprintf("%d %s", statusCode, statusText), Body: io.NopCloser(body), + Header: http.Header{ + "Content-Type": []string{"text/plain; charset=utf-8"}, + }, + Proto: "HTTP/1.1", + ProtoMinor: 1, } t.Write(c.HTTP) diff --git a/pkg/server/tunnel.go b/pkg/server/tunnel.go index 5f04c02fa..ad81179ca 100644 --- a/pkg/server/tunnel.go +++ b/pkg/server/tunnel.go @@ -60,14 +60,6 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Send the HTTP 200 OK status after a successful hijack - _, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) - if err != nil { - klog.ErrorS(err, "failed to send 200 connection established") - conn.Close() - return - } - var closeOnce sync.Once defer closeOnce.Do(func() { conn.Close() }) @@ -110,11 +102,17 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.Server.PendingDial.Add(random, connection) if err := backend.Send(dialRequest); err != nil { klog.ErrorS(err, "failed to tunnel dial request", "host", r.Host, "dialID", connection.dialID, "agentID", connection.agentID) + // Send proper HTTP error response + conn.Write([]byte(fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nFailed to tunnel dial request: %v\r\n", err))) + conn.Close() return } ctxt := backend.Context() if ctxt.Err() != nil { - klog.ErrorS(err, "context reports failure") + klog.ErrorS(ctxt.Err(), "context reports failure") + conn.Write([]byte(fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nBackend context error: %v\r\n", ctxt.Err()))) + conn.Close() + return } select { @@ -125,6 +123,15 @@ func (t *Tunnel) ServeHTTP(w http.ResponseWriter, r *http.Request) { select { case <-connection.connected: // Waiting for response before we begin full communication. + // Now that connection is established, send 200 OK to switch to tunnel mode + _, err = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + if err != nil { + klog.ErrorS(err, "failed to send 200 connection established", "host", r.Host, "agentID", connection.agentID) + conn.Close() + return + } + klog.V(3).InfoS("Connection established, sent 200 OK", "host", r.Host, "agentID", connection.agentID, "connectionID", connection.connectID) + case <-closed: // Connection was closed before being established } diff --git a/tests/proxy_test.go b/tests/proxy_test.go index 66fe76807..2d019d802 100644 --- a/tests/proxy_test.go +++ b/tests/proxy_test.go @@ -689,7 +689,7 @@ func TestFailedDNSLookupProxy_HTTPCONN(t *testing.T) { t.Error(err) } - urlString := "http://thissssssxxxxx.com:80" + urlString := "http://thisdefinitelydoesnotexist.com:80" serverURL, _ := url.Parse(urlString) // Send HTTP-Connect request @@ -705,36 +705,12 @@ func TestFailedDNSLookupProxy_HTTPCONN(t *testing.T) { t.Errorf("reading HTTP response from CONNECT: %v", err) } - if res.StatusCode != 200 { - t.Errorf("expect 200; got %d", res.StatusCode) - } - if br.Buffered() > 0 { - t.Error("unexpected extra buffer") - } - dialer := func(_, _ string) (net.Conn, error) { - return conn, nil - } - - c := &http.Client{ - Transport: &http.Transport{ - Dial: dialer, - }, - } - - resp, err := c.Get(urlString) - if err != nil { - t.Error(err) + if res.StatusCode != 502 { + t.Errorf("expect 502; got %d", res.StatusCode) } - if resp.StatusCode != 503 { - t.Errorf("expect 503; got %d", res.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if !strings.Contains(err.Error(), "connection reset by peer") { - t.Error(err) - } + body, err := io.ReadAll(res.Body) + res.Body.Close() if !strings.Contains(string(body), "no such host") { t.Errorf("Unexpected error: %v", err) @@ -779,37 +755,21 @@ func TestFailedDial_HTTPCONN(t *testing.T) { br := bufio.NewReader(conn) res, err := http.ReadResponse(br, nil) if err != nil { - t.Fatalf("reading HTTP response from CONNECT: %v", err) - } - if res.StatusCode != 200 { - t.Fatalf("expect 200; got %d", res.StatusCode) + t.Errorf("reading HTTP response from CONNECT: %v", err) } - dialer := func(_, _ string) (net.Conn, error) { - return conn, nil + if res.StatusCode != 502 { + t.Errorf("expect 502; got %d", res.StatusCode) } - c := &http.Client{ - Transport: &http.Transport{ - Dial: dialer, - }, - } - - resp, err := c.Get(server.URL) + body, err := io.ReadAll(res.Body) + res.Body.Close() if err != nil { - t.Fatal(err) - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - if err == nil { t.Fatalf("Expected error reading response body; response=%q", body) - } else if !strings.Contains(err.Error(), "connection reset by peer") { - t.Error(err) } if !strings.Contains(string(body), "connection refused") { - t.Errorf("Unexpected error: %v", err) + t.Errorf("Expected 'connection refused' in error body, got: %s", string(body)) } if err := ps.Metrics().ExpectServerDialFailure(metricsserver.DialFailureErrorResponse, 1); err != nil {