diff --git a/dotnet/src/Azure.Iot.Operations.Mqtt/ExtendedPubSubMqttClient.cs b/dotnet/src/Azure.Iot.Operations.Mqtt/ExtendedPubSubMqttClient.cs new file mode 100644 index 0000000000..3f6f15100a --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Mqtt/ExtendedPubSubMqttClient.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Iot.Operations.Protocol; +using Azure.Iot.Operations.Protocol.Connection; +using Azure.Iot.Operations.Protocol.Models; +using IMqttClient = MQTTnet.IMqttClient; + +namespace Azure.Iot.Operations.Mqtt; + +public class ExtendedPubSubMqttClient(IMqttClient mqttNetClient, OrderedAckMqttClientOptions? clientOptions = null) + : OrderedAckMqttClient(mqttNetClient, clientOptions), IExtendedPubSubMqttClient +{ + private MqttClientConnectResult? _connectResult; + + public override async Task ConnectAsync(MqttClientOptions options, CancellationToken cancellationToken = default) + { + var connectResult = await base.ConnectAsync(options, cancellationToken); + _connectResult = connectResult; + return connectResult; + } + + public override async Task ConnectAsync(MqttConnectionSettings settings, CancellationToken cancellationToken = default) + { + var connectResult = await base.ConnectAsync(settings, cancellationToken); + _connectResult = connectResult; + return connectResult; + } + + public MqttClientConnectResult? GetConnectResult() + { + return _connectResult; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/AssemblyInfo.cs b/dotnet/src/Azure.Iot.Operations.Protocol/AssemblyInfo.cs new file mode 100644 index 0000000000..a3c69d3d57 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/AssemblyInfo.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.CompilerServices; + +// TODO: @maximsemenov80 it looks like assembly shoulb be signed for the integration tests to work +//[assembly: InternalsVisibleTo("Azure.Iot.Operations.Protocol.IntegrationTests")] + +namespace Azure.Iot.Operations.Protocol; + +public class AssemblyInfo { + // This class is intentionally left empty. + // It serves as a placeholder for assembly-level attributes and metadata. +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Azure.Iot.Operations.Protocol.csproj b/dotnet/src/Azure.Iot.Operations.Protocol/Azure.Iot.Operations.Protocol.csproj index a8e3587638..1c50d9a748 100644 --- a/dotnet/src/Azure.Iot.Operations.Protocol/Azure.Iot.Operations.Protocol.csproj +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Azure.Iot.Operations.Protocol.csproj @@ -17,6 +17,10 @@ + + + + $(MSBuildProjectDirectory)\..\..\MSSharedLibKey.snk diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChecksumCalculator.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChecksumCalculator.cs new file mode 100644 index 0000000000..e665645035 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChecksumCalculator.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Security.Cryptography; + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// Provides checksum calculation for message chunking. +/// +internal static class ChecksumCalculator +{ + /// + /// Calculates a checksum for the given data using the specified algorithm. + /// + /// The data to calculate a checksum for. + /// The algorithm to use for the checksum. + /// A string representation of the checksum. + public static string CalculateChecksum(ReadOnlySequence data, ChunkingChecksumAlgorithm algorithm) + { + ReadOnlySpan hash = CalculateHashBytes(data, algorithm); + return Convert.ToHexString(hash).ToLowerInvariant(); + } + + /// + /// Verifies that the calculated checksum matches the expected checksum. + /// + /// The data to calculate a checksum for. + /// The expected checksum value. + /// The algorithm to use for the checksum. + /// True if the checksums match, false otherwise. + public static bool VerifyChecksum(ReadOnlySequence data, string expectedChecksum, ChunkingChecksumAlgorithm algorithm) + { + string actualChecksum = CalculateChecksum(data, algorithm); + return string.Equals(actualChecksum, expectedChecksum, StringComparison.OrdinalIgnoreCase); + } + + private static byte[] CalculateHashBytes(ReadOnlySequence data, ChunkingChecksumAlgorithm algorithm) + { + using HashAlgorithm hashAlgorithm = CreateHashAlgorithm(algorithm); + + if (data.IsSingleSegment) + { + return hashAlgorithm.ComputeHash(data.FirstSpan.ToArray()); + } + + // Process multiple segments + hashAlgorithm.Initialize(); + + foreach (ReadOnlyMemory segment in data) + { + hashAlgorithm.TransformBlock(segment.Span.ToArray(), 0, segment.Length, null, 0); + } + + hashAlgorithm.TransformFinalBlock([], 0, 0); + return hashAlgorithm.Hash!; + } + + private static HashAlgorithm CreateHashAlgorithm(ChunkingChecksumAlgorithm algorithm) + { + return algorithm switch + { +#pragma warning disable CA5351 + ChunkingChecksumAlgorithm.MD5 => MD5.Create(), +#pragma warning restore CA5351 + ChunkingChecksumAlgorithm.SHA256 => SHA256.Create(), + _ => throw new ArgumentOutOfRangeException(nameof(algorithm), algorithm, null) + }; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkMetadata.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkMetadata.cs new file mode 100644 index 0000000000..accee1de71 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkMetadata.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Text.Json.Serialization; + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// Represents the metadata for a chunk of a larger MQTT message. +/// +internal class ChunkMetadata +{ + /// + /// Gets or sets the unique identifier for the chunked message. + /// + [JsonPropertyName(ChunkingConstants.MessageIdField)] + public string MessageId { get; set; } = null!; + + /// + /// Gets or sets the index of this chunk in the sequence. + /// + [JsonPropertyName(ChunkingConstants.ChunkIndexField)] + public int ChunkIndex { get; set; } + + /// + /// Gets or sets the total number of chunks in the message. + /// This property is only present in the first chunk. + /// + [JsonPropertyName(ChunkingConstants.TotalChunksField)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? TotalChunks { get; set; } + + /// + /// Gets or sets the checksum of the complete message. + /// This property is only present in the first chunk. + /// + [JsonPropertyName(ChunkingConstants.ChecksumField)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Checksum { get; set; } + + /// + /// Creates a new instance of the class for a first chunk. + /// + /// The unique message identifier. + /// The total number of chunks in the message. + /// The checksum of the complete message. + /// A new instance of configured for the first chunk. + public static ChunkMetadata CreateFirstChunk(string messageId, int totalChunks, string checksum) + { + return new ChunkMetadata + { + MessageId = messageId, + ChunkIndex = 0, + TotalChunks = totalChunks, + Checksum = checksum + }; + } /// + /// Creates a new instance of the class for subsequent chunks. + /// + /// The unique message identifier. + /// The index of this chunk in the sequence. + /// A new instance of configured for a subsequent chunk. + public static ChunkMetadata CreateSubsequentChunk(string messageId, int chunkIndex) + { + return new ChunkMetadata + { + MessageId = messageId, + ChunkIndex = chunkIndex + }; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkedMessageAssembler.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkedMessageAssembler.cs new file mode 100644 index 0000000000..875e3fefef --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkedMessageAssembler.cs @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Iot.Operations.Protocol.Events; +using Azure.Iot.Operations.Protocol.Models; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// Handles the reassembly of chunked MQTT messages. +/// +internal class ChunkedMessageAssembler +{ + private readonly Dictionary _chunks = new(); + private readonly DateTime _creationTime = DateTime.UtcNow; + private readonly object _lock = new(); + private int _totalChunks; + private string? _checksum; + private readonly ChunkingChecksumAlgorithm _checksumAlgorithm; + private TimeSpan? _timeout; + + /// + /// Gets the current buffer size in bytes of all stored chunks. + /// + public long CurrentBufferSize { get; private set; } + + /// + /// Initializes a new instance of the class. + /// + /// The total number of chunks expected (may be updated later). + /// The algorithm to use for checksum verification. + public ChunkedMessageAssembler(int totalChunks, ChunkingChecksumAlgorithm checksumAlgorithm) + { + _totalChunks = totalChunks; + _checksumAlgorithm = checksumAlgorithm; + } + + /// + /// Gets a value indicating whether all chunks have been received. + /// + public bool IsComplete => _totalChunks > 0 && _chunks.Count == _totalChunks; + + /// + /// Updates the metadata for this chunked message when the first chunk is received. + /// + /// The total number of chunks expected. + /// The checksum of the complete message. + /// The timeout duration extracted from MessageExpiryInterval. + public void UpdateMetadata(int totalChunks, string? checksum, TimeSpan? timeout) + { + lock (_lock) + { + _totalChunks = totalChunks; + _checksum = checksum; + _timeout = timeout; + } + } + + /// + /// Adds a chunk to the assembler. + /// + /// The index of the chunk. + /// The MQTT message received event args. + /// True if the chunk was added, false if it was already present. + public bool AddChunk(int chunkIndex, MqttApplicationMessageReceivedEventArgs args) + { + lock (_lock) + { + if (_chunks.ContainsKey(chunkIndex)) + { + return false; + } + + var chunkSize = args.ApplicationMessage.Payload.Length; + _chunks[chunkIndex] = args; + CurrentBufferSize += chunkSize; + return true; + } + } + + /// + /// Attempts to reassemble the complete message from all chunks. + /// + /// The reassembled message event args. + /// True if reassembly was successful, false otherwise. + public bool TryReassemble(out MqttApplicationMessageReceivedEventArgs? reassembledArgs) + { + reassembledArgs = null; + + lock (_lock) + { + if (!IsComplete) + { + return false; + } + + try + { + // Get the first chunk to use as a template for the reassembled message + var firstChunk = _chunks[0]; + var firstMessage = firstChunk.ApplicationMessage; + + // Calculate the total payload size + long totalSize = _chunks.Values.Sum(args => args.ApplicationMessage.Payload.Length); + + // Create a memory stream with the exact capacity we need + using var memoryStream = new MemoryStream((int)totalSize); + + // Write all chunks in order + for (int i = 0; i < _totalChunks; i++) + { + if (!_chunks.TryGetValue(i, out var chunkArgs)) + { + // This should never happen if IsComplete is true + return false; + } + + var payload = chunkArgs.ApplicationMessage.Payload; + foreach (ReadOnlyMemory memory in payload) + { + memoryStream.Write(memory.Span); + } + } + + // Convert to ReadOnlySequence for checksum verification + memoryStream.Position = 0; + ReadOnlySequence reassembledPayload = new ReadOnlySequence(memoryStream.ToArray()); + + // Verify the checksum if provided + if (!string.IsNullOrEmpty(_checksum)) + { + bool checksumValid = ChecksumCalculator.VerifyChecksum(reassembledPayload, _checksum, _checksumAlgorithm); + if (!checksumValid) + { + // Checksum verification failed + return false; + } + } + + // Create a reassembled message without the chunking metadata + var userProperties = firstMessage.UserProperties? + .Where(p => p.Name != ChunkingConstants.ChunkUserProperty) + .ToList(); + + var reassembledMessage = new MqttApplicationMessage(firstMessage.Topic, firstMessage.QualityOfServiceLevel) + { + Retain = firstMessage.Retain, + Payload = reassembledPayload, + ContentType = firstMessage.ContentType, + ResponseTopic = firstMessage.ResponseTopic, + CorrelationData = firstMessage.CorrelationData, + PayloadFormatIndicator = firstMessage.PayloadFormatIndicator, + MessageExpiryInterval = firstMessage.MessageExpiryInterval, + TopicAlias = firstMessage.TopicAlias, + SubscriptionIdentifiers = firstMessage.SubscriptionIdentifiers, + UserProperties = userProperties + }; + + // Create event args for the reassembled message + reassembledArgs = new MqttApplicationMessageReceivedEventArgs( + firstChunk.ClientId, + reassembledMessage, + 1, // TODO: Set the correct packet identifier + AcknowledgeHandler); + + return true; + } + catch (Exception) + { + // If reassembly fails for any reason, return false + return false; + } + } + } + + private async Task AcknowledgeHandler(MqttApplicationMessageReceivedEventArgs reassembledArgs, CancellationToken ct) + { + // When acknowledging the reassembled message, acknowledge all the chunks + var tasks = new List(_totalChunks); + for (int i = 0; i < _totalChunks; i++) + { + if (_chunks.TryGetValue(i, out var chunk)) + { + tasks.Add(chunk.AcknowledgeAsync(ct)); + } + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + } + + /// + /// Checks if this assembler has expired based on the creation time. + /// + /// The timeout duration. + /// True if the assembler has expired, false otherwise. + public bool HasExpired() + { + if (!_timeout.HasValue) + { + return false; // No timeout set, never expires + } + + return DateTime.UtcNow - _creationTime > _timeout.Value; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkedMessageSplitter.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkedMessageSplitter.cs new file mode 100644 index 0000000000..156472820c --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkedMessageSplitter.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Iot.Operations.Protocol.Models; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// Handles splitting large MQTT messages into smaller chunks. +/// +internal class ChunkedMessageSplitter +{ + private readonly ChunkingOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// The chunking options. + public ChunkedMessageSplitter(ChunkingOptions options) + { + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + /// + /// Splits a message into smaller chunks if necessary. + /// + /// The original message to split. + /// The maximum packet size allowed. + /// A list of chunked messages. +public IReadOnlyList SplitMessage(MqttApplicationMessage message, int maxPacketSize) + { + var maxChunkSize = ValidateAndGetMaxChunkSize(message, maxPacketSize); + var (payload, totalChunks, messageId, checksum, userProperties) = PrepareChunkingMetadata(message, maxChunkSize); + + // Create chunks + var chunks = new List(totalChunks); + + for (var chunkIndex = 0; chunkIndex < totalChunks; chunkIndex++) + { + var chunkPayload = ChunkedMessageSplitter.ExtractChunkPayload(payload, chunkIndex, maxChunkSize); + var chunkMessage = ChunkedMessageSplitter.CreateChunk(message, chunkPayload, userProperties, messageId, chunkIndex, totalChunks, checksum); + chunks.Add(chunkMessage); + } + + return chunks; + } + + private int ValidateAndGetMaxChunkSize(MqttApplicationMessage message, int maxPacketSize) + { + ArgumentNullException.ThrowIfNull(message); + ArgumentOutOfRangeException.ThrowIfLessThan(maxPacketSize, 128); // minimum MQTT 5.0 protocol compliance. + + // Calculate the maximum size for each chunk's payload + var maxChunkSize = Utils.GetMaxChunkSize(maxPacketSize, _options.StaticOverhead); + if (message.Payload.Length <= maxChunkSize) + { + throw new ArgumentException($"Message size {message.Payload.Length} is less than the maximum chunk size {maxChunkSize}.", nameof(message)); + } + + return maxChunkSize; + } + + private (ReadOnlySequence Payload, int TotalChunks, string MessageId, string Checksum, List UserProperties) + PrepareChunkingMetadata(MqttApplicationMessage message, int maxChunkSize) + { + var payload = message.Payload; + var totalChunks = (int)Math.Ceiling((double)payload.Length / maxChunkSize); + + // Generate a unique message ID + var messageId = Guid.NewGuid().ToString("D"); + + // Calculate checksum for the entire payload + var checksum = ChecksumCalculator.CalculateChecksum(payload, _options.ChecksumAlgorithm); + + // Create a copy of the user properties + var userProperties = new List(message.UserProperties ?? Enumerable.Empty()); + + return (payload, totalChunks, messageId, checksum, userProperties); + } + + private static ReadOnlySequence ExtractChunkPayload(ReadOnlySequence payload, int chunkIndex, int maxChunkSize) + { + var chunkStart = (long)chunkIndex * maxChunkSize; + var chunkLength = Math.Min(maxChunkSize, payload.Length - chunkStart); + return payload.Slice(chunkStart, chunkLength); + } + + private static MqttApplicationMessage CreateChunk( + MqttApplicationMessage originalMessage, + ReadOnlySequence chunkPayload, + List userProperties, + string messageId, + int chunkIndex, + int totalChunks, + string checksum) + { + // Create chunk metadata + var metadata = chunkIndex == 0 + ? ChunkMetadata.CreateFirstChunk(messageId, totalChunks, checksum) + : ChunkMetadata.CreateSubsequentChunk(messageId, chunkIndex); + + // Serialize the metadata to JSON + var metadataJson = JsonSerializer.Serialize(metadata); + + // Create user properties for this chunk + var chunkUserProperties = new List(userProperties) + { + // Add the chunk metadata property + new(ChunkingConstants.ChunkUserProperty, metadataJson) + }; + + // Create a message for this chunk + return new MqttApplicationMessage(originalMessage.Topic, originalMessage.QualityOfServiceLevel) + { + Retain = originalMessage.Retain, + Payload = chunkPayload, + ContentType = originalMessage.ContentType, + ResponseTopic = originalMessage.ResponseTopic, + CorrelationData = originalMessage.CorrelationData, + PayloadFormatIndicator = originalMessage.PayloadFormatIndicator, + MessageExpiryInterval = originalMessage.MessageExpiryInterval, + TopicAlias = originalMessage.TopicAlias, + SubscriptionIdentifiers = originalMessage.SubscriptionIdentifiers, + UserProperties = chunkUserProperties + }; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingChecksumAlgorithm.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingChecksumAlgorithm.cs new file mode 100644 index 0000000000..d9f81be4cb --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingChecksumAlgorithm.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// Available checksum algorithms for chunk message integrity verification. +/// +public enum ChunkingChecksumAlgorithm +{ + /// + /// MD5 algorithm - 128-bit hash, good performance but not cryptographically secure + /// + MD5, + + /// + /// SHA-256 algorithm - 256-bit hash, cryptographically secure but larger output size + /// + SHA256 +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingConstants.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingConstants.cs new file mode 100644 index 0000000000..a80ec0393b --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingConstants.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// Constants used for the MQTT message chunking feature. +/// +//TODO: @maximsemenov80 public for testing purposes, should be internal +public static class ChunkingConstants +{ + /// + /// The user property name used to store chunking metadata. + /// + public const string ChunkUserProperty = "__chunk"; + + /// + /// JSON field name for the unique message identifier within the chunk metadata. + /// + public const string MessageIdField = "messageId"; + + /// + /// JSON field name for the chunk index within the chunk metadata. + /// + public const string ChunkIndexField = "chunkIndex"; + + /// + /// JSON field name for the total number of chunks within the chunk metadata. + /// + public const string TotalChunksField = "totalChunks"; + + /// + /// JSON field name for the message checksum within the chunk metadata. + /// + public const string ChecksumField = "checksum"; + + /// + /// Default static overhead value subtracted from the maximum packet size. + /// This accounts for MQTT packet headers, topic name, and other metadata. + /// + public const int DefaultStaticOverhead = 1024; + + /// + /// Reason string for successful chunked message transmission. + /// + public const string ChunkedMessageSuccessReasonString = "Chunked message successfully sent"; +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingMqttPubSubClient.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingMqttPubSubClient.cs new file mode 100644 index 0000000000..2477b23a27 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingMqttPubSubClient.cs @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Iot.Operations.Protocol.Events; +using Azure.Iot.Operations.Protocol.Models; +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// MQTT client middleware that provides transparent chunking of large messages. +/// +public class ChunkingMqttPubSubClient : IMqttPubSubClient +{ + private readonly IExtendedPubSubMqttClient _innerClient; + private readonly ChunkingOptions _chunkingOptions; + private readonly ConcurrentDictionary _messageAssemblers = new(); + private readonly ChunkedMessageSplitter _messageSplitter; + private int _maxPacketSize; + private readonly Timer? _cleanupTimer; + + /// + /// Initializes a new instance of the class. + /// + /// The MQTT client to wrap with chunking capabilities. + /// The chunking options. + public ChunkingMqttPubSubClient(IExtendedPubSubMqttClient innerClient, ChunkingOptions? options = null) + { + _innerClient = innerClient ?? throw new ArgumentNullException(nameof(innerClient)); + _chunkingOptions = options ?? new ChunkingOptions(); + _messageSplitter = new ChunkedMessageSplitter(_chunkingOptions); + + UpdateMaxPacketSizeFromConnectResult(_innerClient.GetConnectResult()); + + _innerClient.ApplicationMessageReceivedAsync += HandleApplicationMessageReceivedAsync; + + // Start the cleanup timer + _cleanupTimer = new Timer( + _ => CleanupExpiredAssemblers(), + null, + TimeSpan.FromMinutes(1), + TimeSpan.FromMinutes(1)); + } + + public event Func? ApplicationMessageReceivedAsync; + + /// + public async Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken = default) + { + // If chunking is disabled or the message is small enough, pass through to the inner client + if (!_chunkingOptions.Enabled || applicationMessage.Payload.Length <= Utils.GetMaxChunkSize(_maxPacketSize, _chunkingOptions.StaticOverhead)) + { + return await _innerClient.PublishAsync(applicationMessage, cancellationToken).ConfigureAwait(false); + } + + return await PublishChunkedMessageAsync(applicationMessage, cancellationToken).ConfigureAwait(false); + } + + /// + public Task SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken = default) + { + return _innerClient.SubscribeAsync(options, cancellationToken); + } + + /// + public Task UnsubscribeAsync(MqttClientUnsubscribeOptions options, CancellationToken cancellationToken = default) + { + return _innerClient.UnsubscribeAsync(options, cancellationToken); + } + + public string? ClientId => _innerClient.ClientId; + + public MqttProtocolVersion ProtocolVersion => _innerClient.ProtocolVersion; + + public ValueTask DisposeAsync(bool disposing) + { + return _innerClient.DisposeAsync(disposing); + } + + /// + public ValueTask DisposeAsync() + { + // Clean up resources + _messageAssemblers.Clear(); + + // Dispose cleanup timer + _cleanupTimer?.Dispose(); + + // Detach events + _innerClient.ApplicationMessageReceivedAsync -= HandleApplicationMessageReceivedAsync; + + // Suppress finalization since we're explicitly disposing + GC.SuppressFinalize(this); + + return _innerClient.DisposeAsync(); + } + + private void UpdateMaxPacketSizeFromConnectResult(MqttClientConnectResult? result) + { + if (_chunkingOptions.Enabled && result?.MaximumPacketSize is not > 0) + { + throw new InvalidOperationException("Chunking client requires a defined maximum packet size to function properly."); + } + + // _maxPacketSize = (int)result!.MaximumPacketSize!.Value; + _maxPacketSize = 64 * 1024; + } + + private async Task PublishChunkedMessageAsync(MqttApplicationMessage message, CancellationToken cancellationToken) + { + // Use the message splitter to split the message into chunks + var chunks = _messageSplitter.SplitMessage(message, _maxPacketSize); + + // Publish each chunk + foreach (var chunk in chunks) + { + await _innerClient.PublishAsync(chunk, cancellationToken).ConfigureAwait(false); + } + + // Return a successful result + return new MqttClientPublishResult( + null, + MqttClientPublishReasonCode.Success, + ChunkingConstants.ChunkedMessageSuccessReasonString, + new List(message.UserProperties ?? Enumerable.Empty())); + } + + private async Task HandleApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventArgs args) + { + // Check if this is a chunked message + var onApplicationMessageReceivedAsync = ApplicationMessageReceivedAsync; + if (!TryGetChunkMetadata(args.ApplicationMessage, out var chunkMetadata)) + { + // Not a chunked message, pass it through + if (onApplicationMessageReceivedAsync != null) + { + await onApplicationMessageReceivedAsync.Invoke(args).ConfigureAwait(false); + } + + return; + } + + // This is a chunked message, handle the reassembly + if (TryProcessChunk(args, chunkMetadata!, out var reassembledArgs)) + { + // We have a complete message, invoke the event + if (onApplicationMessageReceivedAsync != null && reassembledArgs != null) + { + await onApplicationMessageReceivedAsync.Invoke(reassembledArgs).ConfigureAwait(false); + } + } + else + { + // Acknowledge the chunk but don't pass it to the application yet + await args.AcknowledgeAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + private bool TryProcessChunk( + MqttApplicationMessageReceivedEventArgs args, + ChunkMetadata metadata, + out MqttApplicationMessageReceivedEventArgs? reassembledArgs) + { + reassembledArgs = null; + + // Check global buffer size limit before processing + if (_chunkingOptions.ReassemblyBufferSizeLimit > 0) + { + var currentTotalBufferSize = CalculateTotalBufferSize(); + var chunkSize = args.ApplicationMessage.Payload.Length; + + // If adding this chunk would exceed the global limit, reject it + if (currentTotalBufferSize + chunkSize > _chunkingOptions.ReassemblyBufferSizeLimit) + { + // Log or handle buffer limit exceeded (could throw exception or return false) + return false; + } + } + + // Get or create the message assembler + var assembler = _messageAssemblers.GetOrAdd( + metadata.MessageId, + _ => new ChunkedMessageAssembler(metadata.TotalChunks ?? 0, _chunkingOptions.ChecksumAlgorithm)); + + // Add this chunk to the assembler + if (assembler.AddChunk(metadata.ChunkIndex, args)) + { + // If this was the first chunk, update total chunks, checksum, and extract timeout from MessageExpiryInterval + if (metadata.ChunkIndex == 0 && metadata.TotalChunks.HasValue) + { + var timeout = args.ApplicationMessage.MessageExpiryInterval > 0 + ? TimeSpan.FromSeconds(args.ApplicationMessage.MessageExpiryInterval) + : (TimeSpan?)null; + assembler.UpdateMetadata(metadata.TotalChunks.Value, metadata.Checksum, timeout); + } + + // Check if we have all the chunks + if (assembler.IsComplete && assembler.TryReassemble(out reassembledArgs)) + { + // Remove the assembler + _messageAssemblers.TryRemove(metadata.MessageId, out _); + return true; + } + } + + return false; + } + + private static bool TryGetChunkMetadata(MqttApplicationMessage message, out ChunkMetadata? metadata) + { + metadata = null; + + if (message.UserProperties == null) + { + return false; + } + + var chunkProperty = message.UserProperties + .FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty) + ?.Value; + + if (string.IsNullOrEmpty(chunkProperty)) + { + return false; + } + + try + { + metadata = JsonSerializer.Deserialize(chunkProperty); + return metadata != null; + } + catch (JsonException) + { + return false; + } + } + + /// + /// Cleans up expired message assemblers to prevent memory leaks. + /// + private void CleanupExpiredAssemblers() + { + var expiredKeys = new List(); + + foreach (var kvp in _messageAssemblers) + { + if (kvp.Value.HasExpired()) + { + expiredKeys.Add(kvp.Key); + } + } + + foreach (var key in expiredKeys) + { + _messageAssemblers.TryRemove(key, out _); + } + } + + /// + /// Calculates the total buffer size across all active message assemblers. + /// + /// The total buffer size in bytes. + private long CalculateTotalBufferSize() + { + return _messageAssemblers.Values.Sum(assembler => assembler.CurrentBufferSize); + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingOptions.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingOptions.cs new file mode 100644 index 0000000000..2ce95f939e --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/ChunkingOptions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Iot.Operations.Protocol.Chunking; + +/// +/// Configuration options for the MQTT message chunking feature. +/// +public class ChunkingOptions +{ + /// + /// Gets or sets whether chunking is enabled. + /// + public bool Enabled { get; set; } + + /// + /// Gets or sets the static overhead value subtracted from the MQTT maximum packet size + /// to account for headers, topic names, and other metadata. + /// + public int StaticOverhead { get; set; } = ChunkingConstants.DefaultStaticOverhead; + + /// + /// Gets or sets the checksum algorithm to use for message integrity verification. + /// + public ChunkingChecksumAlgorithm ChecksumAlgorithm { get; set; } = ChunkingChecksumAlgorithm.SHA256; + + /// + /// Gets or sets the maximum total size (in bytes) of all chunk payloads that can be buffered + /// simultaneously during message reassembly. When this limit is exceeded, new chunks will be rejected. + /// A value of 0 or negative means no limit. + /// + public long ReassemblyBufferSizeLimit { get; set; } = 10 * 1024 * 1024; // 10 MB default +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/BufferLimitExceededError.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/BufferLimitExceededError.cs new file mode 100644 index 0000000000..1e49af1a7b --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/BufferLimitExceededError.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Azure.Iot.Operations.Protocol.Chunking.Exceptions; + +/// +/// Exception thrown when the reassembly buffer size limit is exceeded. +/// +public class BufferLimitExceededError : ChunkingException +{ + /// + /// Gets the current total buffer size across all active message assemblers. + /// + public long CurrentBufferSize { get; } + + /// + /// Gets the configured buffer size limit that was exceeded. + /// + public long BufferLimit { get; } + + /// + /// Gets the size of the chunk that would have exceeded the limit. + /// + public long ChunkSize { get; } + + /// + /// Gets the number of active message assemblers when the limit was exceeded. + /// + public int ActiveAssemblers { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The message ID of the chunk that would exceed the limit. + /// The index of the chunk that would exceed the limit. + /// The current total buffer size. + /// The configured buffer size limit. + /// The size of the chunk that would exceed the limit. + /// The number of active message assemblers. + /// The inner exception, if any. + public BufferLimitExceededError( + string messageId, + int chunkIndex, + long currentBufferSize, + long bufferLimit, + long chunkSize, + int activeAssemblers, + Exception? innerException = null) + : base(messageId, + $"Reassembly buffer limit exceeded. Current: {currentBufferSize:N0} bytes, Limit: {bufferLimit:N0} bytes, Chunk size: {chunkSize:N0} bytes, Active assemblers: {activeAssemblers}", + chunkIndex, + innerException) + { + CurrentBufferSize = currentBufferSize; + BufferLimit = bufferLimit; + ChunkSize = chunkSize; + ActiveAssemblers = activeAssemblers; + } + + /// + /// Gets the amount by which the buffer limit would be exceeded. + /// + public long ExcessBytes => CurrentBufferSize + ChunkSize - BufferLimit; + + /// + /// Gets the current buffer utilization as a percentage. + /// + public double BufferUtilizationPercent => (double)CurrentBufferSize / BufferLimit * 100.0; + + /// + /// Gets the buffer utilization percentage if the chunk were accepted. + /// + public double ProjectedUtilizationPercent => (double)(CurrentBufferSize + ChunkSize) / BufferLimit * 100.0; +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChecksumMismatchError.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChecksumMismatchError.cs new file mode 100644 index 0000000000..32667fc5c8 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChecksumMismatchError.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Azure.Iot.Operations.Protocol.Chunking.Exceptions; + +/// +/// Exception thrown when the reassembled message checksum doesn't match the expected checksum. +/// +public class ChecksumMismatchError : ChunkingException +{ + /// + /// Gets the expected checksum from the first chunk. + /// + public string ExpectedChecksum { get; } + + /// + /// Gets the actual checksum calculated from the reassembled payload. + /// + public string ActualChecksum { get; } + + /// + /// Gets the size of the reassembled payload. + /// + public long PayloadSize { get; } + + /// + /// Gets the checksum algorithm that was used. + /// + public ChunkingChecksumAlgorithm ChecksumAlgorithm { get; } + + /// + /// Gets the total number of chunks that were reassembled. + /// + public int TotalChunks { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The message ID with the checksum mismatch. + /// The expected checksum from the first chunk. + /// The actual checksum calculated from the reassembled payload. + /// The size of the reassembled payload. + /// The checksum algorithm that was used. + /// The total number of chunks that were reassembled. + /// The inner exception, if any. + public ChecksumMismatchError( + string messageId, + string expectedChecksum, + string actualChecksum, + long payloadSize, + ChunkingChecksumAlgorithm checksumAlgorithm, + int totalChunks, + Exception? innerException = null) + : base(messageId, + $"Checksum verification failed. Expected: {expectedChecksum}, Actual: {actualChecksum}, Algorithm: {checksumAlgorithm}, Payload size: {payloadSize:N0} bytes, Chunks: {totalChunks}", + null, + innerException) + { + ExpectedChecksum = expectedChecksum ?? throw new ArgumentNullException(nameof(expectedChecksum)); + ActualChecksum = actualChecksum ?? throw new ArgumentNullException(nameof(actualChecksum)); + PayloadSize = payloadSize; + ChecksumAlgorithm = checksumAlgorithm; + TotalChunks = totalChunks; + } + + /// + /// Gets a value indicating whether the checksum mismatch might be due to data corruption. + /// This is a heuristic based on the difference between expected and actual checksums. + /// + public bool PossibleDataCorruption => !string.Equals(ExpectedChecksum, ActualChecksum, StringComparison.OrdinalIgnoreCase); + + /// + /// Gets diagnostic information about the checksum mismatch. + /// + /// A string containing diagnostic information. + public string GetDiagnosticInfo() + { + return $"Checksum Mismatch Diagnostics:\n" + + $" Message ID: {MessageId}\n" + + $" Expected: {ExpectedChecksum}\n" + + $" Actual: {ActualChecksum}\n" + + $" Algorithm: {ChecksumAlgorithm}\n" + + $" Payload Size: {PayloadSize:N0} bytes\n" + + $" Total Chunks: {TotalChunks}\n" + + $" Possible Corruption: {PossibleDataCorruption}"; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkAssemblyError.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkAssemblyError.cs new file mode 100644 index 0000000000..477c828414 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkAssemblyError.cs @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; + +namespace Azure.Iot.Operations.Protocol.Chunking.Exceptions; + +/// +/// Exception thrown when chunk assembly fails due to malformed chunks or other assembly issues. +/// +public class ChunkAssemblyError : ChunkingException +{ + /// + /// Gets detailed error information about the assembly failure. + /// + public string ErrorDetails { get; } + + /// + /// Gets the type of assembly error that occurred. + /// + public ChunkAssemblyErrorType ErrorType { get; } + + /// + /// Gets additional context about the assembly state when the error occurred. + /// + public Dictionary Context { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The message ID that failed to assemble. + /// The chunk index where the error occurred, if applicable. + /// The type of assembly error. + /// Detailed error information. + /// Additional context about the assembly state. + /// The inner exception, if any. + public ChunkAssemblyError( + string messageId, + int? chunkIndex, + ChunkAssemblyErrorType errorType, + string errorDetails, + Dictionary? context = null, + Exception? innerException = null) + : base(messageId, + $"Chunk assembly failed: {errorType}. {errorDetails}", + chunkIndex, + innerException) + { + ErrorDetails = errorDetails ?? throw new ArgumentNullException(nameof(errorDetails)); + ErrorType = errorType; + Context = context ?? new Dictionary(); + } + + /// + /// Creates a ChunkAssemblyError for malformed chunk metadata. + /// + /// The message ID. + /// The chunk index. + /// Description of the metadata issue. + /// The inner exception, if any. + /// A new ChunkAssemblyError instance. + public static ChunkAssemblyError MalformedMetadata(string messageId, int chunkIndex, string metadataIssue, Exception? innerException = null) + { + return new ChunkAssemblyError( + messageId, + chunkIndex, + ChunkAssemblyErrorType.MalformedMetadata, + $"Chunk metadata is malformed: {metadataIssue}", + new Dictionary { { "MetadataIssue", metadataIssue } }, + innerException); + } + + /// + /// Creates a ChunkAssemblyError for duplicate chunks. + /// + /// The message ID. + /// The duplicate chunk index. + /// The inner exception, if any. + /// A new ChunkAssemblyError instance. + public static ChunkAssemblyError DuplicateChunk(string messageId, int chunkIndex, Exception? innerException = null) + { + return new ChunkAssemblyError( + messageId, + chunkIndex, + ChunkAssemblyErrorType.DuplicateChunk, + $"Duplicate chunk received for index {chunkIndex}", + new Dictionary { { "DuplicateIndex", chunkIndex } }, + innerException); + } + + /// + /// Creates a ChunkAssemblyError for invalid chunk order. + /// + /// The message ID. + /// The out-of-order chunk index. + /// The expected chunk index range. + /// The inner exception, if any. + /// A new ChunkAssemblyError instance. + public static ChunkAssemblyError InvalidChunkOrder(string messageId, int chunkIndex, string expectedRange, Exception? innerException = null) + { + return new ChunkAssemblyError( + messageId, + chunkIndex, + ChunkAssemblyErrorType.InvalidChunkOrder, + $"Chunk index {chunkIndex} is outside expected range: {expectedRange}", + new Dictionary + { + { "ChunkIndex", chunkIndex }, + { "ExpectedRange", expectedRange } + }, + innerException); + } + + /// + /// Creates a ChunkAssemblyError for payload serialization failures. + /// + /// The message ID. + /// Description of the serialization error. + /// The inner exception, if any. + /// A new ChunkAssemblyError instance. + public static ChunkAssemblyError PayloadSerialization(string messageId, string serializationError, Exception? innerException = null) + { + return new ChunkAssemblyError( + messageId, + null, + ChunkAssemblyErrorType.PayloadSerialization, + $"Failed to serialize reassembled payload: {serializationError}", + new Dictionary { { "SerializationError", serializationError } }, + innerException); + } +} + +/// +/// Defines the types of chunk assembly errors that can occur. +/// +public enum ChunkAssemblyErrorType +{ + /// + /// The chunk metadata is malformed or invalid. + /// + MalformedMetadata, + + /// + /// A duplicate chunk was received. + /// + DuplicateChunk, + + /// + /// Chunks were received in an invalid order or with invalid indices. + /// + InvalidChunkOrder, + + /// + /// Failed to serialize the reassembled payload. + /// + PayloadSerialization, + + /// + /// A general assembly error occurred. + /// + General +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkTimeoutError.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkTimeoutError.cs new file mode 100644 index 0000000000..119b7b6403 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkTimeoutError.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; + +namespace Azure.Iot.Operations.Protocol.Chunking.Exceptions; + +/// +/// Exception thrown when a chunked message assembly times out before all chunks are received. +/// +public class ChunkTimeoutError : ChunkingException +{ + /// + /// Gets the total number of chunks expected for the message. + /// + public int ExpectedChunks { get; } + + /// + /// Gets the number of chunks that were actually received before the timeout. + /// + public int ReceivedChunks { get; } + + /// + /// Gets the timeout duration that was exceeded. + /// + public TimeSpan TimeoutDuration { get; } + + /// + /// Gets the time when the first chunk was received. + /// + public DateTime FirstChunkReceived { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The message ID that timed out. + /// The total number of chunks expected. + /// The number of chunks received before timeout. + /// The timeout duration that was exceeded. + /// The time when the first chunk was received. + /// The inner exception, if any. + public ChunkTimeoutError( + string messageId, + int expectedChunks, + int receivedChunks, + TimeSpan timeoutDuration, + DateTime firstChunkReceived, + Exception? innerException = null) + : base(messageId, + $"Chunked message assembly timed out. Expected {expectedChunks} chunks, received {receivedChunks} chunks. Timeout: {timeoutDuration.TotalSeconds:F1}s", + null, + innerException) + { + ExpectedChunks = expectedChunks; + ReceivedChunks = receivedChunks; + TimeoutDuration = timeoutDuration; + FirstChunkReceived = firstChunkReceived; + } + + /// + /// Gets the missing chunk indices that were not received before the timeout. + /// + /// The indices of chunks that were received. + /// An array of missing chunk indices. + public int[] GetMissingChunkIndices(int[] receivedChunkIndices) + { + var missing = new List(); + var receivedSet = new HashSet(receivedChunkIndices); + + for (int i = 0; i < ExpectedChunks; i++) + { + if (!receivedSet.Contains(i)) + { + missing.Add(i); + } + } + + return missing.ToArray(); + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkingException.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkingException.cs new file mode 100644 index 0000000000..d137a7a5b7 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Exceptions/ChunkingException.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Azure.Iot.Operations.Protocol.Chunking.Exceptions; + +/// +/// Base exception class for all chunking-related errors. +/// +public abstract class ChunkingException : Exception +{ + /// + /// Gets the message ID associated with the chunked message that caused the error. + /// + public string MessageId { get; } + + /// + /// Gets the chunk index that caused the error, if applicable. + /// + public int? ChunkIndex { get; } + + /// + /// Gets the timestamp when the error occurred. + /// + public DateTime Timestamp { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The message ID associated with the error. + /// The error message. + /// The chunk index that caused the error, if applicable. + /// The inner exception, if any. + protected ChunkingException(string messageId, string message, int? chunkIndex = null, Exception? innerException = null) + : base(message, innerException) + { + MessageId = messageId ?? throw new ArgumentNullException(nameof(messageId)); + ChunkIndex = chunkIndex; + Timestamp = DateTime.UtcNow; + } + + /// + /// Returns a string that represents the current exception. + /// + /// A string representation of the exception. + public override string ToString() + { + var result = $"{GetType().Name}: {Message} (MessageId: {MessageId}"; + + if (ChunkIndex.HasValue) + { + result += $", ChunkIndex: {ChunkIndex.Value}"; + } + + result += $", Timestamp: {Timestamp:yyyy-MM-dd HH:mm:ss} UTC)"; + + if (InnerException != null) + { + result += $"\n ---> {InnerException}"; + } + + return result; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Utils.cs b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Utils.cs new file mode 100644 index 0000000000..94dba63fe3 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/Chunking/Utils.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Azure.Iot.Operations.Protocol.Chunking; + +public static class Utils +{ + /// + /// Calculates the maximum size for a message chunk based on max packet size and overhead. + /// + /// The maximum packet size allowed by the broker. + /// The static overhead to account for in each chunk. + /// The maximum size that can be used for a message chunk. + public static int GetMaxChunkSize(int maxPacketSize, int staticOverhead) + { + ArgumentOutOfRangeException.ThrowIfLessThanOrEqual(maxPacketSize, staticOverhead); + return maxPacketSize - staticOverhead; + } +} diff --git a/dotnet/src/Azure.Iot.Operations.Protocol/IExtendedPubSubMqttClient.cs b/dotnet/src/Azure.Iot.Operations.Protocol/IExtendedPubSubMqttClient.cs new file mode 100644 index 0000000000..cc2e0738e5 --- /dev/null +++ b/dotnet/src/Azure.Iot.Operations.Protocol/IExtendedPubSubMqttClient.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Iot.Operations.Protocol.Models; + +namespace Azure.Iot.Operations.Protocol; + +public interface IExtendedPubSubMqttClient : IMqttPubSubClient +{ + MqttClientConnectResult? GetConnectResult(); +} diff --git a/dotnet/test/Azure.Iot.Operations.Protocol.IntegrationTests/ChunkingMqttClientIntegrationTests.cs b/dotnet/test/Azure.Iot.Operations.Protocol.IntegrationTests/ChunkingMqttClientIntegrationTests.cs new file mode 100644 index 0000000000..c18fa0d5ca --- /dev/null +++ b/dotnet/test/Azure.Iot.Operations.Protocol.IntegrationTests/ChunkingMqttClientIntegrationTests.cs @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.Iot.Operations.Protocol.Chunking; +using Azure.Iot.Operations.Protocol.Models; +using System.Buffers; + +namespace Azure.Iot.Operations.Protocol.IntegrationTests +{ + public class ChunkingMqttClientTests + { + [Fact] + public async Task ChunkingMqttClient_SmallMessage_NoChunking() + { + // Arrange + // Create a base client + await using var mqttClient = await ClientFactory.CreateExtendedClientAsyncFromEnvAsync(Guid.NewGuid().ToString()); + + // Create a chunking client with modest settings + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 500 // Use modest overhead to ensure small messages aren't chunked + }; + + await using var chunkingClient = new ChunkingMqttPubSubClient(mqttClient, options); + + var messageReceivedTcs = new TaskCompletionSource(); + chunkingClient.ApplicationMessageReceivedAsync += (args) => + { + messageReceivedTcs.TrySetResult(args.ApplicationMessage); + return Task.CompletedTask; + }; + + // Subscribe to a unique topic + var topic = $"chunking/test/{Guid.NewGuid()}"; + await chunkingClient.SubscribeAsync(new MqttClientSubscribeOptions(topic, MqttQualityOfServiceLevel.AtLeastOnce)); + + // Create a small message - 100 bytes payload + var smallPayload = new byte[100]; + Random.Shared.NextBytes(smallPayload); + + var message = new MqttApplicationMessage(topic, MqttQualityOfServiceLevel.AtLeastOnce) + { + Payload = new ReadOnlySequence(smallPayload), + UserProperties = new List + { + new("testProperty", "testValue") + } + }; + + // Act + var publishResult = await chunkingClient.PublishAsync(message); + + // Wait for the message to be received - timeout after 10 seconds + MqttApplicationMessage? receivedMessage = null; + try + { + receivedMessage = await messageReceivedTcs.Task.WaitAsync(TimeSpan.FromSeconds(10)); + } + catch (TimeoutException) + { + Assert.Fail("Timed out waiting for the message to be received"); + } + + // Assert + Assert.NotNull(receivedMessage); + + // Verify payload is identical + Assert.Equal(smallPayload, receivedMessage.Payload.ToArray()); + + // Verify no chunking metadata was added + var chunkProperty = receivedMessage.UserProperties?.FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty); + Assert.Null(chunkProperty); + + // Verify original properties were preserved + var testProperty = receivedMessage.UserProperties?.FirstOrDefault(p => p.Name == "testProperty"); + Assert.NotNull(testProperty); + Assert.Equal("testValue", testProperty!.Value); + } + + [Fact] + public async Task ChunkingMqttClient_LargeMessage_ChunkingAndReassembly() + { + // Arrange + // Create a base client + await using var mqttClient = await ClientFactory.CreateExtendedClientAsyncFromEnvAsync(Guid.NewGuid().ToString()); + + // Create a chunking client with settings that force chunking + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 500 + }; + + await using var chunkingClient = new ChunkingMqttPubSubClient(mqttClient, options); + + var messageReceivedTcs = new TaskCompletionSource(); + chunkingClient.ApplicationMessageReceivedAsync += (args) => + { + messageReceivedTcs.TrySetResult(args.ApplicationMessage); + return Task.CompletedTask; + }; + + // Subscribe to a unique topic + var topic = $"chunking/test/{Guid.NewGuid()}"; + await chunkingClient.SubscribeAsync(new MqttClientSubscribeOptions(topic, MqttQualityOfServiceLevel.AtLeastOnce)); + + // TODO: @maximsemenov80 for the test purpose UpdateMaxPacketSizeFromConnectResult artificially set MaxPacketSize to 64KB + var largePayloadSize = 1024 * 1024; // 1MB + var largePayload = new byte[largePayloadSize]; + + // Fill with recognizable pattern for verification + for (int i = 0; i < largePayloadSize; i++) + { + largePayload[i] = (byte)(i % 256); + } + + var message = new MqttApplicationMessage(topic, MqttQualityOfServiceLevel.AtLeastOnce) + { + Payload = new ReadOnlySequence(largePayload), + UserProperties = new List + { + new("testProperty", "testValue") + } + }; + + // Act + var publishResult = await chunkingClient.PublishAsync(message); + + // Wait for the reassembled message to be received - timeout after 30 seconds + // Reassembly may take longer than a normal message + MqttApplicationMessage? receivedMessage = null; + try + { + receivedMessage = await messageReceivedTcs.Task.WaitAsync(TimeSpan.FromSeconds(30)); + } + catch (TimeoutException) + { + Assert.Fail("Timed out waiting for the reassembled message to be received"); + } + + // Assert + Assert.NotNull(receivedMessage); + + // Verify payload size is correct + Assert.Equal(largePayloadSize, receivedMessage.Payload.Length); + + // Verify payload content is identical + var reassembledPayload = receivedMessage.Payload.ToArray(); + Assert.Equal(largePayload, reassembledPayload); + + // Verify chunking metadata was removed + var chunkProperty = receivedMessage.UserProperties?.FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty); + Assert.Null(chunkProperty); + + // Verify original properties were preserved + var testProperty = receivedMessage.UserProperties?.FirstOrDefault(p => p.Name == "testProperty"); + Assert.NotNull(testProperty); + Assert.Equal("testValue", testProperty!.Value); + } + } +} diff --git a/dotnet/test/Azure.Iot.Operations.Protocol.IntegrationTests/ClientFactory.cs b/dotnet/test/Azure.Iot.Operations.Protocol.IntegrationTests/ClientFactory.cs index 4fc9826c4b..facc1b5a01 100644 --- a/dotnet/test/Azure.Iot.Operations.Protocol.IntegrationTests/ClientFactory.cs +++ b/dotnet/test/Azure.Iot.Operations.Protocol.IntegrationTests/ClientFactory.cs @@ -35,6 +35,29 @@ public static async Task CreateClientAsyncFromEnvAsync(str return orderedAckClient; } + public static async Task CreateExtendedClientAsyncFromEnvAsync(string clientId, bool withTraces = false, CancellationToken cancellationToken = default) + { + Debug.Assert(Environment.GetEnvironmentVariable("MQTT_TEST_BROKER_CS") != null); + string cs = $"{Environment.GetEnvironmentVariable("MQTT_TEST_BROKER_CS")}"; + MqttConnectionSettings mcs = MqttConnectionSettings.FromConnectionString(cs); + if (string.IsNullOrEmpty(clientId)) + { + mcs.ClientId += Guid.NewGuid(); + } + else + { + mcs.ClientId = clientId; + } + + MQTTnet.IMqttClient mqttClient = withTraces + ? new MQTTnet.MqttClientFactory().CreateMqttClient(MqttNetTraceLogger.CreateTraceLogger()) + : new MQTTnet.MqttClientFactory().CreateMqttClient(); + var extendedPubSubClient = new ExtendedPubSubMqttClient(mqttClient); + await extendedPubSubClient.ConnectAsync(new MqttClientOptions(mcs), cancellationToken); + + return extendedPubSubClient; + } + public static async Task CreateSessionClientForFaultableBrokerFromEnv(List? ConnectUserProperties = null, string? clientId = null) { if (string.IsNullOrEmpty(clientId)) diff --git a/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Azure.Iot.Operations.Protocol.UnitTests.csproj b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Azure.Iot.Operations.Protocol.UnitTests.csproj index 7f34e29c39..18d1f9cfc6 100644 --- a/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Azure.Iot.Operations.Protocol.UnitTests.csproj +++ b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Azure.Iot.Operations.Protocol.UnitTests.csproj @@ -19,6 +19,7 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -57,6 +58,10 @@ + + + + $(MSBuildProjectDirectory)\..\..\MSSharedLibKey.snk diff --git a/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkedMessageAssemblerTests.cs b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkedMessageAssemblerTests.cs new file mode 100644 index 0000000000..6d2f8e0009 --- /dev/null +++ b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkedMessageAssemblerTests.cs @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Azure.Iot.Operations.Protocol.Chunking; +using Azure.Iot.Operations.Protocol.Events; +using Azure.Iot.Operations.Protocol.Models; +using Moq; +using Xunit; + +namespace Azure.Iot.Operations.Protocol.UnitTests.Chunking +{ + public class ChunkedMessageAssemblerTests + { + [Fact] + public void Constructor_SetsProperties_Correctly() + { + // Arrange & Act + var assembler = new ChunkedMessageAssembler(5, ChunkingChecksumAlgorithm.SHA256); + + // Assert + Assert.False(assembler.IsComplete); + } + + [Fact] + public void AddChunk_ReturnsTrueForNewChunk_FalseForDuplicate() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + var chunk0 = CreateMqttMessageEventArgs("payload1"); + + // Act & Assert + Assert.True(assembler.AddChunk(0, chunk0)); // First time should return true + Assert.False(assembler.AddChunk(0, chunk0)); // Second time should return false (duplicate) + } + + [Fact] + public void IsComplete_ReturnsTrueWhenAllChunksReceived() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + var chunk0 = CreateMqttMessageEventArgs("payload1"); + var chunk1 = CreateMqttMessageEventArgs("payload2"); + + // Act + assembler.AddChunk(0, chunk0); + assembler.AddChunk(1, chunk1); + + // Assert + Assert.True(assembler.IsComplete); + } + + [Fact] + public void TryReassemble_ReturnsFalseWhenNotComplete() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + var chunk0 = CreateMqttMessageEventArgs("payload1"); + + // Act + assembler.AddChunk(0, chunk0); + var result = assembler.TryReassemble(out var reassembledArgs); + + // Assert + Assert.False(result); + Assert.Null(reassembledArgs); + } + + [Fact] + public void TryReassemble_ReturnsValidMessageWhenComplete() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + var chunk0 = CreateMqttMessageEventArgs("payload1"); + var chunk1 = CreateMqttMessageEventArgs(" payload2"); + + // Act + assembler.AddChunk(0, chunk0); + assembler.AddChunk(1, chunk1); + var result = assembler.TryReassemble(out var reassembledArgs); + + // Assert + Assert.True(result); + Assert.NotNull(reassembledArgs); + + // Convert payload to string for easier assertion + var payload = reassembledArgs!.ApplicationMessage.Payload; + var assembledPayloadAsString = ""; + foreach (var segment in payload) + { + assembledPayloadAsString += Encoding.UTF8.GetString(segment.Span); + } + + Assert.Equal("payload1 payload2", assembledPayloadAsString); + } + + [Fact] + public void TryReassemble_ChecksumVerification_Success() + { + // Arrange + var payload1 = "payload1"; + var payload2 = "payload2"; + var combined = payload1 + payload2; + var combinedBytes = Encoding.UTF8.GetBytes(combined); + var ros = new ReadOnlySequence(combinedBytes); + + // Calculate the actual checksum + var checksum = ChecksumCalculator.CalculateChecksum(ros, ChunkingChecksumAlgorithm.SHA256); + + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + assembler.UpdateMetadata(2, checksum, null); // Set the correct checksum + + var chunk0 = CreateMqttMessageEventArgs(payload1); + var chunk1 = CreateMqttMessageEventArgs(payload2); + + // Act + assembler.AddChunk(0, chunk0); + assembler.AddChunk(1, chunk1); + var result = assembler.TryReassemble(out var reassembledArgs); + + // Assert + Assert.True(result); + Assert.NotNull(reassembledArgs); + } + + [Fact] + public void TryReassemble_ChecksumVerification_Failure() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + assembler.UpdateMetadata(2, "invalid-checksum", null); // Set incorrect checksum + + var chunk0 = CreateMqttMessageEventArgs("payload1"); + var chunk1 = CreateMqttMessageEventArgs("payload2"); + + // Act + assembler.AddChunk(0, chunk0); + assembler.AddChunk(1, chunk1); + var result = assembler.TryReassemble(out var reassembledArgs); + + // Assert + Assert.False(result); + Assert.Null(reassembledArgs); + } + + [Fact] + public void HasExpired_ReturnsTrueWhenTimeoutExceeded() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + var shortTimeout = TimeSpan.FromMilliseconds(1); + + // Set timeout via metadata update + assembler.UpdateMetadata(2, "test-checksum", shortTimeout); + + // Act + Thread.Sleep(10); // Ensure timeout is exceeded + var result = assembler.HasExpired(); + + // Assert + Assert.True(result); + } + + [Fact] + public void HasExpired_ReturnsFalseWhenTimeoutNotExceeded() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + var longTimeout = TimeSpan.FromMinutes(5); + + // Set timeout via metadata update + assembler.UpdateMetadata(2, "test-checksum", longTimeout); + + // Act + var result = assembler.HasExpired(); + + // Assert + Assert.False(result); + } + + [Fact] + public void HasExpired_ReturnsFalseWhenNoTimeoutSet() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + + // Don't set any timeout via metadata update + + // Act + var result = assembler.HasExpired(); + + // Assert + Assert.False(result); + } + + [Fact] + public async Task AcknowledgeHandler_Calls_AcknowledgeAsync_On_All_Chunks() + { + // Arrange + var assembler = new ChunkedMessageAssembler(2, ChunkingChecksumAlgorithm.SHA256); + var chunk0AckCount = false; + var chunk1AckCount = false; + + // Create mock message args with mock acknowledgeAsync methods + var chunk0 = CreateMqttMessageEventArgsWithAckHandler((_, _) => + { + chunk0AckCount = true; + return Task.CompletedTask; + }); + var chunk1 = CreateMqttMessageEventArgsWithAckHandler((_, _) => + { + chunk1AckCount = true; + return Task.CompletedTask; + }); + + // Act + assembler.AddChunk(0, chunk0); + assembler.AddChunk(1, chunk1); + var result = assembler.TryReassemble(out var reassembledArgs); + + // Simulate acknowledgment of reassembled message + if (reassembledArgs != null) + { + await reassembledArgs.AcknowledgeAsync(CancellationToken.None); + } + + // Assert + Assert.True(result); + Assert.True(chunk0AckCount); + Assert.True(chunk1AckCount); + } + + // Helper method to create a simple MQTT message event args with payload + private static MqttApplicationMessageReceivedEventArgs CreateMqttMessageEventArgs(string payload) + { + var bytes = Encoding.UTF8.GetBytes(payload); + var mqttMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(bytes) + }; + + return new MqttApplicationMessageReceivedEventArgs( + "client1", + mqttMessage, + 1, + (_, _) => Task.CompletedTask); + } + + // Helper method to create a mock MQTT message event args + private static MqttApplicationMessageReceivedEventArgs CreateMqttMessageEventArgsWithAckHandler(Func acknowledgeHandler) + { + var bytes = "testpayload"u8.ToArray(); + var mqttMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(bytes) + }; + + var messageEventArgs = new MqttApplicationMessageReceivedEventArgs( + "client1", + mqttMessage, + 1, + acknowledgeHandler); + + return messageEventArgs; + } + } +} diff --git a/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkedMessageSplitterTests.cs b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkedMessageSplitterTests.cs new file mode 100644 index 0000000000..533c2141ee --- /dev/null +++ b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkedMessageSplitterTests.cs @@ -0,0 +1,390 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Buffers; +using System.Text.Json; +using Azure.Iot.Operations.Protocol.Chunking; +using Azure.Iot.Operations.Protocol.Events; +using Azure.Iot.Operations.Protocol.Models; + +namespace Azure.Iot.Operations.Protocol.UnitTests.Chunking; + +public class ChunkedMessageSplitterTests +{ + [Fact] + public void SplitMessage_SmallMessage_ThrowArgumentException() + { + // Arrange + var options = new ChunkingOptions { Enabled = true, StaticOverhead = 100 }; + var splitter = new ChunkedMessageSplitter(options); + + var payload = "Small message that doesn't need chunking"u8.ToArray(); + var originalMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(payload), + UserProperties = [new MqttUserProperty("originalProperty", "value")] + }; + + var maxPacketSize = 1000; // Large enough for the small message + + // Act & Assert + // This should throw an exception because the message is too small to be chunked + Assert.Throws(() => splitter.SplitMessage(originalMessage, maxPacketSize)); + } + + [Fact] + public void SplitMessage_LargeMessage_ReturnsMultipleChunks() + { + // Arrange + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 100, + ChecksumAlgorithm = ChunkingChecksumAlgorithm.SHA256 + }; + var splitter = new ChunkedMessageSplitter(options); + + // Create a large payload (2500 bytes) + var payloadSize = 2500; + var payload = new byte[payloadSize]; + Random.Shared.NextBytes(payload); + + + var originalMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(payload), + UserProperties = [new MqttUserProperty("originalProperty", "value")] + }; + + // Set a max packet size that will force chunking + // MaxChunkSize = MaxPacketSize - StaticOverhead = 900 + var maxPacketSize = 1000; + + // Act + var chunks = splitter.SplitMessage(originalMessage, maxPacketSize); + + // Assert + // Should have 3 chunks (2500 / 900 = 2.78 => 3 chunks) + Assert.Equal(3, chunks.Count); + + // Verify each chunk has the chunk metadata property + foreach (var chunk in chunks) + { + var chunkProperty = chunk.UserProperties?.FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty); + Assert.NotNull(chunkProperty); + + // Check that original properties are preserved + var originalProperty = chunk.UserProperties?.FirstOrDefault(p => p.Name == "originalProperty"); + Assert.NotNull(originalProperty); + Assert.Equal("value", originalProperty!.Value); + } + + // Verify the chunks contain all the original data + var totalSize = chunks.Sum(c => c.Payload.Length); + Assert.Equal(payloadSize, totalSize); + + // Reassemble and verify content + var reassembledPayload = new byte[payloadSize]; + var offset = 0; + + foreach (var chunk in chunks) + foreach (var segment in chunk.Payload) + { + segment.Span.CopyTo(reassembledPayload.AsSpan(offset)); + offset += segment.Length; + } + + Assert.Equal(payload, reassembledPayload); + } + + [Fact] + public void SplitMessage_VerifyChunkMetadata_IsCorrect() + { + // Arrange + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 100, + ChecksumAlgorithm = ChunkingChecksumAlgorithm.SHA256, + }; + var splitter = new ChunkedMessageSplitter(options); + + // Create a payload that needs to be split into exactly 2 chunks + var chunkSize = 900; // maxPacketSize - staticOverhead + var payloadSize = chunkSize + 100; // Just over one chunk + var payload = new byte[payloadSize]; + Random.Shared.NextBytes(payload); + + var originalMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(payload), + MessageExpiryInterval = 30u, // Set expiry interval to 30 seconds + }; + + var maxPacketSize = 1000; + + // Act + var chunks = splitter.SplitMessage(originalMessage, maxPacketSize); + + // Assert + Assert.Equal(2, chunks.Count); + + // Check first chunk metadata + var firstChunkProperty = chunks[0].UserProperties?.FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty); + Assert.NotNull(firstChunkProperty); + var firstChunkMetadata = JsonSerializer.Deserialize(firstChunkProperty!.Value); + + // First chunk should contain totalChunks and checksum + Assert.NotNull(firstChunkMetadata!.MessageId); + Assert.NotNull(firstChunkMetadata.TotalChunks); + Assert.NotNull(firstChunkMetadata.Checksum); + + Assert.Equal(0, firstChunkMetadata.ChunkIndex); + Assert.Equal(2, firstChunkMetadata.TotalChunks); + + // Check that MessageExpiryInterval is set (30 seconds) + Assert.Equal(30u, chunks[0].MessageExpiryInterval); + + // Get the messageId from the first chunk + var messageId = firstChunkMetadata.MessageId; + + // Check second chunk metadata + var secondChunkProperty = chunks[1].UserProperties?.FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty); + Assert.NotNull(secondChunkProperty); + var secondChunkMetadata = JsonSerializer.Deserialize(secondChunkProperty!.Value); + + // Second chunk should not contain totalChunks or checksum + Assert.NotNull(secondChunkMetadata); + Assert.NotNull(secondChunkMetadata!.MessageId); + Assert.Null(secondChunkMetadata.TotalChunks); + Assert.Null(secondChunkMetadata.Checksum); + + Assert.Equal(messageId, secondChunkMetadata.MessageId); + Assert.Equal(1, secondChunkMetadata.ChunkIndex); + + // Check that MessageExpiryInterval is set on second chunk too (30 seconds) + Assert.Equal(30u, chunks[1].MessageExpiryInterval); + } + + [Fact] + public void SplitMessage_ChecksumVerification_ValidChecksum() + { + // Arrange + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 10, + ChecksumAlgorithm = ChunkingChecksumAlgorithm.SHA256 + }; + var splitter = new ChunkedMessageSplitter(options); + + var payload = new byte[128]; + Random.Shared.NextBytes(payload); + var originalMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(payload) + }; + + // Force chunking by using a small max packet size + var maxPacketSize = 128; + + // Act + var chunks = splitter.SplitMessage(originalMessage, maxPacketSize); + + // Get the checksum from the first chunk + var firstChunkProperty = chunks[0].UserProperties?.FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty); + var firstChunkMetadata = JsonSerializer.Deserialize>(firstChunkProperty!.Value); + var checksum = firstChunkMetadata![ChunkingConstants.ChecksumField].GetString(); + + // Calculate the checksum directly using the same algorithm + var calculatedChecksum = ChecksumCalculator.CalculateChecksum( + new ReadOnlySequence(payload), + ChunkingChecksumAlgorithm.SHA256); + + // Assert + Assert.Equal(calculatedChecksum, checksum); + } + + [Fact] + public void SplitMessage_PreservesMessageProperties() + { + // Arrange + var options = new ChunkingOptions { Enabled = true, StaticOverhead = 100 }; + var splitter = new ChunkedMessageSplitter(options); + + // Create a large payload that needs chunking + var payloadSize = 1500; + var payload = new byte[payloadSize]; + + // Create a message with various properties + var originalMessage = new MqttApplicationMessage("test/topic", MqttQualityOfServiceLevel.ExactlyOnce) + { + Payload = new ReadOnlySequence(payload), + ContentType = "application/json", + Retain = true, + ResponseTopic = "response/topic", + CorrelationData = [1, 2, 3, 4, 5, 6, 7, 8, 9], + PayloadFormatIndicator = MqttPayloadFormatIndicator.Unspecified, + MessageExpiryInterval = 3600, + TopicAlias = 5, + SubscriptionIdentifiers = [1, 2, 3], + UserProperties = + [ + new MqttUserProperty("prop1", "value1"), + new MqttUserProperty("prop2", "value2") + ] + }; + + var maxPacketSize = 1000; + + // Act + var chunks = splitter.SplitMessage(originalMessage, maxPacketSize); + + // Assert - check that all chunks preserve the original message properties + foreach (var chunk in chunks) + { + // Check basic properties + Assert.Equal(originalMessage.Topic, chunk.Topic); + Assert.Equal(originalMessage.QualityOfServiceLevel, chunk.QualityOfServiceLevel); + Assert.Equal(originalMessage.ContentType, chunk.ContentType); + Assert.Equal(originalMessage.Retain, chunk.Retain); + Assert.Equal(originalMessage.ResponseTopic, chunk.ResponseTopic); + Assert.Equal(originalMessage.PayloadFormatIndicator, chunk.PayloadFormatIndicator); + Assert.Equal(originalMessage.MessageExpiryInterval, chunk.MessageExpiryInterval); + Assert.Equal(originalMessage.TopicAlias, chunk.TopicAlias); + + // Check correlation data + Assert.Equal(originalMessage.CorrelationData, chunk.CorrelationData); + + // Check subscription identifiers + Assert.Equal(originalMessage.SubscriptionIdentifiers, chunk.SubscriptionIdentifiers); + + // Check user properties (excluding the chunk property) + foreach (var originalProp in originalMessage.UserProperties!) + Assert.Contains(chunk.UserProperties!, p => + p.Name == originalProp.Name && p.Value == originalProp.Value); + } + } + + [Fact] + public void SplitMessage_NullMessage_ThrowsArgumentNullException() + { + // Arrange + var options = new ChunkingOptions { Enabled = true }; + var splitter = new ChunkedMessageSplitter(options); + + // Act & Assert + Assert.Throws(() => splitter.SplitMessage(null!, 1000)); + } + + [Fact] + public void Constructor_NullOptions_ThrowsArgumentNullException() + { + // Act & Assert + Assert.Throws(() => new ChunkedMessageSplitter(null!)); + } + + [Fact] + public void SplitMessage_MaxPacketSizeSmallerThanStaticOverhead_ThrowsArgumentOutOfRangeException() + { + // Arrange + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 1000 // Larger than max packet size + }; + var splitter = new ChunkedMessageSplitter(options); + + var payload = "Test message"u8.ToArray(); + var originalMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(payload) + }; + + var maxPacketSize = 500; // Smaller than the static overhead + + // Act & Assert + // This should not throw + Assert.Throws(() => splitter.SplitMessage(originalMessage, maxPacketSize)); + } + + [Fact] + public void Integration_SplitAndReassemble_RecoversOriginalMessage() + { + // Arrange + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 1024, // 1 KB + ChecksumAlgorithm = ChunkingChecksumAlgorithm.SHA256 + }; + + var splitter = new ChunkedMessageSplitter(options); + + // Create a test payload + var payloadSize = 1024 * 1024; // 1 MB + var payload = new byte[payloadSize]; + Random.Shared.NextBytes(payload); + + var originalMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(payload), + UserProperties = [new MqttUserProperty("originalProperty", "value")] + }; + + // Force chunking by using a small max packet size + var maxPacketSize = 2048; // 2 KB + + // Act - Split the message + var chunks = splitter.SplitMessage(originalMessage, maxPacketSize); + + // Now reassemble + var assembler = new ChunkedMessageAssembler(0, options.ChecksumAlgorithm); + + // Get metadata from first chunk + var firstChunkProperty = chunks[0].UserProperties!.First(p => p.Name == ChunkingConstants.ChunkUserProperty); + var firstChunkMetadata = JsonSerializer.Deserialize(firstChunkProperty.Value); + + var totalChunks = firstChunkMetadata!.TotalChunks!.Value; + var checksum = firstChunkMetadata.Checksum; + + // Update assembler with metadata + assembler.UpdateMetadata(totalChunks, checksum, null); + + // Add all chunks + foreach (var chunk in chunks) + { + // Extract chunk index from metadata + var chunkProperty = chunk.UserProperties!.First(p => p.Name == ChunkingConstants.ChunkUserProperty); + var chunkMetadata = JsonSerializer.Deserialize(chunkProperty.Value); + + // Simulate receiving the chunk + assembler.AddChunk(chunkMetadata!.ChunkIndex, CreateMqttMessageEventArgs(chunk)); + } + + // Try to reassemble + var success = assembler.TryReassemble(out var reassembledArgs); + + // Assert + Assert.True(success); + Assert.NotNull(reassembledArgs); + + // Verify the content is identical + Assert.Equal(payload, reassembledArgs!.ApplicationMessage.Payload.ToArray()); + + // Check that original properties are preserved but chunk metadata is removed + var properties = reassembledArgs.ApplicationMessage.UserProperties; + Assert.Contains(properties!, p => p.Name == "originalProperty" && p.Value == "value"); + Assert.DoesNotContain(properties!, p => p.Name == ChunkingConstants.ChunkUserProperty); + } + + // Helper method to create message event args for testing + private static MqttApplicationMessageReceivedEventArgs CreateMqttMessageEventArgs(MqttApplicationMessage message) + { + return new MqttApplicationMessageReceivedEventArgs( + "testClient", + message, + 1, + (_, _) => Task.CompletedTask); + } +} diff --git a/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkingMqttClientTests.cs b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkingMqttClientTests.cs new file mode 100644 index 0000000000..93b8816665 --- /dev/null +++ b/dotnet/test/Azure.Iot.Operations.Protocol.UnitTests/Chunking/ChunkingMqttClientTests.cs @@ -0,0 +1,307 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Buffers; +using System.Text; +using System.Text.Json; +using Azure.Iot.Operations.Protocol.Chunking; +using Azure.Iot.Operations.Protocol.Events; +using Azure.Iot.Operations.Protocol.Models; +using Moq; + +namespace Azure.Iot.Operations.Protocol.UnitTests.Chunking; + +public class ChunkingMqttClientTests +{ + [Fact] + public async Task PublishAsync_SmallMessage_PassesThroughToInnerClient() + { + // Arrange + var mockInnerClient = new Mock(); + var expectedResult = new MqttClientPublishResult( + null, + MqttClientPublishReasonCode.Success, + "No chunking result", + new List()); + + MqttApplicationMessage? capturedMessage = null; + mockInnerClient + .Setup(c => c.PublishAsync(It.IsAny(), It.IsAny())) + .Callback((msg, _) => capturedMessage = msg) + .ReturnsAsync(expectedResult); + + // Setup connection result with MaximumPacketSize to be large + uint? maxPacketSize = 10000; + var connectResult = new MqttClientConnectResult + { + IsSessionPresent = true, + ResultCode = MqttClientConnectResultCode.Success, + MaximumPacketSize = maxPacketSize, + UserProperties = new List() + }; + mockInnerClient.Setup(c => c.GetConnectResult()).Returns(connectResult); + + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 100 + }; + + var client = new ChunkingMqttPubSubClient(mockInnerClient.Object, options); + + // Create a small message that doesn't need chunking + var smallPayload = new byte[100]; + var smallMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(smallPayload) + }; + + // Act + var result = await client.PublishAsync(smallMessage, CancellationToken.None); + + // Assert + Assert.NotEqual(ChunkingConstants.ChunkedMessageSuccessReasonString, result.ReasonString); + Assert.NotNull(capturedMessage); + Assert.Same(smallMessage, capturedMessage); + } + + [Fact] + public async Task PublishAsync_LargeMessage_ChunksMessageAndSendsMultipleMessages() + { + // Arrange + var mockInnerClient = new Mock(); + var publishedMessages = new List(); + + var mqttClientPublishResult = new MqttClientPublishResult( + null, + MqttClientPublishReasonCode.Success, + "No chunking result", + new List()); + + mockInnerClient + .Setup(c => c.PublishAsync(It.IsAny(), It.IsAny())) + .Callback((msg, _) => publishedMessages.Add(msg)) + .ReturnsAsync(mqttClientPublishResult); + + var maxPacketSize = 64 * 1024; + var connectResult = new MqttClientConnectResult + { + IsSessionPresent = true, + ResultCode = MqttClientConnectResultCode.Success, + MaximumPacketSize = (uint)maxPacketSize, + MaximumQoS = MqttQualityOfServiceLevel.AtLeastOnce, + UserProperties = new List() + }; + mockInnerClient.Setup(c => c.GetConnectResult()).Returns(connectResult); + + var options = new ChunkingOptions + { + Enabled = true, + StaticOverhead = 500, + ChecksumAlgorithm = ChunkingChecksumAlgorithm.SHA256 + }; + + var client = new ChunkingMqttPubSubClient(mockInnerClient.Object, options); + + // Create a large message that needs chunking + // The max chunk size will be maxPacketSize - staticOverhead = 900 bytes + var largePayloadSize = 2 * 64 * 1024; // This should create 3 chunks + var largePayload = new byte[largePayloadSize]; + // Fill with identifiable content for later verification + for (var i = 0; i < largePayloadSize; i++) largePayload[i] = (byte)(i % 256); + + var largeMessage = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(largePayload) + }; + + // Act + var result = await client.PublishAsync(largeMessage, CancellationToken.None); + + // Assert + Assert.Equal(ChunkingConstants.ChunkedMessageSuccessReasonString, result.ReasonString); + + // Should have 3 chunks + Assert.Equal(3, publishedMessages.Count); + + // Verify all messages have the chunk metadata property + var messageIds = new HashSet(); + foreach (var msg in publishedMessages) + { + var chunkProperty = msg.UserProperties?.FirstOrDefault(p => p.Name == ChunkingConstants.ChunkUserProperty); + Assert.NotNull(chunkProperty); + + // Parse the metadata + var metadata = JsonSerializer.Deserialize(chunkProperty!.Value); + Assert.NotNull(metadata); + Assert.NotEmpty(metadata!.MessageId); + messageIds.Add(metadata.MessageId); + Assert.True(metadata.ChunkIndex >= 0); + + // First chunk should have totalChunks and checksum + if (metadata.ChunkIndex == 0) + { + Assert.NotNull(metadata.TotalChunks); + Assert.NotNull(metadata.Checksum); + Assert.Equal(3, metadata.TotalChunks); + } + } + + Assert.Single(messageIds); // All chunks should have the same messageId + + // Verify total payload size across all chunks equals original payload size + var totalChunkSize = publishedMessages.Sum(m => m.Payload.Length); + Assert.Equal(largePayloadSize, totalChunkSize); + } + + [Fact] + public async Task HandleApplicationMessageReceivedAsync_NonChunkedMessage_PassesThroughToHandler() + { + // Arrange + var mockInnerClient = new Mock(); + var handlerCalled = false; + var capturedArgs = default(MqttApplicationMessageReceivedEventArgs); + + var maxPacketSize = 1000; + var connectResult = new MqttClientConnectResult + { + IsSessionPresent = true, + ResultCode = MqttClientConnectResultCode.Success, + MaximumPacketSize = (uint)maxPacketSize, + MaximumQoS = MqttQualityOfServiceLevel.AtLeastOnce, + UserProperties = new List() + }; + mockInnerClient.Setup(c => c.GetConnectResult()).Returns(connectResult); + + var client = new ChunkingMqttPubSubClient(mockInnerClient.Object); + client.ApplicationMessageReceivedAsync += args => + { + handlerCalled = true; + capturedArgs = args; + return Task.CompletedTask; + }; + + // Create a regular message without chunking metadata + var payload = Encoding.UTF8.GetBytes("Regular non-chunked message"); + var message = new MqttApplicationMessage("test/topic") + { + Payload = new ReadOnlySequence(payload) + }; + var receivedArgs = new MqttApplicationMessageReceivedEventArgs("client1", message, 1, (_, _) => Task.CompletedTask); + + // Act + // Simulate receiving a message from the inner client + await mockInnerClient.RaiseAsync(m => m.ApplicationMessageReceivedAsync += null, receivedArgs); + + // Assert + Assert.True(handlerCalled); + Assert.Same(receivedArgs, capturedArgs); + } + + [Fact] + public async Task HandleApplicationMessageReceivedAsync_ChunkedMessage_ReassemblesBeforeDelivering() + { + // Arrange + var mockInnerClient = new Mock(); + var handlerCalled = false; + var capturedArgs = default(MqttApplicationMessageReceivedEventArgs); + + var maxPacketSize = 1000; + var connectResult = new MqttClientConnectResult + { + IsSessionPresent = true, + ResultCode = MqttClientConnectResultCode.Success, + MaximumPacketSize = (uint)maxPacketSize, + MaximumQoS = MqttQualityOfServiceLevel.AtLeastOnce, + UserProperties = new List() + }; + mockInnerClient.Setup(c => c.GetConnectResult()).Returns(connectResult); + + var client = new ChunkingMqttPubSubClient(mockInnerClient.Object); + client.ApplicationMessageReceivedAsync += args => + { + handlerCalled = true; + capturedArgs = args; + return Task.CompletedTask; + }; + + // Create message ID and checksum + var messageId = Guid.NewGuid().ToString("D"); + var fullMessage = "This is a complete message after reassembly"; + var fullPayload = Encoding.UTF8.GetBytes(fullMessage); + var checksum = ChecksumCalculator.CalculateChecksum(new ReadOnlySequence(fullPayload), ChunkingChecksumAlgorithm.SHA256); + + // Create a chunked message with 2 parts + var chunk1Text = "This is a complete "; + var chunk2Text = "message after reassembly"; + + // Create first chunk with metadata + var chunk1 = CreateChunkedMessage("test/topic", chunk1Text, messageId, 0, 2, checksum); + + // Create second chunk with metadata + var chunk2 = CreateChunkedMessage("test/topic", chunk2Text, messageId, 1); + var receivedArgs1 = new MqttApplicationMessageReceivedEventArgs("client1", chunk1, 1, (_, _) => Task.CompletedTask); + var receivedArgs2 = new MqttApplicationMessageReceivedEventArgs("client1", chunk2, 2, (_, _) => Task.CompletedTask); + + // Act + // Simulate receiving chunks from the inner client + await mockInnerClient.RaiseAsync(m => m.ApplicationMessageReceivedAsync += null, receivedArgs1); + await mockInnerClient.RaiseAsync(m => m.ApplicationMessageReceivedAsync += null, receivedArgs2); + + // Assert + Assert.True(handlerCalled); + Assert.NotNull(capturedArgs); + + Assert.Equal(fullPayload, capturedArgs!.ApplicationMessage.Payload.ToArray()); + + // Verify chunk metadata was removed + Assert.DoesNotContain( + capturedArgs.ApplicationMessage.UserProperties ?? Enumerable.Empty(), + p => p.Name == ChunkingConstants.ChunkUserProperty); + } + + // Helper method to create a chunked message with metadata + private static MqttApplicationMessage CreateChunkedMessage( + string topic, + string payloadText, + string messageId, + int chunkIndex, + int? totalChunks = null, + string? checksum = null) + { + // Create chunk metadata + Dictionary metadata = new() + { + { ChunkingConstants.MessageIdField, messageId }, + { ChunkingConstants.ChunkIndexField, chunkIndex } + }; + + // Add totalChunks and checksum for first chunk + if (totalChunks.HasValue) + { + metadata.Add(ChunkingConstants.TotalChunksField, totalChunks.Value); + } + + if (checksum != null) + { + metadata.Add(ChunkingConstants.ChecksumField, checksum); + } + + // Serialize metadata + var metadataJson = JsonSerializer.Serialize(metadata); + + // Create payload + var payload = Encoding.UTF8.GetBytes(payloadText); + + // Create message + var message = new MqttApplicationMessage(topic) + { + Payload = new ReadOnlySequence(payload), + UserProperties = new List + { + new(ChunkingConstants.ChunkUserProperty, metadataJson) + } + }; + return message; + } +}