Skip to content

Commit e6b24fc

Browse files
committed
fix role mapping
1 parent 4c5d9b9 commit e6b24fc

File tree

3 files changed

+48
-32
lines changed

3 files changed

+48
-32
lines changed

src/Access/TokenAccessStorage.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ namespace ErrorCodes
2020

2121
TokenAccessStorage::TokenAccessStorage(const String & storage_name_, AccessControl & access_control_, const Poco::Util::AbstractConfiguration & config_, const String & prefix_)
2222
: IAccessStorage(storage_name_), access_control(access_control_), config(config_), prefix(prefix_),
23-
roles_filter(config.getString(prefix.empty() ? "" : prefix + "." + "roles_filter", "")),
2423
memory_storage(storage_name_, access_control.getChangesNotifier(), false)
2524
{
2625
std::lock_guard lock(mutex);
2726

2827
const String prefix_str = (prefix.empty() ? "" : prefix + ".");
2928

29+
if (config.has(prefix_str + "roles_filter"))
30+
roles_filter.emplace(config.getString(prefix_str + "roles_filter"));
31+
3032
provider_name = config.getString(prefix_str + "processor");
3133
if (provider_name.empty())
3234
throw Exception(ErrorCodes::BAD_ARGUMENTS, "'processor' must be specified for Token user directory");
@@ -35,7 +37,7 @@ TokenAccessStorage::TokenAccessStorage(const String & storage_name_, AccessContr
3537
if (config.has(prefix_str + "common_roles"))
3638
{
3739
Poco::Util::AbstractConfiguration::Keys role_names;
38-
config.keys(prefix_str + "roles", role_names);
40+
config.keys(prefix_str + "common_roles", role_names);
3941

4042
common_roles_cfg.insert(role_names.begin(), role_names.end());
4143
}
@@ -369,21 +371,22 @@ std::optional<AuthResult> TokenAccessStorage::authenticateImpl(
369371
throwAddressNotAllowed(address);
370372

371373
std::set<String> external_roles;
372-
if (!roles_filter.ok())
373-
{
374-
external_roles = token_credentials.getGroups();
375-
LOG_TRACE(getLogger(), "{}: No external role filtering set, applying all available groups", getStorageName());
376-
}
377-
else
374+
if (roles_filter.has_value() && roles_filter.value().ok())
378375
{
376+
LOG_TRACE(getLogger(), "{}: External role filter found, applying only matching groups", getStorageName());
379377
for (const auto & group: token_credentials.getGroups()) {
380-
if (RE2::FullMatch(group, roles_filter))
378+
if (RE2::FullMatch(group, roles_filter.value()))
381379
{
382380
external_roles.insert(group);
383381
LOG_TRACE(getLogger(), "{}: Granted role (group) {} to user", getStorageName(), user->getName());
384382
}
385383
}
386384
}
385+
else
386+
{
387+
LOG_TRACE(getLogger(), "{}: No external role filtering set, applying all available groups", getStorageName());
388+
external_roles = token_credentials.getGroups();
389+
}
387390

388391
if (new_user)
389392
{

src/Access/TokenAccessStorage.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class TokenAccessStorage : public IAccessStorage
4848
const String & prefix;
4949

5050
String provider_name;
51-
re2::RE2 roles_filter;
51+
std::optional<re2::RE2> roles_filter = std::nullopt;
5252

5353
std::set<String> common_role_names; // role name that should be granted to all users at all times
5454
mutable std::map<String, std::size_t> external_role_hashes;

src/Access/TokenProcessorsOpaque.cpp

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,23 @@ namespace
3030
return jsonValue.get<picojson::object>();
3131
}
3232

33-
template<typename ValueType = std::string>
34-
ValueType getValueByKey(const picojson::object & jsonObject, const std::string & key) {
33+
template<typename ValueType = std::string, bool throw_on_exception = true>
34+
std::optional<ValueType> getValueByKey(const picojson::object & jsonObject, const std::string & key) {
3535
auto it = jsonObject.find(key); // Find the key in the object
3636
if (it == jsonObject.end())
3737
{
38-
throw std::runtime_error("Key not found: " + key);
38+
if constexpr (throw_on_exception)
39+
throw std::runtime_error("Key not found: " + key);
40+
else
41+
return std::nullopt;
3942
}
4043

4144
const picojson::value & value = it->second;
4245
if (!value.is<ValueType>()) {
43-
throw std::runtime_error("Value for key '" + key + "' has incorrect type.");
46+
if constexpr (throw_on_exception)
47+
throw std::runtime_error("Value for key '" + key + "' has incorrect type.");
48+
else
49+
return std::nullopt;
4450
}
4551

4652
return value.get<ValueType>();
@@ -94,11 +100,9 @@ bool GoogleTokenProcessor::resolveAndValidate(const TokenCredentials & credentia
94100
throw Exception(ErrorCodes::AUTHENTICATION_FAILED,
95101
"{}: Specified username_claim {} not found in token", processor_name, username_claim);
96102

97-
bool has_email = user_info_json.contains("email");
98-
if (has_email)
99-
user_info["email"] = getValueByKey(user_info_json, "email");
103+
user_info["email"] = getValueByKey<std::string, false>(user_info_json, "email").value_or("");
100104

101-
user_info[username_claim] = getValueByKey(user_info_json, username_claim);
105+
user_info[username_claim] = getValueByKey(user_info_json, username_claim).value();
102106

103107
String user_name = user_info[username_claim];
104108

@@ -109,11 +113,11 @@ bool GoogleTokenProcessor::resolveAndValidate(const TokenCredentials & credentia
109113

110114
auto token_info = getObjectFromURI(Poco::URI("https://www.googleapis.com/oauth2/v3/tokeninfo"), token);
111115
if (token_info.contains("exp"))
112-
const_cast<TokenCredentials &>(credentials).setExpiresAt(std::chrono::system_clock::from_time_t((getValueByKey<time_t>(token_info, "exp"))));
116+
const_cast<TokenCredentials &>(credentials).setExpiresAt(std::chrono::system_clock::from_time_t((getValueByKey<time_t>(token_info, "exp").value())));
113117

114118
/// Groups info can only be retrieved if user email is known.
115119
/// If no email found in user info, we skip this step and there are no external roles for the user.
116-
if (has_email)
120+
if (!user_info["email"].empty())
117121
{
118122
std::set<String> external_groups_names;
119123
const Poco::URI get_groups_uri = Poco::URI("https://cloudidentity.googleapis.com/v1/groups/-/memberships:searchDirectGroups?query=member_key_id==" + user_info["email"] + "'");
@@ -139,10 +143,13 @@ bool GoogleTokenProcessor::resolveAndValidate(const TokenCredentials & credentia
139143
}
140144

141145
auto group_data = group.get<picojson::object>();
142-
String group_name = getValueByKey(group_data["groupKey"].get<picojson::object>(), "id");
143-
external_groups_names.insert(group_name);
144-
LOG_TRACE(getLogger("TokenAuthentication"),
145-
"{}: User {}: new external group {}", processor_name, user_name, group_name);
146+
String group_name = getValueByKey<std::string, false>(group_data["groupKey"].get<picojson::object>(), "id").value_or("");
147+
if (!group_name.empty())
148+
{
149+
external_groups_names.insert(group_name);
150+
LOG_TRACE(getLogger("TokenAuthentication"),
151+
"{}: User {}: new external group {}", processor_name, user_name, group_name);
152+
}
146153
}
147154

148155
const_cast<TokenCredentials &>(credentials).setGroups(external_groups_names);
@@ -172,7 +179,7 @@ bool AzureTokenProcessor::resolveAndValidate(const TokenCredentials & credential
172179
try
173180
{
174181
picojson::object user_info_json = getObjectFromURI(Poco::URI("https://graph.microsoft.com/oidc/userinfo"), token);
175-
String username = getValueByKey(user_info_json, username_claim);
182+
String username = getValueByKey(user_info_json, username_claim).value();
176183
if (!username.empty())
177184
{
178185
/// Credentials are passed as const everywhere up the flow, so we have to comply,
@@ -224,9 +231,15 @@ bool AzureTokenProcessor::resolveAndValidate(const TokenCredentials & credential
224231
}
225232

226233
auto group_data = group.get<picojson::object>();
227-
String group_name = getValueByKey(group_data, "id");
228-
external_groups_names.insert(group_name);
229-
LOG_TRACE(getLogger("TokenAuthentication"), "{}: User {}: new external group {}", processor_name, credentials.getUserName(), group_name);
234+
if (!group_data.contains("displayName"))
235+
continue;
236+
237+
String group_name = getValueByKey<std::string, false>(group_data, "displayName").value_or("");
238+
if (!group_name.empty())
239+
{
240+
external_groups_names.insert(group_name);
241+
LOG_TRACE(getLogger("TokenAuthentication"), "{}: User {}: new external group {}", processor_name, credentials.getUserName(), group_name);
242+
}
230243
}
231244
}
232245
catch (const Exception & e)
@@ -282,7 +295,7 @@ OpenIdTokenProcessor::OpenIdTokenProcessor(const String & processor_name_,
282295
if (!openid_config.contains("userinfo_endpoint") || !openid_config.contains("introspection_endpoint"))
283296
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: Cannot extract userinfo_endpoint or introspection_endpoint from OIDC configuration, consider manual configuration.", processor_name);
284297

285-
if (!openid_config.contains("jwks_uri"))
298+
if (openid_config.contains("jwks_uri"))
286299
{
287300
LOG_TRACE(getLogger("TokenAuthentication"), "{}: JWKS URI set, local JWT processing will be attempted", processor_name_);
288301
jwt_validator.emplace(processor_name_ + "jwks_val",
@@ -291,7 +304,7 @@ OpenIdTokenProcessor::OpenIdTokenProcessor(const String & processor_name_,
291304
groups_claim_,
292305
"",
293306
verifier_leeway_,
294-
getValueByKey(openid_config, "jwks_uri"),
307+
getValueByKey(openid_config, "jwks_uri").value(),
295308
jwks_cache_lifetime_);
296309
}
297310
}
@@ -308,7 +321,7 @@ bool OpenIdTokenProcessor::resolveAndValidate(const TokenCredentials & credentia
308321
{
309322
auto decoded_token = jwt::decode(token);
310323
user_info_json = decoded_token.get_payload_json();
311-
username = getValueByKey(user_info_json, username_claim);
324+
username = getValueByKey(user_info_json, username_claim).value();
312325

313326
/// TODO: Now we work only with Keycloak -- and it provides expires_at in token itself. Need to add actual token introspection logic for other OIDC providers.
314327
if (decoded_token.has_expires_at())
@@ -326,7 +339,7 @@ bool OpenIdTokenProcessor::resolveAndValidate(const TokenCredentials & credentia
326339
try
327340
{
328341
user_info_json = getObjectFromURI(userinfo_endpoint, token);
329-
username = getValueByKey(user_info_json, username_claim);
342+
username = getValueByKey(user_info_json, username_claim).value();
330343
}
331344
catch (...)
332345
{

0 commit comments

Comments
 (0)