Skip to content

Commit f499201

Browse files
authored
fix: multi role with oauth (#61)
1 parent b376ee3 commit f499201

File tree

6 files changed

+82
-78
lines changed

6 files changed

+82
-78
lines changed

server/env.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func InitEnv() {
163163
roles = append(roles, trimVal)
164164
}
165165

166-
if utils.StringContains(defaultRoleSplit, trimVal) {
166+
if utils.StringSliceContains(defaultRoleSplit, trimVal) {
167167
defaultRoles = append(defaultRoles, trimVal)
168168
}
169169
}

server/handlers/oauthCallback.go

Lines changed: 73 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
"golang.org/x/oauth2"
2020
)
2121

22-
func processGoogleUserInfo(code string, roles []string, c *gin.Context) (db.User, error) {
22+
func processGoogleUserInfo(code string) (db.User, error) {
2323
user := db.User{}
2424
token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code)
2525
if err != nil {
@@ -40,38 +40,18 @@ func processGoogleUserInfo(code string, roles []string, c *gin.Context) (db.User
4040
userRawData := make(map[string]string)
4141
json.Unmarshal(body, &userRawData)
4242

43-
existingUser, err := db.Mgr.GetUserByEmail(userRawData["email"])
4443
user = db.User{
4544
FirstName: userRawData["given_name"],
4645
LastName: userRawData["family_name"],
4746
Image: userRawData["picture"],
4847
Email: userRawData["email"],
4948
EmailVerifiedAt: time.Now().Unix(),
5049
}
51-
if err != nil {
52-
// user not registered, register user and generate session token
53-
user.SignupMethod = enum.Google.String()
54-
user.Roles = strings.Join(roles, ",")
55-
} else {
56-
// user exists in db, check if method was google
57-
// if not append google to existing signup method and save it
58-
59-
signupMethod := existingUser.SignupMethod
60-
if !strings.Contains(signupMethod, enum.Google.String()) {
61-
signupMethod = signupMethod + "," + enum.Google.String()
62-
}
63-
user.SignupMethod = signupMethod
64-
user.Password = existingUser.Password
65-
if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) {
66-
return user, fmt.Errorf("invalid role")
67-
}
6850

69-
user.Roles = existingUser.Roles
70-
}
7151
return user, nil
7252
}
7353

74-
func processGithubUserInfo(code string, roles []string, c *gin.Context) (db.User, error) {
54+
func processGithubUserInfo(code string) (db.User, error) {
7555
user := db.User{}
7656
token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code)
7757
if err != nil {
@@ -100,7 +80,6 @@ func processGithubUserInfo(code string, roles []string, c *gin.Context) (db.User
10080
userRawData := make(map[string]string)
10181
json.Unmarshal(body, &userRawData)
10282

103-
existingUser, err := db.Mgr.GetUserByEmail(userRawData["email"])
10483
name := strings.Split(userRawData["name"], " ")
10584
firstName := ""
10685
lastName := ""
@@ -117,32 +96,11 @@ func processGithubUserInfo(code string, roles []string, c *gin.Context) (db.User
11796
Email: userRawData["email"],
11897
EmailVerifiedAt: time.Now().Unix(),
11998
}
120-
if err != nil {
121-
// user not registered, register user and generate session token
122-
user.SignupMethod = enum.Github.String()
123-
user.Roles = strings.Join(roles, ",")
124-
} else {
125-
// user exists in db, check if method was google
126-
// if not append google to existing signup method and save it
127-
128-
signupMethod := existingUser.SignupMethod
129-
if !strings.Contains(signupMethod, enum.Github.String()) {
130-
signupMethod = signupMethod + "," + enum.Github.String()
131-
}
132-
user.SignupMethod = signupMethod
133-
user.Password = existingUser.Password
134-
135-
if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) {
136-
return user, fmt.Errorf("invalid role")
137-
}
138-
139-
user.Roles = existingUser.Roles
140-
}
14199

142100
return user, nil
143101
}
144102

145-
func processFacebookUserInfo(code string, roles []string, c *gin.Context) (db.User, error) {
103+
func processFacebookUserInfo(code string) (db.User, error) {
146104
user := db.User{}
147105
token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code)
148106
if err != nil {
@@ -170,7 +128,6 @@ func processFacebookUserInfo(code string, roles []string, c *gin.Context) (db.Us
170128
json.Unmarshal(body, &userRawData)
171129

172130
email := fmt.Sprintf("%v", userRawData["email"])
173-
existingUser, err := db.Mgr.GetUserByEmail(email)
174131

175132
picObject := userRawData["picture"].(map[string]interface{})["data"]
176133
picDataObject := picObject.(map[string]interface{})
@@ -182,28 +139,6 @@ func processFacebookUserInfo(code string, roles []string, c *gin.Context) (db.Us
182139
EmailVerifiedAt: time.Now().Unix(),
183140
}
184141

185-
if err != nil {
186-
// user not registered, register user and generate session token
187-
user.SignupMethod = enum.Github.String()
188-
user.Roles = strings.Join(roles, ",")
189-
} else {
190-
// user exists in db, check if method was google
191-
// if not append google to existing signup method and save it
192-
193-
signupMethod := existingUser.SignupMethod
194-
if !strings.Contains(signupMethod, enum.Github.String()) {
195-
signupMethod = signupMethod + "," + enum.Github.String()
196-
}
197-
user.SignupMethod = signupMethod
198-
user.Password = existingUser.Password
199-
200-
if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) {
201-
return user, fmt.Errorf("invalid role")
202-
}
203-
204-
user.Roles = existingUser.Roles
205-
}
206-
207142
return user, nil
208143
}
209144

@@ -226,19 +161,19 @@ func OAuthCallbackHandler() gin.HandlerFunc {
226161
return
227162
}
228163

229-
roles := strings.Split(sessionSplit[2], ",")
164+
inputRoles := strings.Split(sessionSplit[2], ",")
230165
redirectURL := sessionSplit[1]
231166

232167
var err error
233168
user := db.User{}
234169
code := c.Request.FormValue("code")
235170
switch provider {
236171
case enum.Google.String():
237-
user, err = processGoogleUserInfo(code, roles, c)
172+
user, err = processGoogleUserInfo(code)
238173
case enum.Github.String():
239-
user, err = processGithubUserInfo(code, roles, c)
174+
user, err = processGithubUserInfo(code)
240175
case enum.Facebook.String():
241-
user, err = processFacebookUserInfo(code, roles, c)
176+
user, err = processFacebookUserInfo(code)
242177
default:
243178
err = fmt.Errorf(`invalid oauth provider`)
244179
}
@@ -248,12 +183,76 @@ func OAuthCallbackHandler() gin.HandlerFunc {
248183
return
249184
}
250185

186+
existingUser, err := db.Mgr.GetUserByEmail(user.Email)
187+
188+
if err != nil {
189+
// user not registered, register user and generate session token
190+
user.SignupMethod = provider
191+
// make sure inputRoles don't include protected roles
192+
hasProtectedRole := false
193+
for _, ir := range inputRoles {
194+
if utils.StringSliceContains(constants.PROTECTED_ROLES, ir) {
195+
hasProtectedRole = true
196+
}
197+
}
198+
199+
if hasProtectedRole {
200+
c.JSON(400, gin.H{"error": "invalid role"})
201+
return
202+
}
203+
204+
user.Roles = strings.Join(inputRoles, ",")
205+
} else {
206+
// user exists in db, check if method was google
207+
// if not append google to existing signup method and save it
208+
209+
signupMethod := existingUser.SignupMethod
210+
if !strings.Contains(signupMethod, provider) {
211+
signupMethod = signupMethod + "," + enum.Github.String()
212+
}
213+
user.SignupMethod = signupMethod
214+
user.Password = existingUser.Password
215+
216+
// There multiple scenarios with roles here in social login
217+
// 1. user has access to protected roles + roles and trying to login
218+
// 2. user has not signed up for one of the available role but trying to signup.
219+
// Need to modify roles in this case
220+
221+
// find the unassigned roles
222+
existingRoles := strings.Split(existingUser.Roles, ",")
223+
unasignedRoles := []string{}
224+
for _, ir := range inputRoles {
225+
if !utils.StringSliceContains(existingRoles, ir) {
226+
unasignedRoles = append(unasignedRoles, ir)
227+
}
228+
}
229+
230+
if len(unasignedRoles) > 0 {
231+
// check if it contains protected unassigned role
232+
hasProtectedRole := false
233+
for _, ur := range unasignedRoles {
234+
if utils.StringSliceContains(constants.PROTECTED_ROLES, ur) {
235+
hasProtectedRole = true
236+
}
237+
}
238+
239+
if hasProtectedRole {
240+
c.JSON(400, gin.H{"error": "invalid role"})
241+
return
242+
} else {
243+
user.Roles = existingUser.Roles + "," + strings.Join(unasignedRoles, ",")
244+
}
245+
} else {
246+
user.Roles = existingUser.Roles
247+
}
248+
}
249+
251250
user, _ = db.Mgr.SaveUser(user)
252251
user, _ = db.Mgr.GetUserByEmail(user.Email)
253252
userIdStr := fmt.Sprintf("%v", user.ID)
254-
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, roles)
253+
refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, inputRoles)
255254

256-
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, roles)
255+
accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, inputRoles)
257256
utils.SetCookie(c, accessToken)
258257
session.SetToken(userIdStr, refreshToken)
259258

server/resolvers/token.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func Token(ctx context.Context, roles []string) (*model.AuthResponse, error) {
5454

5555
if len(roles) > 0 {
5656
for _, v := range roles {
57-
if !utils.StringContains(claimRoles, v) {
57+
if !utils.StringSliceContains(claimRoles, v) {
5858
return res, fmt.Errorf(`unauthorized`)
5959
}
6060
}

server/utils/common.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func WriteToFile(filename string, data string) error {
1919
return file.Sync()
2020
}
2121

22-
func StringContains(s []string, e string) bool {
22+
func StringSliceContains(s []string, e string) bool {
2323
for _, a := range s {
2424
if a == e {
2525
return true

server/utils/initServer.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ func InitServer() {
1818
Role: val,
1919
})
2020
}
21+
for _, val := range constants.PROTECTED_ROLES {
22+
roles = append(roles, db.Role{
23+
Role: val,
24+
})
25+
}
2126
err := db.Mgr.SaveRoles(roles)
2227
if err != nil {
2328
log.Println(`Error saving roles`, err)

server/utils/validator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func IsSuperAdmin(gc *gin.Context) bool {
4343
func IsValidRoles(userRoles []string, roles []string) bool {
4444
valid := true
4545
for _, role := range roles {
46-
if !StringContains(userRoles, role) {
46+
if !StringSliceContains(userRoles, role) {
4747
valid = false
4848
break
4949
}

0 commit comments

Comments
 (0)