Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 111 additions & 78 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <iostream>
#include <iterator>
#include <mutex>
#include <shared_mutex>
#include <stdexcept>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -658,6 +659,8 @@ class RedisDB : public BaseDB<TypeValue>
#ifdef __ENABLE_RMQ__

enum RMQConnectionStatus { FAILED, CONNECTED, CLOSED, ERROR };
// Forward declaration
class RMQPublisher;

/**
* @brief AMS represents the header as follows:
Expand Down Expand Up @@ -832,10 +835,7 @@ class AMSMessage
domain_name.c_str(),
domain_name.size());
current_offset += domain_name.size();
current_offset +=
encode_data(_data + current_offset,
inputs,
outputs);
current_offset += encode_data(_data + current_offset, inputs, outputs);
DBG(AMSMessage, "Allocated message %d: %p", _id, _data);
CALIPER(CALI_MARK_END("AMS_MESSAGE");)
}
Expand All @@ -850,7 +850,12 @@ class AMSMessage

AMSMessage(const AMSMessage& other)
{
DBG(AMSMessage, "Copy AMSMessage (%d, %p) <- (%d, %p)", _id, _data, other._id, other._data);
DBG(AMSMessage,
"Copy AMSMessage (%d, %p) <- (%d, %p)",
_id,
_data,
other._id,
other._data);
swap(other);
};

Expand All @@ -866,7 +871,12 @@ class AMSMessage

AMSMessage& operator=(AMSMessage&& other) noexcept
{
DBG(AMSMessage, "Move AMSMessage (%d, %p) <- (%d, %p)", _id, _data, other._id, other._data);
DBG(AMSMessage,
"Move AMSMessage (%d, %p) <- (%d, %p)",
_id,
_data,
other._id,
other._data);
if (this != &other) {
swap(other);
other._data = nullptr;
Expand All @@ -892,12 +902,16 @@ class AMSMessage

// Creating the body part of the message
for (size_t i = 0; i < _input_dim; i++) {
std::memcpy(data_blob + offset, inputs[i], _num_elements * sizeof(TypeValue));
std::memcpy(data_blob + offset,
inputs[i],
_num_elements * sizeof(TypeValue));
offset += (_num_elements * sizeof(TypeValue));
}

for (size_t i = 0; i < _output_dim; i++) {
std::memcpy(data_blob + offset, outputs[i], _num_elements * sizeof(TypeValue));
std::memcpy(data_blob + offset,
outputs[i],
_num_elements * sizeof(TypeValue));
offset += (_num_elements * sizeof(TypeValue));
}

Expand Down Expand Up @@ -936,16 +950,68 @@ class AMSMessage
* @return Byte size of data pointer
*/
size_t size() const { return _total_size; }

~AMSMessage()
{
DBG(AMSMessage,
"Destroying message %d: %p (underlying memory NOT freed)",
_id,
_data)
}
}; // class AMSMessage

/**
* @brief Class responsible to keep track of which AMSMessage has been not correctly
* acknowledged. If a given message has not been acked then it is stored in
* an internal hashmap to be re-send later.
*/
class AMSMessageRecords
{
using record_t = std::pair<std::shared_ptr<uint8_t>, size_t>;
using iterator_t = std::unordered_map<int, record_t>::iterator;

private:
/** @brief Internal data structure that keeps messages nack */
std::unordered_map<int, record_t> _msgs;
/** @brief Shared mutex to ensure thread-safe access */
std::shared_mutex _mutex;

AMSMessageRecords() = default;

public:
AMSMessageRecords(AMSMessageRecords&) = delete;
AMSMessageRecords& operator=(AMSMessageRecords&) = delete;

AMSMessageRecords(AMSMessageRecords&&) = delete;
AMSMessageRecords& operator=(AMSMessageRecords&&) = delete;

/**
* @brief Return an iterator at the beggining of the records
*/
iterator_t begin() { return std::begin(_msgs); }

/**
* @brief Return an iterator pointing at the end of the records
*/
iterator_t end() { return std::end(_msgs); }

/**
* @brief Insert a new record
* @param[in] id Message ID
* @param[out] value Record that will be inserted
*/
void insert(int id, const record_t& value);

/**
* @brief Print the hashmap for debugging
*/
void print();

/**
* @brief pubslishes all the messages in map
*/
void publishUnacknowledged(RMQPublisher& publisher);

/**
* @brief Return the number of records in the underlying structure
* @return Return the size of the structure.
*/
size_t size();

static AMSMessageRecords& getInstance();
};

/**
* @brief Structure that represents incoming RabbitMQ messages.
Expand Down Expand Up @@ -1015,7 +1081,7 @@ class AMSMessageInbound
class RMQHandler : public AMQP::LibEventHandler
{
protected:
/** @brief Path to TLS certificate (if empty, no TLS certificate)*/
/** @brief Path to TLS certificate (if empty, no TLS certificate) */
std::string _cacert;
/** @brief MPI rank (0 if no MPI support) */
uint64_t _rId;
Expand Down Expand Up @@ -1323,10 +1389,6 @@ class RMQPublisherHandler final : public RMQHandler
int _nb_msg;
/** @brief Number of messages successfully acknowledged */
int _nb_msg_ack;
/** @brief Mutex to protect multithread accesses to _messages */
std::mutex _mutex;
/** @brief Messages that have not been successfully acknowledged */
std::list<AMSMessage> _messages;

public:
/**
Expand All @@ -1346,18 +1408,9 @@ class RMQPublisherHandler final : public RMQHandler
* @brief Publish data on RMQ queue.
* @param[in] msg The AMSMessage to publish
*/
void publish(AMSMessage&& msg);

/**
* @brief Return the messages that have NOT been acknowledged by the RabbitMQ server.
* @return A vector of AMSMessage
*/
std::list<AMSMessage>& msgBuffer();

/**
* @brief Free AMSMessages held by the handler
*/
void cleanup();
// void publish(AMSMessage&& msg);
void publish(int message_id,
const std::pair<std::shared_ptr<uint8_t>, size_t>&);

/**
* @brief Total number of messages sent
Expand Down Expand Up @@ -1390,20 +1443,6 @@ class RMQPublisherHandler final : public RMQHandler
* @param[in] connection The connection that can now be used
*/
virtual void onReady(AMQP::TcpConnection* connection) override;

/**
* @brief Free the data pointed pointer in a vector and update vector.
* @param[in] addr Address of memory to free.
* @param[in] buffer The vector containing memory buffers
*/
void freeMessage(int msg_id, std::list<AMSMessage>& buffer);

/**
* @brief Free the data pointed by each pointer in a vector.
* @param[in] buffer The vector containing memory buffers
*/
void freeAllMessages(std::list<AMSMessage>& buffer);

}; // class RMQPublisherHandler


Expand All @@ -1428,19 +1467,15 @@ class RMQPublisher
std::shared_ptr<struct event_base> _loop;
/** @brief The handler which contains various callbacks for the sender */
std::shared_ptr<RMQPublisherHandler> _handler;
/** @brief Buffer holding unacknowledged messages in case of crash */
std::list<AMSMessage> _buffer_msg;

public:
RMQPublisher(const RMQPublisher&) = delete;
RMQPublisher& operator=(const RMQPublisher&) = delete;

RMQPublisher(
uint64_t rId,
const AMQP::Address& address,
std::string cacert,
std::string queue,
std::list<AMSMessage>&& msgs_to_send = std::list<AMSMessage>());
RMQPublisher(uint64_t rId,
const AMQP::Address& address,
std::string cacert,
std::string queue);

/**
* @brief Check if the underlying RabbitMQ connection is ready and usable
Expand Down Expand Up @@ -1477,20 +1512,12 @@ class RMQPublisher
bool connectionValid();

/**
* @brief Return the messages that have not been acknowledged.
* It does not mean they have not been delivered but the
* acknowledgements have not arrived yet.
* @return A vector of AMSMessage
*/
std::list<AMSMessage>& getMsgBuffer();

/**
* @brief Total number of messages successfully acknowledged
* @return Number of messages
* @brief Publish a message on attached RMQ connection
* @param[in] id The ID of the message
* @param[in] record A pair of the memory content (ptr) and its size in byte
*/
void cleanup();

void publish(AMSMessage&& message);
void publish(int id,
const std::pair<std::shared_ptr<uint8_t>, size_t>& record);

/**
* @brief Total number of messages sent
Expand Down Expand Up @@ -1620,15 +1647,15 @@ class RMQInterface
* @return True, True if connection succeeded for both publisher/consumer
*/
std::pair<bool, bool> connect(std::string rmq_password,
std::string rmq_user,
std::string rmq_vhost,
int service_port,
std::string service_host,
std::string rmq_cert,
std::string outbound_queue,
std::string exchange,
std::string routing_key,
bool update_surrogate);
std::string rmq_user,
std::string rmq_vhost,
int service_port,
std::string service_host,
std::string rmq_cert,
std::string outbound_queue,
std::string exchange,
std::string routing_key,
bool update_surrogate);

/**
* @brief Check if the RabbitMQ connection is connected.
Expand Down Expand Up @@ -1687,7 +1714,13 @@ class RMQInterface
AMSMessage msg(_msg_tag, _rId, domain_name, num_elements, inputs, outputs);

if (!_publisher->connectionValid()) restartPublisher();
_publisher->publish(std::move(msg));

std::shared_ptr<uint8_t> ptr(msg.data());
auto record = std::make_pair(std::move(ptr), msg.size());

// if we have some messages to send first (from a potential restart)
AMSMessageRecords::getInstance().publishUnacknowledged(*_publisher);
_publisher->publish(msg.id(), record);
_msg_tag++;
CALIPER(CALI_MARK_END("STORE_RMQ");)
}
Expand Down
Loading