Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,7 @@ public class OauthAuthorityExtractor implements ProviderAuthorityExtractor {

@Override
public boolean isApplicable(String provider, Map<String, String> customParams) {
var typeMatch = OAUTH.equalsIgnoreCase(provider) || OAUTH.equalsIgnoreCase(customParams.get(TYPE));

if (!typeMatch) {
return false;
}

var containsRolesFieldNameParam = customParams.containsKey(ROLES_FIELD_PARAM_NAME);
if (!containsRolesFieldNameParam) {
log.debug("Provider [{}] doesn't contain a roles field param name, mapping won't be performed", provider);
return false;
}

return true;
return OAUTH.equalsIgnoreCase(provider) || OAUTH.equalsIgnoreCase(customParams.get(TYPE));
}

@Override
Expand All @@ -60,15 +48,25 @@ public Mono<Set<String>> extract(AccessControlService acs, Object value, Map<Str
}

private Set<String> extractUsernameRoles(AccessControlService acs, DefaultOAuth2User principal) {
return acs.getRoles()
var principalName = principal.getName();

log.debug("Principal name is: [{}]", principalName);

var roles = acs.getRoles()
.stream()
.filter(r -> r.getSubjects()
.stream()
.filter(s -> s.getProvider().equals(Provider.OAUTH))
.filter(s -> s.getType().equals("user"))
.anyMatch(s -> s.getValue().equals(principal.getName())))
.peek(s -> log.trace("[{}] matches [{}]? [{}]", s.getValue(), principalName,
s.getValue().equalsIgnoreCase(principalName)))
.anyMatch(s -> s.getValue().equalsIgnoreCase(principalName)))
.map(Role::getName)
.collect(Collectors.toSet());

log.debug("Matched roles by username: [{}]", String.join(", ", roles));

return roles;
}

private Set<String> extractRoles(AccessControlService acs, DefaultOAuth2User principal,
Expand All @@ -77,7 +75,17 @@ private Set<String> extractRoles(AccessControlService acs, DefaultOAuth2User pri
Assert.notNull(provider, "provider is null");
var rolesFieldName = provider.getCustomParams().get(ROLES_FIELD_PARAM_NAME);

if (rolesFieldName == null) {
log.warn("Provider [{}] doesn't contain a roles field param name, won't map roles", provider);
return Collections.emptySet();
}

var principalRoles = convertRoles(principal.getAttribute(rolesFieldName));
if (principalRoles.isEmpty()) {
log.debug("Principal [{}] doesn't have any roles, nothing to do", principal.getName());
return Collections.emptySet();
}

log.debug("Token's groups: [{}]", String.join(",", principalRoles));

Set<String> roles = acs.getRoles()
Expand All @@ -94,15 +102,15 @@ private Set<String> extractRoles(AccessControlService acs, DefaultOAuth2User pri
.map(Role::getName)
.collect(Collectors.toSet());

log.debug("Matched roles: [{}]", String.join(", ", roles));
log.debug("Matched group roles: [{}]", String.join(", ", roles));

return roles;
}

@SuppressWarnings("unchecked")
private Collection<String> convertRoles(Object roles) {
if (roles == null) {
log.debug("Param missing from attributes, skipping");
log.warn("Param missing in attributes, nothing to do");
return Collections.emptySet();
}

Expand All @@ -112,7 +120,7 @@ private Collection<String> convertRoles(Object roles) {
}

if (!(roles instanceof String)) {
log.debug("The field is not a string, skipping");
log.trace("The field is not a string, skipping");
return Collections.emptySet();
}

Expand Down
Loading