Skip to content

Commit e53f8c6

Browse files
committed
Fix encryption context not set when decryption succeeds
1 parent 860931c commit e53f8c6

File tree

8 files changed

+49
-20
lines changed

8 files changed

+49
-20
lines changed

lib/Commands.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,8 @@ uint64_t Commands::serializeSingleMessagesToBatchPayload(SharedBuffer& batchPayl
906906
}
907907

908908
Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32_t batchIndex,
909-
int32_t batchSize, const BatchMessageAckerPtr& acker) {
909+
int32_t batchSize, const BatchMessageAckerPtr& acker,
910+
const optional<EncryptionContext>& encryptionContext) {
910911
SharedBuffer& uncompressedPayload = batchedMessage.impl_->payload;
911912

912913
// Format of batch message
@@ -927,10 +928,9 @@ Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32
927928
auto messageId = MessageIdBuilder::from(m).batchIndex(batchIndex).batchSize(batchSize).build();
928929
auto batchedMessageId = std::make_shared<BatchedMessageIdImpl>(*(messageId.impl_), acker);
929930

930-
// TODO: fix the encryption context is not set
931931
auto msgImpl = std::make_shared<MessageImpl>(messageId, batchedMessage.impl_->brokerEntryMetadata,
932932
batchedMessage.impl_->metadata, payload, metadata,
933-
batchedMessage.impl_->topicName_, std::nullopt);
933+
batchedMessage.impl_->topicName_, encryptionContext);
934934
msgImpl->cnx_ = batchedMessage.impl_->cnx_;
935935

936936
return Message(msgImpl);

lib/Commands.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ class Commands {
155155
const std::vector<Message>& messages);
156156

157157
static Message deSerializeSingleMessageInBatch(Message& batchedMessage, int32_t batchIndex,
158-
int32_t batchSize, const BatchMessageAckerPtr& acker);
158+
int32_t batchSize, const BatchMessageAckerPtr& acker,
159+
const optional<EncryptionContext>& encryptionContext);
159160

160161
static MessageIdImplPtr getMessageIdImpl(const MessageId& messageId);
161162

lib/ConsumerImpl.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
620620
}
621621
BitSet ackSet{std::move(words)};
622622
Lock lock(mutex_);
623-
numOfMessageReceived = receiveIndividualMessagesFromBatch(cnx, m, ackSet, msg.redelivery_count());
623+
numOfMessageReceived =
624+
receiveIndividualMessagesFromBatch(cnx, m, ackSet, msg.redelivery_count(), encryptionContext);
624625
} else {
625626
// try convert key value data.
626627
m.impl_->convertPayloadToKeyValue(config_.getSchema());
@@ -745,9 +746,9 @@ void ConsumerImpl::notifyPendingReceivedCallback(Result result, Message& msg,
745746
}
746747

747748
// Zero Queue size is not supported with Batch Messages
748-
uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnectionPtr& cnx,
749-
Message& batchedMessage, const BitSet& ackSet,
750-
int redeliveryCount) {
749+
uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(
750+
const ClientConnectionPtr& cnx, Message& batchedMessage, const BitSet& ackSet, int redeliveryCount,
751+
const optional<EncryptionContext>& encryptionContext) {
751752
auto batchSize = batchedMessage.impl_->metadata.num_messages_in_batch();
752753
LOG_DEBUG("Received Batch messages of size - " << batchSize
753754
<< " -- msgId: " << batchedMessage.getMessageId());
@@ -759,7 +760,8 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection
759760
std::vector<Message> possibleToDeadLetter;
760761
for (int i = 0; i < batchSize; i++) {
761762
// This is a cheap copy since message contains only one shared pointer (impl_)
762-
Message msg = Commands::deSerializeSingleMessageInBatch(batchedMessage, i, batchSize, acker);
763+
Message msg =
764+
Commands::deSerializeSingleMessageInBatch(batchedMessage, i, batchSize, acker, encryptionContext);
763765
msg.impl_->setRedeliveryCount(redeliveryCount);
764766
msg.impl_->setTopicName(batchedMessage.impl_->topicName_);
765767
msg.impl_->convertPayloadToKeyValue(config_.getSchema());

lib/ConsumerImpl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ class ConsumerImpl : public ConsumerImplBase {
191191
void increaseAvailablePermits(const Message& msg);
192192
void drainIncomingMessageQueue(size_t count);
193193
uint32_t receiveIndividualMessagesFromBatch(const ClientConnectionPtr& cnx, Message& batchedMessage,
194-
const BitSet& ackSet, int redeliveryCount);
194+
const BitSet& ackSet, int redeliveryCount,
195+
const optional<EncryptionContext>& encryptionContext);
195196
bool isPriorBatchIndex(int32_t idx);
196197
bool isPriorEntryIndex(int64_t idx);
197198
void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const BrokerConsumerStatsCallback&);

lib/MessageBatch.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ MessageBatch& MessageBatch::parseFrom(const SharedBuffer& payload, uint32_t batc
4949

5050
auto acker = BatchMessageAckerImpl::create(batchSize);
5151
for (int i = 0; i < batchSize; ++i) {
52-
batch_.push_back(Commands::deSerializeSingleMessageInBatch(batchMessage_, i, batchSize, acker));
52+
batch_.push_back(
53+
Commands::deSerializeSingleMessageInBatch(batchMessage_, i, batchSize, acker, std::nullopt));
5354
}
5455
return *this;
5556
}

lib/MessageImpl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ MessageImpl::MessageImpl(const MessageId& messageId, const proto::BrokerEntryMet
2828
const proto::MessageMetadata& metadata, const SharedBuffer& payload,
2929
const optional<proto::SingleMessageMetadata>& singleMetadata,
3030
const std::shared_ptr<std::string>& topicName,
31-
optional<EncryptionContext>&& encryptionContext)
31+
optional<EncryptionContext> encryptionContext)
3232
: messageId(messageId),
3333
brokerEntryMetadata(brokerEntryMetadata),
3434
metadata(metadata),

lib/MessageImpl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ class MessageImpl {
4444
MessageImpl(const MessageId& messageId, const proto::BrokerEntryMetadata& brokerEntryMetadata,
4545
const proto::MessageMetadata& metadata, const SharedBuffer& payload,
4646
const optional<proto::SingleMessageMetadata>& singleMetadata,
47-
const std::shared_ptr<std::string>& topicName,
48-
optional<EncryptionContext>&& encryptionContext);
47+
const std::shared_ptr<std::string>& topicName, optional<EncryptionContext> encryptionContext);
4948

5049
const Message::StringMap& properties();
5150

tests/EncryptionTests.cc

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,8 @@ static std::vector<std::string> decryptValue(const Message& message) {
8080
return values;
8181
}
8282

83-
TEST(EncryptionTests, testEncryptionContext) {
84-
Client client{lookupUrl};
85-
std::string topic = "test-encryption-context-" + std::to_string(time(nullptr));
86-
83+
static void testDecryption(Client& client, const std::string& topic, bool decryptionSucceed,
84+
int numMessageReceived) {
8785
ProducerConfiguration producerConf;
8886
producerConf.setCompressionType(CompressionLZ4);
8987
producerConf.addEncryptionKey("client-rsa.pem");
@@ -104,22 +102,49 @@ TEST(EncryptionTests, testEncryptionContext) {
104102
}
105103
producer.flush();
106104
send("last-msg");
105+
producer.flush();
106+
107+
ASSERT_EQ(ResultOk, client.createProducer(topic, producer));
108+
send("unencrypted-msg");
109+
producer.flush();
110+
producer.close();
107111

108112
ConsumerConfiguration consumerConf;
109113
consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest);
110-
consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME);
114+
if (decryptionSucceed) {
115+
consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader());
116+
} else {
117+
consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME);
118+
}
111119
Consumer consumer;
112120
ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer));
113121

114122
std::vector<std::string> values;
115-
for (int i = 0; i < 2; i++) {
123+
for (int i = 0; i < numMessageReceived; i++) {
116124
Message msg;
117125
ASSERT_EQ(ResultOk, consumer.receive(msg, 3000));
126+
if (i < numMessageReceived - 1) {
127+
ASSERT_TRUE(msg.getEncryptionContext().has_value());
128+
}
118129
for (auto&& value : decryptValue(msg)) {
119130
values.emplace_back(value);
120131
}
121132
}
122133
ASSERT_EQ(values, sentValues);
134+
consumer.close();
135+
}
123136

137+
TEST(EncryptionTests, testDecryptionSuccess) {
138+
Client client{lookupUrl};
139+
std::string topic = "test-decryption-success-" + std::to_string(time(nullptr));
140+
testDecryption(client, topic, true, 7);
141+
client.close();
142+
}
143+
144+
TEST(EncryptionTests, testDecryptionFailure) {
145+
Client client{lookupUrl};
146+
std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr));
147+
// The 1st batch that has 5 messages cannot be decrypted, so they can be received only once
148+
testDecryption(client, topic, false, 3);
124149
client.close();
125150
}

0 commit comments

Comments
 (0)