diff --git a/proxy/proxy.go b/proxy/proxy.go index e5d7fc6d352..6b404189a53 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -220,11 +220,23 @@ func (p *Proxy) proxyHTTPRequest( _, ttfbSpan := tr.Tracer().Start(tr.Context(), "ttfb_origin") resp, err := httpService.RoundTrip(roundTripReq) if err != nil { - tracing.EndWithErrorStatus(ttfbSpan, err) - if err := roundTripReq.Context().Err(); err != nil { - return errors.Wrap(err, "Incoming request ended abruptly") + // Check for GOAWAY error and retry once if applicable + const goawayMsg = "http2: Transport received Server's graceful shutdown GOAWAY" + if err.Error() == goawayMsg && roundTripReq.GetBody != nil { + // Reset the body for retry + newBody, getBodyErr := roundTripReq.GetBody() + if getBodyErr == nil { + roundTripReq.Body = newBody + resp, err = httpService.RoundTrip(roundTripReq) + } + } + if err != nil { + tracing.EndWithErrorStatus(ttfbSpan, err) + if err := roundTripReq.Context().Err(); err != nil { + return errors.Wrap(err, "Incoming request ended abruptly") + } + return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared") } - return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared") } tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index f50381228c0..f97978dbec9 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "errors" "flag" "fmt" "io" @@ -1014,3 +1015,61 @@ func runEchoWSService(t *testing.T, l net.Listener) { } }() } + +func TestHandleGOAWAYRetry(t *testing.T) { + // Simulate a request body + bodyContent := "test body content" + body := io.NopCloser(strings.NewReader(bodyContent)) + + // Create a mock request with a body + roundTripReq := &http.Request{ + Body: body, + } + + // Simulate the GOAWAY error + goawayError := errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + + // Assign the GetBody function + roundTripReq.GetBody = func() (io.ReadCloser, error) { + if goawayError.Error() == "http2: Transport received Server's graceful shutdown GOAWAY" { + return roundTripReq.Body, nil + } + return nil, goawayError + } + + // Test the GetBody function + retriedBody, err := roundTripReq.GetBody() + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Verify the retried body content + retriedContent, _ := io.ReadAll(retriedBody) + if string(retriedContent) != bodyContent { + t.Fatalf("Expected body content '%s', got '%s'", bodyContent, string(retriedContent)) + } +} +func TestHandleGOAWAYRetryError(t *testing.T) { + // Simulate a request body + bodyContent := "test body content" + body := io.NopCloser(strings.NewReader(bodyContent)) + + // Create a mock request with a body + roundTripReq := &http.Request{ + Body: body, + } + + // Simulate the GOAWAY error + goawayError := errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + + // Assign the GetBody function to return an error + roundTripReq.GetBody = func() (io.ReadCloser, error) { + return nil, goawayError + } + + // Test the GetBody function + retriedBody, err := roundTripReq.GetBody() + if retriedBody != nil || err == nil { + t.Fatalf("Expected error, got body: %v, error: %v", retriedBody, err) + } +}