Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions pkg/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,122 @@ func TestWellKnownReverseProxy(t *testing.T) {
})
}

func TestWellKnownHeaderPropagation(t *testing.T) {
cases := []string{
".well-known/oauth-authorization-server",
".well-known/oauth-protected-resource",
".well-known/openid-configuration",
}
var receivedRequestHeaders http.Header
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.EscapedPath(), "/.well-known/") {
http.NotFound(w, r)
return
}
// Capture headers received from the proxy
receivedRequestHeaders = r.Header.Clone()
// Set response headers that should be propagated back
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "https://example.com")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("X-Custom-Backend-Header", "backend-value")
_, _ = w.Write([]byte(`{"issuer": "https://example.com"}`))
}))
t.Cleanup(testServer.Close)
staticConfig := &config.StaticConfig{
AuthorizationURL: testServer.URL,
RequireOAuth: true,
ValidateToken: true,
ClusterProviderStrategy: config.ClusterProviderKubeConfig,
}
testCaseWithContext(t, &httpContext{StaticConfig: staticConfig}, func(ctx *httpContext) {
for _, path := range cases {
receivedRequestHeaders = nil
req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/%s", ctx.HttpAddress, path), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Add various headers to test propagation
req.Header.Set("Origin", "https://example.com")
req.Header.Set("User-Agent", "Test-Agent/1.0")
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Language", "en-US")
req.Header.Set("X-Custom-Header", "custom-value")
req.Header.Set("Referer", "https://example.com/page")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to get %s endpoint: %v", path, err)
}
t.Cleanup(func() { _ = resp.Body.Close() })

t.Run("Well-known proxy propagates Origin header to backend for "+path, func(t *testing.T) {
if receivedRequestHeaders == nil {
t.Fatal("Backend did not receive any headers")
}
if receivedRequestHeaders.Get("Origin") != "https://example.com" {
t.Errorf("Expected Origin header 'https://example.com', got '%s'", receivedRequestHeaders.Get("Origin"))
}
})

t.Run("Well-known proxy propagates User-Agent header to backend for "+path, func(t *testing.T) {
if receivedRequestHeaders.Get("User-Agent") != "Test-Agent/1.0" {
t.Errorf("Expected User-Agent header 'Test-Agent/1.0', got '%s'", receivedRequestHeaders.Get("User-Agent"))
}
})

t.Run("Well-known proxy propagates Accept header to backend for "+path, func(t *testing.T) {
if receivedRequestHeaders.Get("Accept") != "application/json" {
t.Errorf("Expected Accept header 'application/json', got '%s'", receivedRequestHeaders.Get("Accept"))
}
})

t.Run("Well-known proxy propagates Accept-Language header to backend for "+path, func(t *testing.T) {
if receivedRequestHeaders.Get("Accept-Language") != "en-US" {
t.Errorf("Expected Accept-Language header 'en-US', got '%s'", receivedRequestHeaders.Get("Accept-Language"))
}
})

t.Run("Well-known proxy propagates custom headers to backend for "+path, func(t *testing.T) {
if receivedRequestHeaders.Get("X-Custom-Header") != "custom-value" {
t.Errorf("Expected X-Custom-Header 'custom-value', got '%s'", receivedRequestHeaders.Get("X-Custom-Header"))
}
})

t.Run("Well-known proxy propagates Referer header to backend for "+path, func(t *testing.T) {
if receivedRequestHeaders.Get("Referer") != "https://example.com/page" {
t.Errorf("Expected Referer header 'https://example.com/page', got '%s'", receivedRequestHeaders.Get("Referer"))
}
})

t.Run("Well-known proxy returns Access-Control-Allow-Origin from backend for "+path, func(t *testing.T) {
if resp.Header.Get("Access-Control-Allow-Origin") != "https://example.com" {
t.Errorf("Expected Access-Control-Allow-Origin header 'https://example.com', got '%s'", resp.Header.Get("Access-Control-Allow-Origin"))
}
})

t.Run("Well-known proxy returns Access-Control-Allow-Methods from backend for "+path, func(t *testing.T) {
if resp.Header.Get("Access-Control-Allow-Methods") != "GET, POST, OPTIONS" {
t.Errorf("Expected Access-Control-Allow-Methods header 'GET, POST, OPTIONS', got '%s'", resp.Header.Get("Access-Control-Allow-Methods"))
}
})

t.Run("Well-known proxy returns Cache-Control from backend for "+path, func(t *testing.T) {
if resp.Header.Get("Cache-Control") != "no-cache" {
t.Errorf("Expected Cache-Control header 'no-cache', got '%s'", resp.Header.Get("Cache-Control"))
}
})

t.Run("Well-known proxy returns custom response headers from backend for "+path, func(t *testing.T) {
if resp.Header.Get("X-Custom-Backend-Header") != "backend-value" {
t.Errorf("Expected X-Custom-Backend-Header 'backend-value', got '%s'", resp.Header.Get("X-Custom-Backend-Header"))
}
})
}
})
}

func TestWellKnownOverrides(t *testing.T) {
cases := []string{
".well-known/oauth-authorization-server",
Expand Down
7 changes: 6 additions & 1 deletion pkg/http/wellknown.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var _ http.Handler = &WellKnown{}

func WellKnownHandler(staticConfig *config.StaticConfig, httpClient *http.Client) http.Handler {
authorizationUrl := staticConfig.AuthorizationURL
if authorizationUrl != "" && strings.HasSuffix("authorizationUrl", "/") {
if authorizationUrl != "" && strings.HasSuffix(authorizationUrl, "/") {
authorizationUrl = strings.TrimSuffix(authorizationUrl, "/")
}
if httpClient == nil {
Expand All @@ -56,6 +56,11 @@ func (w WellKnown) ServeHTTP(writer http.ResponseWriter, request *http.Request)
http.Error(writer, "Failed to create request: "+err.Error(), http.StatusInternalServerError)
return
}
for key, values := range request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
resp, err := w.httpClient.Do(req.WithContext(request.Context()))
if err != nil {
http.Error(writer, "Failed to perform request: "+err.Error(), http.StatusInternalServerError)
Expand Down
Loading