Skip to content

Commit cbef9bc

Browse files
feat: add thread safe cache for read and write, configure expiry
1 parent e7e3b32 commit cbef9bc

File tree

1 file changed

+98
-28
lines changed

1 file changed

+98
-28
lines changed

auth-plugin/plugin_libcurl.cpp

Lines changed: 98 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ For more information, please refer to <http://unlicense.org/>
3232
#include <sys/epoll.h>
3333
#include <stdexcept>
3434
#include <map>
35-
35+
#include <chrono>
3636

3737
#include "pluginstate.h"
3838
#include "curl_functions.h"
@@ -41,9 +41,26 @@ For more information, please refer to <http://unlicense.org/>
4141
#include <openssl/bio.h>
4242
#include <openssl/evp.h>
4343
#include <string>
44-
// include jwt-cpp
44+
#include "parallel_hashmap/phmap.h"
4545
#include "jwt-cpp/jwt.h"
46+
#include <shared_mutex>
4647

48+
/*
49+
* creates a p hashmap with 4 shards with shared mutex so read/write is thread safe. Read is lock free.
50+
* this is a global variable, so it is shared between all threads because acl check can happen in any thread.
51+
* 2**4 = 16 sub maps, so lock is applied to sub map so concurrency is intrinsinc to the map.
52+
*/
53+
const int SHARDS = 4;
54+
using TokenCache = phmap::parallel_flat_hash_map<
55+
std::string,
56+
long long,
57+
phmap::priv::hash_default_hash<std::string>,
58+
phmap::priv::hash_default_eq<std::string>,
59+
phmap::priv::Allocator<phmap::priv::Pair<const std::string, long long>>,
60+
SHARDS,
61+
std::shared_mutex>;
62+
63+
static TokenCache token_expiry_cache;
4764

4865
int flashmq_plugin_version()
4966
{
@@ -63,6 +80,8 @@ void flashmq_plugin_main_deinit(std::unordered_map<std::string, std::string> &pl
6380
(void)plugin_opts;
6481

6582
curl_global_cleanup();
83+
token_expiry_cache.clear();
84+
6685
}
6786
// new text test
6887
void flashmq_plugin_allocate_thread_memory(void **thread_data, std::unordered_map<std::string, std::string> &plugin_opts)
@@ -136,7 +155,6 @@ void flashmq_plugin_poll_event_received(void *thread_data, int fd, uint32_t even
136155

137156

138157

139-
140158
bool allow_user_access(const std::string &username) {
141159
const std::vector<std::string> allowed_users = {
142160
"playbook",
@@ -148,9 +166,8 @@ bool allow_user_access(const std::string &username) {
148166
return std::find(allowed_users.begin(), allowed_users.end(), username) != allowed_users.end();
149167
}
150168

151-
std::string get_env_var( std::string const & key )
152-
{
153-
char * val = getenv( key.c_str() );
169+
std::string get_env_var(std::string const &key) {
170+
char *val = getenv(key.c_str());
154171
return val == NULL ? std::string("") : std::string(val);
155172
}
156173

@@ -168,27 +185,38 @@ std::string base64_decode(const std::string &in) {
168185
BIO_set_flags(bio, BIO_FLAGS_BASE64_NO_NL);
169186

170187
int len = BIO_read(bio, &out[0], (int)in.length());
171-
if (len > 0) out.resize(len);
172-
else out.clear();
188+
if (len > 0)
189+
out.resize(len);
190+
else
191+
out.clear();
173192

174193
BIO_free_all(bio);
175194
return out;
176195
}
177196

197+
198+
void flashmq_plugin_client_disconnected(void *thread_data, const std::string &clientid){
199+
(void)thread_data;
200+
// erase is thread safe as it acquires write lock on the shard.
201+
token_expiry_cache.erase(clientid);
202+
}
203+
204+
178205
AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clientid, const std::string &username, const std::string &password,
179206
const std::vector<std::pair<std::string, std::string>> *userProperties, const std::weak_ptr<Client> &client)
180207
{
181208
(void)clientid;
182209
(void)userProperties;
183210
(void)client;
184-
211+
(void)thread_data;
185212

186213
flashmq_logf(LOG_INFO, "username: %s", username.c_str());
187214

188-
if (allow_user_access(username)){
215+
if (allow_user_access(username))
216+
{
189217
return AuthResult::success;
190218
}
191-
219+
192220
// base64 decode the environment variable AUTH_PUBLICKEY
193221
const std::string rsa_pub_env_key = get_env_var("AUTH_PUBLICKEY");
194222
const std::string rsa_pub_key = base64_decode(rsa_pub_env_key);
@@ -199,25 +227,31 @@ AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clie
199227
flashmq_logf(LOG_ERR, "No token found for username: %s", username.c_str());
200228
return AuthResult::error;
201229
}
202-
203-
230+
204231
// decode the username and password, if they are jwt tokens, and check if they are valid.
205-
try{
232+
try
233+
{
206234
/* [allow rsa algorithm] */
207-
auto verify = jwt::verify()
208-
// We only need an RSA public key to verify tokens
209-
.allow_algorithm(jwt::algorithm::rs256(rsa_pub_key, "", "", ""));
235+
auto jwt_verify = jwt::verify()
236+
.allow_algorithm(jwt::algorithm::rs256(rsa_pub_key, "", "", ""));
210237
/* [decode jwt token] */
211238
auto decoded = jwt::decode(token);
212-
flashmq_logf(LOG_INFO, "Decoded JWT token successfully");
213-
verify.verify(decoded);
239+
jwt_verify.verify(decoded);
240+
long long exp_epoch = decoded.get_payload_claim("exp").to_json().get<int64_t>();
241+
/*
242+
* upserts the cache with the clientid and the exp_epoch.
243+
* thread safe write with try_emplace_l as it acquires write lock on the shard (sub maps). Lock is specific to the sub map
244+
*/
245+
token_expiry_cache.try_emplace_l(
246+
clientid,
247+
[&](auto& kv) { kv.second = exp_epoch; },
248+
exp_epoch // construct if missing
249+
);
214250
flashmq_logf(LOG_INFO, "Verified JWT token successfully with public key");
215-
216251
return AuthResult::success;
217252
} catch (const std::exception &e) {
218253
flashmq_logf(LOG_ERR, "Failed to decode JWT token: %s", e.what());
219254
std::cout << "Caught exception: " << e.what() << std::endl;
220-
// print the exception message
221255
return AuthResult::error;
222256
}
223257

@@ -226,15 +260,14 @@ AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clie
226260
}
227261

228262
AuthResult flashmq_plugin_acl_check(void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username,
229-
const std::string &topic, const std::vector<std::string> &subtopics, const std::string &shareName,
230-
std::string_view payload, const uint8_t qos, const bool retain,
231-
const std::optional<std::string> &correlationData, const std::optional<std::string> &responseTopic,
232-
const std::vector<std::pair<std::string, std::string>> *userProperties)
263+
const std::string &topic, const std::vector<std::string> &subtopics, const std::string &shareName,
264+
std::string_view payload, const uint8_t qos, const bool retain,
265+
const std::optional<std::string> &correlationData, const std::optional<std::string> &responseTopic,
266+
const std::vector<std::pair<std::string, std::string>> *userProperties)
233267
{
234268
(void)thread_data;
235269
(void)access;
236270
(void)clientid;
237-
(void)username;
238271
(void)subtopics;
239272
(void)qos;
240273
(void)(retain);
@@ -245,6 +278,43 @@ AuthResult flashmq_plugin_acl_check(void *thread_data, const AclAccess access, c
245278
(void)correlationData;
246279
(void)responseTopic;
247280

248-
return AuthResult::success;
249-
}
281+
// SYS topics are published every 10 seconds, this allow broker internal $SYS topics to be published
282+
bool is_broker_internal_topic = (username.empty() && clientid.empty()) && topic.rfind("$SYS", 0) == 0 && access == AclAccess::write;
283+
284+
if (is_broker_internal_topic)
285+
{
286+
return AuthResult::success;
287+
}
288+
289+
if (allow_user_access(username))
290+
{
291+
return AuthResult::success;
292+
}
250293

294+
long long exp_epoch = 0;
295+
// thread safe read with if_contains
296+
bool cache_hit = token_expiry_cache.if_contains(clientid, [&](const TokenCache::value_type &kv) {
297+
exp_epoch = kv.second;
298+
});
299+
300+
301+
302+
if (cache_hit)
303+
{
304+
flashmq_logf(LOG_DEBUG, "JWT verification cache hit for user: %s and exp: %lld", username.c_str(), exp_epoch);
305+
// jwt expiry is in epoch seconds
306+
long long now_epoch = std::chrono::duration_cast<std::chrono::seconds>(
307+
std::chrono::system_clock::now().time_since_epoch()).count();
308+
309+
bool token_expired = now_epoch > exp_epoch;
310+
if(token_expired){
311+
flashmq_logf(LOG_DEBUG, "JWT verification cache expired for user: %s", username.c_str());
312+
token_expiry_cache.erase(clientid);
313+
return AuthResult::acl_denied;
314+
}
315+
316+
return AuthResult::success;
317+
}
318+
319+
return AuthResult::acl_denied;
320+
}

0 commit comments

Comments
 (0)