Skip to content

Commit be62b6e

Browse files
author
zheyuan.xing
committed
feature: use jmespath library to support parsing multi-level username key when using ouath & add unit test for this
Signed-off-by: zheyuan.xing <[email protected]>
1 parent 12c997b commit be62b6e

File tree

4 files changed

+448
-11
lines changed

4 files changed

+448
-11
lines changed

connector/oauth/oauth.go

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
"github.com/dexidp/dex/connector"
2020
"github.com/dexidp/dex/pkg/log"
21+
"github.com/jmespath/go-jmespath"
2122
)
2223

2324
type oauthConnector struct {
@@ -201,29 +202,34 @@ func (c *oauthConnector) HandleCallback(s connector.Scopes, r *http.Request) (id
201202
return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: status %d", userInfoResp.StatusCode)
202203
}
203204

204-
var userInfoResult map[string]interface{}
205+
var userInfoResult interface{}
205206
err = json.NewDecoder(userInfoResp.Body).Decode(&userInfoResult)
206207
if err != nil {
207208
c.logger.Errorf("OAuth Connector: failed to parse userinfo: %v", err)
208209
return identity, fmt.Errorf("OAuth Connector: failed to parse userinfo: %v", err)
209210
}
210211

211-
userID, found := userInfoResult[c.userIDKey].(string)
212+
tmpUserID, _ := jmespath.Search(c.userIDKey, userInfoResult)
213+
userID, found := tmpUserID.(string)
212214
if !found {
213215
c.logger.Errorf("OAuth Connector: not found %v claim", c.userIDKey)
214216
return identity, fmt.Errorf("OAuth Connector: not found %v claim", c.userIDKey)
215217
}
216218

217219
identity.UserID = userID
218-
identity.Username, _ = userInfoResult[c.userNameKey].(string)
219-
identity.PreferredUsername, _ = userInfoResult[c.preferredUsernameKey].(string)
220-
identity.Email, _ = userInfoResult[c.emailKey].(string)
221-
identity.EmailVerified, _ = userInfoResult[c.emailVerifiedKey].(bool)
220+
username, _ := jmespath.Search(c.userNameKey, userInfoResult)
221+
identity.Username, _ = username.(string)
222+
preferredUsername, _ := jmespath.Search(c.preferredUsernameKey, userInfoResult)
223+
identity.PreferredUsername, _ = preferredUsername.(string)
224+
email, _ := jmespath.Search(c.emailKey, userInfoResult)
225+
identity.Email, _ = email.(string)
226+
emailVerified, _ := jmespath.Search(c.emailVerifiedKey, userInfoResult)
227+
identity.EmailVerified, _ = emailVerified.(bool)
222228

223229
if s.Groups {
224230
groups := map[string]struct{}{}
225231

226-
c.addGroupsFromMap(groups, userInfoResult)
232+
c.addGroups(groups, userInfoResult)
227233
c.addGroupsFromToken(groups, token.AccessToken)
228234

229235
for groupName := range groups {
@@ -243,15 +249,22 @@ func (c *oauthConnector) HandleCallback(s connector.Scopes, r *http.Request) (id
243249
return identity, nil
244250
}
245251

246-
func (c *oauthConnector) addGroupsFromMap(groups map[string]struct{}, result map[string]interface{}) error {
247-
groupsClaim, ok := result[c.groupsKey].([]interface{})
252+
func (c *oauthConnector) addGroups(groups map[string]struct{}, result interface{}) error {
253+
tmpGroupsClaim, _ := jmespath.Search(c.groupsKey, result)
254+
groupsClaim, ok := tmpGroupsClaim.([]interface{})
248255
if !ok {
249256
return errors.New("cannot convert to slice")
250257
}
251258

252259
for _, group := range groupsClaim {
253260
if groupString, ok := group.(string); ok {
254261
groups[groupString] = struct{}{}
262+
continue
263+
}
264+
if groupMap, ok := group.(map[string]interface{}); ok {
265+
if groupName, ok := groupMap["name"].(string); ok {
266+
groups[groupName] = struct{}{}
267+
}
255268
}
256269
}
257270

@@ -269,13 +282,13 @@ func (c *oauthConnector) addGroupsFromToken(groups map[string]struct{}, token st
269282
return err
270283
}
271284

272-
var claimsMap map[string]interface{}
285+
var claimsMap interface{}
273286
err = json.Unmarshal(decoded, &claimsMap)
274287
if err != nil {
275288
return err
276289
}
277290

278-
return c.addGroupsFromMap(groups, claimsMap)
291+
return c.addGroups(groups, claimsMap)
279292
}
280293

281294
func decode(seg string) ([]byte, error) {

0 commit comments

Comments
 (0)