Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions common/include/common/IAuthenticationManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

// Standard imports
#include <string>

namespace SDMS {

struct LogContext;
/**
* Interface class for managing authenticating
*
Expand All @@ -26,7 +25,7 @@ class IAuthenticationManager {
* Increments the number of times that the key has been accessed, this is
*useful information when deciding if a key should be purged.
**/
virtual void incrementKeyAccessCounter(const std::string &public_key) = 0;
virtual void incrementKeyAccessCounter(const std::string &public_key, LogContext log_context) = 0;

/**
* Will return true if the public key is known. This is also dependent on the
Expand All @@ -39,7 +38,7 @@ class IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual bool hasKey(const std::string &pub_key) const = 0;
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const = 0;

/**
* Will get the unique id or throw an error
Expand All @@ -49,7 +48,7 @@ class IAuthenticationManager {
* - SESSION
* - PERSISTENT - user or repo
**/
virtual std::string getUID(const std::string &pub_key) const = 0;
virtual std::string getUID(const std::string &pub_key, LogContext log_context) const = 0;

/**
* Purge keys if needed
Expand Down
9 changes: 6 additions & 3 deletions common/source/operators/AuthenticationOperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// Local public includes
#include "common/TraceException.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <any>
Expand All @@ -26,16 +27,18 @@ void AuthenticationOperator::execute(IMessage &message) {
EXCEPT(1, "'KEY' attribute not defined.");
}

LogContext log_context;
log_context.correlation_id = std::get<std::string>(message.get(MessageAttribute::CORRELATION_ID));
m_authentication_manager->purge();

std::string key = std::get<std::string>(message.get(MessageAttribute::KEY));

std::string uid = "anon";
if (m_authentication_manager->hasKey(key)) {
m_authentication_manager->incrementKeyAccessCounter(key);
if (m_authentication_manager->hasKey(key, log_context)) {
m_authentication_manager->incrementKeyAccessCounter(key, log_context);

try {
uid = m_authentication_manager->getUID(key);
uid = m_authentication_manager->getUID(key, log_context);
} catch (const std::exception& e) {
// Log the exception to help diagnose authentication issues
std::cerr << "[AuthenticationOperator] Failed to get UID for key: "
Expand Down
7 changes: 4 additions & 3 deletions common/tests/unit/test_OperatorFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "common/MessageFactory.hpp"
#include "common/OperatorFactory.hpp"
#include "common/OperatorTypes.hpp"
#include "common/DynaLog.hpp"

// Third party includes
#include <google/protobuf/stubs/common.h>
Expand Down Expand Up @@ -38,15 +39,15 @@ class DummyAuthManager : public IAuthenticationManager {
/**
* Methods only available via the interface
**/
virtual void incrementKeyAccessCounter(const std::string &pub_key) final {
virtual void incrementKeyAccessCounter(const std::string &pub_key, LogContext log_context) final {
++m_counters.at(pub_key);
}

virtual bool hasKey(const std::string &pub_key) const {
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const {
return m_counters.count(pub_key);
}
// Just assume all keys map to the anon_uid
virtual std::string getUID(const std::string &) const {
virtual std::string getUID(const std::string &, LogContext log_context) const {
return "authenticated_uid";
}

Expand Down
18 changes: 11 additions & 7 deletions core/server/AuthMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ size_t AuthMap::size(const PublicKeyType pub_key_type) const {
}

void AuthMap::incrementKeyAccessCounter(const PublicKeyType pub_key_type,
const std::string &public_key) {
const std::string &public_key,
LogContext log_context) {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
if (m_trans_auth_clients.count(public_key)) {
Expand All @@ -183,7 +184,8 @@ void AuthMap::incrementKeyAccessCounter(const PublicKeyType pub_key_type,
}

bool AuthMap::hasKey(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
return m_trans_auth_clients.count(public_key) > 0;
Expand All @@ -203,7 +205,7 @@ bool AuthMap::hasKey(const PublicKeyType pub_key_type,
try {
DatabaseAPI db(m_db_url, m_db_user, m_db_pass);
std::string uid;
if (db.uidByPubKey(public_key, uid)) {
if (db.uidByPubKey(public_key, uid, log_context)) {
return true;
}
} catch (const std::exception& e) {
Expand All @@ -217,9 +219,10 @@ bool AuthMap::hasKey(const PublicKeyType pub_key_type,
}

std::string AuthMap::getUID(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {

std::string uid = getUIDSafe(pub_key_type, public_key);
std::string uid = getUIDSafe(pub_key_type, public_key, log_context);

if (uid.empty()) {
if (pub_key_type == PublicKeyType::TRANSIENT) {
Expand All @@ -238,7 +241,8 @@ std::string AuthMap::getUID(const PublicKeyType pub_key_type,
}

std::string AuthMap::getUIDSafe(const PublicKeyType pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
if (pub_key_type == PublicKeyType::TRANSIENT) {
lock_guard<mutex> lock(m_trans_clients_mtx);
if (m_trans_auth_clients.count(public_key)) {
Expand All @@ -261,7 +265,7 @@ std::string AuthMap::getUIDSafe(const PublicKeyType pub_key_type,
// Check database for user keys
DatabaseAPI db(m_db_url, m_db_user, m_db_pass);
std::string uid;
if (db.uidByPubKey(public_key, uid)) {
if (db.uidByPubKey(public_key, uid, log_context)) {
return uid;
}
}
Expand Down
13 changes: 9 additions & 4 deletions core/server/AuthMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// Local common includes
#include "common/IAuthenticationManager.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <map>
Expand Down Expand Up @@ -113,13 +114,15 @@ class AuthMap {
*does not exist. Best to call hasKey first.
**/
std::string getUID(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/**
* Safe version that returns empty string if key not found
**/
std::string getUIDSafe(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/**
* Will return the number of keys of the provided type. Does not currently
Expand All @@ -128,7 +131,8 @@ class AuthMap {
size_t size(const PublicKeyType pub_key_type) const;

bool hasKey(const PublicKeyType pub_key_type,
const std::string &public_key) const;
const std::string &public_key,
LogContext log_context) const;

/***********************************************************************************
* Manipulators
Expand All @@ -138,7 +142,8 @@ class AuthMap {
* Increase the recorded times the the public key has been accessed by one.
**/
void incrementKeyAccessCounter(const PublicKeyType pub_key_type,
const std::string &public_key);
const std::string &public_key,
LogContext log_context);

/**
* Adds the key to the AuthMap object
Expand Down
47 changes: 25 additions & 22 deletions core/server/AuthenticationManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

// Common includes
#include "common/TraceException.hpp"
#include "common/DynaLog.hpp"

// Standard includes
#include <iostream>
Expand Down Expand Up @@ -69,46 +70,47 @@ void AuthenticationManager::purge(const PublicKeyType pub_key_type) {
}

void AuthenticationManager::incrementKeyAccessCounter(
const std::string &public_key) {
const std::string &public_key,
LogContext log_context) {
std::lock_guard<std::mutex> lock(m_lock);
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::TRANSIENT,
public_key);
} else if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::SESSION, public_key);
public_key, log_context);
} else if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
m_auth_mapper.incrementKeyAccessCounter(PublicKeyType::SESSION, public_key, log_context);
}
// Ignore persistent cases because counter does nothing for them
}

bool AuthenticationManager::hasKey(const std::string &public_key) const {
bool AuthenticationManager::hasKey(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
return true;
}

if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
return true;
}

if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key)) {
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key, log_context)) {
return true;
}

return false;
}

std::string AuthenticationManager::getUID(const std::string &public_key) const {
std::string AuthenticationManager::getUID(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::TRANSIENT, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::TRANSIENT, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::TRANSIENT, public_key, log_context);
}
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::SESSION, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::SESSION, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::SESSION, public_key, log_context);
}
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key)) {
return m_auth_mapper.getUID(PublicKeyType::PERSISTENT, public_key);
if (m_auth_mapper.hasKey(PublicKeyType::PERSISTENT, public_key, log_context)) {
return m_auth_mapper.getUID(PublicKeyType::PERSISTENT, public_key, log_context);
}

EXCEPT(1, "Unrecognized public_key during execution of getUID.");
Expand All @@ -122,9 +124,10 @@ void AuthenticationManager::addKey(const PublicKeyType &pub_key_type,
}

bool AuthenticationManager::hasKey(const PublicKeyType &pub_key_type,
const std::string &public_key) const {
const std::string &public_key,
LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);
return m_auth_mapper.hasKey(pub_key_type, public_key);
return m_auth_mapper.hasKey(pub_key_type, public_key, log_context);
}

void AuthenticationManager::migrateKey(const PublicKeyType &from_type,
Expand All @@ -150,21 +153,21 @@ void AuthenticationManager::clearAllNonPersistentKeys() {
m_auth_mapper.clearAllNonPersistentKeys();
}

std::string AuthenticationManager::getUIDSafe(const std::string &public_key) const {
std::string AuthenticationManager::getUIDSafe(const std::string &public_key, LogContext log_context) const {
std::lock_guard<std::mutex> lock(m_lock);

// Try each key type in order
std::string uid = m_auth_mapper.getUIDSafe(PublicKeyType::TRANSIENT, public_key);
std::string uid = m_auth_mapper.getUIDSafe(PublicKeyType::TRANSIENT, public_key, log_context);
if (!uid.empty()) {
return uid;
}

uid = m_auth_mapper.getUIDSafe(PublicKeyType::SESSION, public_key);
uid = m_auth_mapper.getUIDSafe(PublicKeyType::SESSION, public_key, log_context);
if (!uid.empty()) {
return uid;
}

uid = m_auth_mapper.getUIDSafe(PublicKeyType::PERSISTENT, public_key);
uid = m_auth_mapper.getUIDSafe(PublicKeyType::PERSISTENT, public_key, log_context);
if (!uid.empty()) {
return uid;
}
Expand Down
10 changes: 5 additions & 5 deletions core/server/AuthenticationManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AuthenticationManager : public IAuthenticationManager {
*allotted purge time frame. If the count is above one then the session key
*not be purged.
**/
virtual void incrementKeyAccessCounter(const std::string &public_key) final;
virtual void incrementKeyAccessCounter(const std::string &public_key, LogContext log_context) final;

/**
* This will purge all keys of a particular type that have expired.
Expand All @@ -79,15 +79,15 @@ class AuthenticationManager : public IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual bool hasKey(const std::string &pub_key) const final;
virtual bool hasKey(const std::string &pub_key, LogContext log_context) const final;

void addKey(const PublicKeyType &pub_key_type, const std::string &public_key,
const std::string &uid);

/**
* Check if a specific key exists in a specific map type
**/
bool hasKey(const PublicKeyType &pub_key_type, const std::string &public_key) const;
bool hasKey(const PublicKeyType &pub_key_type, const std::string &public_key, LogContext log_context) const;

/**
* Migrate a key from one type to another
Expand Down Expand Up @@ -121,13 +121,13 @@ class AuthenticationManager : public IAuthenticationManager {
* - SESSION
* - PERSISTENT
**/
virtual std::string getUID(const std::string &pub_key) const final;
virtual std::string getUID(const std::string &pub_key, LogContext log_context) const final;

/**
* Safe version that returns empty string if key not found
* instead of throwing an exception
**/
std::string getUIDSafe(const std::string &pub_key) const;
std::string getUIDSafe(const std::string &pub_key, LogContext log_context) const;
};

} // namespace Core
Expand Down
11 changes: 10 additions & 1 deletion core/server/Condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@

// Standard includes
#include <iostream>
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>

namespace SDMS {
namespace Core {

void Promote::enforce(AuthMap &auth_map, const std::string &public_key) {
if (auth_map.hasKeyType(m_promote_from, public_key)) {
size_t access_count = auth_map.getAccessCount(m_promote_from, public_key);
boost::uuids::random_generator generator;
boost::uuids::uuid uuid = generator();

LogContext log_context;
log_context.correlation_id = boost::uuids::to_string(uuid);

if (access_count >= m_transient_to_session_count_threshold) {
// Convert transient key to session key if has been accessed more than the
// threshold
std::string uid = auth_map.getUID(m_promote_from, public_key);
std::string uid = auth_map.getUID(m_promote_from, public_key, log_context);
auth_map.addKey(m_promote_to, public_key, uid);
}
// Remove expired short lived transient key
Expand Down
Loading
Loading