Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Pulsar.Client/Common/DTO.fs
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ type EncryptionContext =
CompressionType: CompressionType
UncompressedMessageSize: int
BatchSize: Nullable<int>
// Indicates whether the message payload remains encrypted (true) or has been successfully decrypted (false)
IsEncrypted: bool
}
with
static member internal FromMetadata(metadata: Metadata) =
static member internal FromMetadata(metadata: Metadata, isEncrypted: bool) =
if metadata.EncryptionKeys.Length > 0 then
{
Keys = metadata.EncryptionKeys
Expand All @@ -254,6 +256,7 @@ type EncryptionContext =
CompressionType = metadata.CompressionType
UncompressedMessageSize = metadata.UncompressedMessageSize
BatchSize = if metadata.HasNumMessagesInBatch then Nullable(metadata.NumMessages) else Nullable()
IsEncrypted = isEncrypted
} |> Some
else
None
Expand Down
18 changes: 9 additions & 9 deletions src/Pulsar.Client/Internal/ConsumerImpl.fs
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ type internal ConsumerImpl<'T> (consumerConfig: ConsumerConfiguration<'T>, clien
trackMessage rawMessage.MessageId
None

let handleSingleMessagePayload (rawMessage: RawMessage) msgId payload hasWaitingChannel hasWaitingBatchChannel schemaDecodeFunction =
let handleSingleMessagePayload (rawMessage: RawMessage) msgId payload hasWaitingChannel hasWaitingBatchChannel isMessageUndecryptable schemaDecodeFunction =
if duringSeek.IsSome || (isSameEntry(rawMessage.MessageId) && isPriorEntryIndex(rawMessage.MessageId.EntryId)) then
// We need to discard entries that were prior to startMessageId
Log.Logger.LogInformation("{0} Ignoring message from before the startMessageId: {1}", prefix, startMessageId)
Expand All @@ -683,7 +683,7 @@ type internal ConsumerImpl<'T> (consumerConfig: ConsumerConfiguration<'T>, clien
%msgKey,
rawMessage.IsKeyBase64Encoded,
rawMessage.Properties,
EncryptionContext.FromMetadata rawMessage.Metadata,
EncryptionContext.FromMetadata(rawMessage.Metadata, isEncrypted = isMessageUndecryptable),
getSchemaVersionBytes rawMessage.Metadata.SchemaVersion,
rawMessage.Metadata.SequenceId,
rawMessage.Metadata.OrderingKey,
Expand Down Expand Up @@ -721,17 +721,17 @@ type internal ConsumerImpl<'T> (consumerConfig: ConsumerConfiguration<'T>, clien
if isChunkedMessage then
match processMessageChunk rawMessage msgId with
| Some (chunkedPayload, msgIdWithChunk) ->
handleSingleMessagePayload rawMessage msgIdWithChunk chunkedPayload hasWaitingChannel hasWaitingBatchChannel schemaDecodeFunction
handleSingleMessagePayload rawMessage msgIdWithChunk chunkedPayload hasWaitingChannel hasWaitingBatchChannel isMessageUndecryptable schemaDecodeFunction
| None ->
rawMessage.Payload.Dispose()
else
let bytes = rawMessage.Payload.ToArray()
rawMessage.Payload.Dispose()
handleSingleMessagePayload rawMessage msgId bytes hasWaitingChannel hasWaitingBatchChannel schemaDecodeFunction
handleSingleMessagePayload rawMessage msgId bytes hasWaitingChannel hasWaitingBatchChannel isMessageUndecryptable schemaDecodeFunction
elif rawMessage.Metadata.NumMessages > 0 then
// handle batch message enqueuing; uncompressed payload has all messages in batch
match wrapException (fun () ->
this.ReceiveIndividualMessagesFromBatch rawMessage schemaDecodeFunction) with
this.ReceiveIndividualMessagesFromBatch rawMessage schemaDecodeFunction isMessageUndecryptable) with
| Ok () ->
// try respond to channel
if hasWaitingChannel && incomingMessages.Count > 0 then
Expand Down Expand Up @@ -1324,8 +1324,8 @@ type internal ConsumerImpl<'T> (consumerConfig: ConsumerConfiguration<'T>, clien
do startStatTimer()
do startChunkTimer()

abstract member ReceiveIndividualMessagesFromBatch: RawMessage -> (byte [] -> 'T) -> unit
default this.ReceiveIndividualMessagesFromBatch (rawMessage: RawMessage) schemaDecodeFunction =
abstract member ReceiveIndividualMessagesFromBatch: RawMessage -> (byte [] -> 'T) -> bool -> unit
default this.ReceiveIndividualMessagesFromBatch (rawMessage: RawMessage) schemaDecodeFunction isMessageUndecryptable =
let batchSize = rawMessage.Metadata.NumMessages
let acker = BatchMessageAcker(batchSize)
let mutable skippedMessages = 0
Expand Down Expand Up @@ -1379,7 +1379,7 @@ type internal ConsumerImpl<'T> (consumerConfig: ConsumerConfiguration<'T>, clien
%msgKey,
singleMessageMetadata.PartitionKeyB64Encoded,
properties,
EncryptionContext.FromMetadata rawMessage.Metadata,
EncryptionContext.FromMetadata(rawMessage.Metadata, isEncrypted = isMessageUndecryptable),
getSchemaVersionBytes rawMessage.Metadata.SchemaVersion,
%(int64 singleMessageMetadata.SequenceId),
singleMessageMetadata.OrderingKey,
Expand Down Expand Up @@ -1719,7 +1719,7 @@ and internal ZeroQueueConsumerImpl<'T> (consumerConfig: ConsumerConfiguration<'T
if this.Waiters.Count > 0 then
this.SendFlowPermits this.Waiters.Count

override this.ReceiveIndividualMessagesFromBatch (_: RawMessage) _ =
override this.ReceiveIndividualMessagesFromBatch (_: RawMessage) _ _ =
Log.Logger.LogError("{0} Closing consumer due to unsupported received batch-message with zero receiver queue size", prefix)
let _ = postAndAsyncReply this.Mb ConsumerMessage.Close
let exn =
Expand Down
34 changes: 30 additions & 4 deletions tests/IntegrationTests/MessageCrypto.fs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,32 @@ type Consumer2KeyReader() =
{ Key = Encoding.UTF8.GetBytes(privateKeysConsumer2.Item keyName); Metadata = null }


let consumeMessagesAndCheckEncryption (consumer: IConsumer<byte[]>) number consumerName =
task {
for i in 1..number do
let! message = consumer.ReceiveAsync()
let received = Encoding.UTF8.GetString(message.Data)
Log.Debug("{0} received {1}", consumerName, received)

// Check EncryptionContext
match message.EncryptionContext with
| Some context ->
Log.Debug("{0} message {1} EncryptionContext.IsEncrypted = {2}", consumerName, i, context.IsEncrypted)
if context.IsEncrypted then
failwith $"Message {i} should not be encrypted, but IsEncrypted = true"
| None ->
failwith $"Message {i} should have EncryptionContext but it is None"

do! consumer.AcknowledgeAsync(message.MessageId)
Log.Debug("{0} acknowledged {1}", consumerName, received)
let expected = "Message #" + string i
if received.StartsWith(expected) |> not then
failwith $"Incorrect message expected {expected} received {received} consumer {consumerName}"

Log.Debug("{0} consumed {1} messages, all EncryptionContext.IsEncrypted = false", consumerName, number)
}


[<Tests>]
let tests =
testList "MessageCrypto" [
Expand Down Expand Up @@ -129,7 +155,7 @@ let tests =
let consumerTask =
Task.Run(fun () ->
task {
do! consumeMessages consumer numberOfMessages consumerName
do! consumeMessagesAndCheckEncryption consumer numberOfMessages consumerName
} :> Task)

do! Task.WhenAll(producerTask, consumerTask)
Expand Down Expand Up @@ -169,7 +195,7 @@ let tests =
let consumer1Task =
Task.Run(fun () ->
task {
do! consumeMessages consumer1 numberOfMessages consumerName
do! consumeMessagesAndCheckEncryption consumer1 numberOfMessages consumerName
} :> Task)

do! Task.WhenAll(producerTask, consumer1Task)
Expand All @@ -196,7 +222,7 @@ let tests =
let consumer2Task =
Task.Run(fun () ->
task {
do! consumeMessages consumer2 numberOfMessages consumerName
do! consumeMessagesAndCheckEncryption consumer2 numberOfMessages consumerName
} :> Task)

do! Task.WhenAll(producerTask2, consumer2Task)
Expand Down Expand Up @@ -242,7 +268,7 @@ let tests =
Expect.isNonEmpty context.Param ""
Expect.isNonEmpty context.Keys ""
Expect.equal context.CompressionType compressionType ""

Expect.equal context.IsEncrypted true "Message should remain encrypted when decryption fails"
do! Task.Delay 100
Log.Debug("Ended Encryption send message and consume on fail")
}
Expand Down
Loading