Skip to content

Commit 860931c

Browse files
committed
Change the data structure
1 parent a5bbff5 commit 860931c

File tree

10 files changed

+79
-64
lines changed

10 files changed

+79
-64
lines changed

include/pulsar/EncryptionContext.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <cstdint>
2222
#include <string>
2323
#include <unordered_map>
24+
#include <vector>
2425

2526
#include "CompressionType.h"
2627
#include "defines.h"
@@ -34,8 +35,16 @@ class MessageMetadata;
3435
class Message;
3536

3637
struct PULSAR_PUBLIC EncryptionKey {
38+
std::string key;
3739
std::string value;
3840
std::unordered_map<std::string, std::string> metadata;
41+
42+
explicit EncryptionKey() = default;
43+
44+
// Support in-place construction
45+
EncryptionKey(const std::string& key, const std::string& value,
46+
const decltype(EncryptionKey::metadata)& metadata)
47+
: key(key), value(value), metadata(metadata) {}
3948
};
4049

4150
/**
@@ -50,9 +59,10 @@ class PULSAR_PUBLIC EncryptionContext {
5059
batchSize_(-1),
5160
isDecryptionFailed_(false) {}
5261
EncryptionContext(const EncryptionContext&) = default;
62+
EncryptionContext(EncryptionContext&&) noexcept = default;
5363
EncryptionContext(const proto::MessageMetadata& metadata, bool isDecryptionFailed);
5464

55-
using KeysType = std::unordered_map<std::string, EncryptionKey>;
65+
using KeysType = std::vector<EncryptionKey>;
5666

5767
/**
5868
* @return the map of encryption keys used for the message
@@ -95,15 +105,15 @@ class PULSAR_PUBLIC EncryptionContext {
95105
bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; }
96106

97107
private:
98-
const KeysType keys_;
99-
const std::string param_;
100-
const std::string algorithm_;
101-
const CompressionType compressionType_;
102-
const uint32_t uncompressedMessageSize_;
103-
const int32_t batchSize_;
104-
const bool isDecryptionFailed_;
105-
106-
friend class MessageImpl;
108+
KeysType keys_;
109+
std::string param_;
110+
std::string algorithm_;
111+
CompressionType compressionType_;
112+
uint32_t uncompressedMessageSize_;
113+
int32_t batchSize_;
114+
bool isDecryptionFailed_;
115+
116+
friend class ConsumerImpl;
107117
};
108118

109119
} // namespace pulsar

lib/Commands.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,9 +927,10 @@ Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32
927927
auto messageId = MessageIdBuilder::from(m).batchIndex(batchIndex).batchSize(batchSize).build();
928928
auto batchedMessageId = std::make_shared<BatchedMessageIdImpl>(*(messageId.impl_), acker);
929929

930+
// TODO: fix the encryption context is not set
930931
auto msgImpl = std::make_shared<MessageImpl>(messageId, batchedMessage.impl_->brokerEntryMetadata,
931932
batchedMessage.impl_->metadata, payload, metadata,
932-
batchedMessage.impl_->topicName_, false);
933+
batchedMessage.impl_->topicName_, std::nullopt);
933934
msgImpl->cnx_ = batchedMessage.impl_->cnx_;
934935

935936
return Message(msgImpl);

lib/ConsumerImpl.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,16 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
554554
return;
555555
}
556556

557-
auto decryptResult = decryptMessageIfNeeded(cnx, msg, metadata, payload);
557+
auto encryptionContext = metadata.encryption_keys_size() > 0
558+
? optional<EncryptionContext>{std::in_place, metadata, false}
559+
: std::nullopt;
560+
561+
auto decryptResult = decryptMessageIfNeeded(cnx, encryptionContext, payload, msg.message_id());
558562
if (decryptResult == FAILED) {
559563
// Message was discarded due to decryption failure or not consumed due to decryption failure
560564
return;
565+
} else if (decryptResult == CONSUME_ENCRYPTED) {
566+
encryptionContext->isDecryptionFailed_ = true;
561567
}
562568

563569
auto redeliveryCount = msg.redelivery_count();
@@ -582,10 +588,8 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
582588
}
583589
}
584590

585-
auto msgImpl =
586-
std::make_shared<MessageImpl>(messageId, brokerEntryMetadata, metadata, payload, std::nullopt,
587-
getTopicPtr(), decryptResult == CONSUME_ENCRYPTED);
588-
Message m(msgImpl);
591+
Message m{std::make_shared<MessageImpl>(messageId, brokerEntryMetadata, metadata, payload, std::nullopt,
592+
getTopicPtr(), std::move(encryptionContext))};
589593
m.impl_->cnx_ = cnx.get();
590594
m.impl_->setRedeliveryCount(msg.redelivery_count());
591595

@@ -811,10 +815,10 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection
811815
return batchSize - skippedMessages;
812816
}
813817

814-
auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg,
815-
const proto::MessageMetadata& metadata, SharedBuffer& payload)
816-
-> DecryptResult {
817-
if (metadata.encryption_keys_size() == 0) {
818+
auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx,
819+
const optional<EncryptionContext>& context, SharedBuffer& payload,
820+
const proto::MessageIdData& msgId) -> DecryptResult {
821+
if (!context.has_value()) {
818822
return DECRYPTED;
819823
}
820824

@@ -826,18 +830,18 @@ auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const
826830
} else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) {
827831
LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader is not implemented and config "
828832
"is set to discard");
829-
discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError);
833+
discardCorruptedMessage(cnx, msgId, CommandAck_ValidationError_DecryptionError);
830834
} else {
831835
LOG_ERROR(getName() << "Message delivery failed since CryptoKeyReader is not implemented to "
832836
"consume encrypted message");
833-
auto messageId = MessageIdBuilder::from(msg.message_id()).build();
837+
auto messageId = MessageIdBuilder::from(msgId).build();
834838
unAckedMessageTrackerPtr_->add(messageId);
835839
}
836840
return FAILED;
837841
}
838842

839843
SharedBuffer decryptedPayload;
840-
if (msgCrypto_->decrypt(metadata, payload, config_.getCryptoKeyReader(), decryptedPayload)) {
844+
if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), decryptedPayload)) {
841845
payload = decryptedPayload;
842846
return DECRYPTED;
843847
}
@@ -849,10 +853,10 @@ auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const
849853
return CONSUME_ENCRYPTED;
850854
} else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) {
851855
LOG_WARN(getName() << "Discarding message since decryption failed and config is set to discard");
852-
discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError);
856+
discardCorruptedMessage(cnx, msgId, CommandAck_ValidationError_DecryptionError);
853857
} else {
854858
LOG_ERROR(getName() << "Message delivery failed since unable to decrypt incoming message");
855-
auto messageId = MessageIdBuilder::from(msg.message_id()).build();
859+
auto messageId = MessageIdBuilder::from(msgId).build();
856860
unAckedMessageTrackerPtr_->add(messageId);
857861
}
858862
return FAILED;

lib/ConsumerImpl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ using UnAckedMessageTrackerPtr = std::shared_ptr<UnAckedMessageTrackerInterface>
6464
namespace proto {
6565
class CommandMessage;
6666
class BrokerEntryMetadata;
67+
class MessageIdData;
6768
class MessageMetadata;
6869
} // namespace proto
6970

@@ -201,8 +202,8 @@ class ConsumerImpl : public ConsumerImplBase {
201202
CONSUME_ENCRYPTED,
202203
FAILED
203204
};
204-
DecryptResult decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg,
205-
const proto::MessageMetadata& metadata, SharedBuffer& payload);
205+
DecryptResult decryptMessageIfNeeded(const ClientConnectionPtr&, const optional<EncryptionContext>&,
206+
SharedBuffer& payload, const proto::MessageIdData&);
206207

207208
// TODO - Convert these functions to lambda when we move to C++11
208209
Result receiveHelper(Message& msg);

lib/EncryptionContext.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,19 @@
1818
*/
1919
#include <pulsar/EncryptionContext.h>
2020

21-
#include <unordered_map>
22-
2321
#include "PulsarApi.pb.h"
2422

2523
namespace pulsar {
2624

2725
static EncryptionContext::KeysType encryptedKeysFromMetadata(const proto::MessageMetadata& msgMetadata) {
2826
EncryptionContext::KeysType keys;
2927
for (auto&& key : msgMetadata.encryption_keys()) {
30-
decltype(EncryptionKey::metadata) keyMetadata;
31-
for (auto&& entry : key.metadata()) {
32-
keyMetadata[entry.key()] = entry.value();
28+
decltype(EncryptionKey::metadata) metadata;
29+
for (int i = 0; i < key.metadata_size(); i++) {
30+
const auto& entry = key.metadata(i);
31+
metadata[entry.key()] = entry.value();
3332
}
34-
keys[key.key()] = EncryptionKey{key.value(), keyMetadata};
33+
keys.emplace_back(key.key(), key.value(), std::move(metadata));
3534
}
3635
return keys;
3736
}

lib/MessageCrypto.cc

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,13 @@ bool MessageCrypto::encrypt(const std::set<std::string>& encKeys, const CryptoKe
394394
return true;
395395
}
396396

397-
bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader) {
398-
const auto& keyName = encKeys.key();
399-
const auto& encryptedDataKey = encKeys.value();
400-
const auto& encKeyMeta = encKeys.metadata();
397+
bool MessageCrypto::decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader) {
398+
const auto& keyName = encKeys.key;
399+
const auto& encryptedDataKey = encKeys.value;
400+
const auto& encKeyMeta = encKeys.metadata;
401401
StringMap keyMeta;
402402
for (auto iter = encKeyMeta.begin(); iter != encKeyMeta.end(); iter++) {
403-
keyMeta[iter->key()] = iter->value();
403+
keyMeta[iter->first] = iter->second;
404404
}
405405

406406
// Read the private key info using callback
@@ -451,11 +451,10 @@ bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const C
451451
return true;
452452
}
453453

454-
bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata,
454+
bool MessageCrypto::decryptData(const std::string& dataKeySecret, const EncryptionContext& context,
455455
SharedBuffer& payload, SharedBuffer& decryptedPayload) {
456456
// unpack iv and encrypted data
457-
msgMetadata.encryption_param().copy(reinterpret_cast<char*>(iv_.get()),
458-
msgMetadata.encryption_param().size());
457+
context.param().copy(reinterpret_cast<char*>(iv_.get()), context.param().size());
459458

460459
EVP_CIPHER_CTX* cipherCtx = NULL;
461460
decryptedPayload = SharedBuffer::allocate(payload.readableBytes() + EVP_MAX_BLOCK_LENGTH + tagLen_);
@@ -518,15 +517,14 @@ bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::M
518517
return true;
519518
}
520519

521-
bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload,
520+
bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload,
522521
SharedBuffer& decryptedPayload) {
523522
SharedBuffer decryptedData;
524523
bool dataDecrypted = false;
525524

526-
for (auto iter = msgMetadata.encryption_keys().begin(); iter != msgMetadata.encryption_keys().end();
527-
iter++) {
528-
const std::string& keyName = iter->key();
529-
const std::string& encDataKey = iter->value();
525+
for (auto&& kv : context.keys()) {
526+
const std::string& keyName = kv.key;
527+
const std::string& encDataKey = kv.value;
530528
unsigned char keyDigest[EVP_MAX_MD_SIZE];
531529
unsigned int digestLen = 0;
532530
getDigest(keyName, encDataKey.c_str(), encDataKey.size(), keyDigest, digestLen);
@@ -539,7 +537,7 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada
539537
// retruns a different key, decryption fails. At this point, we would
540538
// call decryptDataKey to refresh the cache and come here again to decrypt.
541539
auto dataKeyEntry = dataKeyCacheIter->second;
542-
if (decryptData(dataKeyEntry.first, msgMetadata, payload, decryptedPayload)) {
540+
if (decryptData(dataKeyEntry.first, context, payload, decryptedPayload)) {
543541
dataDecrypted = true;
544542
break;
545543
}
@@ -552,17 +550,16 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada
552550
return dataDecrypted;
553551
}
554552

555-
bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload,
553+
bool MessageCrypto::decrypt(const EncryptionContext& context, SharedBuffer& payload,
556554
const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload) {
557555
// Attempt to decrypt using the existing key
558-
if (getKeyAndDecryptData(msgMetadata, payload, decryptedPayload)) {
556+
if (getKeyAndDecryptData(context, payload, decryptedPayload)) {
559557
return true;
560558
}
561559

562560
// Either first time, or decryption failed. Attempt to regenerate data key
563561
bool isDataKeyDecrypted = false;
564-
for (int index = 0; index < msgMetadata.encryption_keys_size(); index++) {
565-
const proto::EncryptionKeys& encKeys = msgMetadata.encryption_keys(index);
562+
for (auto&& encKeys : context.keys()) {
566563
if (decryptDataKey(encKeys, *keyReader)) {
567564
isDataKeyDecrypted = true;
568565
break;
@@ -574,7 +571,7 @@ bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuf
574571
return false;
575572
}
576573

577-
return getKeyAndDecryptData(msgMetadata, payload, decryptedPayload);
574+
return getKeyAndDecryptData(context, payload, decryptedPayload);
578575
}
579576

580577
} /* namespace pulsar */

lib/MessageCrypto.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
#include <openssl/rsa.h>
2727
#include <openssl/ssl.h>
2828
#include <pulsar/CryptoKeyReader.h>
29+
#include <pulsar/EncryptionContext.h>
2930

3031
#include <boost/date_time/posix_time/ptime.hpp>
3132
#include <boost/scoped_array.hpp>
32-
#include <iostream>
3333
#include <map>
3434
#include <mutex>
3535
#include <set>
@@ -90,15 +90,15 @@ class MessageCrypto {
9090
/*
9191
* Decrypt the payload using the data key. Keys used to encrypt data key can be retrieved from msgMetadata
9292
*
93-
* @param msgMetadata Message Metadata
93+
* @param context the context of encryption
9494
* @param payload Message which needs to be decrypted
9595
* @param keyReader KeyReader implementation to retrieve key value
9696
* @param decryptedPayload Contains decrypted payload if success
9797
*
9898
* @return true if success
9999
*/
100-
bool decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload,
101-
const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload);
100+
bool decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader,
101+
SharedBuffer& decryptedPayload);
102102

103103
private:
104104
typedef std::unique_lock<std::mutex> Lock;
@@ -137,10 +137,10 @@ class MessageCrypto {
137137

138138
Result addPublicKeyCipher(const std::string& keyName, const CryptoKeyReaderPtr& keyReader);
139139

140-
bool decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader);
141-
bool decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata,
140+
bool decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader);
141+
bool decryptData(const std::string& dataKeySecret, const EncryptionContext& context,
142142
SharedBuffer& payload, SharedBuffer& decPayload);
143-
bool getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload,
143+
bool getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload,
144144
SharedBuffer& decryptedPayload);
145145
std::string stringToHex(const std::string& inputStr, size_t len);
146146
std::string stringToHex(const char* inputStr, size_t len);

lib/MessageImpl.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@ namespace pulsar {
2727
MessageImpl::MessageImpl(const MessageId& messageId, const proto::BrokerEntryMetadata& brokerEntryMetadata,
2828
const proto::MessageMetadata& metadata, const SharedBuffer& payload,
2929
const optional<proto::SingleMessageMetadata>& singleMetadata,
30-
const std::shared_ptr<std::string>& topicName, bool isDecryptionFailed)
30+
const std::shared_ptr<std::string>& topicName,
31+
optional<EncryptionContext>&& encryptionContext)
3132
: messageId(messageId),
3233
brokerEntryMetadata(brokerEntryMetadata),
3334
metadata(metadata),
3435
payload(payload),
3536
topicName_(topicName),
36-
encryptionContext_(std::in_place, metadata, isDecryptionFailed) {
37+
encryptionContext_(std::move(encryptionContext)) {
3738
if (singleMetadata.has_value()) {
3839
this->metadata.clear_properties();
3940
if (singleMetadata->properties_size() > 0) {

lib/MessageImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class MessageImpl {
4444
MessageImpl(const MessageId& messageId, const proto::BrokerEntryMetadata& brokerEntryMetadata,
4545
const proto::MessageMetadata& metadata, const SharedBuffer& payload,
4646
const optional<proto::SingleMessageMetadata>& singleMetadata,
47-
const std::shared_ptr<std::string>& topicName, bool undecryptedPayload);
47+
const std::shared_ptr<std::string>& topicName,
48+
optional<EncryptionContext>&& encryptionContext);
4849

4950
const Message::StringMap& properties();
5051

tests/EncryptionTests.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ static std::vector<std::string> decryptValue(const Message& message) {
4949
MessageCrypto crypto{"test", false};
5050
auto msgImpl = PulsarFriend::getMessageImplPtr(message);
5151
SharedBuffer decryptedPayload;
52-
// TODO: change the parameters to get context from EncryptionContext directly
53-
if (!crypto.decrypt(msgImpl->metadata, msgImpl->payload, getDefaultCryptoKeyReader(), decryptedPayload)) {
52+
auto originalPayload =
53+
SharedBuffer::copy(static_cast<const char*>(message.getData()), message.getLength());
54+
if (!crypto.decrypt(*context, originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) {
5455
throw std::runtime_error("Decryption failed");
5556
}
5657

0 commit comments

Comments
 (0)