Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions common/include/common/IAuthenticationManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

// Standard imports
#include <string>

namespace SDMS {

struct LogContext;
/**
* Interface class for managing authenticating
*
Expand All @@ -26,7 +25,7 @@ class IAuthenticationManager {
* Increments the number of times that the key has been accessed, this is
*useful information when deciding if a key should be purged.
**/
virtual void incrementKeyAccessCounter(const std::string &public_key) = 0;
virtual void incrementKeyAccessCounter(const std::string &public_key, LogContext log_context) = 0;

/**
* Will return true if the public key is known. This is also dependent on the
Expand All @@ -39,7 +38,7 @@ class IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual bool hasKey(const std::string &pub_key) const = 0;
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const = 0;

/**
* Will get the unique id or throw an error
Expand All @@ -49,7 +48,7 @@ class IAuthenticationManager {
* - SESSION
* - PERSISTENT - user or repo
**/
virtual std::string getUID(const std::string &pub_key) const = 0;
virtual std::string getUID(const std::string &pub_key, LogContext log_context) const = 0;

/**
* Purge keys if needed
Expand Down
15 changes: 12 additions & 3 deletions common/source/operators/AuthenticationOperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

// Local public includes
#include "common/TraceException.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <any>
#include <iostream>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>

namespace SDMS {

Expand All @@ -25,17 +29,22 @@ void AuthenticationOperator::execute(IMessage &message) {
if (message.exists(MessageAttribute::KEY) == 0) {
EXCEPT(1, "'KEY' attribute not defined.");
}
// 🔹 Generate correlation ID for this request
boost::uuids::random_generator generator;
boost::uuids::uuid uuid = generator();

LogContext log_context;
log_context.correlation_id = boost::uuids::to_string(uuid);
m_authentication_manager->purge();

std::string key = std::get<std::string>(message.get(MessageAttribute::KEY));

std::string uid = "anon";
if (m_authentication_manager->hasKey(key)) {
m_authentication_manager->incrementKeyAccessCounter(key);
if (m_authentication_manager->hasKey(key, log_context)) {
m_authentication_manager->incrementKeyAccessCounter(key, log_context);

try {
uid = m_authentication_manager->getUID(key);
uid = m_authentication_manager->getUID(key, log_context);
} catch (const std::exception& e) {
// Log the exception to help diagnose authentication issues
std::cerr << "[AuthenticationOperator] Failed to get UID for key: "
Expand Down
7 changes: 4 additions & 3 deletions common/tests/unit/test_OperatorFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common/MessageFactory.hpp"
#include "common/OperatorFactory.hpp"
#include "common/OperatorTypes.hpp"
#include "common/DynaLog.hpp"

// Third party includes
#include <google/protobuf/stubs/common.h>
Expand Down Expand Up @@ -38,15 +39,15 @@ class DummyAuthManager : public IAuthenticationManager {
/**
* Methods only available via the interface
**/
virtual void incrementKeyAccessCounter(const std::string &pub_key) final {
virtual void incrementKeyAccessCounter(const std::string &pub_key, LogContext log_context) final {
++m_counters.at(pub_key);
}

virtual bool hasKey(const std::string &pub_key) const {
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const {
return m_counters.count(pub_key);
}
// Just assume all keys map to the anon_uid
virtual std::string getUID(const std::string &) const {
virtual std::string getUID(const std::string &, LogContext log_context) const {
return "authenticated_uid";
}

Expand Down
18 changes: 11 additions & 7 deletions core/server/AuthMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ size_t AuthMap::size(const PublicKeyType pub_key_type) const {
}

void AuthMap::incrementKeyAccessCounter(const PublicKeyType pub_key_type,
const std::string &public_key) {
const std::string &public_key,
LogContext log_context) {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
if (m_trans_auth_clients.count(public_key)) {
Expand All @@ -183,7 +184,8 @@ void AuthMap::incrementKeyAccessCounter(const PublicKeyType pub_key_type,
}

bool AuthMap::hasKey(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
return m_trans_auth_clients.count(public_key) > 0;
Expand All @@ -203,7 +205,7 @@ bool AuthMap::hasKey(const PublicKeyType pub_key_type,
try {
DatabaseAPI db(m_db_url, m_db_user, m_db_pass);
std::string uid;
if (db.uidByPubKey(public_key, uid)) {
if (db.uidByPubKey(public_key, uid, log_context)) {
return true;
}
} catch (const std::exception& e) {
Expand All @@ -217,9 +219,10 @@ bool AuthMap::hasKey(const PublicKeyType pub_key_type,
}

std::string AuthMap::getUID(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {

std::string uid = getUIDSafe(pub_key_type, public_key);
std::string uid = getUIDSafe(pub_key_type, public_key, log_context);

if (uid.empty()) {
if (pub_key_type == PublicKeyType::TRANSIENT) {
Expand All @@ -238,7 +241,8 @@ std::string AuthMap::getUID(const PublicKeyType pub_key_type,
}

std::string AuthMap::getUIDSafe(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
if (m_trans_auth_clients.count(public_key)) {
Expand All @@ -261,7 +265,7 @@ std::string AuthMap::getUIDSafe(const PublicKeyType pub_key_type,
// Check database for user keys
DatabaseAPI db(m_db_url, m_db_user, m_db_pass);
std::string uid;
if (db.uidByPubKey(public_key, uid)) {
if (db.uidByPubKey(public_key, uid, log_context)) {
return uid;
}
}
Expand Down
13 changes: 9 additions & 4 deletions core/server/AuthMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// Local common includes
#include "common/IAuthenticationManager.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <map>
Expand Down Expand Up @@ -113,13 +114,15 @@ class AuthMap {
*does not exist. Best to call hasKey first.
**/
std::string getUID(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/**
* Safe version that returns empty string if key not found
**/
std::string getUIDSafe(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/**
* Will return the number of keys of the provided type. Does not currently
Expand All @@ -128,7 +131,8 @@ class AuthMap {
size_t size(const PublicKeyType pub_key_type) const;

bool hasKey(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/***********************************************************************************
* Manipulators
Expand All @@ -138,7 +142,8 @@ class AuthMap {
* Increase the recorded times the the public key has been accessed by one.
**/
void incrementKeyAccessCounter(const PublicKeyType pub_key_type,
const std::string &public_key);
const std::string &public_key,
LogContext log_context);

/**
* Adds the key to the AuthMap object
Expand Down
47 changes: 25 additions & 22 deletions core/server/AuthenticationManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// Common includes
#include "common/TraceException.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <iostream>
Expand Down Expand Up @@ -69,46 +70,47 @@ void AuthenticationManager::purge(const PublicKeyType pub_key_type) {
}

void AuthenticationManager::incrementKeyAccessCounter(
const std::string &public_key) {
const std::string &public_key,
LogContext log_context) {
std::lock_guard<std::mutex> lock(m_lock);
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::TRANSIENT,
public_key);
} else if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::SESSION, public_key);
public_key, log_context);
} else if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::SESSION, public_key, log_context);
}
// Ignore persistent cases because counter does nothing for them
}

bool AuthenticationManager::hasKey(const std::string &public_key) const {
bool AuthenticationManager::hasKey(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
return true;
}

if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
return true;
}

if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key, log_context)) {
return true;
}

return false;
}

std::string AuthenticationManager::getUID(const std::string &public_key) const {
std::string AuthenticationManager::getUID(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::TRANSIENT, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::TRANSIENT, public_key, log_context);
}
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::SESSION, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::SESSION, public_key, log_context);
}
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::PERSISTENT, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::PERSISTENT, public_key, log_context);
}

EXCEPT(1, "Unrecognized public_key during execution of getUID.");
Expand All @@ -122,9 +124,10 @@ void AuthenticationManager::addKey(const PublicKeyType &pub_key_type,
}

bool AuthenticationManager::hasKey(const PublicKeyType &pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);
return m_auth_mapper.hasKey(pub_key_type, public_key);
return m_auth_mapper.hasKey(pub_key_type, public_key, log_context);
}

void AuthenticationManager::migrateKey(const PublicKeyType &from_type,
Expand All @@ -150,21 +153,21 @@ void AuthenticationManager::clearAllNonPersistentKeys() {
m_auth_mapper.clearAllNonPersistentKeys();
}

std::string AuthenticationManager::getUIDSafe(const std::string &public_key) const {
std::string AuthenticationManager::getUIDSafe(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

// Try each key type in order
std::string uid = m_auth_mapper.getUIDSafe(PublicKeyType::TRANSIENT, public_key);
std::string uid = m_auth_mapper.getUIDSafe(PublicKeyType::TRANSIENT, public_key, log_context);
if (!uid.empty()) {
return uid;
}

uid = m_auth_mapper.getUIDSafe(PublicKeyType::SESSION, public_key);
uid = m_auth_mapper.getUIDSafe(PublicKeyType::SESSION, public_key, log_context);
if (!uid.empty()) {
return uid;
}

uid = m_auth_mapper.getUIDSafe(PublicKeyType::PERSISTENT, public_key);
uid = m_auth_mapper.getUIDSafe(PublicKeyType::PERSISTENT, public_key, log_context);
if (!uid.empty()) {
return uid;
}
Expand Down
10 changes: 5 additions & 5 deletions core/server/AuthenticationManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AuthenticationManager : public IAuthenticationManager {
*allotted purge time frame. If the count is above one then the session key
*not be purged.
**/
virtual void incrementKeyAccessCounter(const std::string &public_key) final;
virtual void incrementKeyAccessCounter(const std::string &public_key, LogContext log_context) final;

/**
* This will purge all keys of a particular type that have expired.
Expand All @@ -79,15 +79,15 @@ class AuthenticationManager : public IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual bool hasKey(const std::string &pub_key) const final;
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const final;

void addKey(const PublicKeyType &pub_key_type, const std::string &public_key,
const std::string &uid);

/**
* Check if a specific key exists in a specific map type
**/
bool hasKey(const PublicKeyType &pub_key_type, const std::string &public_key) const;
bool hasKey(const PublicKeyType &pub_key_type, const std::string &public_key, LogContext log_context) const;

/**
* Migrate a key from one type to another
Expand Down Expand Up @@ -121,13 +121,13 @@ class AuthenticationManager : public IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual std::string getUID(const std::string &pub_key) const final;
virtual std::string getUID(const std::string &pub_key, LogContext log_context) const final;

/**
* Safe version that returns empty string if key not found
* instead of throwing an exception
**/
std::string getUIDSafe(const std::string &pub_key) const;
std::string getUIDSafe(const std::string &pub_key, LogContext log_context) const;
};

} // namespace Core
Expand Down
11 changes: 10 additions & 1 deletion core/server/Condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@

// Standard includes
#include <iostream>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>

namespace SDMS {
namespace Core {

void Promote::enforce(AuthMap &auth_map, const std::string &public_key) {
if (auth_map.hasKeyType(m_promote_from, public_key)) {
size_t access_count = auth_map.getAccessCount(m_promote_from, public_key);
boost::uuids::random_generator generator;
boost::uuids::uuid uuid = generator();

LogContext log_context;
log_context.correlation_id = boost::uuids::to_string(uuid);

if (access_count >= m_transient_to_session_count_threshold) {
// Convert transient key to session key if has been accessed more than the
// threshold
std::string uid = auth_map.getUID(m_promote_from, public_key);
std::string uid = auth_map.getUID(m_promote_from, public_key, log_context);
auth_map.addKey(m_promote_to, public_key, uid);
}
// Remove expired short lived transient key
Expand Down
Loading