diff --git a/pkg/gofr/auth.go b/pkg/gofr/auth.go index 98ddea6a3..1eb18cf47 100644 --- a/pkg/gofr/auth.go +++ b/pkg/gofr/auth.go @@ -2,6 +2,8 @@ package gofr import ( "net/http" + "net/url" + "strings" "time" "github.com/golang-jwt/jwt/v5" @@ -112,11 +114,35 @@ func (a *App) EnableOAuth(jwksEndpoint string, refreshInterval int, options ...jwt.ParserOption, ) { - a.AddHTTPService("gofr_oauth", jwksEndpoint) + parsedURL, err := url.Parse(jwksEndpoint) + if err != nil { + a.container.Errorf("invalid JWKS endpoint URL: %v", err) + return + } + + if parsedURL.Scheme == "" || parsedURL.Host == "" { + a.container.Errorf("invalid JWKS endpoint URL: missing scheme or host in %q", jwksEndpoint) + return + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + a.container.Errorf("invalid JWKS endpoint URL: unsupported scheme %q", parsedURL.Scheme) + return + } + + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + jwksPath := strings.TrimPrefix(parsedURL.Path, "/") + + if parsedURL.RawQuery != "" { + jwksPath += "?" + parsedURL.RawQuery + } + + a.AddHTTPService("gofr_oauth", baseURL) oauthOption := middleware.OauthConfigs{ Provider: a.container.GetHTTPService("gofr_oauth"), RefreshInterval: time.Second * time.Duration(refreshInterval), + Path: jwksPath, } publicKeyProvider := middleware.NewOAuth(oauthOption) diff --git a/pkg/gofr/gofr_test.go b/pkg/gofr/gofr_test.go index 87ae99ef9..037a66039 100644 --- a/pkg/gofr/gofr_test.go +++ b/pkg/gofr/gofr_test.go @@ -421,6 +421,85 @@ func TestEnableBasicAuthWithFunc(t *testing.T) { assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, "TestEnableBasicAuthWithFunc Failed!") } +func TestEnableOAuth_HealthCheckEndpoint(t *testing.T) { + port := testutil.GetFreePort(t) + + // Mock server that serves both /.well-known/alive and /.well-known/jwks.json + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/alive": + w.WriteHeader(http.StatusOK) + case "/.well-known/jwks.json": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"keys":[]}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer mockServer.Close() + + c := container.NewContainer(config.NewMockConfig(nil)) + + a := &App{ + httpServer: &httpServer{ + router: gofrHTTP.NewRouter(), + port: port, + }, + container: c, + } + + // Pass full JWKS URL with path — the fix should extract the base URL + a.EnableOAuth(mockServer.URL+"/.well-known/jwks.json", 600) + + // Verify the service is registered + oauthService := a.container.GetHTTPService("gofr_oauth") + require.NotNil(t, oauthService, "gofr_oauth service should be registered") + + // Health check should hit mockServer/.well-known/alive (not mockServer/.well-known/jwks.json/.well-known/alive) + health := oauthService.HealthCheck(t.Context()) + assert.Equal(t, "UP", health.Status, "Health check should hit the host root, not the JWKS path") + + // JWKS fetch should hit mockServer/.well-known/jwks.json (not mockServer//.well-known/jwks.json) + resp, err := oauthService.GetWithHeaders(t.Context(), ".well-known/jwks.json", nil, nil) + require.NoError(t, err) + + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, "JWKS fetch should hit the correct path without double slash") +} + +func TestEnableOAuth_InvalidEndpoints(t *testing.T) { + invalidEndpoints := []string{ + "", + "not-a-url", + "/.well-known/jwks.json", + "http://", + "ftp://host/.well-known/jwks.json", + } + + for _, endpoint := range invalidEndpoints { + t.Run(endpoint, func(t *testing.T) { + port := testutil.GetFreePort(t) + c := container.NewContainer(config.NewMockConfig(nil)) + + a := &App{ + httpServer: &httpServer{ + router: gofrHTTP.NewRouter(), + port: port, + }, + container: c, + } + + a.EnableOAuth(endpoint, 600) + + // Service should NOT be registered for invalid endpoints + assert.Nil(t, a.container.GetHTTPService("gofr_oauth"), + "gofr_oauth service should not be registered for invalid endpoint: %q", endpoint) + }) + } +} + func encodeBasicAuthorization(t *testing.T, arg string) string { t.Helper()