Skip to content

Commit 3f97007

Browse files
feat: add thread safe cache for read and write, configure expiry
1 parent e7e3b32 commit 3f97007

File tree

1 file changed

+92
-28
lines changed

1 file changed

+92
-28
lines changed

auth-plugin/plugin_libcurl.cpp

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ For more information, please refer to <http://unlicense.org/>
3232
#include <sys/epoll.h>
3333
#include <stdexcept>
3434
#include <map>
35-
35+
#include <chrono>
3636

3737
#include "pluginstate.h"
3838
#include "curl_functions.h"
@@ -41,9 +41,26 @@ For more information, please refer to <http://unlicense.org/>
4141
#include <openssl/bio.h>
4242
#include <openssl/evp.h>
4343
#include <string>
44-
// include jwt-cpp
44+
#include "parallel_hashmap/phmap.h"
4545
#include "jwt-cpp/jwt.h"
46+
#include <shared_mutex>
4647

48+
/*
49+
* creates a p hashmap with 4 shards with shared mutex so read/write is thread safe. Read is lock free.
50+
* this is a global variable, so it is shared between all threads because acl check can happen in any thread.
51+
* 2**4 = 16 sub maps, so lock is applied to sub map so concurrency is intrinsinc to the map.
52+
*/
53+
const int SHARDS = 4;
54+
using TokenCache = phmap::parallel_flat_hash_map<
55+
std::string,
56+
long long,
57+
phmap::priv::hash_default_hash<std::string>,
58+
phmap::priv::hash_default_eq<std::string>,
59+
phmap::priv::Allocator<phmap::priv::Pair<const std::string, long long>>,
60+
SHARDS,
61+
std::shared_mutex>;
62+
63+
static TokenCache token_expiry_cache;
4764

4865
int flashmq_plugin_version()
4966
{
@@ -63,6 +80,8 @@ void flashmq_plugin_main_deinit(std::unordered_map<std::string, std::string> &pl
6380
(void)plugin_opts;
6481

6582
curl_global_cleanup();
83+
token_expiry_cache.clear();
84+
6685
}
6786
// new text test
6887
void flashmq_plugin_allocate_thread_memory(void **thread_data, std::unordered_map<std::string, std::string> &plugin_opts)
@@ -136,7 +155,6 @@ void flashmq_plugin_poll_event_received(void *thread_data, int fd, uint32_t even
136155

137156

138157

139-
140158
bool allow_user_access(const std::string &username) {
141159
const std::vector<std::string> allowed_users = {
142160
"playbook",
@@ -148,9 +166,8 @@ bool allow_user_access(const std::string &username) {
148166
return std::find(allowed_users.begin(), allowed_users.end(), username) != allowed_users.end();
149167
}
150168

151-
std::string get_env_var( std::string const & key )
152-
{
153-
char * val = getenv( key.c_str() );
169+
std::string get_env_var(std::string const &key) {
170+
char *val = getenv(key.c_str());
154171
return val == NULL ? std::string("") : std::string(val);
155172
}
156173

@@ -168,27 +185,38 @@ std::string base64_decode(const std::string &in) {
168185
BIO_set_flags(bio, BIO_FLAGS_BASE64_NO_NL);
169186

170187
int len = BIO_read(bio, &out[0], (int)in.length());
171-
if (len > 0) out.resize(len);
172-
else out.clear();
188+
if (len > 0)
189+
out.resize(len);
190+
else
191+
out.clear();
173192

174193
BIO_free_all(bio);
175194
return out;
176195
}
177196

197+
198+
void flashmq_plugin_client_disconnected(void *thread_data, const std::string &clientid){
199+
(void)thread_data;
200+
// erase is thread safe as it acquires write lock on the shard.
201+
token_expiry_cache.erase(clientid);
202+
}
203+
204+
178205
AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clientid, const std::string &username, const std::string &password,
179206
const std::vector<std::pair<std::string, std::string>> *userProperties, const std::weak_ptr<Client> &client)
180207
{
181208
(void)clientid;
182209
(void)userProperties;
183210
(void)client;
184-
211+
(void)thread_data;
185212

186213
flashmq_logf(LOG_INFO, "username: %s", username.c_str());
187214

188-
if (allow_user_access(username)){
215+
if (allow_user_access(username))
216+
{
189217
return AuthResult::success;
190218
}
191-
219+
192220
// base64 decode the environment variable AUTH_PUBLICKEY
193221
const std::string rsa_pub_env_key = get_env_var("AUTH_PUBLICKEY");
194222
const std::string rsa_pub_key = base64_decode(rsa_pub_env_key);
@@ -199,25 +227,30 @@ AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clie
199227
flashmq_logf(LOG_ERR, "No token found for username: %s", username.c_str());
200228
return AuthResult::error;
201229
}
202-
203-
230+
204231
// decode the username and password, if they are jwt tokens, and check if they are valid.
205-
try{
232+
try {
206233
/* [allow rsa algorithm] */
207-
auto verify = jwt::verify()
208-
// We only need an RSA public key to verify tokens
209-
.allow_algorithm(jwt::algorithm::rs256(rsa_pub_key, "", "", ""));
234+
auto jwt_verify = jwt::verify()
235+
.allow_algorithm(jwt::algorithm::rs256(rsa_pub_key, "", "", ""));
210236
/* [decode jwt token] */
211237
auto decoded = jwt::decode(token);
212-
flashmq_logf(LOG_INFO, "Decoded JWT token successfully");
213-
verify.verify(decoded);
238+
jwt_verify.verify(decoded);
239+
long long exp_epoch = decoded.get_payload_claim("exp").to_json().get<int64_t>();
240+
/*
241+
* upserts the cache with the clientid and the exp_epoch.
242+
* thread safe write with try_emplace_l as it acquires write lock on the shard (sub maps). Lock is specific to the sub map
243+
*/
244+
token_expiry_cache.try_emplace_l(
245+
clientid,
246+
[&](auto& kv) { kv.second = exp_epoch; },
247+
exp_epoch // construct if missing
248+
);
214249
flashmq_logf(LOG_INFO, "Verified JWT token successfully with public key");
215-
216250
return AuthResult::success;
217251
} catch (const std::exception &e) {
218252
flashmq_logf(LOG_ERR, "Failed to decode JWT token: %s", e.what());
219253
std::cout << "Caught exception: " << e.what() << std::endl;
220-
// print the exception message
221254
return AuthResult::error;
222255
}
223256

@@ -226,15 +259,14 @@ AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clie
226259
}
227260

228261
AuthResult flashmq_plugin_acl_check(void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username,
229-
const std::string &topic, const std::vector<std::string> &subtopics, const std::string &shareName,
230-
std::string_view payload, const uint8_t qos, const bool retain,
231-
const std::optional<std::string> &correlationData, const std::optional<std::string> &responseTopic,
232-
const std::vector<std::pair<std::string, std::string>> *userProperties)
262+
const std::string &topic, const std::vector<std::string> &subtopics, const std::string &shareName,
263+
std::string_view payload, const uint8_t qos, const bool retain,
264+
const std::optional<std::string> &correlationData, const std::optional<std::string> &responseTopic,
265+
const std::vector<std::pair<std::string, std::string>> *userProperties)
233266
{
234267
(void)thread_data;
235268
(void)access;
236269
(void)clientid;
237-
(void)username;
238270
(void)subtopics;
239271
(void)qos;
240272
(void)(retain);
@@ -245,6 +277,38 @@ AuthResult flashmq_plugin_acl_check(void *thread_data, const AclAccess access, c
245277
(void)correlationData;
246278
(void)responseTopic;
247279

248-
return AuthResult::success;
249-
}
280+
// SYS topics are published every 10 seconds, this allow broker internal $SYS topics to be published
281+
bool is_broker_internal_topic = (username.empty() && clientid.empty()) && topic.rfind("$SYS", 0) == 0 && access == AclAccess::write;
282+
bool is_allowed_user = allow_user_access(username);
283+
if (is_broker_internal_topic || is_allowed_user)
284+
{
285+
return AuthResult::success;
286+
}
287+
288+
long long exp_epoch = 0;
289+
// thread safe read with if_contains
290+
bool cache_hit = token_expiry_cache.if_contains(clientid, [&](const TokenCache::value_type &kv) {
291+
exp_epoch = kv.second;
292+
});
293+
294+
250295

296+
if (cache_hit)
297+
{
298+
flashmq_logf(LOG_DEBUG, "JWT verification cache hit for user: %s and exp: %lld", username.c_str(), exp_epoch);
299+
// jwt expiry is in epoch seconds
300+
long long now_epoch = std::chrono::duration_cast<std::chrono::seconds>(
301+
std::chrono::system_clock::now().time_since_epoch()).count();
302+
303+
bool token_expired = now_epoch > exp_epoch;
304+
if(token_expired){
305+
flashmq_logf(LOG_DEBUG, "JWT verification cache expired for user: %s", username.c_str());
306+
token_expiry_cache.erase(clientid);
307+
return AuthResult::acl_denied;
308+
}
309+
310+
return AuthResult::success;
311+
}
312+
313+
return AuthResult::acl_denied;
314+
}

0 commit comments

Comments
 (0)