diff --git a/pkg/http/http_test.go b/pkg/http/http_test.go index cac23314..ab531813 100644 --- a/pkg/http/http_test.go +++ b/pkg/http/http_test.go @@ -390,6 +390,122 @@ func TestWellKnownReverseProxy(t *testing.T) { }) } +func TestWellKnownHeaderPropagation(t *testing.T) { + cases := []string{ + ".well-known/oauth-authorization-server", + ".well-known/oauth-protected-resource", + ".well-known/openid-configuration", + } + var receivedRequestHeaders http.Header + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") { + http.NotFound(w, r) + return + } + // Capture headers received from the proxy + receivedRequestHeaders = r.Header.Clone() + // Set response headers that should be propagated back + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "https://example.com") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("X-Custom-Backend-Header", "backend-value") + _, _ = w.Write([]byte(`{"issuer": "https://example.com"}`)) + })) + t.Cleanup(testServer.Close) + staticConfig := &config.StaticConfig{ + AuthorizationURL: testServer.URL, + RequireOAuth: true, + ValidateToken: true, + ClusterProviderStrategy: config.ClusterProviderKubeConfig, + } + testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) { + for _, path := range cases { + receivedRequestHeaders = nil + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + // Add various headers to test propagation + req.Header.Set("Origin", "https://example.com") + req.Header.Set("User-Agent", "Test-Agent/1.0") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US") + req.Header.Set("X-Custom-Header", "custom-value") + req.Header.Set("Referer", "https://example.com/page") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to get %s endpoint: %v", path, err) + } + t.Cleanup(func() { _ = resp.Body.Close() }) + + t.Run("Well-known proxy propagates Origin header to backend for "+path, func(t *testing.T) { + if receivedRequestHeaders == nil { + t.Fatal("Backend did not receive any headers") + } + if receivedRequestHeaders.Get("Origin") != "https://example.com" { + t.Errorf("Expected Origin header 'https://example.com', got '%s'", receivedRequestHeaders.Get("Origin")) + } + }) + + t.Run("Well-known proxy propagates User-Agent header to backend for "+path, func(t *testing.T) { + if receivedRequestHeaders.Get("User-Agent") != "Test-Agent/1.0" { + t.Errorf("Expected User-Agent header 'Test-Agent/1.0', got '%s'", receivedRequestHeaders.Get("User-Agent")) + } + }) + + t.Run("Well-known proxy propagates Accept header to backend for "+path, func(t *testing.T) { + if receivedRequestHeaders.Get("Accept") != "application/json" { + t.Errorf("Expected Accept header 'application/json', got '%s'", receivedRequestHeaders.Get("Accept")) + } + }) + + t.Run("Well-known proxy propagates Accept-Language header to backend for "+path, func(t *testing.T) { + if receivedRequestHeaders.Get("Accept-Language") != "en-US" { + t.Errorf("Expected Accept-Language header 'en-US', got '%s'", receivedRequestHeaders.Get("Accept-Language")) + } + }) + + t.Run("Well-known proxy propagates custom headers to backend for "+path, func(t *testing.T) { + if receivedRequestHeaders.Get("X-Custom-Header") != "custom-value" { + t.Errorf("Expected X-Custom-Header 'custom-value', got '%s'", receivedRequestHeaders.Get("X-Custom-Header")) + } + }) + + t.Run("Well-known proxy propagates Referer header to backend for "+path, func(t *testing.T) { + if receivedRequestHeaders.Get("Referer") != "https://example.com/page" { + t.Errorf("Expected Referer header 'https://example.com/page', got '%s'", receivedRequestHeaders.Get("Referer")) + } + }) + + t.Run("Well-known proxy returns Access-Control-Allow-Origin from backend for "+path, func(t *testing.T) { + if resp.Header.Get("Access-Control-Allow-Origin") != "https://example.com" { + t.Errorf("Expected Access-Control-Allow-Origin header 'https://example.com', got '%s'", resp.Header.Get("Access-Control-Allow-Origin")) + } + }) + + t.Run("Well-known proxy returns Access-Control-Allow-Methods from backend for "+path, func(t *testing.T) { + if resp.Header.Get("Access-Control-Allow-Methods") != "GET, POST, OPTIONS" { + t.Errorf("Expected Access-Control-Allow-Methods header 'GET, POST, OPTIONS', got '%s'", resp.Header.Get("Access-Control-Allow-Methods")) + } + }) + + t.Run("Well-known proxy returns Cache-Control from backend for "+path, func(t *testing.T) { + if resp.Header.Get("Cache-Control") != "no-cache" { + t.Errorf("Expected Cache-Control header 'no-cache', got '%s'", resp.Header.Get("Cache-Control")) + } + }) + + t.Run("Well-known proxy returns custom response headers from backend for "+path, func(t *testing.T) { + if resp.Header.Get("X-Custom-Backend-Header") != "backend-value" { + t.Errorf("Expected X-Custom-Backend-Header 'backend-value', got '%s'", resp.Header.Get("X-Custom-Backend-Header")) + } + }) + } + }) +} + func TestWellKnownOverrides(t *testing.T) { cases := []string{ ".well-known/oauth-authorization-server", diff --git a/pkg/http/wellknown.go b/pkg/http/wellknown.go index 6c065fa5..01ff3092 100644 --- a/pkg/http/wellknown.go +++ b/pkg/http/wellknown.go @@ -32,7 +32,7 @@ var _ http.Handler = &WellKnown{} func WellKnownHandler(staticConfig *config.StaticConfig, httpClient *http.Client) http.Handler { authorizationUrl := staticConfig.AuthorizationURL - if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") { + if authorizationUrl != "" && strings.HasSuffix(authorizationUrl, "/") { authorizationUrl = strings.TrimSuffix(authorizationUrl, "/") } if httpClient == nil { @@ -56,6 +56,11 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request) http.Error(writer, "Failed to create request: "+err.Error(), http.StatusInternalServerError) return } + for key, values := range request.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } resp, err := w.httpClient.Do(req.WithContext(request.Context())) if err != nil { http.Error(writer, "Failed to perform request: "+err.Error(), http.StatusInternalServerError)