Skip to content

Commit 8259fb5

Browse files
committed
Add support for more JWT algo methods
1 parent 6c2a4c3 commit 8259fb5

File tree

11 files changed

+151
-120
lines changed

11 files changed

+151
-120
lines changed

server/constants/env.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ const (
4343
EnvKeyJwtType = "JWT_TYPE"
4444
// EnvKeyJwtSecret key for env variable JWT_SECRET
4545
EnvKeyJwtSecret = "JWT_SECRET"
46+
// EnvKeyJwtPrivateKey key for env variable JWT_PRIVATE_KEY
47+
EnvKeyJwtPrivateKey = "JWT_PRIVATE_KEY"
48+
// EnvKeyJwtPublicKey key for env variable JWT_PUBLIC_KEY
49+
EnvKeyJwtPublicKey = "JWT_PUBLIC_KEY"
4650
// EnvKeyAllowedOrigins key for env variable ALLOWED_ORIGINS
4751
EnvKeyAllowedOrigins = "ALLOWED_ORIGINS"
4852
// EnvKeyAppURL key for env variable APP_URL

server/env/env.go

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func InitEnv() {
1919
envData := envstore.EnvInMemoryStoreObj.GetEnvStoreClone()
2020

2121
if envData.StringEnv[constants.EnvKeyEnv] == "" {
22-
envData.StringEnv[constants.EnvKeyEnv] = os.Getenv("ENV")
22+
envData.StringEnv[constants.EnvKeyEnv] = os.Getenv(constants.EnvKeyEnv)
2323
if envData.StringEnv[constants.EnvKeyEnv] == "" {
2424
envData.StringEnv[constants.EnvKeyEnv] = "production"
2525
}
@@ -50,18 +50,18 @@ func InitEnv() {
5050
}
5151

5252
if envData.StringEnv[constants.EnvKeyPort] == "" {
53-
envData.StringEnv[constants.EnvKeyPort] = os.Getenv("PORT")
53+
envData.StringEnv[constants.EnvKeyPort] = os.Getenv(constants.EnvKeyPort)
5454
if envData.StringEnv[constants.EnvKeyPort] == "" {
5555
envData.StringEnv[constants.EnvKeyPort] = "8080"
5656
}
5757
}
5858

5959
if envData.StringEnv[constants.EnvKeyAdminSecret] == "" {
60-
envData.StringEnv[constants.EnvKeyAdminSecret] = os.Getenv("ADMIN_SECRET")
60+
envData.StringEnv[constants.EnvKeyAdminSecret] = os.Getenv(constants.EnvKeyAdminSecret)
6161
}
6262

6363
if envData.StringEnv[constants.EnvKeyDatabaseType] == "" {
64-
envData.StringEnv[constants.EnvKeyDatabaseType] = os.Getenv("DATABASE_TYPE")
64+
envData.StringEnv[constants.EnvKeyDatabaseType] = os.Getenv(constants.EnvKeyDatabaseType)
6565

6666
if envstore.ARG_DB_TYPE != nil && *envstore.ARG_DB_TYPE != "" {
6767
envData.StringEnv[constants.EnvKeyDatabaseType] = *envstore.ARG_DB_TYPE
@@ -73,7 +73,7 @@ func InitEnv() {
7373
}
7474

7575
if envData.StringEnv[constants.EnvKeyDatabaseURL] == "" {
76-
envData.StringEnv[constants.EnvKeyDatabaseURL] = os.Getenv("DATABASE_URL")
76+
envData.StringEnv[constants.EnvKeyDatabaseURL] = os.Getenv(constants.EnvKeyDatabaseURL)
7777

7878
if envstore.ARG_DB_URL != nil && *envstore.ARG_DB_URL != "" {
7979
envData.StringEnv[constants.EnvKeyDatabaseURL] = *envstore.ARG_DB_URL
@@ -85,97 +85,105 @@ func InitEnv() {
8585
}
8686

8787
if envData.StringEnv[constants.EnvKeyDatabaseName] == "" {
88-
envData.StringEnv[constants.EnvKeyDatabaseName] = os.Getenv("DATABASE_NAME")
88+
envData.StringEnv[constants.EnvKeyDatabaseName] = os.Getenv(constants.EnvKeyDatabaseName)
8989
if envData.StringEnv[constants.EnvKeyDatabaseName] == "" {
9090
envData.StringEnv[constants.EnvKeyDatabaseName] = "authorizer"
9191
}
9292
}
9393

9494
if envData.StringEnv[constants.EnvKeySmtpHost] == "" {
95-
envData.StringEnv[constants.EnvKeySmtpHost] = os.Getenv("SMTP_HOST")
95+
envData.StringEnv[constants.EnvKeySmtpHost] = os.Getenv(constants.EnvKeySmtpHost)
9696
}
9797

9898
if envData.StringEnv[constants.EnvKeySmtpPort] == "" {
99-
envData.StringEnv[constants.EnvKeySmtpPort] = os.Getenv("SMTP_PORT")
99+
envData.StringEnv[constants.EnvKeySmtpPort] = os.Getenv(constants.EnvKeySmtpPort)
100100
}
101101

102102
if envData.StringEnv[constants.EnvKeySmtpUsername] == "" {
103-
envData.StringEnv[constants.EnvKeySmtpUsername] = os.Getenv("SMTP_USERNAME")
103+
envData.StringEnv[constants.EnvKeySmtpUsername] = os.Getenv(constants.EnvKeySmtpUsername)
104104
}
105105

106106
if envData.StringEnv[constants.EnvKeySmtpPassword] == "" {
107-
envData.StringEnv[constants.EnvKeySmtpPassword] = os.Getenv("SMTP_PASSWORD")
107+
envData.StringEnv[constants.EnvKeySmtpPassword] = os.Getenv(constants.EnvKeySmtpPassword)
108108
}
109109

110110
if envData.StringEnv[constants.EnvKeySenderEmail] == "" {
111-
envData.StringEnv[constants.EnvKeySenderEmail] = os.Getenv("SENDER_EMAIL")
111+
envData.StringEnv[constants.EnvKeySenderEmail] = os.Getenv(constants.EnvKeySenderEmail)
112112
}
113113

114114
if envData.StringEnv[constants.EnvKeyJwtSecret] == "" {
115-
envData.StringEnv[constants.EnvKeyJwtSecret] = os.Getenv("JWT_SECRET")
115+
envData.StringEnv[constants.EnvKeyJwtSecret] = os.Getenv(constants.EnvKeyJwtSecret)
116116
if envData.StringEnv[constants.EnvKeyJwtSecret] == "" {
117117
envData.StringEnv[constants.EnvKeyJwtSecret] = uuid.New().String()
118118
}
119119
}
120120

121+
if envData.StringEnv[constants.EnvKeyJwtPrivateKey] == "" {
122+
envData.StringEnv[constants.EnvKeyJwtPrivateKey] = os.Getenv(constants.EnvKeyJwtPrivateKey)
123+
}
124+
125+
if envData.StringEnv[constants.EnvKeyJwtPublicKey] == "" {
126+
envData.StringEnv[constants.EnvKeyJwtPublicKey] = os.Getenv(constants.EnvKeyJwtPublicKey)
127+
}
128+
121129
if envData.StringEnv[constants.EnvKeyJwtType] == "" {
122-
envData.StringEnv[constants.EnvKeyJwtType] = os.Getenv("JWT_TYPE")
130+
envData.StringEnv[constants.EnvKeyJwtType] = os.Getenv(constants.EnvKeyJwtType)
123131
if envData.StringEnv[constants.EnvKeyJwtType] == "" {
124132
envData.StringEnv[constants.EnvKeyJwtType] = "HS256"
125133
}
126134
}
127135

128136
if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" {
129-
envData.StringEnv[constants.EnvKeyJwtRoleClaim] = os.Getenv("JWT_ROLE_CLAIM")
137+
envData.StringEnv[constants.EnvKeyJwtRoleClaim] = os.Getenv(constants.EnvKeyJwtRoleClaim)
130138

131139
if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" {
132140
envData.StringEnv[constants.EnvKeyJwtRoleClaim] = "role"
133141
}
134142
}
135143

136144
if envData.StringEnv[constants.EnvKeyRedisURL] == "" {
137-
envData.StringEnv[constants.EnvKeyRedisURL] = os.Getenv("REDIS_URL")
145+
envData.StringEnv[constants.EnvKeyRedisURL] = os.Getenv(constants.EnvKeyRedisURL)
138146
}
139147

140148
if envData.StringEnv[constants.EnvKeyCookieName] == "" {
141-
envData.StringEnv[constants.EnvKeyCookieName] = os.Getenv("COOKIE_NAME")
149+
envData.StringEnv[constants.EnvKeyCookieName] = os.Getenv(constants.EnvKeyCookieName)
142150
if envData.StringEnv[constants.EnvKeyCookieName] == "" {
143151
envData.StringEnv[constants.EnvKeyCookieName] = "authorizer"
144152
}
145153
}
146154

147155
if envData.StringEnv[constants.EnvKeyGoogleClientID] == "" {
148-
envData.StringEnv[constants.EnvKeyGoogleClientID] = os.Getenv("GOOGLE_CLIENT_ID")
156+
envData.StringEnv[constants.EnvKeyGoogleClientID] = os.Getenv(constants.EnvKeyGoogleClientID)
149157
}
150158

151159
if envData.StringEnv[constants.EnvKeyGoogleClientSecret] == "" {
152-
envData.StringEnv[constants.EnvKeyGoogleClientSecret] = os.Getenv("GOOGLE_CLIENT_SECRET")
160+
envData.StringEnv[constants.EnvKeyGoogleClientSecret] = os.Getenv(constants.EnvKeyGoogleClientSecret)
153161
}
154162

155163
if envData.StringEnv[constants.EnvKeyGithubClientID] == "" {
156-
envData.StringEnv[constants.EnvKeyGithubClientID] = os.Getenv("GITHUB_CLIENT_ID")
164+
envData.StringEnv[constants.EnvKeyGithubClientID] = os.Getenv(constants.EnvKeyGithubClientID)
157165
}
158166

159167
if envData.StringEnv[constants.EnvKeyGithubClientSecret] == "" {
160-
envData.StringEnv[constants.EnvKeyGithubClientSecret] = os.Getenv("GITHUB_CLIENT_SECRET")
168+
envData.StringEnv[constants.EnvKeyGithubClientSecret] = os.Getenv(constants.EnvKeyGithubClientSecret)
161169
}
162170

163171
if envData.StringEnv[constants.EnvKeyFacebookClientID] == "" {
164-
envData.StringEnv[constants.EnvKeyFacebookClientID] = os.Getenv("FACEBOOK_CLIENT_ID")
172+
envData.StringEnv[constants.EnvKeyFacebookClientID] = os.Getenv(constants.EnvKeyFacebookClientID)
165173
}
166174

167175
if envData.StringEnv[constants.EnvKeyFacebookClientSecret] == "" {
168-
envData.StringEnv[constants.EnvKeyFacebookClientSecret] = os.Getenv("FACEBOOK_CLIENT_SECRET")
176+
envData.StringEnv[constants.EnvKeyFacebookClientSecret] = os.Getenv(constants.EnvKeyFacebookClientSecret)
169177
}
170178

171179
if envData.StringEnv[constants.EnvKeyResetPasswordURL] == "" {
172-
envData.StringEnv[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(os.Getenv("RESET_PASSWORD_URL"), "/")
180+
envData.StringEnv[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(os.Getenv(constants.EnvKeyResetPasswordURL), "/")
173181
}
174182

175-
envData.BoolEnv[constants.EnvKeyDisableBasicAuthentication] = os.Getenv("DISABLE_BASIC_AUTHENTICATION") == "true"
176-
envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = os.Getenv("DISABLE_EMAIL_VERIFICATION") == "true"
177-
envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = os.Getenv("DISABLE_MAGIC_LINK_LOGIN") == "true"
178-
envData.BoolEnv[constants.EnvKeyDisableLoginPage] = os.Getenv("DISABLE_LOGIN_PAGE") == "true"
183+
envData.BoolEnv[constants.EnvKeyDisableBasicAuthentication] = os.Getenv(constants.EnvKeyDisableBasicAuthentication) == "true"
184+
envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = os.Getenv(constants.EnvKeyDisableEmailVerification) == "true"
185+
envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = os.Getenv(constants.EnvKeyDisableMagicLinkLogin) == "true"
186+
envData.BoolEnv[constants.EnvKeyDisableLoginPage] = os.Getenv(constants.EnvKeyDisableLoginPage) == "true"
179187

180188
// no need to add nil check as its already done above
181189
if envData.StringEnv[constants.EnvKeySmtpHost] == "" || envData.StringEnv[constants.EnvKeySmtpUsername] == "" || envData.StringEnv[constants.EnvKeySmtpPassword] == "" || envData.StringEnv[constants.EnvKeySenderEmail] == "" && envData.StringEnv[constants.EnvKeySmtpPort] == "" {
@@ -187,7 +195,7 @@ func InitEnv() {
187195
envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true
188196
}
189197

190-
allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
198+
allowedOriginsSplit := strings.Split(os.Getenv(constants.EnvKeyAllowedOrigins), ",")
191199
allowedOrigins := []string{}
192200
hasWildCard := false
193201

@@ -215,22 +223,22 @@ func InitEnv() {
215223

216224
envData.SliceEnv[constants.EnvKeyAllowedOrigins] = allowedOrigins
217225

218-
rolesEnv := strings.TrimSpace(os.Getenv("ROLES"))
226+
rolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyRoles))
219227
rolesSplit := strings.Split(rolesEnv, ",")
220228
roles := []string{}
221229
if len(rolesEnv) == 0 {
222230
roles = []string{"user"}
223231
}
224232

225-
defaultRolesEnv := strings.TrimSpace(os.Getenv("DEFAULT_ROLES"))
233+
defaultRolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyDefaultRoles))
226234
defaultRoleSplit := strings.Split(defaultRolesEnv, ",")
227235
defaultRoles := []string{}
228236

229237
if len(defaultRolesEnv) == 0 {
230238
defaultRoles = []string{"user"}
231239
}
232240

233-
protectedRolesEnv := strings.TrimSpace(os.Getenv("PROTECTED_ROLES"))
241+
protectedRolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyProtectedRoles))
234242
protectedRolesSplit := strings.Split(protectedRolesEnv, ",")
235243
protectedRoles := []string{}
236244

@@ -259,12 +267,12 @@ func InitEnv() {
259267
envData.SliceEnv[constants.EnvKeyDefaultRoles] = defaultRoles
260268
envData.SliceEnv[constants.EnvKeyProtectedRoles] = protectedRoles
261269

262-
if os.Getenv("ORGANIZATION_NAME") != "" {
263-
envData.StringEnv[constants.EnvKeyOrganizationName] = os.Getenv("ORGANIZATION_NAME")
270+
if os.Getenv(constants.EnvKeyOrganizationName) != "" {
271+
envData.StringEnv[constants.EnvKeyOrganizationName] = os.Getenv(constants.EnvKeyOrganizationName)
264272
}
265273

266-
if os.Getenv("ORGANIZATION_LOGO") != "" {
267-
envData.StringEnv[constants.EnvKeyOrganizationLogo] = os.Getenv("ORGANIZATION_LOGO")
274+
if os.Getenv(constants.EnvKeyOrganizationLogo) != "" {
275+
envData.StringEnv[constants.EnvKeyOrganizationLogo] = os.Getenv(constants.EnvKeyOrganizationLogo)
268276
}
269277

270278
envstore.EnvInMemoryStoreObj.UpdateEnvStore(envData)

server/handlers/verify_email.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ func VerifyEmailHandler() gin.HandlerFunc {
3333
}
3434

3535
// verify if token exists in db
36-
claim, err := token.VerifyVerificationToken(tokenInQuery)
36+
claim, err := token.ParseJWTToken(tokenInQuery)
3737
if err != nil {
3838
c.JSON(400, errorRes)
3939
return
4040
}
4141

42-
user, err := db.Provider.GetUserByEmail(claim.Email)
42+
user, err := db.Provider.GetUserByEmail(claim["email"].(string))
4343
if err != nil {
4444
c.JSON(400, gin.H{
4545
"message": err.Error(),
@@ -68,6 +68,6 @@ func VerifyEmailHandler() gin.HandlerFunc {
6868
cookie.SetCookie(c, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash)
6969
utils.SaveSessionInDB(user.ID, c)
7070

71-
c.Redirect(http.StatusTemporaryRedirect, claim.RedirectURL)
71+
c.Redirect(http.StatusTemporaryRedirect, claim["redirect_url"].(string))
7272
}
7373
}

server/resolvers/is_valid_jwt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func IsValidJwtResolver(ctx context.Context, params *model.IsValidJWTQueryInput)
2626
}
2727
}
2828

29-
claims, err := tokenHelper.VerifyJWTToken(token)
29+
claims, err := tokenHelper.ParseJWTToken(token)
3030
if err != nil {
3131
return nil, err
3232
}

server/resolvers/logout.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) {
3838
fingerPrint := string(decryptedFingerPrint)
3939

4040
// verify refresh token and fingerprint
41-
claims, err := token.VerifyJWTToken(refreshToken)
41+
claims, err := token.ParseJWTToken(refreshToken)
4242
if err != nil {
4343
return res, err
4444
}

server/resolvers/reset_password.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput)
3131
}
3232

3333
// verify if token exists in db
34-
claim, err := token.VerifyVerificationToken(params.Token)
34+
claim, err := token.ParseJWTToken(params.Token)
3535
if err != nil {
3636
return res, fmt.Errorf(`invalid token`)
3737
}
3838

39-
user, err := db.Provider.GetUserByEmail(claim.Email)
39+
user, err := db.Provider.GetUserByEmail(claim["email"].(string))
4040
if err != nil {
4141
return res, err
4242
}

server/resolvers/session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod
4141
fingerPrint := string(decryptedFingerPrint)
4242

4343
// verify refresh token and fingerprint
44-
claims, err := token.VerifyJWTToken(refreshToken)
44+
claims, err := token.ParseJWTToken(refreshToken)
4545
if err != nil {
4646
return res, err
4747
}

server/resolvers/verify_email.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m
2828
}
2929

3030
// verify if token exists in db
31-
claim, err := token.VerifyVerificationToken(params.Token)
31+
claim, err := token.ParseJWTToken(params.Token)
3232
if err != nil {
3333
return res, fmt.Errorf(`invalid token`)
3434
}
3535

36-
user, err := db.Provider.GetUserByEmail(claim.Email)
36+
user, err := db.Provider.GetUserByEmail(claim["email"].(string))
3737
if err != nil {
3838
return res, err
3939
}

server/token/auth_token.go

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ func CreateAuthToken(user models.User, roles []string) (*Token, error) {
6262

6363
// CreateRefreshToken util to create JWT token
6464
func CreateRefreshToken(user models.User, roles []string) (string, int64, error) {
65-
t := jwt.New(jwt.GetSigningMethod(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType)))
6665
// expires in 1 year
6766
expiryBound := time.Hour * 8760
6867
expiresAt := time.Now().Add(expiryBound).Unix()
@@ -75,8 +74,7 @@ func CreateRefreshToken(user models.User, roles []string) (string, int64, error)
7574
"id": user.ID,
7675
}
7776

78-
t.Claims = customClaims
79-
token, err := t.SignedString([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)))
77+
token, err := SignJWTToken(customClaims)
8078
if err != nil {
8179
return "", 0, err
8280
}
@@ -86,9 +84,7 @@ func CreateRefreshToken(user models.User, roles []string) (string, int64, error)
8684
// CreateAccessToken util to create JWT token, based on
8785
// user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT
8886
func CreateAccessToken(user models.User, roles []string) (string, int64, error) {
89-
t := jwt.New(jwt.GetSigningMethod(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType)))
9087
expiryBound := time.Minute * 30
91-
9288
expiresAt := time.Now().Add(expiryBound).Unix()
9389

9490
resUser := user.AsAPIUser()
@@ -141,9 +137,7 @@ func CreateAccessToken(user models.User, roles []string) (string, int64, error)
141137
}
142138
}
143139

144-
t.Claims = customClaims
145-
146-
token, err := t.SignedString([]byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)))
140+
token, err := SignJWTToken(customClaims)
147141
if err != nil {
148142
return "", 0, err
149143
}
@@ -187,43 +181,13 @@ func GetFingerPrint(gc *gin.Context) (string, error) {
187181
return fingerPrint, nil
188182
}
189183

190-
// VerifyJWTToken helps in verifying the JWT token
191-
func VerifyJWTToken(token string) (map[string]interface{}, error) {
192-
var res map[string]interface{}
193-
claims := jwt.MapClaims{}
194-
195-
t, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
196-
return []byte(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)), nil
197-
})
198-
if err != nil {
199-
return res, err
200-
}
201-
202-
if !t.Valid {
203-
return res, fmt.Errorf(`invalid token`)
204-
}
205-
206-
// claim parses exp & iat into float 64 with e^10,
207-
// but we expect it to be int64
208-
// hence we need to assert interface and convert to int64
209-
intExp := int64(claims["exp"].(float64))
210-
intIat := int64(claims["iat"].(float64))
211-
212-
data, _ := json.Marshal(claims)
213-
json.Unmarshal(data, &res)
214-
res["exp"] = intExp
215-
res["iat"] = intIat
216-
217-
return res, nil
218-
}
219-
220184
func ValidateAccessToken(gc *gin.Context) (map[string]interface{}, error) {
221185
token, err := GetAccessToken(gc)
222186
if err != nil {
223187
return nil, err
224188
}
225189

226-
claims, err := VerifyJWTToken(token)
190+
claims, err := ParseJWTToken(token)
227191
if err != nil {
228192
return nil, err
229193
}

0 commit comments

Comments
 (0)