@@ -32,7 +32,7 @@ For more information, please refer to <http://unlicense.org/>
3232#include < sys/epoll.h>
3333#include < stdexcept>
3434#include < map>
35-
35+ # include < chrono >
3636
3737#include " pluginstate.h"
3838#include " curl_functions.h"
@@ -41,9 +41,26 @@ For more information, please refer to <http://unlicense.org/>
4141#include < openssl/bio.h>
4242#include < openssl/evp.h>
4343#include < string>
44- // include jwt-cpp
44+ # include " parallel_hashmap/phmap.h "
4545#include " jwt-cpp/jwt.h"
46+ #include < shared_mutex>
4647
48+ /*
49+ * creates a p hashmap with 4 shards with shared mutex so read/write is thread safe. Read is lock free.
50+ * this is a global variable, so it is shared between all threads because acl check can happen in any thread.
51+ * 2**4 = 16 sub maps, so lock is applied to sub map so concurrency is intrinsinc to the map.
52+ */
53+ const int SHARDS = 4 ;
54+ using TokenCache = phmap::parallel_flat_hash_map<
55+ std::string,
56+ long long ,
57+ phmap::priv::hash_default_hash<std::string>,
58+ phmap::priv::hash_default_eq<std::string>,
59+ phmap::priv::Allocator<phmap::priv::Pair<const std::string, long long >>,
60+ SHARDS,
61+ std::shared_mutex>;
62+
63+ static TokenCache token_expiry_cache;
4764
4865int flashmq_plugin_version ()
4966{
@@ -63,6 +80,8 @@ void flashmq_plugin_main_deinit(std::unordered_map<std::string, std::string> &pl
6380 (void )plugin_opts;
6481
6582 curl_global_cleanup ();
83+ token_expiry_cache.clear ();
84+
6685}
6786// new text test
6887void flashmq_plugin_allocate_thread_memory (void **thread_data, std::unordered_map<std::string, std::string> &plugin_opts)
@@ -136,7 +155,6 @@ void flashmq_plugin_poll_event_received(void *thread_data, int fd, uint32_t even
136155
137156
138157
139-
140158bool allow_user_access (const std::string &username) {
141159 const std::vector<std::string> allowed_users = {
142160 " playbook" ,
@@ -148,9 +166,8 @@ bool allow_user_access(const std::string &username) {
148166 return std::find (allowed_users.begin (), allowed_users.end (), username) != allowed_users.end ();
149167}
150168
151- std::string get_env_var ( std::string const & key )
152- {
153- char * val = getenv ( key.c_str () );
169+ std::string get_env_var (std::string const &key) {
170+ char *val = getenv (key.c_str ());
154171 return val == NULL ? std::string (" " ) : std::string (val);
155172}
156173
@@ -168,27 +185,38 @@ std::string base64_decode(const std::string &in) {
168185 BIO_set_flags (bio, BIO_FLAGS_BASE64_NO_NL);
169186
170187 int len = BIO_read (bio, &out[0 ], (int )in.length ());
171- if (len > 0 ) out.resize (len);
172- else out.clear ();
188+ if (len > 0 )
189+ out.resize (len);
190+ else
191+ out.clear ();
173192
174193 BIO_free_all (bio);
175194 return out;
176195}
177196
197+
198+ void flashmq_plugin_client_disconnected (void *thread_data, const std::string &clientid){
199+ (void )thread_data;
200+ // erase is thread safe as it acquires write lock on the shard.
201+ token_expiry_cache.erase (clientid);
202+ }
203+
204+
178205AuthResult flashmq_plugin_login_check (void *thread_data, const std::string &clientid, const std::string &username, const std::string &password,
179206 const std::vector<std::pair<std::string, std::string>> *userProperties, const std::weak_ptr<Client> &client)
180207{
181208 (void )clientid;
182209 (void )userProperties;
183210 (void )client;
184-
211+ ( void )thread_data;
185212
186213 flashmq_logf (LOG_INFO, " username: %s" , username.c_str ());
187214
188- if (allow_user_access (username)){
215+ if (allow_user_access (username))
216+ {
189217 return AuthResult::success;
190218 }
191-
219+
192220 // base64 decode the environment variable AUTH_PUBLICKEY
193221 const std::string rsa_pub_env_key = get_env_var (" AUTH_PUBLICKEY" );
194222 const std::string rsa_pub_key = base64_decode (rsa_pub_env_key);
@@ -199,25 +227,31 @@ AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clie
199227 flashmq_logf (LOG_ERR, " No token found for username: %s" , username.c_str ());
200228 return AuthResult::error;
201229 }
202-
203-
230+
204231 // decode the username and password, if they are jwt tokens, and check if they are valid.
205- try {
232+ try
233+ {
206234 /* [allow rsa algorithm] */
207- auto verify = jwt::verify ()
208- // We only need an RSA public key to verify tokens
209- .allow_algorithm (jwt::algorithm::rs256 (rsa_pub_key, " " , " " , " " ));
235+ auto jwt_verify = jwt::verify ()
236+ .allow_algorithm (jwt::algorithm::rs256 (rsa_pub_key, " " , " " , " " ));
210237 /* [decode jwt token] */
211238 auto decoded = jwt::decode (token);
212- flashmq_logf (LOG_INFO, " Decoded JWT token successfully" );
213- verify.verify (decoded);
239+ jwt_verify.verify (decoded);
240+ long long exp_epoch = decoded.get_payload_claim (" exp" ).to_json ().get <int64_t >();
241+ /*
242+ * upserts the cache with the clientid and the exp_epoch.
243+ * thread safe write with try_emplace_l as it acquires write lock on the shard (sub maps). Lock is specific to the sub map
244+ */
245+ token_expiry_cache.try_emplace_l (
246+ clientid,
247+ [&](auto & kv) { kv.second = exp_epoch; },
248+ exp_epoch // construct if missing
249+ );
214250 flashmq_logf (LOG_INFO, " Verified JWT token successfully with public key" );
215-
216251 return AuthResult::success;
217252 } catch (const std::exception &e) {
218253 flashmq_logf (LOG_ERR, " Failed to decode JWT token: %s" , e.what ());
219254 std::cout << " Caught exception: " << e.what () << std::endl;
220- // print the exception message
221255 return AuthResult::error;
222256 }
223257
@@ -226,15 +260,14 @@ AuthResult flashmq_plugin_login_check(void *thread_data, const std::string &clie
226260}
227261
228262AuthResult flashmq_plugin_acl_check (void *thread_data, const AclAccess access, const std::string &clientid, const std::string &username,
229- const std::string &topic, const std::vector<std::string> &subtopics, const std::string &shareName,
230- std::string_view payload, const uint8_t qos, const bool retain,
231- const std::optional<std::string> &correlationData, const std::optional<std::string> &responseTopic,
232- const std::vector<std::pair<std::string, std::string>> *userProperties)
263+ const std::string &topic, const std::vector<std::string> &subtopics, const std::string &shareName,
264+ std::string_view payload, const uint8_t qos, const bool retain,
265+ const std::optional<std::string> &correlationData, const std::optional<std::string> &responseTopic,
266+ const std::vector<std::pair<std::string, std::string>> *userProperties)
233267{
234268 (void )thread_data;
235269 (void )access;
236270 (void )clientid;
237- (void )username;
238271 (void )subtopics;
239272 (void )qos;
240273 (void )(retain);
@@ -245,6 +278,43 @@ AuthResult flashmq_plugin_acl_check(void *thread_data, const AclAccess access, c
245278 (void )correlationData;
246279 (void )responseTopic;
247280
248- return AuthResult::success;
249- }
281+ // SYS topics are published every 10 seconds, this allow broker internal $SYS topics to be published
282+ bool is_broker_internal_topic = (username.empty () && clientid.empty ()) && topic.rfind (" $SYS" , 0 ) == 0 && access == AclAccess::write;
283+
284+ if (is_broker_internal_topic)
285+ {
286+ return AuthResult::success;
287+ }
288+
289+ if (allow_user_access (username))
290+ {
291+ return AuthResult::success;
292+ }
250293
294+ long long exp_epoch = 0 ;
295+ // thread safe read with if_contains
296+ bool cache_hit = token_expiry_cache.if_contains (clientid, [&](const TokenCache::value_type &kv) {
297+ exp_epoch = kv.second ;
298+ });
299+
300+
301+
302+ if (cache_hit)
303+ {
304+ flashmq_logf (LOG_DEBUG, " JWT verification cache hit for user: %s and exp: %lld" , username.c_str (), exp_epoch);
305+ // jwt expiry is in epoch seconds
306+ long long now_epoch = std::chrono::duration_cast<std::chrono::seconds>(
307+ std::chrono::system_clock::now ().time_since_epoch ()).count ();
308+
309+ bool token_expired = now_epoch > exp_epoch;
310+ if (token_expired){
311+ flashmq_logf (LOG_DEBUG, " JWT verification cache expired for user: %s" , username.c_str ());
312+ token_expiry_cache.erase (clientid);
313+ return AuthResult::acl_denied;
314+ }
315+
316+ return AuthResult::success;
317+ }
318+
319+ return AuthResult::acl_denied;
320+ }
0 commit comments