Skip to content
Draft
11 changes: 11 additions & 0 deletions cmd/serve_cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ func serve(ctx context.Context) {
defer wg.Wait() // Do not return to caller until this goroutine is done.

mrCache := templatemailer.NewCache()
if !config.Mailer.TemplateReloadingEnabled {
// If template reloading is disabled attempt an initial reload at
// startup for fault tolerance.
wg.Add(1)
go func() {
defer wg.Done()

mrCache.Reload(ctx, config)
}()
}

limiterOpts := api.NewLimiterOptions(config)
initialAPI := api.NewAPIWithVersion(
config, db, utilities.Version,
Expand Down
1 change: 1 addition & 0 deletions example.env
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ GOTRUE_LOG_LEVEL="debug"
GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED="false"
GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL="0"
GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION="false"
GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_CURRENT_PASSWORD="false"
GOTRUE_OPERATOR_TOKEN="unused-operator-token"
GOTRUE_RATE_LIMIT_HEADER="X-Forwarded-For"
GOTRUE_RATE_LIMIT_EMAIL_SENT="100"
Expand Down
4 changes: 4 additions & 0 deletions hack/test.env
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ GOTRUE_EXTERNAL_TWITTER_ENABLED=true
GOTRUE_EXTERNAL_TWITTER_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_TWITTER_SECRET=testsecret
GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI=https://identity.services.netlify.com/callback
GOTRUE_EXTERNAL_X_ENABLED=true
GOTRUE_EXTERNAL_X_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_X_SECRET=testsecret
GOTRUE_EXTERNAL_X_REDIRECT_URI=https://identity.services.netlify.com/callback
GOTRUE_EXTERNAL_ZOOM_ENABLED=true
GOTRUE_EXTERNAL_ZOOM_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_ZOOM_SECRET=testsecret
Expand Down
2 changes: 1 addition & 1 deletion internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
}

metering.RecordLogin(metering.LoginTypeAnonymous, newUser.ID, nil)
return sendJSON(w, http.StatusOK, token)
return sendTokenJSON(w, http.StatusOK, token)
}
3 changes: 3 additions & 0 deletions internal/api/apierrors/errorcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const (
ErrorCodeRefreshTokenAlreadyUsed ErrorCode = "refresh_token_already_used"
ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found"
ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired"
ErrorCodeOAuthClientStateNotFound ErrorCode = "oauth_client_state_not_found"
ErrorCodeOAuthClientStateExpired ErrorCode = "oauth_client_state_expired"
ErrorCodeOAuthInvalidState ErrorCode = "oauth_invalid_state"
ErrorCodeSignupDisabled ErrorCode = "signup_disabled"
ErrorCodeUserBanned ErrorCode = "user_banned"
ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification"
Expand Down
15 changes: 15 additions & 0 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/url"

"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
Expand Down Expand Up @@ -33,6 +34,7 @@ const (
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
oauthClientStateKey = contextKey("oauth_client_state_id")
)

// withToken adds the JWT token to the context.
Expand Down Expand Up @@ -137,6 +139,19 @@ func getFlowStateID(ctx context.Context) string {
return obj.(string)
}

func withOAuthClientStateID(ctx context.Context, oauthClientStateID uuid.UUID) context.Context {
return context.WithValue(ctx, oauthClientStateKey, oauthClientStateID)
}

func getOAuthClientStateID(ctx context.Context) uuid.UUID {
obj := ctx.Value(oauthClientStateKey)
if obj == nil {
return uuid.Nil
}

return obj.(uuid.UUID)
}

func getInviteToken(ctx context.Context) string {
obj := ctx.Value(inviteTokenKey)
if obj == nil {
Expand Down
78 changes: 51 additions & 27 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import (
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
type ExternalProviderClaims struct {
AuthMicroserviceClaims
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
EmailOptional bool `json:"email_optional,omitempty"`
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
OAuthClientStateID string `json:"oauth_client_state_id,omitempty"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
EmailOptional bool `json:"email_optional,omitempty"`
}

// ExternalProviderRedirect redirects the request to the oauth provider
Expand Down Expand Up @@ -90,6 +91,32 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
flowStateID = flowState.ID.String()
}

authUrlParams := make([]oauth2.AuthCodeOption, 0)
query.Del("scopes")
query.Del("provider")
query.Del("code_challenge")
query.Del("code_challenge_method")
for key := range query {
if key == "workos_provider" {
// See https://workos.com/docs/reference/sso/authorize/get
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key)))
} else {
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key)))
}
}

oauthClientStateID := ""
if oauthProvider, ok := p.(provider.OAuthProvider); ok && oauthProvider.RequiresPKCE() {
codeVerifier := oauth2.GenerateVerifier()
oauthClientState := models.NewOAuthClientState(providerType, &codeVerifier)
err := db.Create(oauthClientState)
if err != nil {
return "", err
}
oauthClientStateID = oauthClientState.ID.String()
authUrlParams = append(authUrlParams, oauth2.S256ChallengeOption(codeVerifier))
}

claims := ExternalProviderClaims{
AuthMicroserviceClaims: AuthMicroserviceClaims{
RegisteredClaims: jwt.RegisteredClaims{
Expand All @@ -98,11 +125,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
SiteURL: config.SiteURL,
InstanceID: uuid.Nil.String(),
},
Provider: providerType,
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
EmailOptional: pConfig.EmailOptional,
Provider: providerType,
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
OAuthClientStateID: oauthClientStateID,
EmailOptional: pConfig.EmailOptional,
}

if linkingTargetUser != nil {
Expand All @@ -115,20 +143,6 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
return "", apierrors.NewInternalServerError("Error creating state").WithInternalError(err)
}

authUrlParams := make([]oauth2.AuthCodeOption, 0)
query.Del("scopes")
query.Del("provider")
query.Del("code_challenge")
query.Del("code_challenge_method")
for key := range query {
if key == "workos_provider" {
// See https://workos.com/docs/reference/sso/authorize/get
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key)))
} else {
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key)))
}
}

authURL := p.AuthCodeURL(tokenString, authUrlParams...)

return authURL, nil
Expand Down Expand Up @@ -565,6 +579,13 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag
if claims.FlowStateID != "" {
ctx = withFlowStateID(ctx, claims.FlowStateID)
}
if claims.OAuthClientStateID != "" {
oauthClientStateID, err := uuid.FromString(claims.OAuthClientStateID)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (oauth_client_state_id must be UUID)")
}
ctx = withOAuthClientStateID(ctx, oauthClientStateID)
}
if claims.LinkingTargetID != "" {
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
if err != nil {
Expand Down Expand Up @@ -634,7 +655,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
p, err = provider.NewLinkedinProvider(pConfig, scopes)
case "linkedin_oidc":
pConfig = config.External.LinkedinOIDC
p, err = provider.NewLinkedinOIDCProvider(pConfig, scopes)
p, err = provider.NewLinkedinOIDCProvider(ctx, pConfig, scopes)
case "notion":
pConfig = config.External.Notion
p, err = provider.NewNotionProvider(pConfig)
Expand All @@ -656,9 +677,12 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
case "twitter":
pConfig = config.External.Twitter
p, err = provider.NewTwitterProvider(pConfig, scopes)
case "x":
pConfig = config.External.X
p, err = provider.NewXProvider(pConfig, scopes)
case "vercel_marketplace":
pConfig = config.External.VercelMarketplace
p, err = provider.NewVercelMarketplaceProvider(pConfig, scopes)
p, err = provider.NewVercelMarketplaceProvider(ctx, pConfig, scopes)
case "workos":
pConfig = config.External.WorkOS
p, err = provider.NewWorkOSProvider(pConfig)
Expand Down
43 changes: 38 additions & 5 deletions internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ import (
"net/http"
"net/url"

"github.com/gofrs/uuid"
"github.com/mrjones/oauth"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/utilities"
"golang.org/x/oauth2"
)

// OAuthProviderData contains the userData and token returned by the oauth provider
Expand Down Expand Up @@ -55,6 +58,8 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
}

func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) {
db := a.db.WithContext(ctx)

var rq url.Values
if err := r.ParseForm(); r.Method == http.MethodPost && err == nil {
rq = r.Form
Expand All @@ -72,28 +77,56 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing")
}

oAuthProvider, _, err := a.OAuthProvider(ctx, providerType)
oauthProvider, _, err := a.OAuthProvider(ctx, providerType)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
}

log := observability.GetLogEntry(r).Entry

var oauthClientState *models.OAuthClientState
// if there's a non-empty OAuthClientStateID we perform PKCE Flow for the external provider
if oauthClientStateID := getOAuthClientStateID(ctx); oauthClientStateID != uuid.Nil {
oauthClientState, err = models.FindAndDeleteOAuthClientStateByID(db, oauthClientStateID)
if models.IsNotFoundError(err) {
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateNotFound, "OAuth state not found").WithInternalError(err)
} else if err != nil {
return nil, apierrors.NewInternalServerError("Failed to find OAuth state").WithInternalError(err)
}

if oauthClientState.ProviderType != providerType {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth provider mismatch")
}

if oauthClientState.IsExpired() {
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateExpired, "OAuth state expired")
}
}

if oauthProvider.RequiresPKCE() && oauthClientState == nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth PKCE code verifier missing")
}

log.WithFields(logrus.Fields{
"provider": providerType,
"code": oauthCode,
}).Debug("Exchanging oauth code")
}).Debug("Exchanging OAuth code")

token, err := oAuthProvider.GetOAuthToken(oauthCode)
var tokenOpts []oauth2.AuthCodeOption
if oauthClientState != nil {
tokenOpts = append(tokenOpts, oauth2.VerifierOption(*oauthClientState.CodeVerifier))
}
token, err := oauthProvider.GetOAuthToken(ctx, oauthCode, tokenOpts...)
if err != nil {
return nil, apierrors.NewInternalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err)
}

userData, err := oAuthProvider.GetUserData(ctx, token)
userData, err := oauthProvider.GetUserData(ctx, token)
if err != nil {
return nil, apierrors.NewInternalServerError("Error getting user profile from external provider").WithInternalError(err)
}

switch externalProvider := oAuthProvider.(type) {
switch externalProvider := oauthProvider.(type) {
case *provider.AppleProvider:
// apple only returns user info the first time
oauthUser := rq.Get("user")
Expand Down
Loading
Loading