Skip to content

Commit 7be0c3c

Browse files
committed
add keykloak support(2)
1 parent 2c47b40 commit 7be0c3c

File tree

5 files changed

+120
-45
lines changed

5 files changed

+120
-45
lines changed

src/Access/AccessTokenProcessor.cpp

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <Access/AccessTokenProcessor.h>
22
#include <Common/logger_useful.h>
3+
#include <Poco/StreamCopier.h>
34
#include <picojson/picojson.h>
45
#include <jwt-cpp/jwt.h>
56

@@ -128,7 +129,7 @@ std::unique_ptr<IAccessTokenProcessor> IAccessTokenProcessor::parseTokenProcesso
128129

129130
if (is_auto && !is_manual)
130131
{
131-
return std::make_unique<OpenIDAccessTokenProcessor>(name, cache_lifetime, email_regex_str, config.getString(prefix + ".configuration_endpoint"));
132+
return std::make_unique<OpenIDAccessTokenProcessor>(name, cache_lifetime, email_regex_str, config.getString(prefix + ".configuration_endpoint"), config.getString(prefix + ".groups_claim_name", ""));
132133
}
133134
else if (!is_auto && is_manual)
134135
{
@@ -335,57 +336,120 @@ String AzureAccessTokenProcessor::validateTokenAndGetUsername(const String & tok
335336
return getValueByKey(user_info_json, "sub");
336337
}
337338

339+
340+
OpenIDAccessTokenProcessor::OpenIDAccessTokenProcessor(const String & name_,
341+
const UInt64 cache_invalidation_interval_,
342+
const String & email_regex_str,
343+
const String & userinfo_endpoint_,
344+
const String & token_introspection_endpoint_,
345+
const String & jwks_uri_,
346+
const String & groups_claim_name_)
347+
: IAccessTokenProcessor(name_, cache_invalidation_interval_, email_regex_str),
348+
userinfo_endpoint(userinfo_endpoint_), token_introspection_endpoint(token_introspection_endpoint_), groups_claim_name(groups_claim_name_)
349+
{
350+
if (!jwks_uri_.empty())
351+
{
352+
jwt_validator.emplace(name_ + "jwks_val", jwks_uri_, cache_invalidation_interval_);
353+
}
354+
}
355+
338356
OpenIDAccessTokenProcessor::OpenIDAccessTokenProcessor(const String & name_,
339357
const UInt64 cache_invalidation_interval_,
340358
const String & email_regex_str,
341-
const String & openid_config_endpoint_)
359+
const String & openid_config_endpoint_,
360+
const String & groups_claim_name_)
342361
: IAccessTokenProcessor(name_, cache_invalidation_interval_, email_regex_str)
343362
{
344363
const picojson::object openid_config = getObjectFromURI(Poco::URI(openid_config_endpoint_));
345364

346365
if (!openid_config.contains("userinfo_endpoint") || !openid_config.contains("introspection_endpoint"))
347366
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: Cannot extract userinfo_endpoint or introspection_endpoint from OIDC configuration, consider manual configuration.", name);
367+
368+
OpenIDAccessTokenProcessor(name_,
369+
cache_invalidation_interval_,
370+
email_regex_str,
371+
getValueByKey(openid_config, "userinfo_endpoint"),
372+
getValueByKey(openid_config, "introspection_endpoint"),
373+
openid_config.contains("jwks_uri") ? getValueByKey(openid_config, "jwks_uri") : "",
374+
groups_claim_name_);
348375
}
349376

350377
bool OpenIDAccessTokenProcessor::resolveAndValidate(const TokenCredentials & credentials)
351378
{
352379
const String & token = credentials.getToken();
380+
String username;
381+
picojson::object user_info_json;
353382

354-
try
383+
if (jwt_validator.has_value() && jwt_validator.value().validate("", token, username))
355384
{
356-
String username = validateTokenAndGetUsername(token);
357-
if (!username.empty())
385+
386+
try
358387
{
359-
/// Credentials are passed as const everywhere up the flow, so we have to comply,
360-
/// in this case const_cast looks acceptable.
361-
const_cast<TokenCredentials &>(credentials).setUserName(username);
388+
auto decoded_token = jwt::decode(token);
389+
user_info_json = decoded_token.get_payload_json();
390+
391+
/// TODO: Now we work only with Keycloak -- and it provides expires_at in token itself. Need to add actual token introspection logic for other OIDC providers.
392+
if (decoded_token.has_expires_at())
393+
const_cast<TokenCredentials &>(credentials).setExpiresAt(decoded_token.get_expires_at());
362394
}
363-
else
364-
LOG_TRACE(getLogger("AccessTokenProcessor"), "{}: Failed to get username with token", name);
395+
catch (const std::exception & ex)
396+
{
397+
LOG_TRACE(getLogger("AccessTokenProcessor"), "{}: Failed to process token as JWT: {}", name, ex.what());
398+
}
399+
}
400+
401+
/// If username or user info is empty -- local validation failed, trying introspection via provider
402+
if (username.empty() || user_info_json.empty())
403+
{
404+
try
405+
{
406+
user_info_json = getObjectFromURI(userinfo_endpoint, token);
407+
username = getValueByKey(user_info_json, "sub");
408+
}
409+
catch (...)
410+
{
411+
return false;
412+
}
413+
}
365414

415+
if (user_info_json.empty())
416+
{
417+
LOG_TRACE(getLogger("AccessTokenProcessor"), "{}: Failed to obtain user info", name);
418+
return false;
366419
}
367-
catch (...)
420+
else if (username.empty())
368421
{
422+
LOG_TRACE(getLogger("AccessTokenProcessor"), "{}: Failed to get username", name);
369423
return false;
370424
}
371425

372-
return true;
426+
/// Credentials are passed as const everywhere up the flow, so we have to comply,
427+
/// in this case const_cast is acceptable.
428+
const_cast<TokenCredentials &>(credentials).setUserName(username);
373429

374-
/// TODO: add proper groups functionality
375-
// try
376-
// {
377-
// const_cast<TokenCredentials &>(credentials).setExpiresAt(jwt::decode(token).get_expires_at());
378-
// }
379-
// catch (...) {
380-
// LOG_TRACE(getLogger("AccessTokenProcessor"),
381-
// "{}: No expiration data found in a valid token, will use default cache lifetime", name);
382-
// }
383-
}
430+
/// For now, list of groups is expected in a claim with specified name either in token itself or in userinfo response (Keycloak works this way)
431+
/// TODO: add support for custom endpoints for retrieving groups. Keycloak lists groups in /userinfo and token itself, which is not always the case.
432+
if (!groups_claim_name.empty() && user_info_json.contains(groups_claim_name))
433+
{
434+
if (!user_info_json[groups_claim_name].is<picojson::array>())
435+
{
436+
LOG_TRACE(getLogger("AccessTokenProcessor"),
437+
"{}: Failed to extract groups: invalid content in user data", name);
438+
return true;
439+
}
384440

385-
String OpenIDAccessTokenProcessor::validateTokenAndGetUsername(const String & token) const
386-
{
387-
picojson::object user_info_json = getObjectFromURI(userinfo_endpoint, token);
388-
return getValueByKey(user_info_json, "sub");
441+
std::set<String> external_groups_names;
442+
443+
picojson::array groups_array = user_info_json[groups_claim_name].get<picojson::array>();
444+
for (const auto & group: groups_array)
445+
{
446+
if (group.is<std::string>())
447+
external_groups_names.insert(group.get<std::string>());
448+
}
449+
const_cast<TokenCredentials &>(credentials).setGroups(external_groups_names);
450+
}
451+
452+
return true;
389453
}
390454

391455
}

src/Access/AccessTokenProcessor.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,28 @@ class OpenIDAccessTokenProcessor : public IAccessTokenProcessor
106106
OpenIDAccessTokenProcessor(const String & name_,
107107
const UInt64 cache_invalidation_interval_,
108108
const String & email_regex_str,
109-
const String & openid_config_endpoint_);
109+
const String & openid_config_endpoint_,
110+
const String & groups_claim_name_);
110111

111112
/// Specify endpoints manually
112113
OpenIDAccessTokenProcessor(const String & name_,
113114
const UInt64 cache_invalidation_interval_,
114115
const String & email_regex_str,
115116
const String & userinfo_endpoint_,
116-
const String & token_introspection_endpoint_)
117-
: IAccessTokenProcessor(name_, cache_invalidation_interval_, email_regex_str),
118-
userinfo_endpoint(userinfo_endpoint_), token_introspection_endpoint(token_introspection_endpoint_) {}
117+
const String & token_introspection_endpoint_,
118+
const String & jwks_uri_,
119+
const String & groups_claim_name_);
119120

120121
bool resolveAndValidate(const TokenCredentials & credentials) override;
121122
private:
122-
const Poco::URI userinfo_endpoint;
123-
const Poco::URI token_introspection_endpoint;
123+
Poco::URI userinfo_endpoint;
124+
Poco::URI token_introspection_endpoint;
124125

126+
/// Access token is often a valid JWT, so we can validate it locally to avoid unnecesary network requests.
127+
std::optional<JWKSValidator> jwt_validator = std::nullopt;
125128

126-
String validateTokenAndGetUsername(const String & token) const;
129+
/// groups are expected under /userinfo endpoint under specified name
130+
const String groups_claim_name;
127131
};
128132

129133
}

src/Access/JWTValidator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class JWKSValidator : public IJWTValidator
6161
public:
6262
explicit JWKSValidator(const String & name_, std::shared_ptr<IJWKSProvider> provider_)
6363
: IJWTValidator(name_), provider(provider_) {}
64+
65+
explicit JWKSValidator(const String & name_, const String & uri, const size_t refresh_ms_)
66+
: JWKSValidator(name_, std::make_shared<JWKSClient>(uri, refresh_ms_)) {}
6467
private:
6568
void validateImpl(const jwt::decoded_jwt<jwt::traits::kazuho_picojson> & token) const override;
6669

src/Access/TokenAccessStorage.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,9 @@ void TokenAccessStorage::updateAssignedRolesNoLock(const UUID & id, const String
341341
std::optional<AuthResult> TokenAccessStorage::authenticateImpl(
342342
const Credentials & credentials,
343343
const Poco::Net::IPAddress & address,
344-
[[maybe_unused]] const ExternalAuthenticators & external_authenticators,
345-
[[maybe_unused]] bool throw_if_user_not_exists,
344+
const ExternalAuthenticators & external_authenticators,
345+
const ClientInfo & /* client_info */,
346+
bool throw_if_user_not_exists,
346347
bool /* allow_no_password */,
347348
bool /* allow_plaintext_password */) const
348349
{
@@ -356,7 +357,7 @@ std::optional<AuthResult> TokenAccessStorage::authenticateImpl(
356357
{
357358
// Even though token itself may be valid (especially in case of a jwt token), authentication has just failed.
358359
if (throw_if_user_not_exists)
359-
throwNotFound(AccessEntityType::USER, credentials.getUserName());
360+
throwNotFound(AccessEntityType::USER, credentials.getUserName(), getStorageName());
360361
else
361362
return {};
362363
}

src/Access/TokenAccessStorage.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,21 @@ class TokenAccessStorage : public IAccessStorage
6262

6363
bool areTokenCredentialsValidNoLock(const User & user, const Credentials & credentials, const ExternalAuthenticators & external_authenticators) const;
6464

65+
void applyRoleChangeNoLock(bool grant, const UUID & role_id, const String & role_name);
66+
void assignRolesNoLock(User & user, const std::set<String> & external_roles, std::size_t external_roles_hash) const;
67+
void updateAssignedRolesNoLock(const UUID & id, const String & user_name, const std::set<String> & external_roles) const;
68+
69+
protected:
6570
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
6671
std::vector<UUID> findAllImpl(AccessEntityType type) const override;
6772
AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
6873
std::optional<std::pair<String, AccessEntityType>> readNameWithTypeImpl(const UUID & id, bool throw_if_not_exists) const override;
69-
std::optional<AuthResult> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address,
70-
[[maybe_unused]] const ExternalAuthenticators & external_authenticators,
71-
[[maybe_unused]] bool throw_if_user_not_exists,
72-
bool allow_no_password, bool allow_plaintext_password) const override;
73-
74-
75-
void applyRoleChangeNoLock(bool grant, const UUID & role_id, const String & role_name);
76-
void assignRolesNoLock(User & user, const std::set<String> & external_roles, std::size_t external_roles_hash) const;
77-
void updateAssignedRolesNoLock(const UUID & id, const String & user_name, const std::set<String> & external_roles) const;
74+
std::optional<AuthResult> authenticateImpl(const Credentials & credentials,
75+
const Poco::Net::IPAddress & address,
76+
const ExternalAuthenticators & external_authenticators,
77+
const ClientInfo & client_info,
78+
bool throw_if_user_not_exists,
79+
bool allow_no_password,
80+
bool allow_plaintext_password) const override;
7881
};
7982
}

0 commit comments

Comments
 (0)