Skip to content

Commit 4629281

Browse files
Support encryption context on Message (#276)
* Support encryption context on Message * revert unnecessary change * fix tests * fix docs * fix tests * improve tests
1 parent 813e295 commit 4629281

File tree

3 files changed

+188
-8
lines changed

3 files changed

+188
-8
lines changed

pulsar/__init__.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,122 @@ def wrap(cls, msg_id: _pulsar.MessageId):
166166
self._msg_id = msg_id
167167
return self
168168

169+
170+
class EncryptionKey:
171+
"""
172+
The key used for encryption.
173+
"""
174+
175+
def __init__(self, key: _pulsar.EncryptionKey):
176+
"""
177+
Create EncryptionKey instance.
178+
179+
Parameters
180+
----------
181+
key: _pulsar.EncryptionKey
182+
The underlying EncryptionKey instance from the C extension.
183+
"""
184+
self._key = key
185+
186+
@property
187+
def key(self) -> str:
188+
"""
189+
Returns the key, which is usually the key file's name.
190+
"""
191+
return self._key.key
192+
193+
@property
194+
def value(self) -> bytes:
195+
"""
196+
Returns the value, which is usually the key bytes used for encryption.
197+
"""
198+
return self._key.value()
199+
200+
@property
201+
def metadata(self) -> dict:
202+
"""
203+
Returns the metadata associated with the key.
204+
"""
205+
return self._key.metadata
206+
207+
def __str__(self) -> str:
208+
return f"EncryptionKey(key={self.key}, value_len={len(self.value)}, metadata={self.metadata})"
209+
210+
def __repr__(self) -> str:
211+
return self.__str__()
212+
213+
214+
class EncryptionContext:
215+
"""
216+
It contains encryption and compression information in it using which application can decrypt
217+
consumed message with encrypted-payload.
218+
"""
219+
220+
def __init__(self, context: _pulsar.EncryptionContext):
221+
"""
222+
Create EncryptionContext instance.
223+
224+
Parameters
225+
----------
226+
context: _pulsar.EncryptionContext
227+
The underlying EncryptionContext instance from the C extension.
228+
"""
229+
self._context = context
230+
231+
def keys(self) -> List[EncryptionKey]:
232+
"""
233+
Returns all EncryptionKey instances when performing encryption.
234+
"""
235+
keys = self._context.keys()
236+
return [EncryptionKey(key) for key in keys]
237+
238+
def param(self) -> bytes:
239+
"""
240+
Returns the encryption param bytes.
241+
"""
242+
return self._context.param()
243+
244+
def algorithm(self) -> str:
245+
"""
246+
Returns the encryption algorithm.
247+
"""
248+
return self._context.algorithm()
249+
250+
def compression_type(self) -> CompressionType:
251+
"""
252+
Returns the compression type of the message.
253+
"""
254+
return self._context.compression_type()
255+
256+
def uncompressed_message_size(self) -> int:
257+
"""
258+
Returns the uncompressed message size or 0 if the compression type is NONE.
259+
"""
260+
return self._context.uncompressed_message_size()
261+
262+
def batch_size(self) -> int:
263+
"""
264+
Returns the number of messages in the batch or -1 if the message is not batched.
265+
"""
266+
return self._context.batch_size()
267+
268+
def is_decryption_failed(self) -> bool:
269+
"""
270+
Returns whether decryption has failed for this message.
271+
"""
272+
return self._context.is_decryption_failed()
273+
274+
def __str__(self) -> str:
275+
return f"EncryptionContext(algorithm={self.algorithm()}, " \
276+
f"compression_type={self.compression_type().name}, " \
277+
f"uncompressed_message_size={self.uncompressed_message_size()}, " \
278+
f"is_decryption_failed={self.is_decryption_failed()}, " \
279+
f"keys=[{', '.join(str(key) for key in self.keys())}])"
280+
281+
def __repr__(self) -> str:
282+
return self.__str__()
283+
284+
169285
class Message:
170286
"""
171287
Message objects are returned by a consumer, either by calling `receive` or
@@ -250,6 +366,15 @@ def producer_name(self) -> str:
250366
"""
251367
return self._message.producer_name()
252368

369+
def encryption_context(self) -> EncryptionContext | None:
370+
"""
371+
Get the encryption context for this message or None if it's not encrypted.
372+
373+
It should be noted that the result should not be accessed after the current Message instance is deleted.
374+
"""
375+
context = self._message.encryption_context()
376+
return None if context is None else EncryptionContext(context)
377+
253378
@staticmethod
254379
def _wrap(_message):
255380
self = Message()

src/message.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,20 @@ void export_message(py::module_& m) {
8686
})
8787
.def_static("deserialize", &MessageId::deserialize);
8888

89+
class_<EncryptionKey>(m, "EncryptionKey")
90+
.def_readonly("key", &EncryptionKey::key)
91+
.def("value", [](const EncryptionKey& key) { return bytes(key.value); })
92+
.def_readonly("metadata", &EncryptionKey::metadata);
93+
94+
class_<EncryptionContext>(m, "EncryptionContext")
95+
.def("keys", &EncryptionContext::keys)
96+
.def("param", [](const EncryptionContext& context) { return bytes(context.param()); })
97+
.def("algorithm", &EncryptionContext::algorithm, return_value_policy::copy)
98+
.def("compression_type", &EncryptionContext::compressionType)
99+
.def("uncompressed_message_size", &EncryptionContext::uncompressedMessageSize)
100+
.def("batch_size", &EncryptionContext::batchSize)
101+
.def("is_decryption_failed", &EncryptionContext::isDecryptionFailed);
102+
89103
class_<Message>(m, "Message")
90104
.def(init<>())
91105
.def("properties", &Message::getProperties)
@@ -106,7 +120,8 @@ void export_message(py::module_& m) {
106120
.def("redelivery_count", &Message::getRedeliveryCount)
107121
.def("int_schema_version", &Message::getLongSchemaVersion)
108122
.def("schema_version", &Message::getSchemaVersion, return_value_policy::copy)
109-
.def("producer_name", &Message::getProducerName, return_value_policy::copy);
123+
.def("producer_name", &Message::getProducerName, return_value_policy::copy)
124+
.def("encryption_context", &Message::getEncryptionContext, return_value_policy::reference);
110125

111126
MessageBatch& (MessageBatch::*MessageBatchParseFromString)(const std::string& payload,
112127
uint32_t batchSize) = &MessageBatch::parseFrom;

tests/pulsar_test.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def test_producer_send(self):
167167
consumer.acknowledge(msg)
168168
print("receive from {}".format(msg.message_id()))
169169
self.assertEqual(msg_id, msg.message_id())
170+
self.assertIsNone(msg.encryption_context())
170171
client.close()
171172

172173
def test_producer_access_mode_exclusive(self):
@@ -489,15 +490,36 @@ def test_encryption_failure(self):
489490
client = Client(self.serviceUrl)
490491
topic = "my-python-test-end-to-end-encryption-failure-" + str(time.time())
491492
producer = client.create_producer(
492-
topic=topic, encryption_key="client-rsa.pem", crypto_key_reader=crypto_key_reader
493+
topic=topic, encryption_key="client-rsa.pem", crypto_key_reader=crypto_key_reader,
494+
compression_type=CompressionType.LZ4
493495
)
494496
producer.send(b"msg-0")
495497

498+
def verify_encryption_context(context: pulsar.EncryptionContext | None, failed: bool, batch_size: int):
499+
if context is None:
500+
self.fail("Encryption context is None")
501+
keys = context.keys()
502+
self.assertEqual(len(keys), 1)
503+
key = keys[0]
504+
self.assertEqual(key.key, "client-rsa.pem")
505+
self.assertGreater(len(key.value), 0)
506+
self.assertEqual(key.metadata, {})
507+
self.assertGreater(len(context.param()), 0)
508+
self.assertEqual(context.algorithm(), "")
509+
self.assertEqual(context.compression_type(), CompressionType.LZ4)
510+
if batch_size == -1:
511+
self.assertEqual(context.uncompressed_message_size(), len(b"msg-0"))
512+
else:
513+
self.assertGreater(context.uncompressed_message_size(), len(b"msg-0"))
514+
self.assertEqual(context.batch_size(), batch_size)
515+
self.assertEqual(context.is_decryption_failed(), failed)
516+
496517
def verify_next_message(value: bytes):
497518
consumer = client.subscribe(topic, subscription,
498519
crypto_key_reader=crypto_key_reader)
499520
msg = consumer.receive(3000)
500521
self.assertEqual(msg.data(), value)
522+
verify_encryption_context(msg.encryption_context(), False, -1)
501523
consumer.acknowledge(msg)
502524
consumer.close()
503525

@@ -520,22 +542,40 @@ def verify_next_message(value: bytes):
520542

521543
producer.send(b"msg-2")
522544
verify_next_message(b"msg-2") # msg-1 is skipped since the crypto failure action is DISCARD
545+
producer.close()
546+
547+
# send batched messages
548+
producer = client.create_producer(
549+
topic=topic,
550+
encryption_key="client-rsa.pem",
551+
crypto_key_reader=crypto_key_reader,
552+
compression_type=CompressionType.LZ4,
553+
batching_enabled=True,
554+
)
555+
producer.send_async(b"msg-3", None)
556+
producer.send_async(b"msg-4", None)
557+
producer.flush()
558+
559+
def verify_undecrypted_message(msg: pulsar.Message, i: int):
560+
self.assertNotEqual(msg.data(), f"msg-{i}".encode())
561+
self.assertGreater(len(msg.data()), 5, f"msg.data() is {msg.data()}")
562+
verify_encryption_context(msg.encryption_context(), True, 2 if i >= 3 else -1)
523563

524564
# Encrypted messages will be consumed since the crypto failure action is CONSUME
565+
# Only 4 messages can be received because msg-3 and msg-4 are sent in batch and they are delivered
566+
# as a single message when decryption fails.
525567
consumer = client.subscribe(topic, 'another-sub',
526568
initial_position=InitialPosition.Earliest,
527569
crypto_failure_action=pulsar.ConsumerCryptoFailureAction.CONSUME)
528-
for i in range(3):
570+
for i in range(4):
529571
msg = consumer.receive(3000)
530-
self.assertNotEqual(msg.data(), f"msg-{i}".encode())
531-
self.assertTrue(len(msg.data()) > 5, f"msg.data() is {msg.data()}")
572+
verify_undecrypted_message(msg, i)
532573

533574
reader = client.create_reader(topic, MessageId.earliest,
534575
crypto_failure_action=pulsar.ConsumerCryptoFailureAction.CONSUME)
535-
for i in range(3):
576+
for i in range(4):
536577
msg = reader.read_next(3000)
537-
self.assertNotEqual(msg.data(), f"msg-{i}".encode())
538-
self.assertTrue(len(msg.data()) > 5, f"msg.data() is {msg.data()}")
578+
verify_undecrypted_message(msg, i)
539579

540580
client.close()
541581

0 commit comments

Comments
 (0)