diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index e9d87574c..6db884187 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -63,6 +63,17 @@ func serve(ctx context.Context) { defer wg.Wait() // Do not return to caller until this goroutine is done. mrCache := templatemailer.NewCache() + if !config.Mailer.TemplateReloadingEnabled { + // If template reloading is disabled attempt an initial reload at + // startup for fault tolerance. + wg.Add(1) + go func() { + defer wg.Done() + + mrCache.Reload(ctx, config) + }() + } + limiterOpts := api.NewLimiterOptions(config) initialAPI := api.NewAPIWithVersion( config, db, utilities.Version, diff --git a/example.env b/example.env index b98824643..3be32d933 100644 --- a/example.env +++ b/example.env @@ -244,6 +244,7 @@ GOTRUE_LOG_LEVEL="debug" GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED="false" GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL="0" GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION="false" +GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_CURRENT_PASSWORD="false" GOTRUE_OPERATOR_TOKEN="unused-operator-token" GOTRUE_RATE_LIMIT_HEADER="X-Forwarded-For" GOTRUE_RATE_LIMIT_EMAIL_SENT="100" diff --git a/hack/test.env b/hack/test.env index 97a01ba03..dc4769eaa 100644 --- a/hack/test.env +++ b/hack/test.env @@ -105,6 +105,10 @@ GOTRUE_EXTERNAL_TWITTER_ENABLED=true GOTRUE_EXTERNAL_TWITTER_CLIENT_ID=testclientid GOTRUE_EXTERNAL_TWITTER_SECRET=testsecret GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_X_ENABLED=true +GOTRUE_EXTERNAL_X_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_X_SECRET=testsecret +GOTRUE_EXTERNAL_X_REDIRECT_URI=https://identity.services.netlify.com/callback GOTRUE_EXTERNAL_ZOOM_ENABLED=true GOTRUE_EXTERNAL_ZOOM_CLIENT_ID=testclientid GOTRUE_EXTERNAL_ZOOM_SECRET=testsecret diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index a1b445791..9d184e7ad 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -58,5 +58,5 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { } metering.RecordLogin(metering.LoginTypeAnonymous, newUser.ID, nil) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go index 5eb90f8b0..58963eea3 100644 --- a/internal/api/apierrors/errorcode.go +++ b/internal/api/apierrors/errorcode.go @@ -23,6 +23,9 @@ const ( ErrorCodeRefreshTokenAlreadyUsed ErrorCode = "refresh_token_already_used" ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found" ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired" + ErrorCodeOAuthClientStateNotFound ErrorCode = "oauth_client_state_not_found" + ErrorCodeOAuthClientStateExpired ErrorCode = "oauth_client_state_expired" + ErrorCodeOAuthInvalidState ErrorCode = "oauth_invalid_state" ErrorCodeSignupDisabled ErrorCode = "signup_disabled" ErrorCodeUserBanned ErrorCode = "user_banned" ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification" diff --git a/internal/api/context.go b/internal/api/context.go index e1d285cb1..77dbfdf5b 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -4,6 +4,7 @@ import ( "context" "net/url" + "github.com/gofrs/uuid" jwt "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" @@ -33,6 +34,7 @@ const ( ssoProviderKey = contextKey("sso_provider") externalHostKey = contextKey("external_host") flowStateKey = contextKey("flow_state_id") + oauthClientStateKey = contextKey("oauth_client_state_id") ) // withToken adds the JWT token to the context. @@ -137,6 +139,19 @@ func getFlowStateID(ctx context.Context) string { return obj.(string) } +func withOAuthClientStateID(ctx context.Context, oauthClientStateID uuid.UUID) context.Context { + return context.WithValue(ctx, oauthClientStateKey, oauthClientStateID) +} + +func getOAuthClientStateID(ctx context.Context) uuid.UUID { + obj := ctx.Value(oauthClientStateKey) + if obj == nil { + return uuid.Nil + } + + return obj.(uuid.UUID) +} + func getInviteToken(ctx context.Context) string { obj := ctx.Value(inviteTokenKey) if obj == nil { diff --git a/internal/api/external.go b/internal/api/external.go index 9611c24c2..dc2fd6e00 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -27,12 +27,13 @@ import ( // ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow type ExternalProviderClaims struct { AuthMicroserviceClaims - Provider string `json:"provider"` - InviteToken string `json:"invite_token,omitempty"` - Referrer string `json:"referrer,omitempty"` - FlowStateID string `json:"flow_state_id"` - LinkingTargetID string `json:"linking_target_id,omitempty"` - EmailOptional bool `json:"email_optional,omitempty"` + Provider string `json:"provider"` + InviteToken string `json:"invite_token,omitempty"` + Referrer string `json:"referrer,omitempty"` + FlowStateID string `json:"flow_state_id"` + OAuthClientStateID string `json:"oauth_client_state_id,omitempty"` + LinkingTargetID string `json:"linking_target_id,omitempty"` + EmailOptional bool `json:"email_optional,omitempty"` } // ExternalProviderRedirect redirects the request to the oauth provider @@ -90,6 +91,32 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ flowStateID = flowState.ID.String() } + authUrlParams := make([]oauth2.AuthCodeOption, 0) + query.Del("scopes") + query.Del("provider") + query.Del("code_challenge") + query.Del("code_challenge_method") + for key := range query { + if key == "workos_provider" { + // See https://workos.com/docs/reference/sso/authorize/get + authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key))) + } else { + authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key))) + } + } + + oauthClientStateID := "" + if oauthProvider, ok := p.(provider.OAuthProvider); ok && oauthProvider.RequiresPKCE() { + codeVerifier := oauth2.GenerateVerifier() + oauthClientState := models.NewOAuthClientState(providerType, &codeVerifier) + err := db.Create(oauthClientState) + if err != nil { + return "", err + } + oauthClientStateID = oauthClientState.ID.String() + authUrlParams = append(authUrlParams, oauth2.S256ChallengeOption(codeVerifier)) + } + claims := ExternalProviderClaims{ AuthMicroserviceClaims: AuthMicroserviceClaims{ RegisteredClaims: jwt.RegisteredClaims{ @@ -98,11 +125,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ SiteURL: config.SiteURL, InstanceID: uuid.Nil.String(), }, - Provider: providerType, - InviteToken: inviteToken, - Referrer: redirectURL, - FlowStateID: flowStateID, - EmailOptional: pConfig.EmailOptional, + Provider: providerType, + InviteToken: inviteToken, + Referrer: redirectURL, + FlowStateID: flowStateID, + OAuthClientStateID: oauthClientStateID, + EmailOptional: pConfig.EmailOptional, } if linkingTargetUser != nil { @@ -115,20 +143,6 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ return "", apierrors.NewInternalServerError("Error creating state").WithInternalError(err) } - authUrlParams := make([]oauth2.AuthCodeOption, 0) - query.Del("scopes") - query.Del("provider") - query.Del("code_challenge") - query.Del("code_challenge_method") - for key := range query { - if key == "workos_provider" { - // See https://workos.com/docs/reference/sso/authorize/get - authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key))) - } else { - authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key))) - } - } - authURL := p.AuthCodeURL(tokenString, authUrlParams...) return authURL, nil @@ -565,6 +579,13 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag if claims.FlowStateID != "" { ctx = withFlowStateID(ctx, claims.FlowStateID) } + if claims.OAuthClientStateID != "" { + oauthClientStateID, err := uuid.FromString(claims.OAuthClientStateID) + if err != nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (oauth_client_state_id must be UUID)") + } + ctx = withOAuthClientStateID(ctx, oauthClientStateID) + } if claims.LinkingTargetID != "" { linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) if err != nil { @@ -634,7 +655,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide p, err = provider.NewLinkedinProvider(pConfig, scopes) case "linkedin_oidc": pConfig = config.External.LinkedinOIDC - p, err = provider.NewLinkedinOIDCProvider(pConfig, scopes) + p, err = provider.NewLinkedinOIDCProvider(ctx, pConfig, scopes) case "notion": pConfig = config.External.Notion p, err = provider.NewNotionProvider(pConfig) @@ -656,9 +677,12 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide case "twitter": pConfig = config.External.Twitter p, err = provider.NewTwitterProvider(pConfig, scopes) + case "x": + pConfig = config.External.X + p, err = provider.NewXProvider(pConfig, scopes) case "vercel_marketplace": pConfig = config.External.VercelMarketplace - p, err = provider.NewVercelMarketplaceProvider(pConfig, scopes) + p, err = provider.NewVercelMarketplaceProvider(ctx, pConfig, scopes) case "workos": pConfig = config.External.WorkOS p, err = provider.NewWorkOSProvider(pConfig) diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index a02623a38..40e737a04 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -6,13 +6,16 @@ import ( "net/http" "net/url" + "github.com/gofrs/uuid" "github.com/mrjones/oauth" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/provider" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" ) // OAuthProviderData contains the userData and token returned by the oauth provider @@ -55,6 +58,8 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con } func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { + db := a.db.WithContext(ctx) + var rq url.Values if err := r.ParseForm(); r.Method == http.MethodPost && err == nil { rq = r.Form @@ -72,28 +77,56 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") } - oAuthProvider, _, err := a.OAuthProvider(ctx, providerType) + oauthProvider, _, err := a.OAuthProvider(ctx, providerType) if err != nil { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } log := observability.GetLogEntry(r).Entry + + var oauthClientState *models.OAuthClientState + // if there's a non-empty OAuthClientStateID we perform PKCE Flow for the external provider + if oauthClientStateID := getOAuthClientStateID(ctx); oauthClientStateID != uuid.Nil { + oauthClientState, err = models.FindAndDeleteOAuthClientStateByID(db, oauthClientStateID) + if models.IsNotFoundError(err) { + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateNotFound, "OAuth state not found").WithInternalError(err) + } else if err != nil { + return nil, apierrors.NewInternalServerError("Failed to find OAuth state").WithInternalError(err) + } + + if oauthClientState.ProviderType != providerType { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth provider mismatch") + } + + if oauthClientState.IsExpired() { + return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateExpired, "OAuth state expired") + } + } + + if oauthProvider.RequiresPKCE() && oauthClientState == nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth PKCE code verifier missing") + } + log.WithFields(logrus.Fields{ "provider": providerType, "code": oauthCode, - }).Debug("Exchanging oauth code") + }).Debug("Exchanging OAuth code") - token, err := oAuthProvider.GetOAuthToken(oauthCode) + var tokenOpts []oauth2.AuthCodeOption + if oauthClientState != nil { + tokenOpts = append(tokenOpts, oauth2.VerifierOption(*oauthClientState.CodeVerifier)) + } + token, err := oauthProvider.GetOAuthToken(ctx, oauthCode, tokenOpts...) if err != nil { return nil, apierrors.NewInternalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err) } - userData, err := oAuthProvider.GetUserData(ctx, token) + userData, err := oauthProvider.GetUserData(ctx, token) if err != nil { return nil, apierrors.NewInternalServerError("Error getting user profile from external provider").WithInternalError(err) } - switch externalProvider := oAuthProvider.(type) { + switch externalProvider := oauthProvider.(type) { case *provider.AppleProvider: // apple only returns user info the first time oauthUser := rq.Get("user") diff --git a/internal/api/external_x_test.go b/internal/api/external_x_test.go new file mode 100644 index 000000000..3ab6b8077 --- /dev/null +++ b/internal/api/external_x_test.go @@ -0,0 +1,197 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + xUser string = `{"data":{"id":"xTestId","name":"X Test","username":"xtest","confirmed_email":"x@example.com","profile_image_url":"https://pbs.twimg.com/profile_images/test.jpg","url":"https://example.com","created_at":"2020-01-01T00:00:00.000Z"}}` + xUserWrongEmail string = `{"data":{"id":"xTestId","name":"X Test","username":"xtest","confirmed_email":"other@example.com","profile_image_url":"https://pbs.twimg.com/profile_images/test.jpg","url":"https://example.com","created_at":"2020-01-01T00:00:00.000Z"}}` + xUserNoEmail string = `{"data":{"id":"xTestId","name":"X Test","username":"xtest","profile_image_url":"https://pbs.twimg.com/profile_images/test.jpg","url":"https://example.com","created_at":"2020-01-01T00:00:00.000Z"}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalX() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=x", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.X.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.X.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("users.email tweet.read users.read offline.access", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("x", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func XTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/2/oauth2/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.X.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"x_token","expires_in":100000}`) + case "/2/users/me": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown X oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.X.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalX_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUser) + defer server.Close() + + u := performAuthorization(ts, "x", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "x@example.com", "X Test", "xTestId", "https://pbs.twimg.com/profile_images/test.jpg") +} + +func (ts *ExternalTestSuite) TestSignupExternalX_AuthorizationCode_NoEmailWithEmailOptional() { + // When EmailOptional is true, signup should succeed without email + ts.Config.DisableSignup = false + ts.Config.External.X.EmailOptional = true + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "x", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "", "X Test", "xTestId", "https://pbs.twimg.com/profile_images/test.jpg") +} + +func (ts *ExternalTestSuite) TestSignupExternalX_AuthorizationCode_NoEmailWithoutEmailOptional() { + // When EmailOptional is false, signup should fail without email + ts.Config.DisableSignup = false + ts.Config.External.X.EmailOptional = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "x", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalXDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUser) + defer server.Close() + + u := performAuthorization(ts, "x", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "x@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalXDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + ts.Config.External.X.EmailOptional = false + + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "x", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "x@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalXDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("xTestId", "x@example.com", "X Test", "https://pbs.twimg.com/profile_images/test.jpg", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUser) + defer server.Close() + + u := performAuthorization(ts, "x", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "x@example.com", "X Test", "xTestId", "https://pbs.twimg.com/profile_images/test.jpg") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalXSuccessWhenMatchingToken() { + // name and avatar should be populated from X API + ts.createUser("xTestId", "x@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUser) + defer server.Close() + + u := performAuthorization(ts, "x", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "x@example.com", "X Test", "xTestId", "https://pbs.twimg.com/profile_images/test.jpg") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalXErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "x", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalXErrorWhenWrongToken() { + ts.createUser("xTestId", "x@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "x", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalXErrorWhenEmailDoesntMatch() { + ts.createUser("xTestId", "x@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := XTestSignupSetup(ts, &tokenCount, &userCount, code, xUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "x", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 24643f736..a13acb89b 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -19,6 +19,12 @@ func sendJSON(w http.ResponseWriter, status int, obj interface{}) error { return shared.SendJSON(w, status, obj) } +func sendTokenJSON(w http.ResponseWriter, status int, obj interface{}) error { + w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate") + w.Header().Set("Pragma", "no-cache") + return shared.SendJSON(w, status, obj) +} + func isAdmin(u *models.User, config *conf.GlobalConfiguration) bool { return config.JWT.Aud == u.Aud && u.HasRole(config.JWT.AdminGroupName) } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 81523363f..007e5b12c 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -752,7 +752,7 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V Provider: metering.ProviderMFATOTP, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } @@ -892,7 +892,7 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * Provider: metering.ProviderMFAPhone, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { @@ -1012,7 +1012,7 @@ func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, param Provider: metering.ProviderMFAWebAuthn, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { diff --git a/internal/api/provider/apple.go b/internal/api/provider/apple.go index 260a3605c..30064255c 100644 --- a/internal/api/provider/apple.go +++ b/internal/api/provider/apple.go @@ -118,12 +118,17 @@ func NewAppleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration) } // GetOAuthToken returns the apple provider access token -func (p AppleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - opts := []oauth2.AuthCodeOption{ +func (p AppleProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + appleOpts := []oauth2.AuthCodeOption{ oauth2.SetAuthURLParam("client_id", p.ClientID), oauth2.SetAuthURLParam("secret", p.ClientSecret), } - return p.Exchange(context.Background(), code, opts...) + appleOpts = append(appleOpts, opts...) + return p.Exchange(ctx, code, appleOpts...) +} + +func (p AppleProvider) RequiresPKCE() bool { + return false } func (p AppleProvider) AuthCodeURL(state string, args ...oauth2.AuthCodeOption) string { diff --git a/internal/api/provider/azure.go b/internal/api/provider/azure.go index 4a341f4d6..d81a4ffe3 100644 --- a/internal/api/provider/azure.go +++ b/internal/api/provider/azure.go @@ -91,8 +91,12 @@ func NewAzureProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth }, nil } -func (g azureProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g azureProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g azureProvider) RequiresPKCE() bool { + return false } func DetectAzureIDTokenIssuer(ctx context.Context, idToken string) (string, error) { diff --git a/internal/api/provider/bitbucket.go b/internal/api/provider/bitbucket.go index e5fae5c91..f10bcf819 100644 --- a/internal/api/provider/bitbucket.go +++ b/internal/api/provider/bitbucket.go @@ -59,8 +59,12 @@ func NewBitbucketProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, e }, nil } -func (g bitbucketProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g bitbucketProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g bitbucketProvider) RequiresPKCE() bool { + return false } func (g bitbucketProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/discord.go b/internal/api/provider/discord.go index 50d413b7c..7e0199bb9 100644 --- a/internal/api/provider/discord.go +++ b/internal/api/provider/discord.go @@ -61,8 +61,12 @@ func NewDiscordProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu }, nil } -func (g discordProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g discordProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g discordProvider) RequiresPKCE() bool { + return false } func (g discordProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/facebook.go b/internal/api/provider/facebook.go index e73c419da..5940cf57c 100644 --- a/internal/api/provider/facebook.go +++ b/internal/api/provider/facebook.go @@ -70,8 +70,12 @@ func NewFacebookProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA }, nil } -func (p facebookProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return p.Exchange(context.Background(), code) +func (p facebookProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.Exchange(ctx, code, opts...) +} + +func (p facebookProvider) RequiresPKCE() bool { + return false } func (p facebookProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/figma.go b/internal/api/provider/figma.go index d9b5b6b6a..e6777770a 100644 --- a/internal/api/provider/figma.go +++ b/internal/api/provider/figma.go @@ -60,8 +60,12 @@ func NewFigmaProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth }, nil } -func (p figmaProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return p.Exchange(context.Background(), code) +func (p figmaProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.Exchange(ctx, code, opts...) +} + +func (p figmaProvider) RequiresPKCE() bool { + return false } func (p figmaProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/fly.go b/internal/api/provider/fly.go index d9337524f..2f863b192 100644 --- a/internal/api/provider/fly.go +++ b/internal/api/provider/fly.go @@ -65,8 +65,12 @@ func NewFlyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthPr }, nil } -func (p flyProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return p.Exchange(context.Background(), code) +func (p flyProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.Exchange(ctx, code, opts...) +} + +func (p flyProvider) RequiresPKCE() bool { + return false } func (p flyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/github.go b/internal/api/provider/github.go index 0da3e8842..d6d8b5504 100644 --- a/internal/api/provider/github.go +++ b/internal/api/provider/github.go @@ -70,8 +70,12 @@ func NewGithubProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut }, nil } -func (g githubProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g githubProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g githubProvider) RequiresPKCE() bool { + return false } func (g githubProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/gitlab.go b/internal/api/provider/gitlab.go index 4b5d70cb8..02411bad4 100644 --- a/internal/api/provider/gitlab.go +++ b/internal/api/provider/gitlab.go @@ -61,8 +61,12 @@ func NewGitlabProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut }, nil } -func (g gitlabProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g gitlabProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g gitlabProvider) RequiresPKCE() bool { + return false } func (g gitlabProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/google.go b/internal/api/provider/google.go index 03b76aebe..b4f40c82c 100644 --- a/internal/api/provider/google.go +++ b/internal/api/provider/google.go @@ -72,8 +72,12 @@ func NewGoogleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration, }, nil } -func (g googleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g googleProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g googleProvider) RequiresPKCE() bool { + return false } const UserInfoEndpointGoogle = "https://www.googleapis.com/userinfo/v2/me" diff --git a/internal/api/provider/kakao.go b/internal/api/provider/kakao.go index 2482b97a8..a5588d1f8 100644 --- a/internal/api/provider/kakao.go +++ b/internal/api/provider/kakao.go @@ -33,8 +33,12 @@ type kakaoUser struct { } `json:"kakao_account"` } -func (p kakaoProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return p.Exchange(context.Background(), code) +func (p kakaoProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.Exchange(ctx, code, opts...) +} + +func (p kakaoProvider) RequiresPKCE() bool { + return false } func (p kakaoProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/keycloak.go b/internal/api/provider/keycloak.go index 480a46724..48d3cdd67 100644 --- a/internal/api/provider/keycloak.go +++ b/internal/api/provider/keycloak.go @@ -85,8 +85,12 @@ func NewKeycloakProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA }, nil } -func (g keycloakProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g keycloakProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g keycloakProvider) RequiresPKCE() bool { + return false } func (g keycloakProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/linkedin.go b/internal/api/provider/linkedin.go index bc33515e7..fa2ffe11e 100644 --- a/internal/api/provider/linkedin.go +++ b/internal/api/provider/linkedin.go @@ -97,8 +97,12 @@ func NewLinkedinProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA }, nil } -func (g linkedinProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g linkedinProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g linkedinProvider) RequiresPKCE() bool { + return false } func GetName(name linkedinName) string { diff --git a/internal/api/provider/linkedin_oidc.go b/internal/api/provider/linkedin_oidc.go index a5d94fa09..64260cb25 100644 --- a/internal/api/provider/linkedin_oidc.go +++ b/internal/api/provider/linkedin_oidc.go @@ -21,7 +21,7 @@ type linkedinOIDCProvider struct { } // NewLinkedinOIDCProvider creates a Linkedin account provider via OIDC. -func NewLinkedinOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { +func NewLinkedinOIDCProvider(ctx context.Context, ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -38,7 +38,7 @@ func NewLinkedinOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) } - oidcProvider, err := oidc.NewProvider(context.Background(), IssuerLinkedin) + oidcProvider, err := oidc.NewProvider(ctx, IssuerLinkedin) if err != nil { return nil, err } @@ -59,8 +59,12 @@ func NewLinkedinOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) }, nil } -func (g linkedinOIDCProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g linkedinOIDCProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g linkedinOIDCProvider) RequiresPKCE() bool { + return false } func (g linkedinOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/notion.go b/internal/api/provider/notion.go index f8d0ee706..6f9844572 100644 --- a/internal/api/provider/notion.go +++ b/internal/api/provider/notion.go @@ -59,8 +59,12 @@ func NewNotionProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, erro }, nil } -func (g notionProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g notionProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g notionProvider) RequiresPKCE() bool { + return false } func (g notionProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/provider.go b/internal/api/provider/provider.go index 5471e1cee..9ced63937 100644 --- a/internal/api/provider/provider.go +++ b/internal/api/provider/provider.go @@ -104,7 +104,8 @@ type Provider interface { type OAuthProvider interface { AuthCodeURL(string, ...oauth2.AuthCodeOption) string GetUserData(context.Context, *oauth2.Token) (*UserProvidedData, error) - GetOAuthToken(string) (*oauth2.Token, error) + GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + RequiresPKCE() bool } func chooseHost(base, defaultHost string) string { diff --git a/internal/api/provider/slack.go b/internal/api/provider/slack.go index 40377b0aa..5949b2f30 100644 --- a/internal/api/provider/slack.go +++ b/internal/api/provider/slack.go @@ -57,8 +57,12 @@ func NewSlackProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuth }, nil } -func (g slackProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g slackProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g slackProvider) RequiresPKCE() bool { + return false } func (g slackProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/slack_oidc.go b/internal/api/provider/slack_oidc.go index 3c7a5eb62..9485949fd 100644 --- a/internal/api/provider/slack_oidc.go +++ b/internal/api/provider/slack_oidc.go @@ -60,8 +60,12 @@ func NewSlackOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) (O }, nil } -func (g slackOIDCProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g slackOIDCProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g slackOIDCProvider) RequiresPKCE() bool { + return false } func (g slackOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/snapchat.go b/internal/api/provider/snapchat.go index 6c64b86e2..54abe6c56 100644 --- a/internal/api/provider/snapchat.go +++ b/internal/api/provider/snapchat.go @@ -71,8 +71,12 @@ func NewSnapchatProvider(ext conf.OAuthProviderConfiguration, scopes string) (OA }, nil } -func (p snapchatProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return p.Exchange(context.Background(), code) +func (p snapchatProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.Exchange(ctx, code, opts...) +} + +func (p snapchatProvider) RequiresPKCE() bool { + return false } func (p snapchatProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/spotify.go b/internal/api/provider/spotify.go index e6d2f383c..ceaf104d4 100644 --- a/internal/api/provider/spotify.go +++ b/internal/api/provider/spotify.go @@ -63,8 +63,12 @@ func NewSpotifyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu }, nil } -func (g spotifyProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g spotifyProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g spotifyProvider) RequiresPKCE() bool { + return false } func (g spotifyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/twitch.go b/internal/api/provider/twitch.go index defb1983a..f662b8e28 100644 --- a/internal/api/provider/twitch.go +++ b/internal/api/provider/twitch.go @@ -75,8 +75,12 @@ func NewTwitchProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAut }, nil } -func (t twitchProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return t.Exchange(context.Background(), code) +func (t twitchProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return t.Exchange(ctx, code, opts...) +} + +func (t twitchProvider) RequiresPKCE() bool { + return false } func (t twitchProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/twitter.go b/internal/api/provider/twitter.go index 8dc5a4c64..57af16306 100644 --- a/internal/api/provider/twitter.go +++ b/internal/api/provider/twitter.go @@ -60,10 +60,14 @@ func NewTwitterProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAu } // GetOAuthToken is a stub method for OAuthProvider interface, unused in OAuth1.0 protocol -func (t TwitterProvider) GetOAuthToken(_ string) (*oauth2.Token, error) { +func (t TwitterProvider) GetOAuthToken(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { return &oauth2.Token{}, nil } +func (t TwitterProvider) RequiresPKCE() bool { + return false +} + // GetUserData is a stub method for OAuthProvider interface, unused in OAuth1.0 protocol func (t TwitterProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { return &UserProvidedData{}, nil diff --git a/internal/api/provider/vercel_marketplace.go b/internal/api/provider/vercel_marketplace.go index ba76a7412..c74c8e4ad 100644 --- a/internal/api/provider/vercel_marketplace.go +++ b/internal/api/provider/vercel_marketplace.go @@ -22,7 +22,7 @@ type vercelMarketplaceProvider struct { } // NewVercelMarketplaceProvider creates a VercelMarketplace account provider via OIDC. -func NewVercelMarketplaceProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { +func NewVercelMarketplaceProvider(ctx context.Context, ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { if err := ext.ValidateOAuth(); err != nil { return nil, err } @@ -35,7 +35,7 @@ func NewVercelMarketplaceProvider(ext conf.OAuthProviderConfiguration, scopes st oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) } - oidcProvider, err := oidc.NewProvider(context.Background(), IssuerVercelMarketplace) + oidcProvider, err := oidc.NewProvider(ctx, IssuerVercelMarketplace) if err != nil { return nil, err } @@ -56,8 +56,12 @@ func NewVercelMarketplaceProvider(ext conf.OAuthProviderConfiguration, scopes st }, nil } -func (g vercelMarketplaceProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g vercelMarketplaceProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g vercelMarketplaceProvider) RequiresPKCE() bool { + return false } func (g vercelMarketplaceProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/workos.go b/internal/api/provider/workos.go index 75cafa27d..74be569c2 100644 --- a/internal/api/provider/workos.go +++ b/internal/api/provider/workos.go @@ -53,8 +53,12 @@ func NewWorkOSProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, erro }, nil } -func (g workosProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g workosProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g workosProvider) RequiresPKCE() bool { + return false } func (g workosProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/provider/x.go b/internal/api/provider/x.go new file mode 100644 index 000000000..dbac8a4c5 --- /dev/null +++ b/internal/api/provider/x.go @@ -0,0 +1,137 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +// X (formerly Twitter) API v2 OAuth 2.0 endpoints +// See: https://developer.x.com/en/docs/authentication/oauth-2-0/authorization-code +const ( + defaultXAuthBase = "x.com" + defaultXAPIBase = "api.x.com" +) + +type xProvider struct { + *oauth2.Config + APIHost string +} + +// xUser represents the user object from X API v2 +// See: https://developer.x.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me +type xUser struct { + ID string `json:"id"` + Name string `json:"name"` + Username string `json:"username"` + ConfirmedEmail string `json:"confirmed_email"` + ProfileImageURL string `json:"profile_image_url"` + URL string `json:"url"` + CreatedAt string `json:"created_at"` +} + +// xUserResponse is the wrapper for the X API v2 response +type xUserResponse struct { + Data xUser `json:"data"` +} + +// NewXProvider creates an X (formerly Twitter) v2 OAuth 2.0 provider. +// This uses OAuth 2.0 with PKCE instead of OAuth 1.0a. +// See: https://developer.x.com/en/docs/authentication/oauth-2-0/authorization-code +func NewXProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultXAuthBase) + apiHost := chooseHost(ext.URL, defaultXAPIBase) + + // Default scopes for user authentication + // users.email: Access to the user's email address (confirmed_email field) + // users.read: Read user profile information + // tweet.read: Required scope for OAuth 2.0 user context even if not accessing tweets + // offline.access: Get refresh tokens for long-lived access + // See: https://developer.x.com/en/docs/authentication/oauth-2-0/authorization-code + // and: https://docs.x.com/fundamentals/authentication/guides/v2-authentication-mapping + oauthScopes := []string{ + "users.email", + "tweet.read", + "users.read", + "offline.access", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &xProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/i/oauth2/authorize", + TokenURL: apiHost + "/2/oauth2/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIHost: apiHost, + }, nil +} + +func (x xProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return x.Exchange(ctx, code, opts...) +} + +func (x xProvider) RequiresPKCE() bool { + return true +} + +func (x xProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var resp xUserResponse + + // See: https://developer.x.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me + userInfoURL := x.APIHost + "/2/users/me?user.fields=id,name,username,confirmed_email,profile_image_url,url,created_at" + + if err := makeRequest(ctx, tok, x.Config, userInfoURL, &resp); err != nil { + return nil, err + } + + u := resp.Data + + data := &UserProvidedData{ + Metadata: &Claims{ + Issuer: x.APIHost, + Subject: u.ID, + Name: u.Name, + PreferredUsername: u.Username, + Picture: u.ProfileImageURL, + Profile: "https://x.com/" + u.Username, + Website: u.URL, + + // Custom claims for X specific data + CustomClaims: map[string]any{ + "created_at": u.CreatedAt, + }, + + // To be deprecated + AvatarURL: u.ProfileImageURL, + FullName: u.Name, + ProviderId: u.ID, + UserNameKey: u.Username, + }, + } + + if u.ConfirmedEmail != "" { + data.Emails = []Email{{ + Email: u.ConfirmedEmail, + // X returns only confirmed emails + Verified: true, + Primary: true, + }} + } + + return data, nil +} diff --git a/internal/api/provider/zoom.go b/internal/api/provider/zoom.go index 8e2e9fa4d..df12d72b1 100644 --- a/internal/api/provider/zoom.go +++ b/internal/api/provider/zoom.go @@ -51,8 +51,12 @@ func NewZoomProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) }, nil } -func (g zoomProvider) GetOAuthToken(code string) (*oauth2.Token, error) { - return g.Exchange(context.Background(), code) +func (g zoomProvider) GetOAuthToken(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return g.Exchange(ctx, code, opts...) +} + +func (g zoomProvider) RequiresPKCE() bool { + return false } func (g zoomProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { diff --git a/internal/api/signup.go b/internal/api/signup.go index 89c79f889..1af7c6a7a 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -329,7 +329,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { "immediate_login_after_signup": true, }, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } if user.HasBeenInvited() { // Remove sensitive fields diff --git a/internal/api/token.go b/internal/api/token.go index 5ab02ad24..367138127 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -36,6 +36,14 @@ type PKCEGrantParams struct { const useCookieHeader = "x-use-cookie" const InvalidLoginMessage = "Invalid login credentials" +const dummyPasswordHash = "$2a$10$JUbiChr4qVqzEEHDLbRmgOvGTUajEl0g6JJjOzN.drbF9oX.iL/sq" + +// performDummyPasswordVerification prevents user enumeration via timing attacks +// by performing a bcrypt comparison even when user is not found +func (a *API) performDummyPasswordVerification(ctx context.Context, password string) { + _ = crypto.CompareHashAndPassword(ctx, dummyPasswordHash, password) +} + // Token is the endpoint for OAuth access token requests func (a *API) Token(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() @@ -108,12 +116,14 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri if err != nil { if models.IsNotFoundError(err) { + a.performDummyPasswordVerification(ctx, params.Password) return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) } return apierrors.NewInternalServerError("Database error querying schema").WithInternalError(err) } if !user.HasPassword() { + a.performDummyPasswordVerification(ctx, params.Password) return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) } @@ -207,7 +217,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri metering.RecordLogin(metering.LoginTypePassword, user.ID, &metering.LoginData{ Provider: provider, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) error { @@ -284,7 +294,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) metering.RecordLogin(metering.LoginTypePKCE, user.ID, &metering.LoginData{ Provider: flowState.ProviderType, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index e62940abc..c0ec6d3e2 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -319,5 +319,5 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R Provider: providerType, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 3178b9ec9..f27847c4b 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -57,5 +57,5 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h return err } - return sendJSON(w, http.StatusOK, tokenResponse) + return sendTokenJSON(w, http.StatusOK, tokenResponse) } diff --git a/internal/api/user.go b/internal/api/user.go index 723bce144..2458e8837 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -18,6 +18,7 @@ import ( type UserUpdateParams struct { Email string `json:"email"` Password *string `json:"password"` + CurrentPassword *string `json:"current_password"` Nonce string `json:"nonce"` Data map[string]interface{} `json:"data"` AppData map[string]interface{} `json:"app_metadata,omitempty"` @@ -147,6 +148,22 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } if params.Password != nil { + if config.Security.UpdatePasswordRequireCurrentPassword { + if params.CurrentPassword == nil || *params.CurrentPassword == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Current password is required to update password") + } + + if user.HasPassword() { + authenticated, _, err := user.Authenticate(ctx, db, *params.CurrentPassword, config.Security.DBEncryption.DecryptionKeys, false, "") + if err != nil { + return apierrors.NewInternalServerError("Error verifying current password").WithInternalError(err) + } + if !authenticated { + return apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, InvalidLoginMessage) + } + } + } + if config.Security.UpdatePasswordRequireReauthentication { now := time.Now() // we require reauthentication if the user hasn't signed in recently in the current session diff --git a/internal/api/verify.go b/internal/api/verify.go index 6209774bb..b33586b04 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -35,6 +35,11 @@ const ( singleConfirmation ) +// OTP brute force protection +const ( + maxOTPVerificationAttempts = 3 +) + // Only applicable when SECURE_EMAIL_CHANGE_ENABLED const singleConfirmationAccepted = "Confirmation link accepted. Please proceed to confirm link sent to the other email" @@ -307,7 +312,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP Provider: provider, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.Connection, user *models.User) (*models.User, error) { @@ -723,6 +728,15 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeUserBanned, "User is banned") } + // OTP Protection: Check if token is invalidated before attempting verification + tokenType := getTokenTypeForVerification(params.Type) + if tokenType != "" { + invalidated, err := checkOTPTokenInvalidated(conn, user.ID.String(), tokenType) + if err == nil && invalidated { + return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Token has been invalidated due to too many failed attempts. Please request a new verification code.") + } + } + var isValid bool smsProvider, _ := sms_provider.GetSmsProvider(*config) @@ -770,6 +784,14 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, isValid = isOtpValid(tokenHash, expectedToken, sentAt, config.Sms.OtpExp) } + // OTP Protection: Record attempt + if tokenType != "" { + if err := recordOTPAttempt(conn, user.ID.String(), tokenType, isValid); err != nil { + // Log error but don't fail the request + logrus.WithError(err).Warn("Failed to record OTP attempt") + } + } + if !isValid { return nil, apierrors.NewForbiddenError(apierrors.ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") } @@ -811,3 +833,92 @@ func emailAddressChanged(oldEmail, newEmail string) bool { func phoneNumberChanged(oldPhone, newPhone string) bool { return oldPhone != "" && newPhone != "" && oldPhone != newPhone } + +// ============================================ +// OTP Brute Force Protection Functions +// ============================================ + +// getTokenTypeForVerification maps verification type to one_time_token type +func getTokenTypeForVerification(verificationType string) string { + switch verificationType { + case mail.SignupVerification, mail.InviteVerification, mail.EmailOTPVerification: + return "confirmation_token" + case mail.RecoveryVerification, mail.MagicLinkVerification: + return "recovery_token" + case mail.EmailChangeVerification: + return "email_change_token_current" + case smsVerification: + return "phone_confirmation_token" + case phoneChangeVerification: + return "phone_change_token" + default: + return "" + } +} + +// checkOTPTokenInvalidated checks if an OTP token has been invalidated due to too many failed attempts +func checkOTPTokenInvalidated(conn *storage.Connection, userID string, tokenType string) (bool, error) { + var invalidatedAt *time.Time + err := conn.RawQuery(` + SELECT invalidated_at + FROM auth.one_time_tokens + WHERE user_id = $1 AND token_type = $2::auth.one_time_token_type + `, userID, tokenType).First(&invalidatedAt) + + if err != nil { + if storage.IsNotFoundError(err) { + return false, nil + } + return false, err + } + + return invalidatedAt != nil, nil +} + +// recordOTPAttempt records a failed OTP verification attempt and invalidates token after max failures +func recordOTPAttempt(conn *storage.Connection, userID string, tokenType string, isValid bool) error { + // If token is valid, reset attempts + if isValid { + _, err := conn.RawQuery(` + UPDATE auth.one_time_tokens + SET attempt_count = 0, invalidated_at = NULL + WHERE user_id = $1 AND token_type = $2::auth.one_time_token_type + `, userID, tokenType).Exec() + return err + } + + // Token is invalid - increment attempt count + var attemptCount int + err := conn.RawQuery(` + UPDATE auth.one_time_tokens + SET attempt_count = attempt_count + 1 + WHERE user_id = $1 AND token_type = $2::auth.one_time_token_type + RETURNING attempt_count + `, userID, tokenType).First(&attemptCount) + + if err != nil { + return err + } + + // If max attempts reached, invalidate the token + if attemptCount >= maxOTPVerificationAttempts { + _, err = conn.RawQuery(` + UPDATE auth.one_time_tokens + SET invalidated_at = NOW() + WHERE user_id = $1 AND token_type = $2::auth.one_time_token_type + `, userID, tokenType).Exec() + return err + } + + return nil +} + +// clearOTPAttempts resets attempt tracking when a new OTP is generated +func clearOTPAttempts(conn *storage.Connection, userID string, tokenType string) error { + _, err := conn.RawQuery(` + UPDATE auth.one_time_tokens + SET attempt_count = 0, invalidated_at = NULL + WHERE user_id = $1 AND token_type = $2::auth.one_time_token_type + `, userID, tokenType).Exec() + return err +} diff --git a/internal/api/web3.go b/internal/api/web3.go index 3928b1a25..d2105aa20 100644 --- a/internal/api/web3.go +++ b/internal/api/web3.go @@ -200,7 +200,7 @@ func (a *API) web3GrantSolana(ctx context.Context, w http.ResponseWriter, r *htt }, }) - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } func (a *API) web3GrantEthereum(ctx context.Context, w http.ResponseWriter, r *http.Request, params *Web3GrantParams) error { @@ -335,5 +335,5 @@ func (a *API) web3GrantEthereum(ctx context.Context, w http.ResponseWriter, r *h return err } } - return sendJSON(w, http.StatusOK, token) + return sendTokenJSON(w, http.StatusOK, token) } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index aa744c26f..dd98a26f1 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -439,6 +439,7 @@ type ProviderConfiguration struct { WorkOS OAuthProviderConfiguration `json:"workos"` Email EmailProviderConfiguration `json:"email"` Phone PhoneProviderConfiguration `json:"phone"` + X OAuthProviderConfiguration `json:"x" envconfig:"X"` Zoom OAuthProviderConfiguration `json:"zoom"` IosBundleId string `json:"ios_bundle_id" split_words:"true"` RedirectURL string `json:"redirect_url"` @@ -729,6 +730,7 @@ type SecurityConfiguration struct { RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"` RefreshTokenAllowReuse bool `json:"refresh_token_allow_reuse" split_words:"true"` UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` + UpdatePasswordRequireCurrentPassword bool `json:"update_password_require_current_password" split_words:"true" default:"false"` ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"` diff --git a/internal/indexworker/indexworker.go b/internal/indexworker/indexworker.go index 64e8f0415..bed42f47f 100644 --- a/internal/indexworker/indexworker.go +++ b/internal/indexworker/indexworker.go @@ -20,6 +20,14 @@ import ( var ErrAdvisoryLockAlreadyAcquired = errors.New("advisory lock already acquired by another process") var ErrExtensionNotFound = errors.New("extension not found") +type Outcome string + +const ( + OutcomeSuccess Outcome = "success" + OutcomeFailure Outcome = "failure" + OutcomeSkipped Outcome = "skipped" +) + // CreateIndexes ensures that the necessary indexes on the users table exist. // If the indexes already exist and are valid, it skips creation. // It uses a Postgres advisory lock to prevent concurrent index creation @@ -67,16 +75,22 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo lockQuery := fmt.Sprintf("SELECT pg_try_advisory_lock(hashtext('%s')::bigint)", lockName) if err := db.RawQuery(lockQuery).First(&lockAcquired); err != nil { - le.Errorf("Failed to attempt advisory lock acquisition: %+v", err) + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "advisory_lock_acquisition_failed", + }).WithError(err).Error("Failed to attempt advisory lock acquisition") return err } if !lockAcquired { - le.Infof("Another process is currently creating indexes. Skipping index creation.") + le.WithFields(logrus.Fields{ + "outcome": OutcomeSkipped, + "code": "advisory_lock_already_acquired", + }).Info("Another process is currently holding the advisory lock, skipping index creation") return ErrAdvisoryLockAlreadyAcquired } - le.Infof("Successfully acquired advisory lock for index creation.") + le.Debug("Successfully acquired advisory lock for index creation") // Ensure lock is released on function exit defer func() { @@ -84,28 +98,35 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo var unlocked bool if err := db.RawQuery(unlockQuery).First(&unlocked); err != nil { if ctx.Err() != nil { - le.Infof("Context cancelled. Advisory lock will be released upon session termination.") + le.Debug("Context cancelled, advisory lock will be released upon session termination") } else { - le.Errorf("Failed to release advisory lock: %+v", err) + le.WithError(err).Error("Failed to release advisory lock") } } else if unlocked { - le.Infof("Successfully released advisory lock.") + le.Debug("Successfully released advisory lock") } else { - le.Warnf("Advisory lock was not held when attempting to release.") + le.Debug("Advisory lock was not held when attempting to release") } }() // Ensure either auth_trgm or pg_trgm extension is installed extName, err := ensureTrgmExtension(db, config.DB.Namespace, le) if err != nil { - le.Errorf("Failed to ensure trgm extension is available: %+v", err) + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "trgm_extension_unavailable", + }).WithError(err).Error("Failed to ensure trgm extension is available") return err } // Look up which schema the trgm extension is installed in trgmSchema, err := getTrgmExtensionSchema(db, extName) if err != nil { - le.Errorf("Failed to find %s extension schema: %+v", extName, err) + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "extension_schema_not_found", + "extension": extName, + }).WithError(err).Error("Failed to find extension schema") return ErrExtensionNotFound } @@ -118,64 +139,104 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo // Check existing indexes and their statuses. If all exist and are valid, skip creation. existingIndexes, err := getIndexStatuses(db, config.DB.Namespace, indexNames) if err != nil { - le.Warnf("Failed to check existing indexes: %+v. Proceeding with index creation.", err) + le.WithError(err).Warn("Failed to check existing index statuses, proceeding with index creation") } else { if len(existingIndexes) == len(indexes) { allHealthy := true for _, idx := range existingIndexes { if !idx.IsValid || !idx.IsReady { - le.Infof("Index %s exists but is not healthy (valid: %v, ready: %v)", idx.IndexName, idx.IsValid, idx.IsReady) + le.WithFields(logrus.Fields{ + "code": "index_unhealthy", + "index_name": idx.IndexName, + "index_valid": idx.IsValid, + "index_ready": idx.IsReady, + }).Info("Index exists but is not healthy") allHealthy = false break } } if allHealthy { - le.Infof("All %d indexes on auth.users already exist and are ready. Skipping index creation.", len(indexes)) + le.WithFields(logrus.Fields{ + "outcome": OutcomeSkipped, + "code": "indexes_already_exist", + "index_count": len(indexes), + }).Debug("All indexes on auth.users already exist and are ready, skipping index creation") return nil } } else { - le.Infof("Found %d of %d expected indexes. Proceeding with index creation.", len(existingIndexes), len(indexes)) + le.WithFields(logrus.Fields{ + "code": "indexes_missing", + "existing_count": len(existingIndexes), + "expected_count": len(indexes), + }).Info("Found fewer indexes than expected, proceeding with index creation") } } userCount, err := getApproximateUserCount(db, config.DB.Namespace) if err != nil { - le.Warnf("Failed to get approximate user count: %+v. Proceeding with index creation.", err) + le.WithError(err).Warn("Failed to get approximate user count, proceeding with index creation") } - le.Infof("User count: %d. Starting index creation...", userCount) + le.WithFields(logrus.Fields{ + "code": "index_creation_starting", + "user_count": userCount, + }).Info("Starting index creation") // First, clean up any invalid indexes from previous interrupted attempts dropInvalidIndexes(db, le, config.DB.Namespace, indexNames) // Create indexes one by one var failedIndexes []string + var succeededIndexes []string totalStartTime := time.Now() for _, idx := range indexes { startTime := time.Now() - le.Infof("Creating index: %s", idx.name) + le.WithFields(logrus.Fields{ + "code": "index_creating", + "index_name": idx.name, + }).Info("Creating index") if err := db.RawQuery(idx.query).Exec(); err != nil { duration := time.Since(startTime).Milliseconds() - - le.Errorf("Failed to create index %s after %d ms: %v", idx.name, duration, err) + le.WithFields(logrus.Fields{ + "code": "index_creation_failed", + "index_name": idx.name, + "duration_ms": duration, + }).WithError(err).Error("Failed to create index") failedIndexes = append(failedIndexes, idx.name) } else { duration := time.Since(startTime).Milliseconds() - le.Infof("Successfully created index %s in %d ms", idx.name, duration) + le.WithFields(logrus.Fields{ + "code": "index_created", + "index_name": idx.name, + "duration_ms": duration, + }).Info("Successfully created index") + succeededIndexes = append(succeededIndexes, idx.name) } } totalDuration := time.Since(totalStartTime).Milliseconds() if len(failedIndexes) > 0 { - le.Warnf("Index creation completed in %d ms with some failures: %v", totalDuration, failedIndexes) + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "index_creation_partial_failure", + "duration_ms": totalDuration, + "failed_indexes": failedIndexes, + "succeeded_indexes": succeededIndexes, + }).Error("Index creation completed with some failures") + return fmt.Errorf("failed to create indexes: %v", failedIndexes) - } else { - le.Infof("All indexes created successfully in %d ms", totalDuration) } + le.WithFields(logrus.Fields{ + "outcome": OutcomeSuccess, + "code": "index_creation_completed", + "duration_ms": totalDuration, + "succeeded_indexes": succeededIndexes, + }).Info("All indexes created successfully") + return nil } @@ -249,19 +310,30 @@ func ensureTrgmExtension(db *pop.Connection, authSchema string, le *logrus.Entry if authTrgmStatus.Available { if !authTrgmStatus.Installed { - le.Infof("auth_trgm extension is available but not installed. Installing...") + le.Debug("auth_trgm extension is available but not installed, installing") + if err := installExtension(db, "auth_trgm", authSchema); err != nil { - le.Errorf("Failed to install auth_trgm extension: %v", err) + le.WithFields(logrus.Fields{ + "outcome": OutcomeFailure, + "code": "extension_install_failed", + "extension": "auth_trgm", + }).WithError(err).Error("Failed to install auth_trgm extension") + return "", fmt.Errorf("auth_trgm extension is available but failed to install: %w", err) } - le.Infof("Successfully installed auth_trgm extension") + + le.WithFields(logrus.Fields{ + "code": "extension_installed", + "extension": "auth_trgm", + }).Info("Successfully installed auth_trgm extension") } else { - le.Infof("auth_trgm extension is already installed") + le.Debug("auth_trgm extension is already installed") } + return "auth_trgm", nil } - le.Infof("auth_trgm extension is not available, checking pg_trgm...") + le.Debug("auth_trgm extension is not available, checking pg_trgm") pgTrgmStatus, err := getExtensionStatus(db, "pg_trgm") if err != nil { @@ -273,14 +345,23 @@ func ensureTrgmExtension(db *pop.Connection, authSchema string, le *logrus.Entry } if !pgTrgmStatus.Installed { - le.Infof("pg_trgm extension is available but not installed. Installing...") + le.Debug("pg_trgm extension is available but not installed, installing") + if err := installExtension(db, "pg_trgm", "pg_catalog"); err != nil { - le.Errorf("Failed to install pg_trgm extension: %v", err) + le.WithFields(logrus.Fields{ + "code": "extension_install_failed", + "extension": "pg_trgm", + }).WithError(err).Error("Failed to install pg_trgm extension") + return "", fmt.Errorf("pg_trgm extension is available but failed to install: %w", err) } - le.Infof("Successfully installed pg_trgm extension") + + le.WithFields(logrus.Fields{ + "code": "extension_installed", + "extension": "pg_trgm", + }).Info("Successfully installed pg_trgm extension") } else { - le.Infof("pg_trgm extension is already installed") + le.Debug("pg_trgm extension is already installed") } return "pg_trgm", nil @@ -308,12 +389,6 @@ func getUsersIndexes(namespace, trgmSchema string) []struct { query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_email_trgm ON %q.users USING gin (email %s.gin_trgm_ops);`, namespace, trgmSchema), }, - // enables exact-match and prefix searches and sorting by phone number - { - name: "idx_users_phone_pattern", - query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_phone_pattern - ON %q.users USING btree (phone text_pattern_ops);`, namespace), - }, // for range queries and sorting on created_at and last_sign_in_at { name: "idx_users_created_at_desc", @@ -403,10 +478,16 @@ func dropInvalidIndexes(db *pop.Connection, le *logrus.Entry, namespace string, var invalidIndexes []invalidIndex if err := db.RawQuery(cleanupQuery).All(&invalidIndexes); err == nil && len(invalidIndexes) > 0 { for _, idx := range invalidIndexes { - le.Warnf("Dropping invalid index from previous interrupted run: %s", idx.IndexName) + le.WithFields(logrus.Fields{ + "code": "dropping_invalid_index", + "index_name": idx.IndexName, + }).Info("Dropping invalid index from previous interrupted run") dropQuery := fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %q.%s", namespace, idx.IndexName) if err := db.RawQuery(dropQuery).Exec(); err != nil { - le.Errorf("Failed to drop invalid index %s: %v", idx.IndexName, err) + le.WithFields(logrus.Fields{ + "code": "drop_invalid_index_failed", + "index_name": idx.IndexName, + }).WithError(err).Error("Failed to drop invalid index") } } } diff --git a/internal/models/cleanup.go b/internal/models/cleanup.go index 9669c8d4b..e7a8fc75d 100644 --- a/internal/models/cleanup.go +++ b/internal/models/cleanup.go @@ -40,6 +40,7 @@ func NewCleanup(config *conf.GlobalConfiguration) *Cleanup { tableFlowStates := FlowState{}.TableName() tableMFAChallenges := Challenge{}.TableName() tableMFAFactors := Factor{}.TableName() + tableOAuthClientStates := OAuthClientState{}.TableName() c := &Cleanup{} @@ -56,6 +57,7 @@ func NewCleanup(config *conf.GlobalConfiguration) *Cleanup { fmt.Sprintf("delete from %q where id in (select id from %q where not_after < now() - interval '72 hours' limit 10 for update skip locked);", tableSessions, tableSessions), fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRelayStates, tableRelayStates), fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableFlowStates, tableFlowStates), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableOAuthClientStates, tableOAuthClientStates), fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableMFAChallenges, tableMFAChallenges), fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' and status = 'unverified' limit 100 for update skip locked);", tableMFAFactors, tableMFAFactors), ) diff --git a/internal/models/errors.go b/internal/models/errors.go index fa3a2674e..4f1c95e60 100644 --- a/internal/models/errors.go +++ b/internal/models/errors.go @@ -31,6 +31,8 @@ func IsNotFoundError(err error) bool { return true case OAuthServerAuthorizationNotFoundError, *OAuthServerAuthorizationNotFoundError: return true + case OAuthClientStateNotFoundError, *OAuthClientStateNotFoundError: + return true } return false } @@ -127,3 +129,9 @@ type UserEmailUniqueConflictError struct{} func (e UserEmailUniqueConflictError) Error() string { return "User email unique constraint violated" } + +type OAuthClientStateNotFoundError struct{} + +func (e OAuthClientStateNotFoundError) Error() string { + return "OAuth state not found" +} diff --git a/internal/models/oauth_client_state.go b/internal/models/oauth_client_state.go new file mode 100644 index 000000000..f89efdd4e --- /dev/null +++ b/internal/models/oauth_client_state.go @@ -0,0 +1,45 @@ +package models + +import ( + "database/sql" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +const OAuthClientStateTimeout = 5 * time.Minute + +type OAuthClientState struct { + ID uuid.UUID `json:"id" db:"id"` + ProviderType string `json:"provider_type" db:"provider_type"` + CodeVerifier *string `json:"code_verifier,omitempty" db:"code_verifier"` + CreatedAt time.Time `json:"created_at" db:"created_at"` +} + +func (OAuthClientState) TableName() string { + return "oauth_client_states" +} + +func NewOAuthClientState(providerType string, codeVerifier *string) *OAuthClientState { + return &OAuthClientState{ + ID: uuid.Must(uuid.NewV4()), + ProviderType: providerType, + CodeVerifier: codeVerifier, + } +} +func FindAndDeleteOAuthClientStateByID(tx *storage.Connection, id uuid.UUID) (*OAuthClientState, error) { + obj := &OAuthClientState{} + if err := tx.RawQuery("DELETE FROM "+obj.TableName()+" WHERE id = ? RETURNING *", id).First(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthClientStateNotFoundError{} + } + return nil, errors.Wrap(err, "error deleting oauth state") + } + return obj, nil +} + +func (s *OAuthClientState) IsExpired() bool { + return time.Since(s.CreatedAt) > OAuthClientStateTimeout +} diff --git a/internal/reloader/reloader_test.go b/internal/reloader/reloader_test.go index 331252c57..8e0943567 100644 --- a/internal/reloader/reloader_test.go +++ b/internal/reloader/reloader_test.go @@ -217,55 +217,62 @@ func TestWatchNotify(t *testing.T) { }) t.Run("ErrorChanClosed", func(t *testing.T) { - dir, cleanup := helpTestDir(t) - defer cleanup() + fn := func() string { + ctx, cancel := context.WithCancel(ctx) + defer cancel() - cfg := e2e.Must(e2e.Config()).Reloading - cfg.SignalEnabled = false - cfg.PollerEnabled = false + dir, cleanup := helpTestDir(t) + defer cleanup() - rr := mockReloadRecorder() - wr := newMockWatcher(nil) - wr.errorCh <- errors.New("sentinel") - close(wr.errorCh) + cfg := e2e.Must(e2e.Config()).Reloading + cfg.SignalEnabled = false + cfg.PollerEnabled = false - rl := NewReloader(cfg, dir) - rl.watchFn = func() (watcher, error) { return wr, nil } + rr := mockReloadRecorder() + wr := newMockWatcher(nil) + wr.errorCh <- errors.New("sentinel") + close(wr.errorCh) - err := rl.Watch(ctx, rr.configFn) - require.NotNil(t, err) + rl := NewReloader(cfg, dir) + rl.watchFn = func() (watcher, error) { return wr, nil } - msg := "reloader: fsnotify error channel was closed" - if exp, got := msg, err.Error(); exp != got { - require.Equal(t, exp, got) + err := rl.Watch(ctx, rr.configFn) + require.NotNil(t, err) + return err.Error() } + + const exp = "reloader: fsnotify error channel was closed" + runUntilErrorStr(t, exp, fn) }) t.Run("EventChanClosed", func(t *testing.T) { - dir, cleanup := helpTestDir(t) - defer cleanup() + fn := func() string { + ctx, cancel := context.WithCancel(ctx) + defer cancel() - cfg := e2e.Must(e2e.Config()).Reloading - cfg.SignalEnabled = false - cfg.PollerEnabled = false - cfg.GracePeriodInterval = time.Second / 100 + dir, cleanup := helpTestDir(t) + defer cleanup() - rr := mockReloadRecorder() - wr := newMockWatcher(nil) - close(wr.eventCh) + cfg := e2e.Must(e2e.Config()).Reloading + cfg.SignalEnabled = false + cfg.PollerEnabled = false + cfg.GracePeriodInterval = time.Second / 100 - rl := NewReloader(cfg, dir) - rl.watchFn = func() (watcher, error) { return wr, nil } + rr := mockReloadRecorder() + wr := newMockWatcher(nil) + close(wr.eventCh) - err := rl.Watch(ctx, rr.configFn) - if err == nil { + rl := NewReloader(cfg, dir) + rl.watchFn = func() (watcher, error) { return wr, nil } + + err := rl.Watch(ctx, rr.configFn) require.NotNil(t, err) - } - msg := "reloader: fsnotify event channel was closed" - if exp, got := msg, err.Error(); exp != got { - require.Equal(t, exp, got) + return err.Error() } + + const exp = "reloader: fsnotify event channel was closed" + runUntilErrorStr(t, exp, fn) }) t.Run("ErrorChan", func(t *testing.T) { @@ -739,6 +746,16 @@ func TestReloadCheckAt(t *testing.T) { } } +func runUntilErrorStr(t testing.TB, exp string, fn func() string) { + var got string + for range 100 { + if got = fn(); got == exp { + break + } + } + require.Equal(t, exp, got) +} + func helpTestDir(t testing.TB) (dir string, cleanup func()) { name := fmt.Sprintf("%v_%v", t.Name(), time.Now().Nanosecond()) dir = filepath.Join("testdata", name) diff --git a/migrations/20251201000000_add_oauth_client_states_table.up.sql b/migrations/20251201000000_add_oauth_client_states_table.up.sql new file mode 100644 index 000000000..6bece1a57 --- /dev/null +++ b/migrations/20251201000000_add_oauth_client_states_table.up.sql @@ -0,0 +1,11 @@ +/* auth_migration: 20251201000000 */ +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.oauth_client_states( + id UUID PRIMARY KEY, + provider_type TEXT NOT NULL, + code_verifier TEXT, + created_at TIMESTAMPTZ NOT NULL +); +/* auth_migration: 20251201000000 */ +CREATE INDEX IF NOT EXISTS idx_oauth_client_states_created_at ON {{ index .Options "Namespace" }}.oauth_client_states(created_at); +/* auth_migration: 20251201000000 */ +COMMENT ON TABLE {{ index .Options "Namespace" }}.oauth_client_states IS 'Stores OAuth states for third-party provider authentication flows where Supabase acts as the OAuth client.'; diff --git a/migrations/20251203120046_add_otp_attempt_tracking.up.sql b/migrations/20251203120046_add_otp_attempt_tracking.up.sql new file mode 100644 index 000000000..cb2579348 --- /dev/null +++ b/migrations/20251203120046_add_otp_attempt_tracking.up.sql @@ -0,0 +1,10 @@ +-- Add OTP brute force protection columns to one_time_tokens table + +ALTER TABLE {{ index .Options "Namespace" }}.one_time_tokens +ADD COLUMN IF NOT EXISTS attempt_count INT DEFAULT 0, +ADD COLUMN IF NOT EXISTS invalidated_at timestamptz NULL; + +-- Add index for invalidated tokens +CREATE INDEX IF NOT EXISTS one_time_tokens_invalidated_at_idx +ON {{ index .Options "Namespace" }}.one_time_tokens(invalidated_at) +WHERE invalidated_at IS NOT NULL;