From 897c551d8cb9ef0e6d47380f4a0460c108730b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 7 Nov 2024 21:37:36 +0800 Subject: [PATCH 01/85] Use IConnectionFactory to create ConnectionContext to replace SocketConnection --- Samples/MQTTnet.Samples.csproj | 2 +- .../MQTTnet.AspTestApp.csproj | 2 +- .../MQTTnet.AspNetCore.csproj | 2 +- .../MqttClientConnectionContextFactory.cs | 39 ++- .../MqttConnectionContext.cs | 38 +-- .../MqttConnectionHandler.cs | 10 +- .../ServiceCollectionExtensions.cs | 20 +- Source/MQTTnet.AspnetCore/SocketAwaitable.cs | 77 ------ Source/MQTTnet.AspnetCore/SocketConnection.cs | 261 ------------------ Source/MQTTnet.AspnetCore/SocketReceiver.cs | 36 --- Source/MQTTnet.AspnetCore/SocketSender.cs | 93 ------- .../MQTTnet.Benchmarks/AsyncLockBenchmark.cs | 2 +- Source/MQTTnet.Benchmarks/LoggerBenchmark.cs | 2 +- .../MQTTnet.Benchmarks.csproj | 2 +- .../MQTTnet.Benchmarks/MemoryCopyBenchmark.cs | 2 +- .../MessageProcessingBenchmark.cs | 2 +- ...rocessingMqttConnectionContextBenchmark.cs | 23 +- .../MqttBufferReaderBenchmark.cs | 2 +- .../MqttPacketReaderWriterBenchmark.cs | 2 +- .../MqttTcpChannelBenchmark.cs | 2 +- .../ReaderExtensionsBenchmark.cs | 2 +- .../RoundtripProcessingBenchmark.cs | 2 +- .../SendPacketAsyncBenchmark.cs | 2 +- .../MQTTnet.Benchmarks/SerializerBenchmark.cs | 2 +- .../ServerProcessingBenchmark.cs | 2 +- .../MQTTnet.Benchmarks/TcpPipesBenchmark.cs | 35 ++- .../TopicFilterComparerBenchmark.cs | 2 +- .../MQTTnet.Extensions.Rpc.csproj | 2 +- Source/MQTTnet.Server/MQTTnet.Server.csproj | 2 +- Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj | 2 +- Source/MQTTnet.Tests/MQTTnet.Tests.csproj | 2 +- .../Adapter/IMqttClientAdapterFactory.cs | 3 +- .../MqttClientAdapterFactory.cs | 10 +- .../LowLevelClient/LowLevelMqttClient.cs | 2 +- Source/MQTTnet/MQTTnet.csproj | 2 +- Source/MQTTnet/MqttClient.cs | 114 ++++---- 36 files changed, 199 insertions(+), 606 deletions(-) delete mode 100644 Source/MQTTnet.AspnetCore/SocketAwaitable.cs delete mode 100644 Source/MQTTnet.AspnetCore/SocketConnection.cs delete mode 100644 Source/MQTTnet.AspnetCore/SocketReceiver.cs delete mode 100644 Source/MQTTnet.AspnetCore/SocketSender.cs diff --git a/Samples/MQTTnet.Samples.csproj b/Samples/MQTTnet.Samples.csproj index 5fe84d380..2441f1c9c 100644 --- a/Samples/MQTTnet.Samples.csproj +++ b/Samples/MQTTnet.Samples.csproj @@ -14,7 +14,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj b/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj index 254069b60..ede29a49f 100644 --- a/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj +++ b/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj @@ -13,7 +13,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 5357dd702..6b7f60ce2 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -37,7 +37,7 @@ true low low - latest-Recommended + diff --git a/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs b/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs index 0ddbbd8f1..a8405cefa 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs +++ b/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs @@ -2,16 +2,28 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; +using Microsoft.AspNetCore.Connections; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; +using System; +using System.Linq; +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks; namespace MQTTnet.AspNetCore { public sealed class MqttClientConnectionContextFactory : IMqttClientAdapterFactory { - public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) + private readonly IConnectionFactory connectionFactory; + + public MqttClientConnectionContextFactory(IConnectionFactory connectionFactory) + { + this.connectionFactory = connectionFactory; + } + + public async ValueTask CreateClientAdapterAsync(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -19,8 +31,8 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa { case MqttClientTcpOptions tcpOptions: { - var tcpConnection = new SocketConnection(tcpOptions.RemoteEndpoint); - + var endPoint = await CreateIPEndPointAsync(tcpOptions.RemoteEndpoint); + var tcpConnection = await connectionFactory.ConnectAsync(endPoint); var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttBufferWriter(4096, 65535)); return new MqttConnectionContext(formatter, tcpConnection); } @@ -30,5 +42,24 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa } } } + + private static async ValueTask CreateIPEndPointAsync(EndPoint endpoint) + { + if (endpoint is IPEndPoint ipEndPoint) + { + return ipEndPoint; + } + + if (endpoint is DnsEndPoint dnsEndPoint) + { + var hostEntry = await Dns.GetHostEntryAsync(dnsEndPoint.Host); + var address = hostEntry.AddressList.OrderBy(item => item.AddressFamily).FirstOrDefault(); + return address == null + ? throw new SocketException((int)SocketError.HostNotFound) + : new IPEndPoint(address, dnsEndPoint.Port); + } + + throw new NotSupportedException("Only supports IPEndPoint or DnsEndPoint for now."); + } } } diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs index c16e5f483..b7b1c775d 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs @@ -2,13 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Buffers; -using System.IO.Pipelines; -using System.Net; -using System.Security.Cryptography.X509Certificates; -using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Features; @@ -17,6 +10,13 @@ using MQTTnet.Formatter; using MQTTnet.Internal; using MQTTnet.Packets; +using System; +using System.Buffers; +using System.IO.Pipelines; +using System.Net; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet.AspNetCore; @@ -25,19 +25,15 @@ public sealed class MqttConnectionContext : IMqttChannelAdapter readonly ConnectionContext _connection; readonly AsyncLock _writerLock = new(); - PipeReader _input; - PipeWriter _output; + readonly PipeReader _input; + readonly PipeWriter _output; public MqttConnectionContext(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) { PacketFormatterAdapter = packetFormatterAdapter ?? throw new ArgumentNullException(nameof(packetFormatterAdapter)); _connection = connection ?? throw new ArgumentNullException(nameof(connection)); - - if (!(_connection is SocketConnection tcp) || tcp.IsConnected) - { - _input = connection.Transport.Input; - _output = connection.Transport.Output; - } + _input = connection.Transport.Input; + _output = connection.Transport.Output; } public long BytesReceived { get; private set; } @@ -106,15 +102,9 @@ public bool IsSecureConnection public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } - public async Task ConnectAsync(CancellationToken cancellationToken) + public Task ConnectAsync(CancellationToken cancellationToken) { - if (_connection is SocketConnection tcp && !tcp.IsConnected) - { - await tcp.StartAsync().ConfigureAwait(false); - } - - _input = _connection.Transport.Input; - _output = _connection.Transport.Output; + return Task.CompletedTask; } public Task DisconnectAsync(CancellationToken cancellationToken) @@ -126,7 +116,7 @@ public Task DisconnectAsync(CancellationToken cancellationToken) } public void Dispose() - { + { _writerLock.Dispose(); } diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs index b4cbc42a8..dfeb2f3f7 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs @@ -33,13 +33,11 @@ public override async Task OnConnectedAsync(ConnectionContext connection) } var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); - using (var adapter = new MqttConnectionContext(formatter, connection)) + using var adapter = new MqttConnectionContext(formatter, connection); + var clientHandler = ClientHandler; + if (clientHandler != null) { - var clientHandler = ClientHandler; - if (clientHandler != null) - { - await clientHandler(adapter).ConfigureAwait(false); - } + await clientHandler(adapter).ConfigureAwait(false); } } diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 915f6791c..bcb8b38e7 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -2,18 +2,25 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; +using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; +using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Server; using MQTTnet.Server.Internal.Adapter; +using System; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; namespace MQTTnet.AspNetCore; public static class ServiceCollectionExtensions { + const string SocketConnectionFactoryTypeName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketConnectionFactory"; + const string SocketConnectionFactoryAssemblyName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"; + public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, MqttServerOptions options) { ArgumentNullException.ThrowIfNull(services); @@ -106,4 +113,15 @@ public static IServiceCollection AddMqttWebSocketServerAdapter(this IServiceColl return services; } + + + [DynamicDependency(DynamicallyAccessedMemberTypes.All, SocketConnectionFactoryTypeName, SocketConnectionFactoryAssemblyName)] + public static IServiceCollection AddMqttClientConnectionContextFactory(this IServiceCollection services) + { + var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); + services.AddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); + services.TryAddSingleton(); + services.TryAddSingleton(serviceProvider => serviceProvider.GetRequiredService()); + return services; + } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketAwaitable.cs b/Source/MQTTnet.AspnetCore/SocketAwaitable.cs deleted file mode 100644 index 2c9607279..000000000 --- a/Source/MQTTnet.AspnetCore/SocketAwaitable.cs +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace MQTTnet.AspNetCore; - -public class SocketAwaitable : ICriticalNotifyCompletion -{ - static readonly Action _callbackCompleted = () => - { - }; - - readonly PipeScheduler _ioScheduler; - int _bytesTransferred; - - Action _callback; - SocketError _error; - - public SocketAwaitable(PipeScheduler ioScheduler) - { - _ioScheduler = ioScheduler; - } - - public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); - - public void Complete(int bytesTransferred, SocketError socketError) - { - _error = socketError; - _bytesTransferred = bytesTransferred; - var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); - - if (continuation != null) - { - _ioScheduler.Schedule(state => ((Action)state)(), continuation); - } - } - - public SocketAwaitable GetAwaiter() - { - return this; - } - - public int GetResult() - { - Debug.Assert(ReferenceEquals(_callback, _callbackCompleted)); - - _callback = null; - - if (_error != SocketError.Success) - { - throw new SocketException((int)_error); - } - - return _bytesTransferred; - } - - public void OnCompleted(Action continuation) - { - if (ReferenceEquals(_callback, _callbackCompleted) || ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) - { - Task.Run(continuation); - } - } - - public void UnsafeOnCompleted(Action continuation) - { - OnCompleted(continuation); - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketConnection.cs b/Source/MQTTnet.AspnetCore/SocketConnection.cs deleted file mode 100644 index 2021eccac..000000000 --- a/Source/MQTTnet.AspnetCore/SocketConnection.cs +++ /dev/null @@ -1,261 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.IO; -using System.IO.Pipelines; -using System.Net; -using System.Net.Sockets; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Http.Features; -using MQTTnet.Exceptions; - -namespace MQTTnet.AspNetCore; - -public sealed class SocketConnection : ConnectionContext -{ - readonly EndPoint _endPoint; - volatile bool _aborted; - IDuplexPipe _application; - SocketReceiver _receiver; - SocketSender _sender; - - Socket _socket; - - public SocketConnection(EndPoint endPoint) - { - _endPoint = endPoint; - } - - public SocketConnection(Socket socket) - { - _socket = socket; - _endPoint = socket.RemoteEndPoint; - - _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); - _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); - } - - public override string ConnectionId { get; set; } - public override IFeatureCollection Features { get; } - - public bool IsConnected { get; private set; } - public override IDictionary Items { get; set; } - public override IDuplexPipe Transport { get; set; } - - public override ValueTask DisposeAsync() - { - IsConnected = false; - - Transport?.Output.Complete(); - Transport?.Input.Complete(); - - _socket?.Dispose(); - - return base.DisposeAsync(); - } - - public async Task StartAsync() - { - if (_socket == null) - { - _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); - _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); - await _socket.ConnectAsync(_endPoint); - } - - var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); - - Transport = pair.Transport; - _application = pair.Application; - - _ = ExecuteAsync(); - - IsConnected = true; - } - - Exception ConnectionAborted() - { - return new MqttCommunicationException("Connection Aborted"); - } - - async Task DoReceive() - { - Exception error = null; - - try - { - await ProcessReceives(); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.ConnectionReset) - { - error = new MqttCommunicationException(ex); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted || ex.SocketErrorCode == SocketError.ConnectionAborted || - ex.SocketErrorCode == SocketError.Interrupted || ex.SocketErrorCode == SocketError.InvalidArgument) - { - if (!_aborted) - { - // Calling Dispose after ReceiveAsync can cause an "InvalidArgument" error on *nix. - error = ConnectionAborted(); - } - } - catch (ObjectDisposedException) - { - if (!_aborted) - { - error = ConnectionAborted(); - } - } - catch (IOException ex) - { - error = ex; - } - catch (Exception ex) - { - error = new IOException(ex.Message, ex); - } - finally - { - if (_aborted) - { - error = error ?? ConnectionAborted(); - } - - _application.Output.Complete(error); - } - } - - async Task DoSend() - { - Exception error = null; - - try - { - await ProcessSends(); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted) - { - } - catch (ObjectDisposedException) - { - } - catch (IOException ex) - { - error = ex; - } - catch (Exception ex) - { - error = new IOException(ex.Message, ex); - } - finally - { - _aborted = true; - _socket.Shutdown(SocketShutdown.Both); - } - - return error; - } - - async Task ExecuteAsync() - { - Exception sendError = null; - try - { - // Spawn send and receive logic - var receiveTask = DoReceive(); - var sendTask = DoSend(); - - // If the sending task completes then close the receive - // We don't need to do this in the other direction because the kestrel - // will trigger the output closing once the input is complete. - if (await Task.WhenAny(receiveTask, sendTask).ConfigureAwait(false) == sendTask) - { - // Tell the reader it's being aborted - _socket.Dispose(); - } - - // Now wait for both to complete - await receiveTask; - sendError = await sendTask; - - // Dispose the socket(should noop if already called) - _socket.Dispose(); - } - catch (Exception ex) - { - Console.WriteLine($"Unexpected exception in {nameof(SocketConnection)}.{nameof(StartAsync)}: " + ex); - } - finally - { - // Complete the output after disposing the socket - await _application.Input.CompleteAsync(sendError).ConfigureAwait(false); - } - } - - async Task ProcessReceives() - { - while (true) - { - // Ensure we have some reasonable amount of buffer space - var buffer = _application.Output.GetMemory(); - - var bytesReceived = await _receiver.ReceiveAsync(buffer); - - if (bytesReceived == 0) - { - // FIN - break; - } - - _application.Output.Advance(bytesReceived); - - var flushTask = _application.Output.FlushAsync(); - - if (!flushTask.IsCompleted) - { - await flushTask; - } - - var result = flushTask.GetAwaiter().GetResult(); - if (result.IsCompleted) - { - // Pipe consumer is shut down, do we stop writing - break; - } - } - } - - async Task ProcessSends() - { - while (true) - { - // Wait for data to write from the pipe producer - var result = await _application.Input.ReadAsync(); - var buffer = result.Buffer; - - if (result.IsCanceled) - { - break; - } - - var end = buffer.End; - var isCompleted = result.IsCompleted; - if (!buffer.IsEmpty) - { - await _sender.SendAsync(buffer); - } - - _application.Input.AdvanceTo(end); - - if (isCompleted) - { - break; - } - } - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketReceiver.cs b/Source/MQTTnet.AspnetCore/SocketReceiver.cs deleted file mode 100644 index f8b628fb5..000000000 --- a/Source/MQTTnet.AspnetCore/SocketReceiver.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.IO.Pipelines; -using System.Net.Sockets; - -namespace MQTTnet.AspNetCore; - -public sealed class SocketReceiver -{ - readonly SocketAwaitable _awaitable; - readonly SocketAsyncEventArgs _eventArgs = new(); - readonly Socket _socket; - - public SocketReceiver(Socket socket, PipeScheduler scheduler) - { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); - } - - public SocketAwaitable ReceiveAsync(Memory buffer) - { - _eventArgs.SetBuffer(buffer); - - if (!_socket.ReceiveAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } - - return _awaitable; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketSender.cs b/Source/MQTTnet.AspnetCore/SocketSender.cs deleted file mode 100644 index fc06ea6cf..000000000 --- a/Source/MQTTnet.AspnetCore/SocketSender.cs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Buffers; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Net.Sockets; -using System.Runtime.InteropServices; - -namespace MQTTnet.AspNetCore; - -public sealed class SocketSender -{ - readonly SocketAwaitable _awaitable; - readonly SocketAsyncEventArgs _eventArgs = new(); - readonly Socket _socket; - - List> _bufferList; - - public SocketSender(Socket socket, PipeScheduler scheduler) - { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); - } - - public SocketAwaitable SendAsync(in ReadOnlySequence buffers) - { - if (buffers.IsSingleSegment) - { - return SendAsync(buffers.First); - } - - if (!_eventArgs.MemoryBuffer.Equals(Memory.Empty)) - { - _eventArgs.SetBuffer(null, 0, 0); - } - - _eventArgs.BufferList = GetBufferList(buffers); - - if (!_socket.SendAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } - - return _awaitable; - } - - List> GetBufferList(in ReadOnlySequence buffer) - { - Debug.Assert(!buffer.IsEmpty); - Debug.Assert(!buffer.IsSingleSegment); - - if (_bufferList == null) - { - _bufferList = new List>(); - } - else - { - // Buffers are pooled, so it's OK to root them until the next multi-buffer write. - _bufferList.Clear(); - } - - foreach (var b in buffer) - { - _bufferList.Add(b.GetArray()); - } - - return _bufferList; - } - - SocketAwaitable SendAsync(ReadOnlyMemory memory) - { - // The BufferList getter is much less expensive then the setter. - if (_eventArgs.BufferList != null) - { - _eventArgs.BufferList = null; - } - - _eventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); - - if (!_socket.SendAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } - - return _awaitable; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs b/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs index 27a348d4f..96b40c568 100644 --- a/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class AsyncLockBenchmark : BaseBenchmark { diff --git a/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs b/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs index 6b8b5e410..45e536e74 100644 --- a/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs @@ -8,7 +8,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [MemoryDiagnoser] public class LoggerBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj index d5e1f11f9..286e2b6ba 100644 --- a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj +++ b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -14,7 +14,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs b/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs index 0733e2bb9..012f8d847 100644 --- a/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class MemoryCopyBenchmark diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs index 894ef19e5..cca6e4804 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs @@ -8,7 +8,7 @@ namespace MQTTnet.Benchmarks; -[SimpleJob(RuntimeMoniker.Net60)] +[SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [RankColumn] [MemoryDiagnoser] diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs index b22d365e8..1225b1c06 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs @@ -6,12 +6,13 @@ using BenchmarkDotNet.Jobs; using Microsoft.AspNetCore; using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; using MQTTnet.AspNetCore; using MQTTnet.Diagnostics.Logger; namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class MessageProcessingMqttConnectionContextBenchmark : BaseBenchmark { @@ -24,20 +25,26 @@ public void Setup() { _host = WebHost.CreateDefaultBuilder() .UseKestrel(o => o.ListenAnyIP(1883, l => l.UseMqtt())) - .ConfigureServices(services => { - services - .AddHostedMqttServer(mqttServerOptions => mqttServerOptions.WithoutDefaultEndpoint()) - .AddMqttConnectionHandler(); + .ConfigureServices(services => + { + services + .AddHostedMqttServer(mqttServerOptions => mqttServerOptions.WithoutDefaultEndpoint()) + .AddMqttConnectionHandler() + .AddMqttClientConnectionContextFactory(); }) - .Configure(app => { - app.UseMqttServer(s => { + .Configure(app => + { + app.UseMqttServer(s => + { }); }) .Build(); + var factory = new MqttClientFactory(); - _mqttClient = factory.CreateMqttClient(new MqttNetEventLogger(), new MqttClientConnectionContextFactory()); + var mqttClientConnectionContextFactory = _host.Services.GetRequiredService(); + _mqttClient = factory.CreateMqttClient(new MqttNetEventLogger(), mqttClientConnectionContextFactory); _host.StartAsync().GetAwaiter().GetResult(); diff --git a/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs index bfa3d209c..c8529535c 100644 --- a/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs @@ -10,7 +10,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class MqttBufferReaderBenchmark { diff --git a/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs index 0efc7ffac..4190d4bfb 100644 --- a/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs @@ -10,7 +10,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class MqttPacketReaderWriterBenchmark : BaseBenchmark { diff --git a/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs index 613647471..95257ef21 100644 --- a/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs @@ -17,7 +17,7 @@ namespace MQTTnet.Benchmarks; -[SimpleJob(RuntimeMoniker.Net60)] +[SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class MqttTcpChannelBenchmark : BaseBenchmark { diff --git a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs index 5f2242461..debc48aa2 100644 --- a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs @@ -14,7 +14,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class ReaderExtensionsBenchmark diff --git a/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs b/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs index e3358fb91..69fedda41 100644 --- a/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class RoundtripProcessingBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs b/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs index b31782e66..a4e7d05d6 100644 --- a/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class SendPacketAsyncBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs index 0ddea15f1..68dc3568a 100644 --- a/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -18,7 +18,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [MemoryDiagnoser] public class SerializerBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs b/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs index fbac6dc02..f2e582af7 100644 --- a/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class ServerProcessingBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs b/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs index 7692f78b3..4111be6e8 100644 --- a/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; using System.IO.Pipelines; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Jobs; -using MQTTnet.AspNetCore; namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class TcpPipesBenchmark : BaseBenchmark { @@ -29,17 +28,17 @@ public void Setup() var task = Task.Run(() => server.AcceptSocket()); - var clientConnection = new SocketConnection(new IPEndPoint(IPAddress.Loopback, 1883)); + var clientConnection = new Socket(SocketType.Stream, ProtocolType.Tcp); + clientConnection.Connect(new IPEndPoint(IPAddress.Loopback, 1883)); + _client = new SocketDuplexPipe(clientConnection); - clientConnection.StartAsync().GetAwaiter().GetResult(); - _client = clientConnection.Transport; - - var serverConnection = new SocketConnection(task.GetAwaiter().GetResult()); - serverConnection.StartAsync().GetAwaiter().GetResult(); - _server = serverConnection.Transport; + var serverConnection =task.GetAwaiter().GetResult(); + _server = new SocketDuplexPipe(serverConnection); } + + [Benchmark] public async Task Send_10000_Chunks_Pipe() { @@ -76,5 +75,19 @@ async Task WriteAsync(int iterations, int size) await output.WriteAsync(new byte[size], CancellationToken.None).ConfigureAwait(false); } } + + private class SocketDuplexPipe : IDuplexPipe + { + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public SocketDuplexPipe(Socket socket) + { + var stream = new NetworkStream(socket); + this.Input = PipeReader.Create(stream); + this.Output = PipeWriter.Create(stream); + } + } } } diff --git a/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs b/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs index f78cb81c3..be8f2e7c8 100644 --- a/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs @@ -8,7 +8,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [MemoryDiagnoser] public class TopicFilterComparerBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj index b38fb489d..f4353d21a 100644 --- a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj +++ b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj @@ -35,7 +35,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Server/MQTTnet.Server.csproj b/Source/MQTTnet.Server/MQTTnet.Server.csproj index df4863607..094fb10f8 100644 --- a/Source/MQTTnet.Server/MQTTnet.Server.csproj +++ b/Source/MQTTnet.Server/MQTTnet.Server.csproj @@ -36,7 +36,7 @@ low enable disable - latest-Recommended + diff --git a/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj b/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj index 374ded794..dec52e679 100644 --- a/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj +++ b/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj @@ -13,7 +13,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj index b29190740..c13308c4c 100644 --- a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj +++ b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj @@ -11,7 +11,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs b/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs index 3ea49381b..6c181e4c6 100644 --- a/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs +++ b/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs @@ -3,10 +3,11 @@ // See the LICENSE file in the project root for more information. using MQTTnet.Diagnostics.Logger; +using System.Threading.Tasks; namespace MQTTnet.Adapter; public interface IMqttClientAdapterFactory { - IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger); + ValueTask CreateClientAdapterAsync(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger); } \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs b/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs index 0a4031f31..16d1dd9d5 100644 --- a/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs +++ b/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs @@ -3,16 +3,17 @@ // See the LICENSE file in the project root for more information. using MQTTnet.Adapter; -using MQTTnet.Formatter; -using System; using MQTTnet.Channel; using MQTTnet.Diagnostics.Logger; +using MQTTnet.Formatter; +using System; +using System.Threading.Tasks; namespace MQTTnet.Implementations { public sealed class MqttClientAdapterFactory : IMqttClientAdapterFactory { - public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) + public ValueTask CreateClientAdapterAsync(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) { ArgumentNullException.ThrowIfNull(options); @@ -40,11 +41,12 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax); var packetFormatterAdapter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter); - return new MqttChannelAdapter(channel, packetFormatterAdapter, logger) + IMqttChannelAdapter adapter = new MqttChannelAdapter(channel, packetFormatterAdapter, logger) { AllowPacketFragmentation = options.AllowPacketFragmentation, PacketInspector = packetInspector }; + return ValueTask.FromResult(adapter); } } } diff --git a/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs b/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs index 186e8bfd1..f72fc7ed7 100644 --- a/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs +++ b/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs @@ -54,7 +54,7 @@ public async Task ConnectAsync(MqttClientOptions options, CancellationToken canc packetInspector = new MqttPacketInspector(_inspectPacketEvent, _rootLogger); } - var newAdapter = _clientAdapterFactory.CreateClientAdapter(options, packetInspector, _rootLogger); + var newAdapter = await _clientAdapterFactory.CreateClientAdapterAsync(options, packetInspector, _rootLogger); try { diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index ece6812a0..7c7a93f19 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -44,7 +44,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet/MqttClient.cs b/Source/MQTTnet/MqttClient.cs index 9d19ce574..fe8b50742 100644 --- a/Source/MQTTnet/MqttClient.cs +++ b/Source/MQTTnet/MqttClient.cs @@ -2,10 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Diagnostics.PacketInspection; @@ -15,6 +11,10 @@ using MQTTnet.PacketDispatcher; using MQTTnet.Packets; using MQTTnet.Protocol; +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet; @@ -121,7 +121,7 @@ public async Task ConnectAsync(MqttClientOptions option _mqttClientAlive = new CancellationTokenSource(); var mqttClientAliveToken = _mqttClientAlive.Token; - var adapter = _adapterFactory.CreateClientAdapter(options, new MqttPacketInspector(_events.InspectPacketEvent, _rootLogger), _rootLogger); + var adapter = await _adapterFactory.CreateClientAdapterAsync(options, new MqttPacketInspector(_events.InspectPacketEvent, _rootLogger), _rootLogger); _adapter = adapter ?? throw new InvalidOperationException("The adapter factory did not provide an adapter."); _unexpectedDisconnectPacket = null; @@ -270,21 +270,21 @@ public Task PublishAsync(MqttApplicationMessage applica switch (applicationMessage.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: - { - return PublishAtMostOnce(publishPacket, cancellationToken); - } + { + return PublishAtMostOnce(publishPacket, cancellationToken); + } case MqttQualityOfServiceLevel.AtLeastOnce: - { - return PublishAtLeastOnce(publishPacket, cancellationToken); - } + { + return PublishAtLeastOnce(publishPacket, cancellationToken); + } case MqttQualityOfServiceLevel.ExactlyOnce: - { - return PublishExactlyOnce(publishPacket, cancellationToken); - } + { + return PublishExactlyOnce(publishPacket, cancellationToken); + } default: - { - throw new NotSupportedException(); - } + { + throw new NotSupportedException(); + } } } @@ -395,34 +395,34 @@ Task AcknowledgeReceivedPublishPacket(MqttApplicationMessageReceivedEventArgs ev switch (eventArgs.PublishPacket.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: - { - // no response required - break; - } - case MqttQualityOfServiceLevel.AtLeastOnce: - { - if (!eventArgs.ProcessingFailed) { - var pubAckPacket = MqttPubAckPacketFactory.Create(eventArgs); - return Send(pubAckPacket, cancellationToken); + // no response required + break; } + case MqttQualityOfServiceLevel.AtLeastOnce: + { + if (!eventArgs.ProcessingFailed) + { + var pubAckPacket = MqttPubAckPacketFactory.Create(eventArgs); + return Send(pubAckPacket, cancellationToken); + } - break; - } + break; + } case MqttQualityOfServiceLevel.ExactlyOnce: - { - if (!eventArgs.ProcessingFailed) { - var pubRecPacket = MqttPubRecPacketFactory.Create(eventArgs); - return Send(pubRecPacket, cancellationToken); - } + if (!eventArgs.ProcessingFailed) + { + var pubRecPacket = MqttPubRecPacketFactory.Create(eventArgs); + return Send(pubRecPacket, cancellationToken); + } - break; - } + break; + } default: - { - throw new MqttProtocolViolationException("Received a not supported QoS level."); - } + { + throw new MqttProtocolViolationException("Received a not supported QoS level."); + } } return CompletedTask.Instance; @@ -442,22 +442,22 @@ async Task Authenticate(IMqttChannelAdapter channelAdap switch (receivedPacket) { case MqttConnAckPacket connAckPacket: - { - result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion); - break; - } + { + result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion); + break; + } case MqttAuthPacket _: - { - throw new NotSupportedException("Extended authentication handler is not yet supported"); - } + { + throw new NotSupportedException("Extended authentication handler is not yet supported"); + } case null: - { - throw new MqttCommunicationException("Connection closed."); - } + { + throw new MqttCommunicationException("Connection closed."); + } default: - { - throw new InvalidOperationException($"Received an unexpected MQTT packet ({receivedPacket})."); - } + { + throw new InvalidOperationException($"Received an unexpected MQTT packet ({receivedPacket})."); + } } } catch (Exception exception) @@ -967,14 +967,14 @@ async Task TryProcessReceivedPacket(MqttPacket packet, CancellationToken cancell case MqttPingReqPacket _: throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a client to the server only."); default: - { - if (!_packetDispatcher.TryDispatch(packet)) { - throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); - } + if (!_packetDispatcher.TryDispatch(packet)) + { + throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + } - break; - } + break; + } } } catch (Exception exception) From f38cca37ba404eb7591f837b6e1f337287446250 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Tue, 12 Nov 2024 23:42:28 +0800 Subject: [PATCH 02/85] Refactoring AspNetMQTT --- Samples/Server/Server_ASP_NET_Samples.cs | 23 ++--- Source/MQTTnet.AspTestApp/Program.cs | 12 ++- Source/MQTTnet.AspTestApp/appsettings.json | 10 ++ .../ApplicationBuilderExtensions.cs | 40 +------- .../AspNetMqttServerOptionsBuilder.cs | 18 ---- Source/MQTTnet.AspnetCore/BufferExtensions.cs | 21 ---- .../ConnectionBuilderExtensions.cs | 4 +- Source/MQTTnet.AspnetCore/DuplexPipe.cs | 44 --------- .../EndpointRouteBuilderExtensions.cs | 48 +++++++++ .../EndpointRouterExtensions.cs | 24 ----- .../AspNetCoreMqttChannelAdapter.cs} | 6 +- .../AspNetCoreMqttClientAdapterFactory.cs} | 6 +- .../Internal/AspNetCoreMqttHostedServer.cs | 26 +++++ .../Internal/AspNetCoreMqttServer.cs | 30 ++++++ .../{ => Internal}/MqttConnectionHandler.cs | 8 +- .../MqttPacketFormatterAdapterExtensions.cs} | 8 +- Source/MQTTnet.AspnetCore/InternalsVisible.cs | 3 + .../MQTTnet.AspNetCore.csproj | 10 +- Source/MQTTnet.AspnetCore/MqttHostedServer.cs | 48 --------- .../MqttSubProtocolSelector.cs | 34 ------- .../MqttWebSocketServerAdapter.cs | 66 ------------- .../ServiceCollectionExtensions.cs | 99 ++++--------------- ...rocessingMqttConnectionContextBenchmark.cs | 12 ++- .../ReaderExtensionsBenchmark.cs | 2 +- .../ASP/Mockups/ConnectionHandlerMockup.cs | 4 +- .../ASP/MqttConnectionContextTest.cs | 6 +- 26 files changed, 189 insertions(+), 423 deletions(-) delete mode 100644 Source/MQTTnet.AspnetCore/AspNetMqttServerOptionsBuilder.cs delete mode 100644 Source/MQTTnet.AspnetCore/BufferExtensions.cs delete mode 100644 Source/MQTTnet.AspnetCore/DuplexPipe.cs create mode 100644 Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs delete mode 100644 Source/MQTTnet.AspnetCore/EndpointRouterExtensions.cs rename Source/MQTTnet.AspnetCore/{MqttConnectionContext.cs => Internal/AspNetCoreMqttChannelAdapter.cs} (97%) rename Source/MQTTnet.AspnetCore/{MqttClientConnectionContextFactory.cs => Internal/AspNetCoreMqttClientAdapterFactory.cs} (90%) create mode 100644 Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs create mode 100644 Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs rename Source/MQTTnet.AspnetCore/{ => Internal}/MqttConnectionHandler.cs (90%) rename Source/MQTTnet.AspnetCore/{ReaderExtensions.cs => Internal/MqttPacketFormatterAdapterExtensions.cs} (98%) create mode 100644 Source/MQTTnet.AspnetCore/InternalsVisible.cs delete mode 100644 Source/MQTTnet.AspnetCore/MqttHostedServer.cs delete mode 100644 Source/MQTTnet.AspnetCore/MqttSubProtocolSelector.cs delete mode 100644 Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 9247093e2..3466a8b95 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -9,6 +9,7 @@ // ReSharper disable MemberCanBeMadeStatic.Local using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -71,15 +72,7 @@ sealed class Startup public void Configure(IApplicationBuilder app, IWebHostEnvironment environment, MqttController mqttController) { app.UseRouting(); - - app.UseEndpoints( - endpoints => - { - endpoints.MapConnectionHandler( - "/mqtt", - httpConnectionDispatcherOptions => httpConnectionDispatcherOptions.WebSockets.SubProtocolSelector = - protocolList => protocolList.FirstOrDefault() ?? string.Empty); - }); + app.UseEndpoints(endpoints => endpoints.MapMqtt("/mqtt")); app.UseMqttServer( server => @@ -95,14 +88,10 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment environment, public void ConfigureServices(IServiceCollection services) { - services.AddHostedMqttServer( - optionsBuilder => - { - optionsBuilder.WithDefaultEndpoint(); - }); - - services.AddMqttConnectionHandler(); - services.AddConnections(); + services.AddMqttServer(optionsBuilder => + { + optionsBuilder.WithDefaultEndpoint(); + }); services.AddSingleton(); } diff --git a/Source/MQTTnet.AspTestApp/Program.cs b/Source/MQTTnet.AspTestApp/Program.cs index 317c5f7eb..42c67ef06 100644 --- a/Source/MQTTnet.AspTestApp/Program.cs +++ b/Source/MQTTnet.AspTestApp/Program.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.AspNetCore.Connections; using MQTTnet; using MQTTnet.AspNetCore; using MQTTnet.Server; @@ -12,10 +13,15 @@ // Setup MQTT stuff. builder.Services.AddMqttServer(); -builder.Services.AddConnections(); -var app = builder.Build(); +// UseMqttEndPoint +builder.WebHost.ConfigureKestrel((context, serverOptions) => +{ + var kestrelSection = context.Configuration.GetSection("Kestrel"); + serverOptions.Configure(kestrelSection).Endpoint("Mqtt", mqtt => mqtt.ListenOptions.UseMqtt()); +}); +var app = builder.Build(); if (!app.Environment.IsDevelopment()) { app.UseExceptionHandler("/Error"); @@ -29,7 +35,7 @@ app.MapRazorPages(); -// Setup MQTT stuff. +// mqtt over websocket app.MapMqtt("/mqtt"); app.UseMqttServer(server => diff --git a/Source/MQTTnet.AspTestApp/appsettings.json b/Source/MQTTnet.AspTestApp/appsettings.json index 10f68b8c8..9d11d79ad 100644 --- a/Source/MQTTnet.AspTestApp/appsettings.json +++ b/Source/MQTTnet.AspTestApp/appsettings.json @@ -1,4 +1,14 @@ { + "Kestrel": { + "Endpoints": { + "Http": { + "Url": "http://localhost:5000" + }, + "Mqtt": { + "Url": "http://localhost:1883" + } + } + }, "Logging": { "LogLevel": { "Default": "Information", diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index f98983da9..32835b791 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -2,52 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using MQTTnet.Server; +using System; -namespace MQTTnet.AspNetCore; +namespace Microsoft.AspNetCore.Builder; public static class ApplicationBuilderExtensions -{ - [Obsolete( - "This class is obsolete and will be removed in a future version. The recommended alternative is to use MapMqtt inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] - public static IApplicationBuilder UseMqttEndpoint(this IApplicationBuilder app, string path = "/mqtt") - { - app.UseWebSockets(); - app.Use( - async (context, next) => - { - if (!context.WebSockets.IsWebSocketRequest || context.Request.Path != path) - { - await next(); - return; - } - - string subProtocol = null; - - if (context.Request.Headers.TryGetValue("Sec-WebSocket-Protocol", out var requestedSubProtocolValues)) - { - subProtocol = MqttSubProtocolSelector.SelectSubProtocol(requestedSubProtocolValues); - } - - var adapter = app.ApplicationServices.GetRequiredService(); - using (var webSocket = await context.WebSockets.AcceptWebSocketAsync(subProtocol).ConfigureAwait(false)) - { - await adapter.RunWebSocketConnectionAsync(webSocket, context); - } - }); - - return app; - } - +{ public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app, Action configure) { var server = app.ApplicationServices.GetRequiredService(); - configure(server); - return app; } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/AspNetMqttServerOptionsBuilder.cs b/Source/MQTTnet.AspnetCore/AspNetMqttServerOptionsBuilder.cs deleted file mode 100644 index 394483959..000000000 --- a/Source/MQTTnet.AspnetCore/AspNetMqttServerOptionsBuilder.cs +++ /dev/null @@ -1,18 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore; - -public sealed class AspNetMqttServerOptionsBuilder : MqttServerOptionsBuilder -{ - public AspNetMqttServerOptionsBuilder(IServiceProvider serviceProvider) - { - ServiceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); - } - - public IServiceProvider ServiceProvider { get; } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/BufferExtensions.cs b/Source/MQTTnet.AspnetCore/BufferExtensions.cs deleted file mode 100644 index 47a5c0747..000000000 --- a/Source/MQTTnet.AspnetCore/BufferExtensions.cs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.InteropServices; - -namespace MQTTnet.AspNetCore; - -public static class BufferExtensions -{ - public static ArraySegment GetArray(this ReadOnlyMemory memory) - { - if (!MemoryMarshal.TryGetArray(memory, out var result)) - { - throw new InvalidOperationException("Buffer backed by array was expected"); - } - - return result; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index 9ea8922ea..0d4e2eba0 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -2,9 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections; +using MQTTnet.AspNetCore; -namespace MQTTnet.AspNetCore +namespace Microsoft.AspNetCore.Connections { public static class ConnectionBuilderExtensions { diff --git a/Source/MQTTnet.AspnetCore/DuplexPipe.cs b/Source/MQTTnet.AspnetCore/DuplexPipe.cs deleted file mode 100644 index 35075e800..000000000 --- a/Source/MQTTnet.AspnetCore/DuplexPipe.cs +++ /dev/null @@ -1,44 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.IO.Pipelines; - -namespace MQTTnet.AspNetCore; - -public class DuplexPipe : IDuplexPipe -{ - public DuplexPipe(PipeReader reader, PipeWriter writer) - { - Input = reader; - Output = writer; - } - - public PipeReader Input { get; } - - public PipeWriter Output { get; } - - public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - var transportToApplication = new DuplexPipe(output.Reader, input.Writer); - var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } - - // This class exists to work around issues with value tuple on .NET Framework - public readonly struct DuplexPipePair - { - public IDuplexPipe Transport { get; } - public IDuplexPipe Application { get; } - - public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) - { - Transport = transport; - Application = application; - } - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs new file mode 100644 index 000000000..1577ef49e --- /dev/null +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Routing; +using MQTTnet.AspNetCore; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.AspNetCore.Builder +{ + public static class EndpointRouteBuilderExtensions + { + /// + /// mqtt over websocket + /// + /// + /// + /// + public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, string pattern) + { + ArgumentNullException.ThrowIfNull(endpoints); + return endpoints.MapMqtt(pattern, options => options.WebSockets.SubProtocolSelector = SelectSubProtocol); + + static string SelectSubProtocol(IList requestedSubProtocolValues) + { + // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. + return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt")); + } + } + + /// + /// mqtt over websocket + /// + /// + /// + /// + /// + public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, string pattern, Action options) + { + ArgumentNullException.ThrowIfNull(endpoints); + return endpoints.MapConnectionHandler(pattern, options); + } + } +} + diff --git a/Source/MQTTnet.AspnetCore/EndpointRouterExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouterExtensions.cs deleted file mode 100644 index 8e96a6c28..000000000 --- a/Source/MQTTnet.AspnetCore/EndpointRouterExtensions.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Routing; - -namespace MQTTnet.AspNetCore -{ - public static class EndpointRouterExtensions - { - public static void MapMqtt(this IEndpointRouteBuilder endpoints, string pattern) - { - ArgumentNullException.ThrowIfNull(endpoints); - - endpoints.MapConnectionHandler(pattern, options => - { - options.WebSockets.SubProtocolSelector = MqttSubProtocolSelector.SelectSubProtocol; - }); - } - } -} - diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs similarity index 97% rename from Source/MQTTnet.AspnetCore/MqttConnectionContext.cs rename to Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs index b7b1c775d..38b9f066a 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs @@ -20,7 +20,7 @@ namespace MQTTnet.AspNetCore; -public sealed class MqttConnectionContext : IMqttChannelAdapter +sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter { readonly ConnectionContext _connection; readonly AsyncLock _writerLock = new(); @@ -28,7 +28,7 @@ public sealed class MqttConnectionContext : IMqttChannelAdapter readonly PipeReader _input; readonly PipeWriter _output; - public MqttConnectionContext(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) + public AspNetCoreMqttChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) { PacketFormatterAdapter = packetFormatterAdapter ?? throw new ArgumentNullException(nameof(packetFormatterAdapter)); _connection = connection ?? throw new ArgumentNullException(nameof(connection)); @@ -116,7 +116,7 @@ public Task DisconnectAsync(CancellationToken cancellationToken) } public void Dispose() - { + { _writerLock.Dispose(); } diff --git a/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs similarity index 90% rename from Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs rename to Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index a8405cefa..920464504 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -14,11 +14,11 @@ namespace MQTTnet.AspNetCore { - public sealed class MqttClientConnectionContextFactory : IMqttClientAdapterFactory + sealed class AspNetCoreMqttClientAdapterFactory : IMqttClientAdapterFactory { private readonly IConnectionFactory connectionFactory; - public MqttClientConnectionContextFactory(IConnectionFactory connectionFactory) + public AspNetCoreMqttClientAdapterFactory(IConnectionFactory connectionFactory) { this.connectionFactory = connectionFactory; } @@ -34,7 +34,7 @@ public async ValueTask CreateClientAdapterAsync(MqttClientO var endPoint = await CreateIPEndPointAsync(tcpOptions.RemoteEndpoint); var tcpConnection = await connectionFactory.ConnectAsync(endPoint); var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttBufferWriter(4096, 65535)); - return new MqttConnectionContext(formatter, tcpConnection); + return new AspNetCoreMqttChannelAdapter(formatter, tcpConnection); } default: { diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs new file mode 100644 index 000000000..0aeb7413b --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -0,0 +1,26 @@ +using Microsoft.Extensions.Hosting; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore.Internal +{ + sealed class AspNetCoreMqttHostedServer : IHostedService + { + private readonly AspNetCoreMqttServer _aspNetCoreMqttServer; + + public AspNetCoreMqttHostedServer(AspNetCoreMqttServer aspNetCoreMqttServer) + { + _aspNetCoreMqttServer = aspNetCoreMqttServer; + } + + public Task StartAsync(CancellationToken cancellationToken) + { + return _aspNetCoreMqttServer.StartAsync(); + } + + public Task StopAsync(CancellationToken cancellationToken) + { + return _aspNetCoreMqttServer.StopAsync(); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs new file mode 100644 index 000000000..66f289b2d --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Server; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class AspNetCoreMqttServer : MqttServer +{ + private readonly IOptions _stopOptions; + + public AspNetCoreMqttServer( + IOptions serverOptions, + IOptions stopOptions, + IEnumerable adapters, + IMqttNetLogger logger) : base(serverOptions.Value.Build(), adapters, logger) + { + _stopOptions = stopOptions; + } + + public Task StopAsync() + { + return base.StopAsync(_stopOptions.Value.Build()); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs similarity index 90% rename from Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs rename to Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index dfeb2f3f7..14bfdefaf 100644 --- a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -2,18 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; using MQTTnet.Server; +using System; +using System.Threading.Tasks; namespace MQTTnet.AspNetCore; -public sealed class MqttConnectionHandler : ConnectionHandler, IMqttServerAdapter +sealed class MqttConnectionHandler : ConnectionHandler, IMqttServerAdapter { MqttServerOptions _serverOptions; @@ -33,7 +33,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) } var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); - using var adapter = new MqttConnectionContext(formatter, connection); + using var adapter = new AspNetCoreMqttChannelAdapter(formatter, connection); var clientHandler = ClientHandler; if (clientHandler != null) { diff --git a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs similarity index 98% rename from Source/MQTTnet.AspnetCore/ReaderExtensions.cs rename to Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs index 9b4f24ca5..bd4c08b39 100644 --- a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs @@ -2,17 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Buffers; -using System.Runtime.InteropServices; using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Formatter; using MQTTnet.Packets; +using System; +using System.Buffers; +using System.Runtime.InteropServices; namespace MQTTnet.AspNetCore; -public static class ReaderExtensions +static class MqttPacketFormatterAdapterExtensions { public static bool TryDecode( this MqttPacketFormatterAdapter formatter, diff --git a/Source/MQTTnet.AspnetCore/InternalsVisible.cs b/Source/MQTTnet.AspnetCore/InternalsVisible.cs new file mode 100644 index 000000000..92a58575a --- /dev/null +++ b/Source/MQTTnet.AspnetCore/InternalsVisible.cs @@ -0,0 +1,3 @@ +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.Tests")] +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.AspTestApp")] +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.Benchmarks")] \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 6b7f60ce2..3afdd218a 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -1,4 +1,4 @@ - + net8.0 @@ -48,16 +48,16 @@ - - + + - - + + diff --git a/Source/MQTTnet.AspnetCore/MqttHostedServer.cs b/Source/MQTTnet.AspnetCore/MqttHostedServer.cs deleted file mode 100644 index 4c74f6a43..000000000 --- a/Source/MQTTnet.AspnetCore/MqttHostedServer.cs +++ /dev/null @@ -1,48 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Hosting; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore; - -public sealed class MqttHostedServer : MqttServer, IHostedService -{ - readonly IHostApplicationLifetime _hostApplicationLifetime; - readonly MqttServerFactory _mqttFactory; - - public MqttHostedServer( - IHostApplicationLifetime hostApplicationLifetime, - MqttServerFactory mqttFactory, - MqttServerOptions options, - IEnumerable adapters, - IMqttNetLogger logger) : base(options, adapters, logger) - { - _mqttFactory = mqttFactory ?? throw new ArgumentNullException(nameof(mqttFactory)); - _hostApplicationLifetime = hostApplicationLifetime; - } - - public async Task StartAsync(CancellationToken cancellationToken) - { - // The yield makes sure that the hosted service is considered up and running. - await Task.Yield(); - - _hostApplicationLifetime.ApplicationStarted.Register(OnStarted); - } - - public Task StopAsync(CancellationToken cancellationToken) - { - return StopAsync(_mqttFactory.CreateMqttServerStopOptionsBuilder().Build()); - } - - void OnStarted() - { - _ = StartAsync(); - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttSubProtocolSelector.cs b/Source/MQTTnet.AspnetCore/MqttSubProtocolSelector.cs deleted file mode 100644 index c6acdfa8e..000000000 --- a/Source/MQTTnet.AspnetCore/MqttSubProtocolSelector.cs +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.AspNetCore.Http; - -namespace MQTTnet.AspNetCore; - -public static class MqttSubProtocolSelector -{ - public static string SelectSubProtocol(HttpRequest request) - { - ArgumentNullException.ThrowIfNull(request); - - string subProtocol = null; - if (request.Headers.TryGetValue("Sec-WebSocket-Protocol", out var requestedSubProtocolValues)) - { - subProtocol = SelectSubProtocol(requestedSubProtocolValues); - } - - return subProtocol; - } - - public static string SelectSubProtocol(IList requestedSubProtocolValues) - { - ArgumentNullException.ThrowIfNull(requestedSubProtocolValues); - - // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. - return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt")); - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs b/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs deleted file mode 100644 index 272ead6bf..000000000 --- a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs +++ /dev/null @@ -1,66 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net.WebSockets; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using MQTTnet.Adapter; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Formatter; -using MQTTnet.Implementations; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore; - -public sealed class MqttWebSocketServerAdapter : IMqttServerAdapter -{ - IMqttNetLogger _logger = MqttNetNullLogger.Instance; - - public Func ClientHandler { get; set; } - - public void Dispose() - { - } - - public async Task RunWebSocketConnectionAsync(WebSocket webSocket, HttpContext httpContext) - { - ArgumentNullException.ThrowIfNull(webSocket); - - var endpoint = $"{httpContext.Connection.RemoteIpAddress}:{httpContext.Connection.RemotePort}"; - - var clientCertificate = await httpContext.Connection.GetClientCertificateAsync().ConfigureAwait(false); - try - { - var isSecureConnection = clientCertificate != null; - - var clientHandler = ClientHandler; - if (clientHandler != null) - { - var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(4096, 65535)); - var channel = new MqttWebSocketChannel(webSocket, endpoint, isSecureConnection, clientCertificate); - - using (var channelAdapter = new MqttChannelAdapter(channel, formatter, _logger)) - { - await clientHandler(channelAdapter).ConfigureAwait(false); - } - } - } - finally - { - clientCertificate?.Dispose(); - } - } - - public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) - { - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - return Task.CompletedTask; - } - - public Task StopAsync() - { - return Task.CompletedTask; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index bcb8b38e7..a410b4f14 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -5,8 +5,8 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; -using Microsoft.Extensions.Hosting; using MQTTnet.Adapter; +using MQTTnet.AspNetCore.Internal; using MQTTnet.Diagnostics.Logger; using MQTTnet.Server; using MQTTnet.Server.Internal.Adapter; @@ -21,107 +21,48 @@ public static class ServiceCollectionExtensions const string SocketConnectionFactoryTypeName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketConnectionFactory"; const string SocketConnectionFactoryAssemblyName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"; - public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, MqttServerOptions options) + public static IServiceCollection AddMqttServer(this IServiceCollection services, Action configure, Action stopConfigure) { - ArgumentNullException.ThrowIfNull(services); - ArgumentNullException.ThrowIfNull(options); - - services.AddSingleton(options); - services.AddHostedMqttServer(); - - return services; + services.AddOptions().Configure(stopConfigure); + return services.AddMqttServer(configure); } - public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, Action configure) + public static IServiceCollection AddMqttServer(this IServiceCollection services, Action configure) { - ArgumentNullException.ThrowIfNull(services); - - var serverOptionsBuilder = new MqttServerOptionsBuilder(); - - configure?.Invoke(serverOptionsBuilder); - - var options = serverOptionsBuilder.Build(); - - return AddHostedMqttServer(services, options); + services.AddOptions().Configure(configure); + return services.AddMqttServer(); } - public static void AddHostedMqttServer(this IServiceCollection services) + public static IServiceCollection AddMqttServer(this IServiceCollection services) { - // The user may have these services already registered. + services.AddConnections(); services.TryAddSingleton(MqttNetNullLogger.Instance); - services.TryAddSingleton(new MqttServerFactory()); - - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - services.AddSingleton(s => s.GetService()); - } - - public static IServiceCollection AddHostedMqttServerWithServices(this IServiceCollection services, Action configure) - { - ArgumentNullException.ThrowIfNull(services); - - services.AddSingleton( - s => - { - var builder = new AspNetMqttServerOptionsBuilder(s); - configure(builder); - return builder.Build(); - }); - - services.AddHostedMqttServer(); - - return services; - } - public static IServiceCollection AddMqttConnectionHandler(this IServiceCollection services) - { - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - - return services; - } - - public static void AddMqttLogger(this IServiceCollection services, IMqttNetLogger logger) - { - ArgumentNullException.ThrowIfNull(services); - - services.AddSingleton(logger); - } + var mqttConnectionHandler = new MqttConnectionHandler(); + services.TryAddSingleton(mqttConnectionHandler); + services.TryAddEnumerable(ServiceDescriptor.Singleton(mqttConnectionHandler)); - public static IServiceCollection AddMqttServer(this IServiceCollection serviceCollection, Action configure = null) - { - ArgumentNullException.ThrowIfNull(serviceCollection); - - serviceCollection.AddMqttConnectionHandler(); - serviceCollection.AddHostedMqttServer(configure); + services.TryAddSingleton(); + services.TryAddSingleton(s => s.GetRequiredService()); + services.AddHostedService(); - return serviceCollection; + return services.AddOptions(); } public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection services) { - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - - return services; - } - - public static IServiceCollection AddMqttWebSocketServerAdapter(this IServiceCollection services) - { - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - + services.TryAddEnumerable(ServiceDescriptor.Singleton()); return services; } [DynamicDependency(DynamicallyAccessedMemberTypes.All, SocketConnectionFactoryTypeName, SocketConnectionFactoryAssemblyName)] - public static IServiceCollection AddMqttClientConnectionContextFactory(this IServiceCollection services) + public static IServiceCollection AddMqttClientAdapterFactory(this IServiceCollection services) { var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); services.AddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); - services.TryAddSingleton(); - services.TryAddSingleton(serviceProvider => serviceProvider.GetRequiredService()); + services.TryAddSingleton(); + services.TryAddSingleton(s => s.GetRequiredService()); return services; } } \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs index 1225b1c06..8f3eb3d2c 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs @@ -7,8 +7,11 @@ using Microsoft.AspNetCore; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Adapter; using MQTTnet.AspNetCore; using MQTTnet.Diagnostics.Logger; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Builder; namespace MQTTnet.Benchmarks { @@ -28,9 +31,8 @@ public void Setup() .ConfigureServices(services => { services - .AddHostedMqttServer(mqttServerOptions => mqttServerOptions.WithoutDefaultEndpoint()) - .AddMqttConnectionHandler() - .AddMqttClientConnectionContextFactory(); + .AddMqttServer(mqttServerOptions => mqttServerOptions.WithoutDefaultEndpoint()) + .AddMqttClientAdapterFactory(); }) .Configure(app => { @@ -43,8 +45,8 @@ public void Setup() var factory = new MqttClientFactory(); - var mqttClientConnectionContextFactory = _host.Services.GetRequiredService(); - _mqttClient = factory.CreateMqttClient(new MqttNetEventLogger(), mqttClientConnectionContextFactory); + var mqttClientAdapterFactory = _host.Services.GetRequiredService(); + _mqttClient = factory.CreateMqttClient(new MqttNetEventLogger(), mqttClientAdapterFactory); _host.StartAsync().GetAwaiter().GetResult(); diff --git a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs index debc48aa2..7b9cb19f3 100644 --- a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs @@ -116,7 +116,7 @@ public async Task After() { if (!buffer.IsEmpty) { - if (ReaderExtensions.TryDecode(mqttPacketFormatter, buffer, out var packet, out consumed, out observed, out var received)) + if (MqttPacketFormatterAdapterExtensions.TryDecode(mqttPacketFormatter, buffer, out var packet, out consumed, out observed, out var received)) { break; } diff --git a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs index f45284cc2..d221470ea 100644 --- a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs +++ b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs @@ -16,7 +16,7 @@ namespace MQTTnet.Tests.ASP.Mockups; public sealed class ConnectionHandlerMockup : IMqttServerAdapter { public Func ClientHandler { get; set; } - public TaskCompletionSource Context { get; } = new(); + TaskCompletionSource Context { get; } = new(); public void Dispose() { @@ -27,7 +27,7 @@ public async Task OnConnectedAsync(ConnectionContext connection) try { var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(4096, 65535)); - var context = new MqttConnectionContext(formatter, connection); + var context = new AspNetCoreMqttChannelAdapter(formatter, connection); Context.TrySetResult(context); await ClientHandler(context); diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs index bfd0f8431..b616392b7 100644 --- a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs @@ -30,7 +30,7 @@ public async Task TestCorruptedConnectPacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttConnectionContext(serializer, connection); + var ctx = new AspNetCoreMqttChannelAdapter(serializer, connection); await pipe.Receive.Writer.WriteAsync(writer.AddMqttHeader(MqttControlPacketType.Connect, Array.Empty())); @@ -98,7 +98,7 @@ public async Task TestLargePacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttConnectionContext(serializer, connection); + var ctx = new AspNetCoreMqttChannelAdapter(serializer, connection); await ctx.SendPacketAsync(new MqttPublishPacket { PayloadSegment = new byte[20_000] }, CancellationToken.None).ConfigureAwait(false); @@ -113,7 +113,7 @@ public async Task TestReceivePacketAsyncThrowsWhenReaderCompleted() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttConnectionContext(serializer, connection); + var ctx = new AspNetCoreMqttChannelAdapter(serializer, connection); pipe.Receive.Writer.Complete(); From d8c94bd68eb3db2c95b81b5066e995119a7a339f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 10:41:51 +0800 Subject: [PATCH 03/85] Separate AspNetCoreMqttServerAdapter from MqttConnectionHandler --- Source/MQTTnet.AspTestApp/Program.cs | 1 - .../ApplicationBuilderExtensions.cs | 12 +++++- .../ConnectionBuilderExtensions.cs | 9 +++- .../EndpointRouteBuilderExtensions.cs | 8 ++-- .../Internal/AspNetCoreMqttServerAdapter.cs | 41 ++++++++++++++++++ .../Internal/MqttConnectionHandler.cs | 38 +++++++--------- .../MQTTnet.AspNetCore.csproj | 1 + .../ServiceCollectionExtensions.cs | 43 ++++++++++++++++--- 8 files changed, 116 insertions(+), 37 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs diff --git a/Source/MQTTnet.AspTestApp/Program.cs b/Source/MQTTnet.AspTestApp/Program.cs index 42c67ef06..ba64ccf79 100644 --- a/Source/MQTTnet.AspTestApp/Program.cs +++ b/Source/MQTTnet.AspTestApp/Program.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections; using MQTTnet; using MQTTnet.AspNetCore; using MQTTnet.Server; diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index 32835b791..a0de842de 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -2,14 +2,22 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using MQTTnet.Server; using System; -namespace Microsoft.AspNetCore.Builder; +namespace MQTTnet.AspNetCore; public static class ApplicationBuilderExtensions -{ +{ + /// + /// Get and use MqttServer + /// Also, you can inject MqttServer into your service + /// + /// + /// + /// public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app, Action configure) { var server = app.ApplicationServices.GetRequiredService(); diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index 0d4e2eba0..0060c0f63 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -2,12 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using MQTTnet.AspNetCore; +using Microsoft.AspNetCore.Connections; -namespace Microsoft.AspNetCore.Connections +namespace MQTTnet.AspNetCore { public static class ConnectionBuilderExtensions { + /// + /// Treat the obtained connection as an mqtt connection + /// + /// + /// public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder) { return builder.UseConnectionHandler(); diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs index 1577ef49e..6b7e790ff 100644 --- a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -2,19 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Routing; -using MQTTnet.AspNetCore; using System; using System.Collections.Generic; using System.Linq; -namespace Microsoft.AspNetCore.Builder +namespace MQTTnet.AspNetCore { public static class EndpointRouteBuilderExtensions { /// - /// mqtt over websocket + /// Treat the obtained WebSocket as an mqtt connection /// /// /// @@ -32,7 +32,7 @@ static string SelectSubProtocol(IList requestedSubProtocolValues) } /// - /// mqtt over websocket + /// Treat the obtained WebSocket as an mqtt connection /// /// /// diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs new file mode 100644 index 000000000..c471f0d1d --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Adapter; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Server; +using System; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class AspNetCoreMqttServerAdapter : IMqttServerAdapter +{ + readonly MqttConnectionHandler _connectionHandler; + + public Func ClientHandler + { + get => _connectionHandler.ClientHandler; + set => _connectionHandler.ClientHandler = value; + } + + public AspNetCoreMqttServerAdapter(MqttConnectionHandler connectionHandler) + { + _connectionHandler = connectionHandler; + } + + public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) + { + return Task.CompletedTask; + } + + public Task StopAsync() + { + return Task.CompletedTask; + } + + public void Dispose() + { + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 14bfdefaf..57c28ddf9 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -4,8 +4,8 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.Extensions.Options; using MQTTnet.Adapter; -using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; using MQTTnet.Server; using System; @@ -13,18 +13,27 @@ namespace MQTTnet.AspNetCore; -sealed class MqttConnectionHandler : ConnectionHandler, IMqttServerAdapter +sealed class MqttConnectionHandler : ConnectionHandler { - MqttServerOptions _serverOptions; + readonly IOptions _serverOptions; public Func ClientHandler { get; set; } - public void Dispose() + public MqttConnectionHandler(IOptions serverOptions) { + _serverOptions = serverOptions; } public override async Task OnConnectedAsync(ConnectionContext connection) { + var clientHandler = ClientHandler; + if (clientHandler == null) + { + // MqttServer has not been initialized yet. + connection.Abort(); + return; + } + // required for websocket transport to work var transferFormatFeature = connection.Features.Get(); if (transferFormatFeature != null) @@ -32,24 +41,9 @@ public override async Task OnConnectedAsync(ConnectionContext connection) transferFormatFeature.ActiveFormat = TransferFormat.Binary; } - var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); + var options = _serverOptions.Value; + var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax)); using var adapter = new AspNetCoreMqttChannelAdapter(formatter, connection); - var clientHandler = ClientHandler; - if (clientHandler != null) - { - await clientHandler(adapter).ConfigureAwait(false); - } - } - - public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) - { - _serverOptions = options; - - return Task.CompletedTask; - } - - public Task StopAsync() - { - return Task.CompletedTask; + await clientHandler(adapter).ConfigureAwait(false); } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 3afdd218a..741d3e5a3 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -45,6 +45,7 @@ True \ + diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index a410b4f14..7591b638c 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -21,26 +21,48 @@ public static class ServiceCollectionExtensions const string SocketConnectionFactoryTypeName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketConnectionFactory"; const string SocketConnectionFactoryAssemblyName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"; - public static IServiceCollection AddMqttServer(this IServiceCollection services, Action configure, Action stopConfigure) + /// + /// Register MqttServer as a service + /// + /// + /// serverOptions configure + /// server stop configure + /// + public static IServiceCollection AddMqttServer( + this IServiceCollection services, + Action configure, + Action stopConfigure) { services.AddOptions().Configure(stopConfigure); return services.AddMqttServer(configure); } - public static IServiceCollection AddMqttServer(this IServiceCollection services, Action configure) + /// + /// Register MqttServer as a service + /// + /// serverOptions configure + /// + /// + public static IServiceCollection AddMqttServer( + this IServiceCollection services, + Action configure) { services.AddOptions().Configure(configure); return services.AddMqttServer(); } + /// + /// Register MqttServer as a service + /// + /// + /// public static IServiceCollection AddMqttServer(this IServiceCollection services) { services.AddConnections(); services.TryAddSingleton(MqttNetNullLogger.Instance); - var mqttConnectionHandler = new MqttConnectionHandler(); - services.TryAddSingleton(mqttConnectionHandler); - services.TryAddEnumerable(ServiceDescriptor.Singleton(mqttConnectionHandler)); + services.TryAddSingleton(); + services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddSingleton(); services.TryAddSingleton(s => s.GetRequiredService()); @@ -49,13 +71,22 @@ public static IServiceCollection AddMqttServer(this IServiceCollection services) return services.AddOptions(); } + /// + /// Register MqttTcpServerAdapter as a IMqttServerAdapter + /// + /// + /// public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection services) { services.TryAddEnumerable(ServiceDescriptor.Singleton()); return services; } - + /// + /// Register IMqttClientAdapterFactory as a service + /// + /// + /// [DynamicDependency(DynamicallyAccessedMemberTypes.All, SocketConnectionFactoryTypeName, SocketConnectionFactoryAssemblyName)] public static IServiceCollection AddMqttClientAdapterFactory(this IServiceCollection services) { From 73502e459a9bfd79f26cdc0ae0d55a4e8c843b30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 11:30:26 +0800 Subject: [PATCH 04/85] TryAdd IConnectionFactory as Singleton --- .../MQTTnet.AspnetCore/ServiceCollectionExtensions.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 7591b638c..dd9d72de0 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -58,6 +58,7 @@ public static IServiceCollection AddMqttServer( /// public static IServiceCollection AddMqttServer(this IServiceCollection services) { + services.AddOptions(); services.AddConnections(); services.TryAddSingleton(MqttNetNullLogger.Instance); @@ -65,10 +66,10 @@ public static IServiceCollection AddMqttServer(this IServiceCollection services) services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddSingleton(); - services.TryAddSingleton(s => s.GetRequiredService()); services.AddHostedService(); + services.TryAddSingleton(s => s.GetRequiredService()); - return services.AddOptions(); + return services; } /// @@ -91,9 +92,8 @@ public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection public static IServiceCollection AddMqttClientAdapterFactory(this IServiceCollection services) { var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); - services.AddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); - services.TryAddSingleton(); - services.TryAddSingleton(s => s.GetRequiredService()); + services.TryAddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); + services.TryAddSingleton(); return services; } } \ No newline at end of file From 3a641abc5177ff072ca099a0159d9ac4f5c4c49a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 12:14:12 +0800 Subject: [PATCH 05/85] Add some remarks --- Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs | 2 +- .../Internal/AspNetCoreMqttClientAdapterFactory.cs | 6 +++--- .../Internal/AspNetCoreMqttHostedServer.cs | 1 + Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs | 2 +- Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs | 3 ++- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index a0de842de..96d15d000 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -13,8 +13,8 @@ public static class ApplicationBuilderExtensions { /// /// Get and use MqttServer - /// Also, you can inject MqttServer into your service /// + /// Also, you can inject MqttServer into your service /// /// /// diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index 920464504..255b5f08b 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -16,11 +16,11 @@ namespace MQTTnet.AspNetCore { sealed class AspNetCoreMqttClientAdapterFactory : IMqttClientAdapterFactory { - private readonly IConnectionFactory connectionFactory; + private readonly IConnectionFactory _connectionFactory; public AspNetCoreMqttClientAdapterFactory(IConnectionFactory connectionFactory) { - this.connectionFactory = connectionFactory; + _connectionFactory = connectionFactory; } public async ValueTask CreateClientAdapterAsync(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) @@ -32,7 +32,7 @@ public async ValueTask CreateClientAdapterAsync(MqttClientO case MqttClientTcpOptions tcpOptions: { var endPoint = await CreateIPEndPointAsync(tcpOptions.RemoteEndpoint); - var tcpConnection = await connectionFactory.ConnectAsync(endPoint); + var tcpConnection = await _connectionFactory.ConnectAsync(endPoint); var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttBufferWriter(4096, 65535)); return new AspNetCoreMqttChannelAdapter(formatter, tcpConnection); } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs index 0aeb7413b..7d1358cf1 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -15,6 +15,7 @@ public AspNetCoreMqttHostedServer(AspNetCoreMqttServer aspNetCoreMqttServer) public Task StartAsync(CancellationToken cancellationToken) { + // We need to set up ClientHandler for MqttConnectionHandler as soon as possible. return _aspNetCoreMqttServer.StartAsync(); } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 57c28ddf9..870fa8725 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -29,7 +29,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) var clientHandler = ClientHandler; if (clientHandler == null) { - // MqttServer has not been initialized yet. + // MqttServer has not been started yet. connection.Abort(); return; } diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index dd9d72de0..0a476032d 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -73,8 +73,9 @@ public static IServiceCollection AddMqttServer(this IServiceCollection services) } /// - /// Register MqttTcpServerAdapter as a IMqttServerAdapter + /// Register MqttTcpServerAdapter as a IMqttServerAdapter /// + /// We recommend using ListenOptions.UseMqtt() instead of using MqttTcpServerAdapter in an AspNetCore environment /// /// public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection services) From 91763aeab6080c5b66aaeeb0d42c51fd401b51cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 12:29:50 +0800 Subject: [PATCH 06/85] Add AspNetCoreMqttNetLogger --- .../Internal/AspNetCoreMqttNetLogger.cs | 37 +++++++++++++++++++ .../ServiceCollectionExtensions.cs | 3 +- 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs new file mode 100644 index 000000000..72cd4a1d9 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs @@ -0,0 +1,37 @@ +using Microsoft.Extensions.Logging; +using MQTTnet.Diagnostics.Logger; +using System; + +namespace MQTTnet.AspNetCore.Internal +{ + sealed class AspNetCoreMqttNetLogger : IMqttNetLogger + { + private readonly ILoggerFactory _loggerFactory; + private const string categoryNamePrefix = "MQTTnet.AspNetCore."; + + public bool IsEnabled => true; + + public AspNetCoreMqttNetLogger(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; + } + + public void Publish(MqttNetLogLevel logLevel, string source, string message, object[] parameters, Exception exception) + { + var logger = _loggerFactory.CreateLogger($"{categoryNamePrefix}{source}"); + logger.Log(CastLogLevel(logLevel), exception, message, parameters); + } + + private static LogLevel CastLogLevel(MqttNetLogLevel level) + { + return level switch + { + MqttNetLogLevel.Verbose => LogLevel.Trace, + MqttNetLogLevel.Info => LogLevel.Information, + MqttNetLogLevel.Warning => LogLevel.Warning, + MqttNetLogLevel.Error => LogLevel.Error, + _ => LogLevel.None + }; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 0a476032d..f08eebbbb 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -60,7 +60,8 @@ public static IServiceCollection AddMqttServer(this IServiceCollection services) { services.AddOptions(); services.AddConnections(); - services.TryAddSingleton(MqttNetNullLogger.Instance); + services.AddLogging(); + services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddEnumerable(ServiceDescriptor.Singleton()); From 9c142222f23a0290aea38e9db10c48b7ed35c527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 13:48:04 +0800 Subject: [PATCH 07/85] Delayed start of AspNetCoreMqttServer --- .../Internal/AspNetCoreMqttHostedServer.cs | 13 ++++++++++--- .../Internal/MqttConnectionHandler.cs | 9 +++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs index 7d1358cf1..643727a60 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -8,15 +8,22 @@ sealed class AspNetCoreMqttHostedServer : IHostedService { private readonly AspNetCoreMqttServer _aspNetCoreMqttServer; - public AspNetCoreMqttHostedServer(AspNetCoreMqttServer aspNetCoreMqttServer) + public AspNetCoreMqttHostedServer( + AspNetCoreMqttServer aspNetCoreMqttServer, + IHostApplicationLifetime hostApplicationLifetime) { _aspNetCoreMqttServer = aspNetCoreMqttServer; + hostApplicationLifetime.ApplicationStarted.Register(ApplicationStarted); } public Task StartAsync(CancellationToken cancellationToken) { - // We need to set up ClientHandler for MqttConnectionHandler as soon as possible. - return _aspNetCoreMqttServer.StartAsync(); + return Task.CompletedTask; + } + + private void ApplicationStarted() + { + _ = _aspNetCoreMqttServer.StartAsync(); } public Task StopAsync(CancellationToken cancellationToken) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 870fa8725..556bdaeb6 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Connections.Features; using Microsoft.Extensions.Options; using MQTTnet.Adapter; +using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; using MQTTnet.Server; using System; @@ -15,12 +16,16 @@ namespace MQTTnet.AspNetCore; sealed class MqttConnectionHandler : ConnectionHandler { + readonly IMqttNetLogger _logger; readonly IOptions _serverOptions; public Func ClientHandler { get; set; } - public MqttConnectionHandler(IOptions serverOptions) + public MqttConnectionHandler( + IMqttNetLogger logger, + IOptions serverOptions) { + _logger = logger; _serverOptions = serverOptions; } @@ -29,8 +34,8 @@ public override async Task OnConnectedAsync(ConnectionContext connection) var clientHandler = ClientHandler; if (clientHandler == null) { - // MqttServer has not been started yet. connection.Abort(); + _logger.Publish(MqttNetLogLevel.Warning, nameof(MqttConnectionHandler), "MqttServer has not been started yet.", null, null); return; } From 23723526b9837e71bc9ebab6b53dd1bd58e92aae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 14:19:54 +0800 Subject: [PATCH 08/85] Using fields to cache IHttpContextFeature --- .../Internal/AspNetCoreMqttChannelAdapter.cs | 46 ++++++------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs index 38b9f066a..0c362e1c3 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs @@ -27,6 +27,7 @@ sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter readonly PipeReader _input; readonly PipeWriter _output; + readonly IHttpContextFeature _httpContextFeature; public AspNetCoreMqttChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) { @@ -34,7 +35,9 @@ public AspNetCoreMqttChannelAdapter(MqttPacketFormatterAdapter packetFormatterAd _connection = connection ?? throw new ArgumentNullException(nameof(connection)); _input = connection.Transport.Input; _output = connection.Transport.Output; + _httpContextFeature = connection.Features.Get(); } + public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } public long BytesReceived { get; private set; } @@ -44,16 +47,13 @@ public X509Certificate2 ClientCertificate { get { - // mqtt over tcp - var tlsFeature = _connection.Features.Get(); - if (tlsFeature != null) + if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) { - return tlsFeature.ClientCertificate; + return _httpContextFeature.HttpContext.Connection.ClientCertificate; } - // mqtt over websocket - var httpFeature = _connection.Features.Get(); - return httpFeature?.HttpContext?.Connection.ClientCertificate; + var tlsFeature = _connection.Features.Get(); + return tlsFeature?.ClientCertificate; } } @@ -61,20 +61,13 @@ public string Endpoint { get { - // mqtt over tcp - if (_connection.RemoteEndPoint != null) - { - return _connection.RemoteEndPoint.ToString(); - } - - // mqtt over websocket - var httpFeature = _connection.Features.Get(); - if (httpFeature?.RemoteIpAddress != null) + if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) { - return new IPEndPoint(httpFeature.RemoteIpAddress, httpFeature.RemotePort).ToString(); + var httpConnection = _httpContextFeature.HttpContext.Connection; + return httpConnection == null ? null : new IPEndPoint(httpConnection.RemoteIpAddress, httpConnection.RemotePort).ToString(); } - return null; + return _connection.RemoteEndPoint?.ToString(); } } @@ -82,25 +75,16 @@ public bool IsSecureConnection { get { - // mqtt over tcp - var tlsFeature = _connection.Features.Get(); - if (tlsFeature != null) - { - return true; - } - - // mqtt over websocket - var httpFeature = _connection.Features.Get(); - if (httpFeature?.HttpContext != null) + if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) { - return httpFeature.HttpContext.Request.IsHttps; + return _httpContextFeature.HttpContext.Request.IsHttps; } - return false; + var tlsFeature = _connection.Features.Get(); + return tlsFeature != null; } } - public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } public Task ConnectAsync(CancellationToken cancellationToken) { From 8d96e1911974692099caeebf6bb390399511dfd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 15:45:10 +0800 Subject: [PATCH 09/85] Update Server_ASP_NET_Samples --- Samples/Server/Server_ASP_NET_Samples.cs | 85 +++++++------------ .../ApplicationBuilderExtensions.cs | 12 +++ 2 files changed, 41 insertions(+), 56 deletions(-) diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 3466a8b95..963a975b2 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -9,10 +9,9 @@ // ReSharper disable MemberCanBeMadeStatic.Local using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; using MQTTnet.AspNetCore; using MQTTnet.Server; @@ -22,78 +21,52 @@ public static class Server_ASP_NET_Samples { public static Task Start_Server_With_WebSockets_Support() { - /* - * This sample starts a minimal ASP.NET Webserver including a hosted MQTT server. - */ - var host = Host.CreateDefaultBuilder(Array.Empty()) - .ConfigureWebHostDefaults( - webBuilder => - { - webBuilder.UseKestrel( - o => - { - // This will allow MQTT connections based on TCP port 1883. - o.ListenAnyIP(1883, l => l.UseMqtt()); + var builder = WebApplication.CreateBuilder(); + builder.Services.AddMqttServer(); + builder.Services.AddSingleton(); - // This will allow MQTT connections based on HTTP WebSockets with URI "localhost:5000/mqtt" - // See code below for URI configuration. - o.ListenAnyIP(5000); // Default HTTP pipeline - }); + builder.WebHost.UseKestrel(kestrel => + { + // mqtt over tcp + kestrel.ListenAnyIP(1883, l => l.UseMqtt()); + + // mqtt over tls over tcp + kestrel.ListenAnyIP(1884, l => l.UseHttps().UseMqtt()); - webBuilder.UseStartup(); - }); + // This will allow MQTT connections based on HTTP WebSockets with URI "localhost:5000/mqtt" + // See code below for URI configuration. + kestrel.ListenAnyIP(5000); // Default HTTP pipeline + }); - return host.RunConsoleAsync(); + var app = builder.Build(); + app.MapMqtt("/mqtt"); + app.UseMqttServer(); + return app.RunAsync(); } sealed class MqttController { - public MqttController() + private readonly ILogger _logger; + + public MqttController( + MqttServer mqttServer, + ILogger logger) { - // Inject other services via constructor. + mqttServer.ValidatingConnectionAsync += ValidateConnection; + mqttServer.ClientConnectedAsync += OnClientConnected; + _logger = logger; } public Task OnClientConnected(ClientConnectedEventArgs eventArgs) { - Console.WriteLine($"Client '{eventArgs.ClientId}' connected."); + _logger.LogInformation($"Client '{eventArgs.ClientId}' connected."); return Task.CompletedTask; } - public Task ValidateConnection(ValidatingConnectionEventArgs eventArgs) { - Console.WriteLine($"Client '{eventArgs.ClientId}' wants to connect. Accepting!"); + _logger.LogInformation($"Client '{eventArgs.ClientId}' wants to connect. Accepting!"); return Task.CompletedTask; } } - - sealed class Startup - { - public void Configure(IApplicationBuilder app, IWebHostEnvironment environment, MqttController mqttController) - { - app.UseRouting(); - app.UseEndpoints(endpoints => endpoints.MapMqtt("/mqtt")); - - app.UseMqttServer( - server => - { - /* - * Attach event handlers etc. if required. - */ - - server.ValidatingConnectionAsync += mqttController.ValidateConnection; - server.ClientConnectedAsync += mqttController.OnClientConnected; - }); - } - - public void ConfigureServices(IServiceCollection services) - { - services.AddMqttServer(optionsBuilder => - { - optionsBuilder.WithDefaultEndpoint(); - }); - - services.AddSingleton(); - } - } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index 96d15d000..2974c0c59 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -24,4 +24,16 @@ public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app, Ac configure(server); return app; } + + /// + /// Use MqttServer's wrapper service + /// + /// + /// + /// + public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app) + { + app.ApplicationServices.GetRequiredService(); + return app; + } } \ No newline at end of file From ad5c798647e8bec86a7c12bbf1e33773bd682618 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 20:39:08 +0800 Subject: [PATCH 10/85] Use ActivatorUtilities to create TMQttServerWrapper --- Samples/Server/Server_ASP_NET_Samples.cs | 2 -- Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 963a975b2..9cbe7c723 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -10,7 +10,6 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using MQTTnet.AspNetCore; using MQTTnet.Server; @@ -23,7 +22,6 @@ public static Task Start_Server_With_WebSockets_Support() { var builder = WebApplication.CreateBuilder(); builder.Services.AddMqttServer(); - builder.Services.AddSingleton(); builder.WebHost.UseKestrel(kestrel => { diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index 2974c0c59..918638843 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -26,14 +26,14 @@ public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app, Ac } /// - /// Use MqttServer's wrapper service + /// Active MqttServer's wrapper service /// - /// + /// /// /// - public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app) + public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app) { - app.ApplicationServices.GetRequiredService(); + ActivatorUtilities.GetServiceOrCreateInstance(app.ApplicationServices); return app; } } \ No newline at end of file From 7f95f02b01539ffd38a6b082bdd103f834d32f56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 21:03:50 +0800 Subject: [PATCH 11/85] Add IMqttServerBuilder --- .../MQTTnet.AspnetCore/IMqttServerBuilder.cs | 12 ++++ .../MqttServerBuilderExtensions.cs | 59 +++++++++++++++++++ .../ServiceCollectionExtensions.cs | 56 +++--------------- ...rocessingMqttConnectionContextBenchmark.cs | 15 +---- 4 files changed, 83 insertions(+), 59 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs create mode 100644 Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs diff --git a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs new file mode 100644 index 000000000..d57b10173 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs @@ -0,0 +1,12 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace MQTTnet.AspNetCore +{ + /// + /// Builder of MqttServer + /// + public interface IMqttServerBuilder + { + IServiceCollection Services { get; } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs new file mode 100644 index 000000000..849c4dae2 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -0,0 +1,59 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Server; +using MQTTnet.Server.Internal.Adapter; +using System; + +namespace MQTTnet.AspNetCore +{ + public static class MqttServerBuilderExtensions + { + /// + /// Disable logging + /// + /// + /// + public static IMqttServerBuilder UseNullLogger(this IMqttServerBuilder builder) + { + builder.Services.Replace(ServiceDescriptor.Singleton(MqttNetNullLogger.Instance)); + return builder; + } + + /// + /// Configure MqttServerOptionsBuilder + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder; + } + + /// + /// Configure MqttServerStopOptionsBuilder + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder; + } + + /// + /// Add MqttTcpServerAdapter to MqttServer + /// + /// We recommend using ListenOptions.UseMqtt() instead of using MqttTcpServerAdapter in an AspNetCore environment + /// + /// + public static IMqttServerBuilder AddMqttTcpServerAdapter(this IMqttServerBuilder builder) + { + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton()); + return builder; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index f08eebbbb..0ca9043d3 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -9,8 +9,6 @@ using MQTTnet.AspNetCore.Internal; using MQTTnet.Diagnostics.Logger; using MQTTnet.Server; -using MQTTnet.Server.Internal.Adapter; -using System; using System.Diagnostics.CodeAnalysis; using System.Reflection; @@ -22,41 +20,11 @@ public static class ServiceCollectionExtensions const string SocketConnectionFactoryAssemblyName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"; /// - /// Register MqttServer as a service + /// Register MqttServer as a singleton service /// /// - /// serverOptions configure - /// server stop configure /// - public static IServiceCollection AddMqttServer( - this IServiceCollection services, - Action configure, - Action stopConfigure) - { - services.AddOptions().Configure(stopConfigure); - return services.AddMqttServer(configure); - } - - /// - /// Register MqttServer as a service - /// - /// serverOptions configure - /// - /// - public static IServiceCollection AddMqttServer( - this IServiceCollection services, - Action configure) - { - services.AddOptions().Configure(configure); - return services.AddMqttServer(); - } - - /// - /// Register MqttServer as a service - /// - /// - /// - public static IServiceCollection AddMqttServer(this IServiceCollection services) + public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) { services.AddOptions(); services.AddConnections(); @@ -70,19 +38,7 @@ public static IServiceCollection AddMqttServer(this IServiceCollection services) services.AddHostedService(); services.TryAddSingleton(s => s.GetRequiredService()); - return services; - } - - /// - /// Register MqttTcpServerAdapter as a IMqttServerAdapter - /// - /// We recommend using ListenOptions.UseMqtt() instead of using MqttTcpServerAdapter in an AspNetCore environment - /// - /// - public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection services) - { - services.TryAddEnumerable(ServiceDescriptor.Singleton()); - return services; + return new MqttServerBuilder(services); } /// @@ -98,4 +54,10 @@ public static IServiceCollection AddMqttClientAdapterFactory(this IServiceCollec services.TryAddSingleton(); return services; } + + + private class MqttServerBuilder(IServiceCollection services) : IMqttServerBuilder + { + public IServiceCollection Services { get; } = services; + } } \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs index 8f3eb3d2c..a9cb07e04 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs @@ -10,8 +10,6 @@ using MQTTnet.Adapter; using MQTTnet.AspNetCore; using MQTTnet.Diagnostics.Logger; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Builder; namespace MQTTnet.Benchmarks { @@ -31,16 +29,9 @@ public void Setup() .ConfigureServices(services => { services - .AddMqttServer(mqttServerOptions => mqttServerOptions.WithoutDefaultEndpoint()) - .AddMqttClientAdapterFactory(); - }) - .Configure(app => - { - app.UseMqttServer(s => - { - - }); - }) + .AddMqttClientAdapterFactory() + .AddMqttServer(); + }) .Build(); From 7b44ea2d70a982063ace382a3dff2183adb9c32f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 22:40:13 +0800 Subject: [PATCH 12/85] Add IMqttClientBuilder --- Samples/Server/Server_ASP_NET_Samples.cs | 35 +++++++++++++--- Source/MQTTnet.AspnetCore/IMqttBuilder.cs | 9 ++++ .../MQTTnet.AspnetCore/IMqttClientBuilder.cs | 9 ++++ .../MQTTnet.AspnetCore/IMqttClientFactory.cs | 7 ++++ .../MQTTnet.AspnetCore/IMqttServerBuilder.cs | 5 +-- .../Internal/AspNetCoreMqttClientFactory.cs | 24 +++++++++++ .../MqttBuilderExtensions.cs | 39 ++++++++++++++++++ .../MqttClientBuilderExtensions.cs | 41 +++++++++++++++++++ .../MqttServerBuilderExtensions.cs | 14 +------ .../ServiceCollectionExtensions.cs | 32 +++++++-------- ...rocessingMqttConnectionContextBenchmark.cs | 15 +++---- 11 files changed, 179 insertions(+), 51 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/IMqttBuilder.cs create mode 100644 Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs create mode 100644 Source/MQTTnet.AspnetCore/IMqttClientFactory.cs create mode 100644 Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs create mode 100644 Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs create mode 100644 Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 9cbe7c723..3fa1cb6aa 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -10,6 +10,8 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using MQTTnet.AspNetCore; using MQTTnet.Server; @@ -22,6 +24,8 @@ public static Task Start_Server_With_WebSockets_Support() { var builder = WebApplication.CreateBuilder(); builder.Services.AddMqttServer(); + builder.Services.AddMqttClient().UseAspNetCoreMqttClientAdapterFactory(); + builder.Services.AddHostedService(); builder.WebHost.UseKestrel(kestrel => { @@ -38,21 +42,22 @@ public static Task Start_Server_With_WebSockets_Support() var app = builder.Build(); app.MapMqtt("/mqtt"); - app.UseMqttServer(); + app.UseMqttServer(); return app.RunAsync(); } - sealed class MqttController + sealed class MqttServerController { - private readonly ILogger _logger; + private readonly ILogger _logger; - public MqttController( + public MqttServerController( MqttServer mqttServer, - ILogger logger) + ILogger logger) { + _logger = logger; + mqttServer.ValidatingConnectionAsync += ValidateConnection; mqttServer.ClientConnectedAsync += OnClientConnected; - _logger = logger; } public Task OnClientConnected(ClientConnectedEventArgs eventArgs) @@ -67,4 +72,22 @@ public Task ValidateConnection(ValidatingConnectionEventArgs eventArgs) return Task.CompletedTask; } } + + sealed class MqttClientController : BackgroundService + { + private readonly IMqttClientFactory _mqttClientFactory; + + public MqttClientController(IMqttClientFactory mqttClientFactory) + { + _mqttClientFactory = mqttClientFactory; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + await Task.Delay(3000); + using var client = _mqttClientFactory.CreateMqttClient(); + var options = new MqttClientOptionsBuilder().WithTcpServer("localhost").Build(); + await client.ConnectAsync(options, stoppingToken); + } + } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/IMqttBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttBuilder.cs new file mode 100644 index 000000000..41003a259 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttBuilder.cs @@ -0,0 +1,9 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace MQTTnet.AspNetCore +{ + public interface IMqttBuilder + { + IServiceCollection Services { get; } + } +} diff --git a/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs new file mode 100644 index 000000000..c38bba15b --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs @@ -0,0 +1,9 @@ +namespace MQTTnet.AspNetCore +{ + /// + /// Builder of IMqttClientFactory + /// + public interface IMqttClientBuilder: IMqttBuilder + { + } +} diff --git a/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs new file mode 100644 index 000000000..b80d72b56 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs @@ -0,0 +1,7 @@ +namespace MQTTnet.AspNetCore +{ + public interface IMqttClientFactory + { + IMqttClient CreateMqttClient(); + } +} diff --git a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs index d57b10173..337e5fd26 100644 --- a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs +++ b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs @@ -1,12 +1,9 @@ -using Microsoft.Extensions.DependencyInjection; - namespace MQTTnet.AspNetCore { /// /// Builder of MqttServer /// - public interface IMqttServerBuilder + public interface IMqttServerBuilder : IMqttBuilder { - IServiceCollection Services { get; } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs new file mode 100644 index 000000000..8fd104726 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs @@ -0,0 +1,24 @@ +using MQTTnet.Adapter; +using MQTTnet.Diagnostics.Logger; + +namespace MQTTnet.AspNetCore +{ + sealed class AspNetCoreMqttClientFactory : IMqttClientFactory + { + private readonly IMqttClientAdapterFactory _mqttClientAdapterFactory; + private readonly IMqttNetLogger _logger; + + public AspNetCoreMqttClientFactory( + IMqttClientAdapterFactory mqttClientAdapterFactory, + IMqttNetLogger logger) + { + _mqttClientAdapterFactory = mqttClientAdapterFactory; + _logger = logger; + } + + public IMqttClient CreateMqttClient() + { + return new MqttClient(_mqttClientAdapterFactory, _logger); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs new file mode 100644 index 000000000..25a92fee3 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs @@ -0,0 +1,39 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using MQTTnet.Diagnostics.Logger; +using System; +using System.Diagnostics.CodeAnalysis; + +namespace MQTTnet.AspNetCore +{ + public static class MqttBuilderExtensions + { + /// + /// Disable logging + /// + /// + /// + public static IMqttBuilder UseNullLogger(this IMqttBuilder builder) + { + return builder.UseLogger(); + } + + /// + /// Use a logger + /// + /// + /// + /// + public static IMqttBuilder UseLogger<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TLogger>(this IMqttBuilder builder) + where TLogger : class, IMqttNetLogger + { + builder.Services.Replace(ServiceDescriptor.Singleton()); + return builder; + } + + private class MqttBuilder(IServiceCollection services) : IMqttBuilder + { + public IServiceCollection Services { get; } = services; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs new file mode 100644 index 000000000..bbccddec1 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -0,0 +1,41 @@ +using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using MQTTnet.Adapter; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace MQTTnet.AspNetCore +{ + public static class MqttClientBuilderExtensions + { + const string SocketConnectionFactoryTypeName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketConnectionFactory"; + const string SocketConnectionFactoryAssemblyName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"; + + /// + /// Replace the implementation of IMqttClientAdapterFactory to AspNetCoreMqttClientAdapterFactory + /// + /// + /// + [DynamicDependency(DynamicallyAccessedMemberTypes.All, SocketConnectionFactoryTypeName, SocketConnectionFactoryAssemblyName)] + public static IMqttClientBuilder UseAspNetCoreMqttClientAdapterFactory(this IMqttClientBuilder builder) + { + var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); + builder.Services.TryAddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); + return builder.UseMqttClientAdapterFactory(); + } + + /// + /// Replace the implementation of IMqttClientAdapterFactory to TMqttClientAdapterFactory + /// + /// + /// + /// + public static IMqttClientBuilder UseMqttClientAdapterFactory<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TMqttClientAdapterFactory>(this IMqttClientBuilder builder) + where TMqttClientAdapterFactory : class, IMqttClientAdapterFactory + { + builder.Services.Replace(ServiceDescriptor.Singleton()); + return builder; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index 849c4dae2..0bd77d601 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; -using MQTTnet.Diagnostics.Logger; using MQTTnet.Server; using MQTTnet.Server.Internal.Adapter; using System; @@ -8,18 +7,7 @@ namespace MQTTnet.AspNetCore { public static class MqttServerBuilderExtensions - { - /// - /// Disable logging - /// - /// - /// - public static IMqttServerBuilder UseNullLogger(this IMqttServerBuilder builder) - { - builder.Services.Replace(ServiceDescriptor.Singleton(MqttNetNullLogger.Instance)); - return builder; - } - + { /// /// Configure MqttServerOptionsBuilder /// diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 0ca9043d3..b12aeba70 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -2,23 +2,18 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Adapter; using MQTTnet.AspNetCore.Internal; using MQTTnet.Diagnostics.Logger; +using MQTTnet.Implementations; using MQTTnet.Server; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; namespace MQTTnet.AspNetCore; public static class ServiceCollectionExtensions { - const string SocketConnectionFactoryTypeName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketConnectionFactory"; - const string SocketConnectionFactoryAssemblyName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"; - /// /// Register MqttServer as a singleton service /// @@ -28,9 +23,6 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) { services.AddOptions(); services.AddConnections(); - services.AddLogging(); - services.TryAddSingleton(); - services.TryAddSingleton(); services.TryAddEnumerable(ServiceDescriptor.Singleton()); @@ -38,25 +30,29 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) services.AddHostedService(); services.TryAddSingleton(s => s.GetRequiredService()); - return new MqttServerBuilder(services); + return services.AddMqtt(); } /// - /// Register IMqttClientAdapterFactory as a service + /// Register IMqttClientFactory as a singleton service /// /// /// - [DynamicDependency(DynamicallyAccessedMemberTypes.All, SocketConnectionFactoryTypeName, SocketConnectionFactoryAssemblyName)] - public static IServiceCollection AddMqttClientAdapterFactory(this IServiceCollection services) + public static IMqttClientBuilder AddMqttClient(this IServiceCollection services) { - var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); - services.TryAddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); - services.TryAddSingleton(); - return services; + services.TryAddSingleton(); + services.TryAddSingleton(); + return services.AddMqtt(); } + private static MqttBuilder AddMqtt(this IServiceCollection services) + { + services.AddLogging(); + services.TryAddSingleton(); + return new MqttBuilder(services); + } - private class MqttServerBuilder(IServiceCollection services) : IMqttServerBuilder + private class MqttBuilder(IServiceCollection services) : IMqttServerBuilder, IMqttClientBuilder { public IServiceCollection Services { get; } = services; } diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs index a9cb07e04..f76bd4058 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs @@ -7,9 +7,7 @@ using Microsoft.AspNetCore; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; -using MQTTnet.Adapter; using MQTTnet.AspNetCore; -using MQTTnet.Diagnostics.Logger; namespace MQTTnet.Benchmarks { @@ -28,16 +26,13 @@ public void Setup() .UseKestrel(o => o.ListenAnyIP(1883, l => l.UseMqtt())) .ConfigureServices(services => { - services - .AddMqttClientAdapterFactory() - .AddMqttServer(); - }) + services.AddMqttServer(); + services.AddMqttClient().UseAspNetCoreMqttClientAdapterFactory(); + }) .Build(); - - var factory = new MqttClientFactory(); - var mqttClientAdapterFactory = _host.Services.GetRequiredService(); - _mqttClient = factory.CreateMqttClient(new MqttNetEventLogger(), mqttClientAdapterFactory); + var factory = _host.Services.GetRequiredService(); + _mqttClient = factory.CreateMqttClient(); _host.StartAsync().GetAwaiter().GetResult(); From 73c136599887df18ce7d381859a0d8f0b117e197 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 23:02:27 +0800 Subject: [PATCH 13/85] Conditionally load SocketConnectionFactoryAssembly --- Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs index bbccddec1..e0b858a3a 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -3,6 +3,7 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Adapter; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Reflection; namespace MQTTnet.AspNetCore @@ -20,8 +21,12 @@ public static class MqttClientBuilderExtensions [DynamicDependency(DynamicallyAccessedMemberTypes.All, SocketConnectionFactoryTypeName, SocketConnectionFactoryAssemblyName)] public static IMqttClientBuilder UseAspNetCoreMqttClientAdapterFactory(this IMqttClientBuilder builder) { - var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); - builder.Services.TryAddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); + if (!builder.Services.Any(s => s.ServiceType == typeof(IConnectionFactory))) + { + var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); + builder.Services.AddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); + } + return builder.UseMqttClientAdapterFactory(); } From 2c10d2cdd28e93e202ff226e31f61b4ada5e4efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 23:07:32 +0800 Subject: [PATCH 14/85] Add LICENSE --- Source/MQTTnet.AspnetCore/IMqttBuilder.cs | 4 ++++ Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs | 4 ++++ Source/MQTTnet.AspnetCore/IMqttClientFactory.cs | 4 ++++ Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs | 4 ++++ .../Internal/AspNetCoreMqttClientFactory.cs | 4 ++++ .../MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs | 4 ++++ Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs | 4 ++++ Source/MQTTnet.AspnetCore/InternalsVisible.cs | 4 ++++ Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs | 4 ++++ Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs | 4 ++++ Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs | 4 ++++ 11 files changed, 44 insertions(+) diff --git a/Source/MQTTnet.AspnetCore/IMqttBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttBuilder.cs index 41003a259..a4438ff4f 100644 --- a/Source/MQTTnet.AspnetCore/IMqttBuilder.cs +++ b/Source/MQTTnet.AspnetCore/IMqttBuilder.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.Extensions.DependencyInjection; namespace MQTTnet.AspNetCore diff --git a/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs index c38bba15b..f7fcda33c 100644 --- a/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs +++ b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + namespace MQTTnet.AspNetCore { /// diff --git a/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs index b80d72b56..41a7ce551 100644 --- a/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs +++ b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + namespace MQTTnet.AspNetCore { public interface IMqttClientFactory diff --git a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs index 337e5fd26..1b6057bf5 100644 --- a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs +++ b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + namespace MQTTnet.AspNetCore { /// diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs index 8fd104726..68c8b0085 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs index 643727a60..ffb62ecdd 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.Extensions.Hosting; using System.Threading; using System.Threading.Tasks; diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs index 72cd4a1d9..1564c4387 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.Extensions.Logging; using MQTTnet.Diagnostics.Logger; using System; diff --git a/Source/MQTTnet.AspnetCore/InternalsVisible.cs b/Source/MQTTnet.AspnetCore/InternalsVisible.cs index 92a58575a..b823bc96d 100644 --- a/Source/MQTTnet.AspnetCore/InternalsVisible.cs +++ b/Source/MQTTnet.AspnetCore/InternalsVisible.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + [assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.Tests")] [assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.AspTestApp")] [assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.Benchmarks")] \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs index 25a92fee3..7ed88a2c2 100644 --- a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Diagnostics.Logger; diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs index e0b858a3a..1547ddf9b 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index 0bd77d601..487a8ce3a 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Server; From 6cda003993c762711243c6e46c052b3e95e919cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 13 Nov 2024 23:25:23 +0800 Subject: [PATCH 15/85] DynamicallyAccessedMembers --- Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index 918638843..4ab0eb653 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -6,6 +6,7 @@ using Microsoft.Extensions.DependencyInjection; using MQTTnet.Server; using System; +using System.Diagnostics.CodeAnalysis; namespace MQTTnet.AspNetCore; @@ -31,7 +32,7 @@ public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app, Ac /// /// /// - public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app) + public static IApplicationBuilder UseMqttServer<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TMQttServerWrapper>(this IApplicationBuilder app) { ActivatorUtilities.GetServiceOrCreateInstance(app.ApplicationServices); return app; From faaadbdcec93dbb78adff58ff81a7dda98ca071d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 14 Nov 2024 00:19:05 +0800 Subject: [PATCH 16/85] Inject IOptions --- .../Internal/MqttConnectionHandler.cs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 556bdaeb6..cf25eaca7 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -17,16 +17,16 @@ namespace MQTTnet.AspNetCore; sealed class MqttConnectionHandler : ConnectionHandler { readonly IMqttNetLogger _logger; - readonly IOptions _serverOptions; + readonly MqttServerOptions _serverOptions; public Func ClientHandler { get; set; } public MqttConnectionHandler( IMqttNetLogger logger, - IOptions serverOptions) + IOptions serverOptions) { _logger = logger; - _serverOptions = serverOptions; + _serverOptions = serverOptions.Value.Build(); } public override async Task OnConnectedAsync(ConnectionContext connection) @@ -35,7 +35,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) if (clientHandler == null) { connection.Abort(); - _logger.Publish(MqttNetLogLevel.Warning, nameof(MqttConnectionHandler), "MqttServer has not been started yet.", null, null); + _logger.Publish(MqttNetLogLevel.Warning, nameof(MqttConnectionHandler), $"{nameof(MqttServer)} has not been started yet.", null, null); return; } @@ -46,8 +46,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) transferFormatFeature.ActiveFormat = TransferFormat.Binary; } - var options = _serverOptions.Value; - var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax)); + var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); using var adapter = new AspNetCoreMqttChannelAdapter(formatter, connection); await clientHandler(adapter).ConfigureAwait(false); } From 13c9198772681f9f3dfec94df79955c32a1b9dcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 14 Nov 2024 00:30:10 +0800 Subject: [PATCH 17/85] Change the namespace to MQTTnet.AspNetCore --- .../MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs | 2 +- Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs | 2 +- Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs index ffb62ecdd..1c310f7bd 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -6,7 +6,7 @@ using System.Threading; using System.Threading.Tasks; -namespace MQTTnet.AspNetCore.Internal +namespace MQTTnet.AspNetCore { sealed class AspNetCoreMqttHostedServer : IHostedService { diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs index 1564c4387..50e35964a 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs @@ -6,7 +6,7 @@ using MQTTnet.Diagnostics.Logger; using System; -namespace MQTTnet.AspNetCore.Internal +namespace MQTTnet.AspNetCore { sealed class AspNetCoreMqttNetLogger : IMqttNetLogger { diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index b12aeba70..ae4e613b9 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -5,7 +5,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Adapter; -using MQTTnet.AspNetCore.Internal; using MQTTnet.Diagnostics.Logger; using MQTTnet.Implementations; using MQTTnet.Server; From b8d8abb4210aad2e98be8ad7ab46e266743ec3f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 14 Nov 2024 08:56:08 +0800 Subject: [PATCH 18/85] await for_aspNetCoreMqttServer.StartAsync --- .../Internal/AspNetCoreMqttHostedServer.cs | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs index 1c310f7bd..8877b037b 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -8,29 +8,38 @@ namespace MQTTnet.AspNetCore { - sealed class AspNetCoreMqttHostedServer : IHostedService + sealed class AspNetCoreMqttHostedServer : BackgroundService { private readonly AspNetCoreMqttServer _aspNetCoreMqttServer; + private readonly Task _applicationStartedTask; public AspNetCoreMqttHostedServer( AspNetCoreMqttServer aspNetCoreMqttServer, IHostApplicationLifetime hostApplicationLifetime) { _aspNetCoreMqttServer = aspNetCoreMqttServer; - hostApplicationLifetime.ApplicationStarted.Register(ApplicationStarted); + _applicationStartedTask = WaitApplicationStartedAsync(hostApplicationLifetime); } - public Task StartAsync(CancellationToken cancellationToken) + private static Task WaitApplicationStartedAsync(IHostApplicationLifetime hostApplicationLifetime) { - return Task.CompletedTask; + var taskCompletionSource = new TaskCompletionSource(); + hostApplicationLifetime.ApplicationStarted.Register(OnApplicationStarted); + return taskCompletionSource.Task; + + void OnApplicationStarted() + { + taskCompletionSource.TrySetResult(); + } } - private void ApplicationStarted() + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - _ = _aspNetCoreMqttServer.StartAsync(); + await _applicationStartedTask.WaitAsync(stoppingToken); + await _aspNetCoreMqttServer.StartAsync(); } - public Task StopAsync(CancellationToken cancellationToken) + public override Task StopAsync(CancellationToken cancellationToken) { return _aspNetCoreMqttServer.StopAsync(); } From 64ed638e24dd57a27cf7a1c0c5d92fb579038cd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 14 Nov 2024 13:49:34 +0800 Subject: [PATCH 19/85] enable Nullable --- .../EndpointRouteBuilderExtensions.cs | 4 +--- .../Internal/AspNetCoreMqttChannelAdapter.cs | 23 ++++++++++--------- .../Internal/AspNetCoreMqttNetLogger.cs | 6 ++--- .../Internal/AspNetCoreMqttServerAdapter.cs | 2 +- .../Internal/MqttConnectionHandler.cs | 2 +- .../MqttPacketFormatterAdapterExtensions.cs | 3 ++- .../MQTTnet.AspNetCore.csproj | 4 +++- .../MqttClientBuilderExtensions.cs | 8 ++++++- 8 files changed, 30 insertions(+), 22 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs index 6b7e790ff..d7fe5a2e2 100644 --- a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -21,13 +21,12 @@ public static class EndpointRouteBuilderExtensions /// public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, string pattern) { - ArgumentNullException.ThrowIfNull(endpoints); return endpoints.MapMqtt(pattern, options => options.WebSockets.SubProtocolSelector = SelectSubProtocol); static string SelectSubProtocol(IList requestedSubProtocolValues) { // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. - return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt")); + return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt"))!; } } @@ -40,7 +39,6 @@ static string SelectSubProtocol(IList requestedSubProtocolValues) /// public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, string pattern, Action options) { - ArgumentNullException.ThrowIfNull(endpoints); return endpoints.MapConnectionHandler(pattern, options); } } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs index 0c362e1c3..abeeef347 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs @@ -13,7 +13,6 @@ using System; using System.Buffers; using System.IO.Pipelines; -using System.Net; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -27,7 +26,7 @@ sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter readonly PipeReader _input; readonly PipeWriter _output; - readonly IHttpContextFeature _httpContextFeature; + readonly IHttpContextFeature? _httpContextFeature; public AspNetCoreMqttChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) { @@ -37,13 +36,14 @@ public AspNetCoreMqttChannelAdapter(MqttPacketFormatterAdapter packetFormatterAd _output = connection.Transport.Output; _httpContextFeature = connection.Features.Get(); } + public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } public long BytesReceived { get; private set; } public long BytesSent { get; private set; } - public X509Certificate2 ClientCertificate + public X509Certificate2? ClientCertificate { get { @@ -57,14 +57,15 @@ public X509Certificate2 ClientCertificate } } - public string Endpoint + public string? Endpoint { get { if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) { var httpConnection = _httpContextFeature.HttpContext.Connection; - return httpConnection == null ? null : new IPEndPoint(httpConnection.RemoteIpAddress, httpConnection.RemotePort).ToString(); + var remoteAddress = httpConnection.RemoteIpAddress; + return remoteAddress == null ? null : $"{remoteAddress}:{httpConnection.RemotePort}"; } return _connection.RemoteEndPoint?.ToString(); @@ -93,8 +94,8 @@ public Task ConnectAsync(CancellationToken cancellationToken) public Task DisconnectAsync(CancellationToken cancellationToken) { - _input?.Complete(); - _output?.Complete(); + _input.Complete(); + _output.Complete(); return Task.CompletedTask; } @@ -104,7 +105,7 @@ public void Dispose() _writerLock.Dispose(); } - public async Task ReceivePacketAsync(CancellationToken cancellationToken) + public async Task ReceivePacketAsync(CancellationToken cancellationToken) { try { @@ -153,8 +154,8 @@ public async Task ReceivePacketAsync(CancellationToken cancellationT catch (Exception exception) { // completing the channel makes sure that there is no more data read after a protocol error - _input?.Complete(exception); - _output?.Complete(exception); + _input.Complete(exception); + _output.Complete(exception); throw; } @@ -206,7 +207,7 @@ static void WritePacketBuffer(PipeWriter output, MqttPacketBuffer buffer) var span = output.GetSpan(buffer.Length); buffer.Packet.AsSpan().CopyTo(span); - int offset = buffer.Packet.Count; + var offset = buffer.Packet.Count; buffer.Payload.CopyTo(destination: span.Slice(offset)); output.Advance(buffer.Length); } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs index 50e35964a..23a80bad9 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs @@ -20,13 +20,13 @@ public AspNetCoreMqttNetLogger(ILoggerFactory loggerFactory) _loggerFactory = loggerFactory; } - public void Publish(MqttNetLogLevel logLevel, string source, string message, object[] parameters, Exception exception) + public void Publish(MqttNetLogLevel logLevel, string? source, string? message, object[]? parameters, Exception? exception) { var logger = _loggerFactory.CreateLogger($"{categoryNamePrefix}{source}"); - logger.Log(CastLogLevel(logLevel), exception, message, parameters); + logger.Log(ConvertLogLevel(logLevel), exception, message, parameters ?? []); } - private static LogLevel CastLogLevel(MqttNetLogLevel level) + private static LogLevel ConvertLogLevel(MqttNetLogLevel? level) { return level switch { diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs index c471f0d1d..b095d549a 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs @@ -14,7 +14,7 @@ sealed class AspNetCoreMqttServerAdapter : IMqttServerAdapter { readonly MqttConnectionHandler _connectionHandler; - public Func ClientHandler + public Func? ClientHandler { get => _connectionHandler.ClientHandler; set => _connectionHandler.ClientHandler = value; diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index cf25eaca7..08730a0ad 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -19,7 +19,7 @@ sealed class MqttConnectionHandler : ConnectionHandler readonly IMqttNetLogger _logger; readonly MqttServerOptions _serverOptions; - public Func ClientHandler { get; set; } + public Func? ClientHandler { get; set; } public MqttConnectionHandler( IMqttNetLogger logger, diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs index bd4c08b39..94d8e4dff 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs @@ -8,6 +8,7 @@ using MQTTnet.Packets; using System; using System.Buffers; +using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; namespace MQTTnet.AspNetCore; @@ -17,7 +18,7 @@ static class MqttPacketFormatterAdapterExtensions public static bool TryDecode( this MqttPacketFormatterAdapter formatter, in ReadOnlySequence input, - out MqttPacket packet, + [MaybeNullWhen(false)] out MqttPacket packet, out SequencePosition consumed, out SequencePosition observed, out int bytesRead) diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 741d3e5a3..537e5e258 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -1,4 +1,5 @@ - + + net8.0 @@ -37,6 +38,7 @@ true low low + enable diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs index 1547ddf9b..f9d2a9bd0 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -6,6 +6,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Adapter; +using System; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; @@ -27,7 +28,12 @@ public static IMqttClientBuilder UseAspNetCoreMqttClientAdapterFactory(this IMqt { if (!builder.Services.Any(s => s.ServiceType == typeof(IConnectionFactory))) { - var socketConnectionFactoryType = Assembly.Load(SocketConnectionFactoryAssemblyName).GetType(SocketConnectionFactoryTypeName); + var assembly = Assembly.Load(SocketConnectionFactoryAssemblyName); + var socketConnectionFactoryType = assembly.GetType(SocketConnectionFactoryTypeName); + if (socketConnectionFactoryType == null) + { + throw new TypeLoadException($"Cannot find type {SocketConnectionFactoryTypeName} in assembly {SocketConnectionFactoryAssemblyName}"); + } builder.Services.AddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); } From 352621205ba7481b3178dbfff720b63d2d35ab74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 14 Nov 2024 21:51:22 +0800 Subject: [PATCH 20/85] Always dispose _connection of AspNetCoreMqttChannelAdapter --- .../Internal/AspNetCoreMqttChannelAdapter.cs | 14 +++++++++++--- .../Internal/MqttConnectionHandler.cs | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs index abeeef347..dd64ea890 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs @@ -19,7 +19,7 @@ namespace MQTTnet.AspNetCore; -sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter +sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter, IAsyncDisposable { readonly ConnectionContext _connection; readonly AsyncLock _writerLock = new(); @@ -30,8 +30,9 @@ sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter public AspNetCoreMqttChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) { - PacketFormatterAdapter = packetFormatterAdapter ?? throw new ArgumentNullException(nameof(packetFormatterAdapter)); - _connection = connection ?? throw new ArgumentNullException(nameof(connection)); + PacketFormatterAdapter = packetFormatterAdapter; + _connection = connection; + _input = connection.Transport.Input; _output = connection.Transport.Output; _httpContextFeature = connection.Features.Get(); @@ -103,6 +104,13 @@ public Task DisconnectAsync(CancellationToken cancellationToken) public void Dispose() { _writerLock.Dispose(); + _connection.DisposeAsync().GetAwaiter().GetResult(); + } + + public async ValueTask DisposeAsync() + { + _writerLock.Dispose(); + await _connection.DisposeAsync(); } public async Task ReceivePacketAsync(CancellationToken cancellationToken) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 08730a0ad..efc831df2 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -47,7 +47,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) } var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); - using var adapter = new AspNetCoreMqttChannelAdapter(formatter, connection); + await using var adapter = new AspNetCoreMqttChannelAdapter(formatter, connection); await clientHandler(adapter).ConfigureAwait(false); } } \ No newline at end of file From b4b0b04008a1fbd5c03dbf0e1d153fb31060c87a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 14 Nov 2024 22:59:20 +0800 Subject: [PATCH 21/85] UseTls --- Samples/Server/Server_ASP_NET_Samples.cs | 4 +- .../AspNetCoreMqttClientAdapterFactory.cs | 147 ++++++++++++++- .../Internal/DuplexPipeStream.cs | 173 ++++++++++++++++++ 3 files changed, 320 insertions(+), 4 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 3fa1cb6aa..4ee00e6cb 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -33,7 +33,7 @@ public static Task Start_Server_With_WebSockets_Support() kestrel.ListenAnyIP(1883, l => l.UseMqtt()); // mqtt over tls over tcp - kestrel.ListenAnyIP(1884, l => l.UseHttps().UseMqtt()); + kestrel.ListenLocalhost(1884, l => l.UseHttps().UseMqtt()); // This will allow MQTT connections based on HTTP WebSockets with URI "localhost:5000/mqtt" // See code below for URI configuration. @@ -86,7 +86,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { await Task.Delay(3000); using var client = _mqttClientFactory.CreateMqttClient(); - var options = new MqttClientOptionsBuilder().WithTcpServer("localhost").Build(); + var options = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://localhost:1883").Build(); await client.ConnectAsync(options, stoppingToken); } } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index 255b5f08b..3b8d0846f 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -3,13 +3,19 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; using System; +using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Net; +using System.Net.Security; using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; namespace MQTTnet.AspNetCore @@ -32,9 +38,11 @@ public async ValueTask CreateClientAdapterAsync(MqttClientO case MqttClientTcpOptions tcpOptions: { var endPoint = await CreateIPEndPointAsync(tcpOptions.RemoteEndpoint); - var tcpConnection = await _connectionFactory.ConnectAsync(endPoint); + var connection = await _connectionFactory.ConnectAsync(endPoint); + await AuthenticateAsClientAsync(connection, tcpOptions); + var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttBufferWriter(4096, 65535)); - return new AspNetCoreMqttChannelAdapter(formatter, tcpConnection); + return new AspNetCoreMqttChannelAdapter(formatter, connection); } default: { @@ -61,5 +69,140 @@ private static async ValueTask CreateIPEndPointAsync(EndPoint endpoi throw new NotSupportedException("Only supports IPEndPoint or DnsEndPoint for now."); } + + + private static async ValueTask AuthenticateAsClientAsync(ConnectionContext connection, MqttClientTcpOptions tcpOptions) + { + if (tcpOptions.TlsOptions?.UseTls != true) + { + return; + } + + var targetHost = tcpOptions.TlsOptions.TargetHost; + if (string.IsNullOrEmpty(targetHost)) + { + if (tcpOptions.RemoteEndpoint is DnsEndPoint dns) + { + targetHost = dns.Host; + } + } + + SslStream sslStream; + var networkStream = new DuplexPipeStream(connection.Transport); + if (tcpOptions.TlsOptions.CertificateSelectionHandler != null) + { + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: true, + InternalUserCertificateValidationCallback, + InternalUserCertificateSelectionCallback); + } + else + { + // Use a different constructor depending on the options for MQTTnet so that we do not have + // to copy the exact same behavior of the selection handler. + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: true, + InternalUserCertificateValidationCallback); + } + + var sslOptions = new SslClientAuthenticationOptions + { + ApplicationProtocols = tcpOptions.TlsOptions.ApplicationProtocols, + ClientCertificates = LoadCertificates(), + EnabledSslProtocols = tcpOptions.TlsOptions.SslProtocol, + CertificateRevocationCheckMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode, + TargetHost = targetHost, + CipherSuitesPolicy = tcpOptions.TlsOptions.CipherSuitesPolicy, + EncryptionPolicy = tcpOptions.TlsOptions.EncryptionPolicy, + AllowRenegotiation = tcpOptions.TlsOptions.AllowRenegotiation + }; + + if (tcpOptions.TlsOptions.TrustChain?.Count > 0) + { + sslOptions.CertificateChainPolicy = new X509ChainPolicy + { + TrustMode = X509ChainTrustMode.CustomRootTrust, + VerificationFlags = X509VerificationFlags.IgnoreEndRevocationUnknown, + RevocationMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode + }; + + sslOptions.CertificateChainPolicy.CustomTrustStore.AddRange(tcpOptions.TlsOptions.TrustChain); + } + + try + { + await sslStream.AuthenticateAsClientAsync(sslOptions).ConfigureAwait(false); + } + catch (Exception) + { + await sslStream.DisposeAsync(); + throw; + } + + connection.Transport = new StreamDuplexPipe(sslStream); + connection.ConnectionClosed.Register(() => + { + sslStream.Dispose(); + }); + connection.Features.Set(new TlsConnectionFeature()); + + X509Certificate InternalUserCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection? localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers) + { + var certificateSelectionHandler = tcpOptions?.TlsOptions?.CertificateSelectionHandler; + if (certificateSelectionHandler != null) + { + var eventArgs = new MqttClientCertificateSelectionEventArgs(targetHost, localCertificates, remoteCertificate, acceptableIssuers, tcpOptions); + return certificateSelectionHandler(eventArgs); + } + + if (localCertificates?.Count > 0) + { + return localCertificates[0]; + } + + return null!; + } + + bool InternalUserCertificateValidationCallback(object sender, X509Certificate? x509Certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) + { + var certificateValidationHandler = tcpOptions?.TlsOptions?.CertificateValidationHandler; + if (certificateValidationHandler != null) + { + var eventArgs = new MqttClientCertificateValidationEventArgs(x509Certificate, chain, sslPolicyErrors, tcpOptions); + return certificateValidationHandler(eventArgs); + } + + if (tcpOptions?.TlsOptions?.IgnoreCertificateChainErrors ?? false) + { + sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateChainErrors; + } + + return sslPolicyErrors == SslPolicyErrors.None; + } + + X509CertificateCollection? LoadCertificates() + { + return tcpOptions.TlsOptions.ClientCertificatesProvider?.GetCertificates(); + } + } + + private class StreamDuplexPipe(Stream stream) : IDuplexPipe + { + public PipeReader Input { get; } = PipeReader.Create(stream); + + public PipeWriter Output { get; } = PipeWriter.Create(stream); + } + + private class TlsConnectionFeature : ITlsConnectionFeature + { + public X509Certificate2? ClientCertificate { get; set; } + + public Task GetClientCertificateAsync(CancellationToken cancellationToken) + { + return Task.FromResult(ClientCertificate); + } + } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs b/Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs new file mode 100644 index 000000000..6044b486e --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs @@ -0,0 +1,173 @@ +using System; +using System.Buffers; +using System.IO; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + sealed class DuplexPipeStream : Stream + { + private readonly PipeReader input; + private readonly PipeWriter output; + private readonly bool throwOnCancelled; + private volatile bool cancelCalled; + + public DuplexPipeStream(IDuplexPipe duplexPipe, bool throwOnCancelled = false) + { + input = duplexPipe.Input; + output = duplexPipe.Output; + this.throwOnCancelled = throwOnCancelled; + } + + public void CancelPendingRead() + { + cancelCalled = true; + input.CancelPendingRead(); + } + + /// + public override bool CanRead => true; + + /// + public override bool CanSeek => false; + + /// + public override bool CanWrite => true; + + /// + public override long Length => throw new NotSupportedException(); + + /// + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + /// + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + /// + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + /// + public override int Read(byte[] buffer, int offset, int count) + { + var task = ReadAsyncInternal(new Memory(buffer, offset, count), default); + return task.IsCompleted ? task.Result : task.AsTask().GetAwaiter().GetResult(); + } + + /// + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } + + /// + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + return ReadAsyncInternal(destination, cancellationToken); + } + + /// + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } + + /// + public override async Task WriteAsync(byte[]? buffer, int offset, int count, CancellationToken cancellationToken) + { + await output.WriteAsync(buffer.AsMemory(offset, count), cancellationToken); + } + + /// + public override async ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + await output.WriteAsync(source, cancellationToken); + } + + /// + public override void Flush() + { + FlushAsync(CancellationToken.None).GetAwaiter().GetResult(); + } + + /// + public override async Task FlushAsync(CancellationToken cancellationToken) + { + await output.FlushAsync(cancellationToken); + } + + + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] + private async ValueTask ReadAsyncInternal(Memory destination, CancellationToken cancellationToken) + { + while (true) + { + var result = await input.ReadAsync(cancellationToken); + var readableBuffer = result.Buffer; + try + { + if (throwOnCancelled && result.IsCanceled && cancelCalled) + { + // Reset the bool + cancelCalled = false; + throw new OperationCanceledException(); + } + + if (!readableBuffer.IsEmpty) + { + // buffer.Count is int + var count = (int)Math.Min(readableBuffer.Length, destination.Length); + readableBuffer = readableBuffer.Slice(0, count); + readableBuffer.CopyTo(destination.Span); + return count; + } + + if (result.IsCompleted) + { + return 0; + } + } + finally + { + input.AdvanceTo(readableBuffer.End, readableBuffer.End); + } + } + } + + /// + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + return TaskToAsyncResult.Begin(ReadAsync(buffer, offset, count), callback, state); + } + + /// + public override int EndRead(IAsyncResult asyncResult) + { + return TaskToAsyncResult.End(asyncResult); + } + + /// + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + return TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state); + } + + /// + public override void EndWrite(IAsyncResult asyncResult) + { + TaskToAsyncResult.End(asyncResult); + } + } +} From d9e02cec511a6abe283469efeb5215b36759e216 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 15 Nov 2024 13:42:53 +0800 Subject: [PATCH 22/85] Restore the IMqttClientAdapterFactory interface --- Samples/Server/Server_ASP_NET_Samples.cs | 3 +- .../AspNetCoreMqttClientAdapterFactory.cs | 180 +-------------- .../Internal/ClientConnectionContext.cs | 212 ++++++++++++++++++ .../Internal/DuplexPipeStream.cs | 173 -------------- ...reMqttChannelAdapter.cs => MqttChannel.cs} | 26 +-- .../Internal/MqttClientChannelAdapter.cs | 105 +++++++++ .../Internal/MqttConnectionHandler.cs | 2 +- .../Internal/MqttServerChannelAdapter.cs | 66 ++++++ .../MqttClientBuilderExtensions.cs | 18 -- .../ASP/Mockups/ConnectionHandlerMockup.cs | 4 +- .../ASP/MqttConnectionContextTest.cs | 6 +- .../Adapter/IMqttClientAdapterFactory.cs | 3 +- .../MqttClientAdapterFactory.cs | 6 +- .../LowLevelClient/LowLevelMqttClient.cs | 2 +- Source/MQTTnet/MqttClient.cs | 2 +- 15 files changed, 403 insertions(+), 405 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs delete mode 100644 Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs rename Source/MQTTnet.AspnetCore/Internal/{AspNetCoreMqttChannelAdapter.cs => MqttChannel.cs} (90%) create mode 100644 Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs create mode 100644 Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 4ee00e6cb..9c101fdec 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -86,8 +86,9 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { await Task.Delay(3000); using var client = _mqttClientFactory.CreateMqttClient(); - var options = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://localhost:1883").Build(); + var options = new MqttClientOptionsBuilder().WithConnectionUri("mqtts://localhost:1884").WithTlsOptions(x => x.WithIgnoreCertificateChainErrors()).Build(); await client.ConnectAsync(options, stoppingToken); + await client.DisconnectAsync(); } } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index 3b8d0846f..f5b928fb9 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -2,34 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Http.Features; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; using System; -using System.IO; -using System.IO.Pipelines; -using System.Linq; -using System.Net; -using System.Net.Security; -using System.Net.Sockets; -using System.Security.Cryptography.X509Certificates; -using System.Threading; -using System.Threading.Tasks; namespace MQTTnet.AspNetCore { sealed class AspNetCoreMqttClientAdapterFactory : IMqttClientAdapterFactory { - private readonly IConnectionFactory _connectionFactory; - - public AspNetCoreMqttClientAdapterFactory(IConnectionFactory connectionFactory) - { - _connectionFactory = connectionFactory; - } - - public async ValueTask CreateClientAdapterAsync(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) + public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) { if (options == null) throw new ArgumentNullException(nameof(options)); @@ -37,12 +19,8 @@ public async ValueTask CreateClientAdapterAsync(MqttClientO { case MqttClientTcpOptions tcpOptions: { - var endPoint = await CreateIPEndPointAsync(tcpOptions.RemoteEndpoint); - var connection = await _connectionFactory.ConnectAsync(endPoint); - await AuthenticateAsClientAsync(connection, tcpOptions); - var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttBufferWriter(4096, 65535)); - return new AspNetCoreMqttChannelAdapter(formatter, connection); + return new MqttClientChannelAdapter(formatter, tcpOptions); } default: { @@ -50,159 +28,5 @@ public async ValueTask CreateClientAdapterAsync(MqttClientO } } } - - private static async ValueTask CreateIPEndPointAsync(EndPoint endpoint) - { - if (endpoint is IPEndPoint ipEndPoint) - { - return ipEndPoint; - } - - if (endpoint is DnsEndPoint dnsEndPoint) - { - var hostEntry = await Dns.GetHostEntryAsync(dnsEndPoint.Host); - var address = hostEntry.AddressList.OrderBy(item => item.AddressFamily).FirstOrDefault(); - return address == null - ? throw new SocketException((int)SocketError.HostNotFound) - : new IPEndPoint(address, dnsEndPoint.Port); - } - - throw new NotSupportedException("Only supports IPEndPoint or DnsEndPoint for now."); - } - - - private static async ValueTask AuthenticateAsClientAsync(ConnectionContext connection, MqttClientTcpOptions tcpOptions) - { - if (tcpOptions.TlsOptions?.UseTls != true) - { - return; - } - - var targetHost = tcpOptions.TlsOptions.TargetHost; - if (string.IsNullOrEmpty(targetHost)) - { - if (tcpOptions.RemoteEndpoint is DnsEndPoint dns) - { - targetHost = dns.Host; - } - } - - SslStream sslStream; - var networkStream = new DuplexPipeStream(connection.Transport); - if (tcpOptions.TlsOptions.CertificateSelectionHandler != null) - { - sslStream = new SslStream( - networkStream, - leaveInnerStreamOpen: true, - InternalUserCertificateValidationCallback, - InternalUserCertificateSelectionCallback); - } - else - { - // Use a different constructor depending on the options for MQTTnet so that we do not have - // to copy the exact same behavior of the selection handler. - sslStream = new SslStream( - networkStream, - leaveInnerStreamOpen: true, - InternalUserCertificateValidationCallback); - } - - var sslOptions = new SslClientAuthenticationOptions - { - ApplicationProtocols = tcpOptions.TlsOptions.ApplicationProtocols, - ClientCertificates = LoadCertificates(), - EnabledSslProtocols = tcpOptions.TlsOptions.SslProtocol, - CertificateRevocationCheckMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode, - TargetHost = targetHost, - CipherSuitesPolicy = tcpOptions.TlsOptions.CipherSuitesPolicy, - EncryptionPolicy = tcpOptions.TlsOptions.EncryptionPolicy, - AllowRenegotiation = tcpOptions.TlsOptions.AllowRenegotiation - }; - - if (tcpOptions.TlsOptions.TrustChain?.Count > 0) - { - sslOptions.CertificateChainPolicy = new X509ChainPolicy - { - TrustMode = X509ChainTrustMode.CustomRootTrust, - VerificationFlags = X509VerificationFlags.IgnoreEndRevocationUnknown, - RevocationMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode - }; - - sslOptions.CertificateChainPolicy.CustomTrustStore.AddRange(tcpOptions.TlsOptions.TrustChain); - } - - try - { - await sslStream.AuthenticateAsClientAsync(sslOptions).ConfigureAwait(false); - } - catch (Exception) - { - await sslStream.DisposeAsync(); - throw; - } - - connection.Transport = new StreamDuplexPipe(sslStream); - connection.ConnectionClosed.Register(() => - { - sslStream.Dispose(); - }); - connection.Features.Set(new TlsConnectionFeature()); - - X509Certificate InternalUserCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection? localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers) - { - var certificateSelectionHandler = tcpOptions?.TlsOptions?.CertificateSelectionHandler; - if (certificateSelectionHandler != null) - { - var eventArgs = new MqttClientCertificateSelectionEventArgs(targetHost, localCertificates, remoteCertificate, acceptableIssuers, tcpOptions); - return certificateSelectionHandler(eventArgs); - } - - if (localCertificates?.Count > 0) - { - return localCertificates[0]; - } - - return null!; - } - - bool InternalUserCertificateValidationCallback(object sender, X509Certificate? x509Certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) - { - var certificateValidationHandler = tcpOptions?.TlsOptions?.CertificateValidationHandler; - if (certificateValidationHandler != null) - { - var eventArgs = new MqttClientCertificateValidationEventArgs(x509Certificate, chain, sslPolicyErrors, tcpOptions); - return certificateValidationHandler(eventArgs); - } - - if (tcpOptions?.TlsOptions?.IgnoreCertificateChainErrors ?? false) - { - sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateChainErrors; - } - - return sslPolicyErrors == SslPolicyErrors.None; - } - - X509CertificateCollection? LoadCertificates() - { - return tcpOptions.TlsOptions.ClientCertificatesProvider?.GetCertificates(); - } - } - - private class StreamDuplexPipe(Stream stream) : IDuplexPipe - { - public PipeReader Input { get; } = PipeReader.Create(stream); - - public PipeWriter Output { get; } = PipeWriter.Create(stream); - } - - private class TlsConnectionFeature : ITlsConnectionFeature - { - public X509Certificate2? ClientCertificate { get; set; } - - public Task GetClientCertificateAsync(CancellationToken cancellationToken) - { - return Task.FromResult(ClientCertificate); - } - } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs new file mode 100644 index 000000000..fde2d2bcd --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -0,0 +1,212 @@ +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; +using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore.Internal +{ + sealed class ClientConnectionContext : ConnectionContext + { + private readonly Stream _stream; + private readonly CancellationTokenSource _connectionCloseSource = new(); + + public override IDuplexPipe Transport { get; set; } + + public override CancellationToken ConnectionClosed + { + get => _connectionCloseSource.Token; + set => throw new InvalidOperationException(); + } + + public override string ConnectionId { get; set; } = Guid.NewGuid().ToString(); + + public override IFeatureCollection Features { get; } = new FeatureCollection(); + + public override IDictionary Items { get; set; } = new Dictionary(); + + public ClientConnectionContext(Stream stream) + { + _stream = stream; + Transport = new StreamTransport(stream); + } + + public override async ValueTask DisposeAsync() + { + await _stream.DisposeAsync(); + _connectionCloseSource.Cancel(); + _connectionCloseSource.Dispose(); + } + + public override void Abort() + { + _stream.Close(); + _connectionCloseSource.Cancel(); + } + + + public static async Task CreateAsync(MqttClientTcpOptions tcpOptions, CancellationToken cancellationToken) + { + var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + try + { + await socket.ConnectAsync(tcpOptions.RemoteEndpoint, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + socket.Dispose(); + throw; + } + + var networkStream = new NetworkStream(socket, ownsSocket: true); + if (tcpOptions.TlsOptions?.UseTls != true) + { + return new ClientConnectionContext(networkStream); + } + + var targetHost = tcpOptions.TlsOptions.TargetHost; + if (string.IsNullOrEmpty(targetHost)) + { + if (tcpOptions.RemoteEndpoint is DnsEndPoint dns) + { + targetHost = dns.Host; + } + } + + SslStream sslStream; + if (tcpOptions.TlsOptions.CertificateSelectionHandler != null) + { + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: false, + InternalUserCertificateValidationCallback, + InternalUserCertificateSelectionCallback); + } + else + { + // Use a different constructor depending on the options for MQTTnet so that we do not have + // to copy the exact same behavior of the selection handler. + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: false, + InternalUserCertificateValidationCallback); + } + + var sslOptions = new SslClientAuthenticationOptions + { + ApplicationProtocols = tcpOptions.TlsOptions.ApplicationProtocols, + ClientCertificates = LoadCertificates(), + EnabledSslProtocols = tcpOptions.TlsOptions.SslProtocol, + CertificateRevocationCheckMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode, + TargetHost = targetHost, + CipherSuitesPolicy = tcpOptions.TlsOptions.CipherSuitesPolicy, + EncryptionPolicy = tcpOptions.TlsOptions.EncryptionPolicy, + AllowRenegotiation = tcpOptions.TlsOptions.AllowRenegotiation + }; + + if (tcpOptions.TlsOptions.TrustChain?.Count > 0) + { + sslOptions.CertificateChainPolicy = new X509ChainPolicy + { + TrustMode = X509ChainTrustMode.CustomRootTrust, + VerificationFlags = X509VerificationFlags.IgnoreEndRevocationUnknown, + RevocationMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode + }; + + sslOptions.CertificateChainPolicy.CustomTrustStore.AddRange(tcpOptions.TlsOptions.TrustChain); + } + + try + { + await sslStream.AuthenticateAsClientAsync(sslOptions, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + await sslStream.DisposeAsync(); + throw; + } + + var connection = new ClientConnectionContext(sslStream) + { + LocalEndPoint = socket.LocalEndPoint, + RemoteEndPoint = socket.RemoteEndPoint, + }; + connection.Features.Set(TlsConnectionFeature.Instance); + connection.Features.Set(new ConnectionSocketFeature(socket)); + return connection; + + + X509Certificate InternalUserCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection? localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers) + { + var certificateSelectionHandler = tcpOptions?.TlsOptions?.CertificateSelectionHandler; + if (certificateSelectionHandler != null) + { + var eventArgs = new MqttClientCertificateSelectionEventArgs(targetHost, localCertificates, remoteCertificate, acceptableIssuers, tcpOptions); + return certificateSelectionHandler(eventArgs); + } + + if (localCertificates?.Count > 0) + { + return localCertificates[0]; + } + + return null!; + } + + bool InternalUserCertificateValidationCallback(object sender, X509Certificate? x509Certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) + { + var certificateValidationHandler = tcpOptions?.TlsOptions?.CertificateValidationHandler; + if (certificateValidationHandler != null) + { + var eventArgs = new MqttClientCertificateValidationEventArgs(x509Certificate, chain, sslPolicyErrors, tcpOptions); + return certificateValidationHandler(eventArgs); + } + + if (tcpOptions?.TlsOptions?.IgnoreCertificateChainErrors ?? false) + { + sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateChainErrors; + } + + return sslPolicyErrors == SslPolicyErrors.None; + } + + X509CertificateCollection? LoadCertificates() + { + return tcpOptions.TlsOptions.ClientCertificatesProvider?.GetCertificates(); + } + } + + + private class StreamTransport(Stream stream) : IDuplexPipe + { + public PipeReader Input { get; } = PipeReader.Create(stream, new StreamPipeReaderOptions(leaveOpen: true)); + + public PipeWriter Output { get; } = PipeWriter.Create(stream, new StreamPipeWriterOptions(leaveOpen: true)); + } + + private class TlsConnectionFeature : ITlsConnectionFeature + { + public static readonly TlsConnectionFeature Instance = new(); + + public X509Certificate2? ClientCertificate { get; set; } + + public Task GetClientCertificateAsync(CancellationToken cancellationToken) + { + return Task.FromResult(ClientCertificate); + } + } + + private class ConnectionSocketFeature(Socket socket) : IConnectionSocketFeature + { + public Socket Socket { get; } = socket; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs b/Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs deleted file mode 100644 index 6044b486e..000000000 --- a/Source/MQTTnet.AspnetCore/Internal/DuplexPipeStream.cs +++ /dev/null @@ -1,173 +0,0 @@ -using System; -using System.Buffers; -using System.IO; -using System.IO.Pipelines; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace MQTTnet.AspNetCore -{ - sealed class DuplexPipeStream : Stream - { - private readonly PipeReader input; - private readonly PipeWriter output; - private readonly bool throwOnCancelled; - private volatile bool cancelCalled; - - public DuplexPipeStream(IDuplexPipe duplexPipe, bool throwOnCancelled = false) - { - input = duplexPipe.Input; - output = duplexPipe.Output; - this.throwOnCancelled = throwOnCancelled; - } - - public void CancelPendingRead() - { - cancelCalled = true; - input.CancelPendingRead(); - } - - /// - public override bool CanRead => true; - - /// - public override bool CanSeek => false; - - /// - public override bool CanWrite => true; - - /// - public override long Length => throw new NotSupportedException(); - - /// - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - - /// - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } - - /// - public override void SetLength(long value) - { - throw new NotSupportedException(); - } - - /// - public override int Read(byte[] buffer, int offset, int count) - { - var task = ReadAsyncInternal(new Memory(buffer, offset, count), default); - return task.IsCompleted ? task.Result : task.AsTask().GetAwaiter().GetResult(); - } - - /// - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - { - return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); - } - - /// - public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) - { - return ReadAsyncInternal(destination, cancellationToken); - } - - /// - public override void Write(byte[] buffer, int offset, int count) - { - WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - /// - public override async Task WriteAsync(byte[]? buffer, int offset, int count, CancellationToken cancellationToken) - { - await output.WriteAsync(buffer.AsMemory(offset, count), cancellationToken); - } - - /// - public override async ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) - { - await output.WriteAsync(source, cancellationToken); - } - - /// - public override void Flush() - { - FlushAsync(CancellationToken.None).GetAwaiter().GetResult(); - } - - /// - public override async Task FlushAsync(CancellationToken cancellationToken) - { - await output.FlushAsync(cancellationToken); - } - - - [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] - private async ValueTask ReadAsyncInternal(Memory destination, CancellationToken cancellationToken) - { - while (true) - { - var result = await input.ReadAsync(cancellationToken); - var readableBuffer = result.Buffer; - try - { - if (throwOnCancelled && result.IsCanceled && cancelCalled) - { - // Reset the bool - cancelCalled = false; - throw new OperationCanceledException(); - } - - if (!readableBuffer.IsEmpty) - { - // buffer.Count is int - var count = (int)Math.Min(readableBuffer.Length, destination.Length); - readableBuffer = readableBuffer.Slice(0, count); - readableBuffer.CopyTo(destination.Span); - return count; - } - - if (result.IsCompleted) - { - return 0; - } - } - finally - { - input.AdvanceTo(readableBuffer.End, readableBuffer.End); - } - } - } - - /// - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) - { - return TaskToAsyncResult.Begin(ReadAsync(buffer, offset, count), callback, state); - } - - /// - public override int EndRead(IAsyncResult asyncResult) - { - return TaskToAsyncResult.End(asyncResult); - } - - /// - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) - { - return TaskToAsyncResult.Begin(WriteAsync(buffer, offset, count), callback, state); - } - - /// - public override void EndWrite(IAsyncResult asyncResult) - { - TaskToAsyncResult.End(asyncResult); - } - } -} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs similarity index 90% rename from Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs rename to Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index dd64ea890..ea6bbfeb3 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -5,7 +5,6 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Features; -using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Formatter; using MQTTnet.Internal; @@ -19,7 +18,7 @@ namespace MQTTnet.AspNetCore; -sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter, IAsyncDisposable +sealed class MqttChannel : IDisposable { readonly ConnectionContext _connection; readonly AsyncLock _writerLock = new(); @@ -28,7 +27,7 @@ sealed class AspNetCoreMqttChannelAdapter : IMqttChannelAdapter, IAsyncDisposabl readonly PipeWriter _output; readonly IHttpContextFeature? _httpContextFeature; - public AspNetCoreMqttChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) + public MqttChannel(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) { PacketFormatterAdapter = packetFormatterAdapter; _connection = connection; @@ -87,30 +86,15 @@ public bool IsSecureConnection } } - - public Task ConnectAsync(CancellationToken cancellationToken) + public async Task DisconnectAsync() { - return Task.CompletedTask; - } - - public Task DisconnectAsync(CancellationToken cancellationToken) - { - _input.Complete(); - _output.Complete(); - - return Task.CompletedTask; + await _input.CompleteAsync(); + await _output.CompleteAsync(); } public void Dispose() { _writerLock.Dispose(); - _connection.DisposeAsync().GetAwaiter().GetResult(); - } - - public async ValueTask DisposeAsync() - { - _writerLock.Dispose(); - await _connection.DisposeAsync(); } public async Task ReceivePacketAsync(CancellationToken cancellationToken) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs new file mode 100644 index 000000000..05a8d643c --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -0,0 +1,105 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using MQTTnet.Adapter; +using MQTTnet.AspNetCore.Internal; +using MQTTnet.Formatter; +using MQTTnet.Packets; +using System; +using System.Runtime.CompilerServices; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable +{ + private bool _disposed = false; + private ConnectionContext? _connection; + private MqttChannel? _channel; + private readonly MqttPacketFormatterAdapter _packetFormatterAdapter; + private readonly MqttClientTcpOptions _tcpOptions; + + public MqttClientChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, MqttClientTcpOptions tcpOptions) + { + _packetFormatterAdapter = packetFormatterAdapter; + _tcpOptions = tcpOptions; + } + + public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter; + + public long BytesReceived => GetChannel().BytesReceived; + + public long BytesSent => GetChannel().BytesSent; + + public X509Certificate2? ClientCertificate => GetChannel().ClientCertificate; + + public string? Endpoint => GetChannel().Endpoint; + + public bool IsSecureConnection => GetChannel().IsSecureConnection; + + + public async Task ConnectAsync(CancellationToken cancellationToken) + { + _connection = await ClientConnectionContext.CreateAsync(_tcpOptions, cancellationToken); + _channel = new MqttChannel(_packetFormatterAdapter, _connection); + } + + public async Task DisconnectAsync(CancellationToken cancellationToken) + { + if (_channel != null) + { + await _channel.DisconnectAsync(); + } + } + + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + + if (_channel != null) + { + await _channel.DisconnectAsync(); + _channel.Dispose(); + } + + if (_connection != null) + { + await _connection.DisposeAsync(); + } + } + + public void Dispose() + { + DisposeAsync().GetAwaiter().GetResult(); + } + + public Task ReceivePacketAsync(CancellationToken cancellationToken) + { + return GetChannel().ReceivePacketAsync(cancellationToken); + } + + public void ResetStatistics() + { + GetChannel().ResetStatistics(); + } + + public Task SendPacketAsync(MqttPacket packet, CancellationToken cancellationToken) + { + return GetChannel().SendPacketAsync(packet, cancellationToken); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private MqttChannel GetChannel() + { + return _channel ?? throw new InvalidOperationException(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index efc831df2..9747e6d22 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -47,7 +47,7 @@ public override async Task OnConnectedAsync(ConnectionContext connection) } var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); - await using var adapter = new AspNetCoreMqttChannelAdapter(formatter, connection); + using var adapter = new MqttServerChannelAdapter(formatter, connection); await clientHandler(adapter).ConfigureAwait(false); } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs new file mode 100644 index 000000000..7ee44eab7 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using MQTTnet.Adapter; +using MQTTnet.Formatter; +using MQTTnet.Packets; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class MqttServerChannelAdapter : IMqttChannelAdapter +{ + private readonly MqttChannel _channel; + + public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) + { + _channel = new MqttChannel(packetFormatterAdapter, connection); + } + + public MqttPacketFormatterAdapter PacketFormatterAdapter => _channel.PacketFormatterAdapter; + + public long BytesReceived => _channel.BytesReceived; + + public long BytesSent => _channel.BytesSent; + + public X509Certificate2? ClientCertificate => _channel.ClientCertificate; + + public string? Endpoint => _channel.Endpoint; + + public bool IsSecureConnection => _channel.IsSecureConnection; + + + public Task ConnectAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public Task DisconnectAsync(CancellationToken cancellationToken) + { + return _channel.DisconnectAsync(); + } + + public void Dispose() + { + _channel.Dispose(); + } + + public Task ReceivePacketAsync(CancellationToken cancellationToken) + { + return _channel.ReceivePacketAsync(cancellationToken); + } + + public void ResetStatistics() + { + _channel.ResetStatistics(); + } + + public Task SendPacketAsync(MqttPacket packet, CancellationToken cancellationToken) + { + return _channel.SendPacketAsync(packet, cancellationToken); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs index f9d2a9bd0..409d51a62 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -2,41 +2,23 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Adapter; using System; using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Reflection; namespace MQTTnet.AspNetCore { public static class MqttClientBuilderExtensions { - const string SocketConnectionFactoryTypeName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketConnectionFactory"; - const string SocketConnectionFactoryAssemblyName = "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"; - /// /// Replace the implementation of IMqttClientAdapterFactory to AspNetCoreMqttClientAdapterFactory /// /// /// - [DynamicDependency(DynamicallyAccessedMemberTypes.All, SocketConnectionFactoryTypeName, SocketConnectionFactoryAssemblyName)] public static IMqttClientBuilder UseAspNetCoreMqttClientAdapterFactory(this IMqttClientBuilder builder) { - if (!builder.Services.Any(s => s.ServiceType == typeof(IConnectionFactory))) - { - var assembly = Assembly.Load(SocketConnectionFactoryAssemblyName); - var socketConnectionFactoryType = assembly.GetType(SocketConnectionFactoryTypeName); - if (socketConnectionFactoryType == null) - { - throw new TypeLoadException($"Cannot find type {SocketConnectionFactoryTypeName} in assembly {SocketConnectionFactoryAssemblyName}"); - } - builder.Services.AddSingleton(typeof(IConnectionFactory), socketConnectionFactoryType); - } - return builder.UseMqttClientAdapterFactory(); } diff --git a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs index d221470ea..e95ce11ce 100644 --- a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs +++ b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs @@ -16,7 +16,7 @@ namespace MQTTnet.Tests.ASP.Mockups; public sealed class ConnectionHandlerMockup : IMqttServerAdapter { public Func ClientHandler { get; set; } - TaskCompletionSource Context { get; } = new(); + TaskCompletionSource Context { get; } = new(); public void Dispose() { @@ -27,7 +27,7 @@ public async Task OnConnectedAsync(ConnectionContext connection) try { var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(4096, 65535)); - var context = new AspNetCoreMqttChannelAdapter(formatter, connection); + var context = new MqttServerChannelAdapter(formatter, connection); Context.TrySetResult(context); await ClientHandler(context); diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs index b616392b7..77d3c36cd 100644 --- a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs @@ -30,7 +30,7 @@ public async Task TestCorruptedConnectPacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new AspNetCoreMqttChannelAdapter(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection); await pipe.Receive.Writer.WriteAsync(writer.AddMqttHeader(MqttControlPacketType.Connect, Array.Empty())); @@ -98,7 +98,7 @@ public async Task TestLargePacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new AspNetCoreMqttChannelAdapter(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection); await ctx.SendPacketAsync(new MqttPublishPacket { PayloadSegment = new byte[20_000] }, CancellationToken.None).ConfigureAwait(false); @@ -113,7 +113,7 @@ public async Task TestReceivePacketAsyncThrowsWhenReaderCompleted() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new AspNetCoreMqttChannelAdapter(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection); pipe.Receive.Writer.Complete(); diff --git a/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs b/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs index 6c181e4c6..3ea49381b 100644 --- a/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs +++ b/Source/MQTTnet/Adapter/IMqttClientAdapterFactory.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using MQTTnet.Diagnostics.Logger; -using System.Threading.Tasks; namespace MQTTnet.Adapter; public interface IMqttClientAdapterFactory { - ValueTask CreateClientAdapterAsync(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger); + IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger); } \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs b/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs index 16d1dd9d5..6d7e6b40e 100644 --- a/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs +++ b/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs @@ -7,13 +7,12 @@ using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; using System; -using System.Threading.Tasks; namespace MQTTnet.Implementations { public sealed class MqttClientAdapterFactory : IMqttClientAdapterFactory { - public ValueTask CreateClientAdapterAsync(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) + public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) { ArgumentNullException.ThrowIfNull(options); @@ -41,12 +40,11 @@ public ValueTask CreateClientAdapterAsync(MqttClientOptions var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax); var packetFormatterAdapter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter); - IMqttChannelAdapter adapter = new MqttChannelAdapter(channel, packetFormatterAdapter, logger) + return new MqttChannelAdapter(channel, packetFormatterAdapter, logger) { AllowPacketFragmentation = options.AllowPacketFragmentation, PacketInspector = packetInspector }; - return ValueTask.FromResult(adapter); } } } diff --git a/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs b/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs index f72fc7ed7..186e8bfd1 100644 --- a/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs +++ b/Source/MQTTnet/LowLevelClient/LowLevelMqttClient.cs @@ -54,7 +54,7 @@ public async Task ConnectAsync(MqttClientOptions options, CancellationToken canc packetInspector = new MqttPacketInspector(_inspectPacketEvent, _rootLogger); } - var newAdapter = await _clientAdapterFactory.CreateClientAdapterAsync(options, packetInspector, _rootLogger); + var newAdapter = _clientAdapterFactory.CreateClientAdapter(options, packetInspector, _rootLogger); try { diff --git a/Source/MQTTnet/MqttClient.cs b/Source/MQTTnet/MqttClient.cs index fe8b50742..f89c6bdad 100644 --- a/Source/MQTTnet/MqttClient.cs +++ b/Source/MQTTnet/MqttClient.cs @@ -121,7 +121,7 @@ public async Task ConnectAsync(MqttClientOptions option _mqttClientAlive = new CancellationTokenSource(); var mqttClientAliveToken = _mqttClientAlive.Token; - var adapter = await _adapterFactory.CreateClientAdapterAsync(options, new MqttPacketInspector(_events.InspectPacketEvent, _rootLogger), _rootLogger); + var adapter = _adapterFactory.CreateClientAdapter(options, new MqttPacketInspector(_events.InspectPacketEvent, _rootLogger), _rootLogger); _adapter = adapter ?? throw new InvalidOperationException("The adapter factory did not provide an adapter."); _unexpectedDisconnectPacket = null; From b5019bd0bacf684e83ebfd0e910418e6dc6555ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 15 Nov 2024 14:52:08 +0800 Subject: [PATCH 23/85] =?UTF-8?q?Calculate=20the=20property=20values=20?= =?UTF-8?q?=E2=80=8B=E2=80=8Bwhen=20constructing=20MqttChannel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Internal/MqttChannel.cs | 78 +++++++++---------- 1 file changed, 35 insertions(+), 43 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index ea6bbfeb3..74dff72a4 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -12,6 +12,7 @@ using System; using System.Buffers; using System.IO.Pipelines; +using System.Net; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -20,22 +21,9 @@ namespace MQTTnet.AspNetCore; sealed class MqttChannel : IDisposable { - readonly ConnectionContext _connection; readonly AsyncLock _writerLock = new(); - readonly PipeReader _input; readonly PipeWriter _output; - readonly IHttpContextFeature? _httpContextFeature; - - public MqttChannel(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) - { - PacketFormatterAdapter = packetFormatterAdapter; - _connection = connection; - - _input = connection.Transport.Input; - _output = connection.Transport.Output; - _httpContextFeature = connection.Features.Get(); - } public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } @@ -43,49 +31,53 @@ public MqttChannel(MqttPacketFormatterAdapter packetFormatterAdapter, Connection public long BytesSent { get; private set; } - public X509Certificate2? ClientCertificate + public X509Certificate2? ClientCertificate { get; } + + public string? Endpoint { get; } + + public bool IsSecureConnection { get; } + + + public MqttChannel(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) { - get - { - if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) - { - return _httpContextFeature.HttpContext.Connection.ClientCertificate; - } + var httpContextFeature = connection.Features.Get(); + var tlsConnectionFeature = connection.Features.Get(); - var tlsFeature = _connection.Features.Get(); - return tlsFeature?.ClientCertificate; - } + PacketFormatterAdapter = packetFormatterAdapter; + Endpoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint); + IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature); + ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature); + + _input = connection.Transport.Input; + _output = connection.Transport.Output; } - public string? Endpoint + private static string? GetRemoteEndPoint(IHttpContextFeature? _httpContextFeature, EndPoint? remoteEndPoint) { - get + if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) { - if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) - { - var httpConnection = _httpContextFeature.HttpContext.Connection; - var remoteAddress = httpConnection.RemoteIpAddress; - return remoteAddress == null ? null : $"{remoteAddress}:{httpConnection.RemotePort}"; - } - - return _connection.RemoteEndPoint?.ToString(); + var httpConnection = _httpContextFeature.HttpContext.Connection; + var remoteAddress = httpConnection.RemoteIpAddress; + return remoteAddress == null ? null : $"{remoteAddress}:{httpConnection.RemotePort}"; } + return remoteEndPoint?.ToString(); } - public bool IsSecureConnection + private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature) { - get - { - if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) - { - return _httpContextFeature.HttpContext.Request.IsHttps; - } + return _httpContextFeature != null && _httpContextFeature.HttpContext != null + ? _httpContextFeature.HttpContext.Request.IsHttps + : tlsConnectionFeature != null; + } - var tlsFeature = _connection.Features.Get(); - return tlsFeature != null; - } + private static X509Certificate2? GetClientCertificate(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature) + { + return _httpContextFeature != null && _httpContextFeature.HttpContext != null + ? _httpContextFeature.HttpContext.Connection.ClientCertificate + : tlsConnectionFeature?.ClientCertificate; } + public async Task DisconnectAsync() { await _input.CompleteAsync(); From c206a4a4effa961a1fd1b536f3f012bba28d6e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 15 Nov 2024 15:08:50 +0800 Subject: [PATCH 24/85] MqttServerChannelAdapter is modified to inherit MqttChannel --- .../Internal/MqttChannel.cs | 2 +- .../Internal/MqttServerChannelAdapter.cs | 48 ++++--------------- 2 files changed, 9 insertions(+), 41 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index 74dff72a4..1d1703cfe 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -19,7 +19,7 @@ namespace MQTTnet.AspNetCore; -sealed class MqttChannel : IDisposable +class MqttChannel : IDisposable { readonly AsyncLock _writerLock = new(); readonly PipeReader _input; diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs index 7ee44eab7..7e6e482f2 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -5,35 +5,23 @@ using Microsoft.AspNetCore.Connections; using MQTTnet.Adapter; using MQTTnet.Formatter; -using MQTTnet.Packets; -using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; namespace MQTTnet.AspNetCore; -sealed class MqttServerChannelAdapter : IMqttChannelAdapter +sealed class MqttServerChannelAdapter : MqttChannel, IMqttChannelAdapter { - private readonly MqttChannel _channel; - public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) + : base(packetFormatterAdapter, connection) { - _channel = new MqttChannel(packetFormatterAdapter, connection); } - public MqttPacketFormatterAdapter PacketFormatterAdapter => _channel.PacketFormatterAdapter; - - public long BytesReceived => _channel.BytesReceived; - - public long BytesSent => _channel.BytesSent; - - public X509Certificate2? ClientCertificate => _channel.ClientCertificate; - - public string? Endpoint => _channel.Endpoint; - - public bool IsSecureConnection => _channel.IsSecureConnection; - - + /// + /// This method will never be called + /// + /// + /// public Task ConnectAsync(CancellationToken cancellationToken) { return Task.CompletedTask; @@ -41,26 +29,6 @@ public Task ConnectAsync(CancellationToken cancellationToken) public Task DisconnectAsync(CancellationToken cancellationToken) { - return _channel.DisconnectAsync(); - } - - public void Dispose() - { - _channel.Dispose(); - } - - public Task ReceivePacketAsync(CancellationToken cancellationToken) - { - return _channel.ReceivePacketAsync(cancellationToken); - } - - public void ResetStatistics() - { - _channel.ResetStatistics(); - } - - public Task SendPacketAsync(MqttPacket packet, CancellationToken cancellationToken) - { - return _channel.SendPacketAsync(packet, cancellationToken); + return base.DisconnectAsync(); } } \ No newline at end of file From 0a216a94c4452eae0a5c24f4a790813da1956131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 15 Nov 2024 19:56:56 +0800 Subject: [PATCH 25/85] Add ClientConnectionContext.WebSocket --- Samples/Server/Server_ASP_NET_Samples.cs | 4 +- .../AspNetCoreMqttClientAdapterFactory.cs | 18 +- .../Internal/ClientConnectionContext.Tcp.cs | 147 +++++++++++++ .../ClientConnectionContext.WebSocket.cs | 198 ++++++++++++++++++ .../Internal/ClientConnectionContext.cs | 138 +----------- .../Internal/MqttChannel.cs | 7 +- .../Internal/MqttClientChannelAdapter.cs | 18 +- 7 files changed, 368 insertions(+), 162 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs create mode 100644 Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 9c101fdec..8dd9b3238 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -84,9 +84,9 @@ public MqttClientController(IMqttClientFactory mqttClientFactory) protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - await Task.Delay(3000); + await Task.Delay(1000); using var client = _mqttClientFactory.CreateMqttClient(); - var options = new MqttClientOptionsBuilder().WithConnectionUri("mqtts://localhost:1884").WithTlsOptions(x => x.WithIgnoreCertificateChainErrors()).Build(); + var options = new MqttClientOptionsBuilder().WithConnectionUri("ws://localhost:5000/mqtt").Build(); await client.ConnectAsync(options, stoppingToken); await client.DisconnectAsync(); } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index f5b928fb9..702f85aee 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -13,20 +13,10 @@ sealed class AspNetCoreMqttClientAdapterFactory : IMqttClientAdapterFactory { public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) { - if (options == null) throw new ArgumentNullException(nameof(options)); - - switch (options.ChannelOptions) - { - case MqttClientTcpOptions tcpOptions: - { - var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttBufferWriter(4096, 65535)); - return new MqttClientChannelAdapter(formatter, tcpOptions); - } - default: - { - throw new NotSupportedException(); - } - } + ArgumentNullException.ThrowIfNull(nameof(options)); + var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax); + var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter); + return new MqttClientChannelAdapter(formatter, options.ChannelOptions); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs new file mode 100644 index 000000000..e7385f13b --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs @@ -0,0 +1,147 @@ +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; +using System; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore.Internal +{ + partial class ClientConnectionContext : ConnectionContext + { + public static async Task CreateAsync(MqttClientTcpOptions options, CancellationToken cancellationToken) + { + var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + try + { + await socket.ConnectAsync(options.RemoteEndpoint, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + socket.Dispose(); + throw; + } + + var networkStream = new NetworkStream(socket, ownsSocket: true); + if (options.TlsOptions?.UseTls != true) + { + return new ClientConnectionContext(networkStream); + } + + var targetHost = options.TlsOptions.TargetHost; + if (string.IsNullOrEmpty(targetHost)) + { + if (options.RemoteEndpoint is DnsEndPoint dns) + { + targetHost = dns.Host; + } + } + + SslStream sslStream; + if (options.TlsOptions.CertificateSelectionHandler != null) + { + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: false, + InternalUserCertificateValidationCallback, + InternalUserCertificateSelectionCallback); + } + else + { + // Use a different constructor depending on the options for MQTTnet so that we do not have + // to copy the exact same behavior of the selection handler. + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: false, + InternalUserCertificateValidationCallback); + } + + var sslOptions = new SslClientAuthenticationOptions + { + ApplicationProtocols = options.TlsOptions.ApplicationProtocols, + ClientCertificates = LoadCertificates(), + EnabledSslProtocols = options.TlsOptions.SslProtocol, + CertificateRevocationCheckMode = options.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : options.TlsOptions.RevocationMode, + TargetHost = targetHost, + CipherSuitesPolicy = options.TlsOptions.CipherSuitesPolicy, + EncryptionPolicy = options.TlsOptions.EncryptionPolicy, + AllowRenegotiation = options.TlsOptions.AllowRenegotiation + }; + + if (options.TlsOptions.TrustChain?.Count > 0) + { + sslOptions.CertificateChainPolicy = new X509ChainPolicy + { + TrustMode = X509ChainTrustMode.CustomRootTrust, + VerificationFlags = X509VerificationFlags.IgnoreEndRevocationUnknown, + RevocationMode = options.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : options.TlsOptions.RevocationMode + }; + + sslOptions.CertificateChainPolicy.CustomTrustStore.AddRange(options.TlsOptions.TrustChain); + } + + try + { + await sslStream.AuthenticateAsClientAsync(sslOptions, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + await sslStream.DisposeAsync(); + throw; + } + + var connection = new ClientConnectionContext(sslStream) + { + LocalEndPoint = socket.LocalEndPoint, + RemoteEndPoint = socket.RemoteEndPoint, + }; + connection.Features.Set(TlsConnectionFeature.Instance); + connection.Features.Set(new ConnectionSocketFeature(socket)); + return connection; + + + X509Certificate InternalUserCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection? localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers) + { + var certificateSelectionHandler = options?.TlsOptions?.CertificateSelectionHandler; + if (certificateSelectionHandler != null) + { + var eventArgs = new MqttClientCertificateSelectionEventArgs(targetHost, localCertificates, remoteCertificate, acceptableIssuers, options); + return certificateSelectionHandler(eventArgs); + } + + if (localCertificates?.Count > 0) + { + return localCertificates[0]; + } + + return null!; + } + + bool InternalUserCertificateValidationCallback(object sender, X509Certificate? x509Certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) + { + var certificateValidationHandler = options?.TlsOptions?.CertificateValidationHandler; + if (certificateValidationHandler != null) + { + var eventArgs = new MqttClientCertificateValidationEventArgs(x509Certificate, chain, sslPolicyErrors, options); + return certificateValidationHandler(eventArgs); + } + + if (options?.TlsOptions?.IgnoreCertificateChainErrors ?? false) + { + sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateChainErrors; + } + + return sslPolicyErrors == SslPolicyErrors.None; + } + + X509CertificateCollection? LoadCertificates() + { + return options.TlsOptions.ClientCertificatesProvider?.GetCertificates(); + } + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs new file mode 100644 index 000000000..4cf7a8add --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -0,0 +1,198 @@ +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using System; +using System.IO; +using System.Net; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore.Internal +{ + partial class ClientConnectionContext : ConnectionContext + { + public static async Task CreateAsync(MqttClientWebSocketOptions options, CancellationToken cancellationToken) + { + var uri = new UriBuilder(new Uri(options.Uri, UriKind.Absolute)) + { + Scheme = options.TlsOptions?.UseTls == true ? Uri.UriSchemeWss : Uri.UriSchemeWs + }.Uri; + + var clientWebSocket = new ClientWebSocket(); + try + { + SetupClientWebSocket(clientWebSocket, options); + await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); + } + catch + { + // Prevent a memory leak when always creating new instance which will fail while connecting. + clientWebSocket.Dispose(); + throw; + } + + var webSocketStream = new WebSocketStream(clientWebSocket); + var connection = new ClientConnectionContext(webSocketStream) + { + LocalEndPoint = null, + RemoteEndPoint = new DnsEndPoint(uri.Host, uri.Port), + }; + + if (uri.Scheme == Uri.UriSchemeWss) + { + connection.Features.Set(TlsConnectionFeature.Instance); + } + return connection; + } + + private static void SetupClientWebSocket(ClientWebSocket clientWebSocket, MqttClientWebSocketOptions options) + { + if (options.ProxyOptions != null) + { + clientWebSocket.Options.Proxy = CreateProxy(options); + } + + if (options.RequestHeaders != null) + { + foreach (var requestHeader in options.RequestHeaders) + { + clientWebSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); + } + } + + if (options.SubProtocols != null) + { + foreach (var subProtocol in options.SubProtocols) + { + clientWebSocket.Options.AddSubProtocol(subProtocol); + } + } + + if (options.CookieContainer != null) + { + clientWebSocket.Options.Cookies = options.CookieContainer; + } + + if (options.TlsOptions?.UseTls == true) + { + var certificates = options.TlsOptions?.ClientCertificatesProvider?.GetCertificates(); + if (certificates?.Count > 0) + { + clientWebSocket.Options.ClientCertificates = certificates; + } + } + + // Only set the value if it is actually true. This property is not supported on all platforms + // and will throw a _PlatformNotSupported_ (i.e. WASM) exception when being used regardless of the actual value. + if (options.UseDefaultCredentials) + { + clientWebSocket.Options.UseDefaultCredentials = options.UseDefaultCredentials; + } + + if (options.KeepAliveInterval != WebSocket.DefaultKeepAliveInterval) + { + clientWebSocket.Options.KeepAliveInterval = options.KeepAliveInterval; + } + + if (options.Credentials != null) + { + clientWebSocket.Options.Credentials = options.Credentials; + } + + var certificateValidationHandler = options.TlsOptions?.CertificateValidationHandler; + if (certificateValidationHandler != null) + { + clientWebSocket.Options.RemoteCertificateValidationCallback = (_, certificate, chain, sslPolicyErrors) => + { + // TODO: Find a way to add client options to same callback. Problem is that they have a different type. + var context = new MqttClientCertificateValidationEventArgs(certificate, chain, sslPolicyErrors, options); + return certificateValidationHandler(context); + }; + + var certificateSelectionHandler = options.TlsOptions?.CertificateSelectionHandler; + if (certificateSelectionHandler != null) + { + throw new NotSupportedException("Remote certificate selection callback is not supported for WebSocket connections."); + } + } + } + + private static IWebProxy? CreateProxy(MqttClientWebSocketOptions options) + { + if (!Uri.TryCreate(options.ProxyOptions?.Address, UriKind.Absolute, out var proxyUri)) + { + return null; + } + + + WebProxy webProxy; + if (!string.IsNullOrEmpty(options.ProxyOptions.Username) && !string.IsNullOrEmpty(options.ProxyOptions.Password)) + { + var credentials = new NetworkCredential(options.ProxyOptions.Username, options.ProxyOptions.Password, options.ProxyOptions.Domain); + webProxy = new WebProxy(proxyUri, options.ProxyOptions.BypassOnLocal, options.ProxyOptions.BypassList, credentials); + } + else + { + webProxy = new WebProxy(proxyUri, options.ProxyOptions.BypassOnLocal, options.ProxyOptions.BypassList); + } + + if (options.ProxyOptions.UseDefaultCredentials) + { + // Only update the property if required because setting it to false will alter + // the used credentials internally! + webProxy.UseDefaultCredentials = true; + } + + return webProxy; + } + + + private class WebSocketStream : Stream + { + private readonly WebSocket _webSocket; + + public WebSocketStream(WebSocket webSocket) + { + _webSocket = webSocket; + } + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() { } + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _webSocket.SendAsync(buffer, WebSocketMessageType.Binary, true, cancellationToken); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var result = await _webSocket.ReceiveAsync(buffer, cancellationToken); + return result.MessageType == WebSocketMessageType.Close ? 0 : result.Count; + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + protected override void Dispose(bool disposing) + { + _webSocket.Dispose(); + base.Dispose(disposing); + } + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs index fde2d2bcd..547c7fa9d 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -5,8 +5,6 @@ using System.Collections.Generic; using System.IO; using System.IO.Pipelines; -using System.Net; -using System.Net.Security; using System.Net.Sockets; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -14,7 +12,7 @@ namespace MQTTnet.AspNetCore.Internal { - sealed class ClientConnectionContext : ConnectionContext + sealed partial class ClientConnectionContext : ConnectionContext { private readonly Stream _stream; private readonly CancellationTokenSource _connectionCloseSource = new(); @@ -52,139 +50,7 @@ public override void Abort() _connectionCloseSource.Cancel(); } - - public static async Task CreateAsync(MqttClientTcpOptions tcpOptions, CancellationToken cancellationToken) - { - var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); - try - { - await socket.ConnectAsync(tcpOptions.RemoteEndpoint, cancellationToken).ConfigureAwait(false); - } - catch (Exception) - { - socket.Dispose(); - throw; - } - - var networkStream = new NetworkStream(socket, ownsSocket: true); - if (tcpOptions.TlsOptions?.UseTls != true) - { - return new ClientConnectionContext(networkStream); - } - - var targetHost = tcpOptions.TlsOptions.TargetHost; - if (string.IsNullOrEmpty(targetHost)) - { - if (tcpOptions.RemoteEndpoint is DnsEndPoint dns) - { - targetHost = dns.Host; - } - } - - SslStream sslStream; - if (tcpOptions.TlsOptions.CertificateSelectionHandler != null) - { - sslStream = new SslStream( - networkStream, - leaveInnerStreamOpen: false, - InternalUserCertificateValidationCallback, - InternalUserCertificateSelectionCallback); - } - else - { - // Use a different constructor depending on the options for MQTTnet so that we do not have - // to copy the exact same behavior of the selection handler. - sslStream = new SslStream( - networkStream, - leaveInnerStreamOpen: false, - InternalUserCertificateValidationCallback); - } - - var sslOptions = new SslClientAuthenticationOptions - { - ApplicationProtocols = tcpOptions.TlsOptions.ApplicationProtocols, - ClientCertificates = LoadCertificates(), - EnabledSslProtocols = tcpOptions.TlsOptions.SslProtocol, - CertificateRevocationCheckMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode, - TargetHost = targetHost, - CipherSuitesPolicy = tcpOptions.TlsOptions.CipherSuitesPolicy, - EncryptionPolicy = tcpOptions.TlsOptions.EncryptionPolicy, - AllowRenegotiation = tcpOptions.TlsOptions.AllowRenegotiation - }; - - if (tcpOptions.TlsOptions.TrustChain?.Count > 0) - { - sslOptions.CertificateChainPolicy = new X509ChainPolicy - { - TrustMode = X509ChainTrustMode.CustomRootTrust, - VerificationFlags = X509VerificationFlags.IgnoreEndRevocationUnknown, - RevocationMode = tcpOptions.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : tcpOptions.TlsOptions.RevocationMode - }; - - sslOptions.CertificateChainPolicy.CustomTrustStore.AddRange(tcpOptions.TlsOptions.TrustChain); - } - - try - { - await sslStream.AuthenticateAsClientAsync(sslOptions, cancellationToken).ConfigureAwait(false); - } - catch (Exception) - { - await sslStream.DisposeAsync(); - throw; - } - - var connection = new ClientConnectionContext(sslStream) - { - LocalEndPoint = socket.LocalEndPoint, - RemoteEndPoint = socket.RemoteEndPoint, - }; - connection.Features.Set(TlsConnectionFeature.Instance); - connection.Features.Set(new ConnectionSocketFeature(socket)); - return connection; - - - X509Certificate InternalUserCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection? localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers) - { - var certificateSelectionHandler = tcpOptions?.TlsOptions?.CertificateSelectionHandler; - if (certificateSelectionHandler != null) - { - var eventArgs = new MqttClientCertificateSelectionEventArgs(targetHost, localCertificates, remoteCertificate, acceptableIssuers, tcpOptions); - return certificateSelectionHandler(eventArgs); - } - - if (localCertificates?.Count > 0) - { - return localCertificates[0]; - } - - return null!; - } - - bool InternalUserCertificateValidationCallback(object sender, X509Certificate? x509Certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) - { - var certificateValidationHandler = tcpOptions?.TlsOptions?.CertificateValidationHandler; - if (certificateValidationHandler != null) - { - var eventArgs = new MqttClientCertificateValidationEventArgs(x509Certificate, chain, sslPolicyErrors, tcpOptions); - return certificateValidationHandler(eventArgs); - } - - if (tcpOptions?.TlsOptions?.IgnoreCertificateChainErrors ?? false) - { - sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateChainErrors; - } - - return sslPolicyErrors == SslPolicyErrors.None; - } - - X509CertificateCollection? LoadCertificates() - { - return tcpOptions.TlsOptions.ClientCertificatesProvider?.GetCertificates(); - } - } - - + private class StreamTransport(Stream stream) : IDuplexPipe { public PipeReader Input { get; } = PipeReader.Create(stream, new StreamPipeReaderOptions(leaveOpen: true)); diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index 1d1703cfe..256438b84 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -60,7 +60,10 @@ public MqttChannel(MqttPacketFormatterAdapter packetFormatterAdapter, Connection var remoteAddress = httpConnection.RemoteIpAddress; return remoteAddress == null ? null : $"{remoteAddress}:{httpConnection.RemotePort}"; } - return remoteEndPoint?.ToString(); + + return remoteEndPoint is DnsEndPoint dnsEndPoint + ? $"{dnsEndPoint.Host}:{dnsEndPoint.Port}" + : remoteEndPoint?.ToString(); } private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature) @@ -84,7 +87,7 @@ public async Task DisconnectAsync() await _output.CompleteAsync(); } - public void Dispose() + public virtual void Dispose() { _writerLock.Dispose(); } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index 05a8d643c..ed1dc1f6a 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -21,12 +21,12 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable private ConnectionContext? _connection; private MqttChannel? _channel; private readonly MqttPacketFormatterAdapter _packetFormatterAdapter; - private readonly MqttClientTcpOptions _tcpOptions; + private readonly IMqttClientChannelOptions _channelOptions; - public MqttClientChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, MqttClientTcpOptions tcpOptions) + public MqttClientChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, IMqttClientChannelOptions channelOptions) { _packetFormatterAdapter = packetFormatterAdapter; - _tcpOptions = tcpOptions; + _channelOptions = channelOptions; } public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter; @@ -44,16 +44,18 @@ public MqttClientChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapte public async Task ConnectAsync(CancellationToken cancellationToken) { - _connection = await ClientConnectionContext.CreateAsync(_tcpOptions, cancellationToken); + _connection = _channelOptions switch + { + MqttClientTcpOptions tcpOptions => await ClientConnectionContext.CreateAsync(tcpOptions, cancellationToken), + MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken), + _ => throw new NotSupportedException(), + }; _channel = new MqttChannel(_packetFormatterAdapter, _connection); } public async Task DisconnectAsync(CancellationToken cancellationToken) { - if (_channel != null) - { - await _channel.DisconnectAsync(); - } + await GetChannel().DisconnectAsync(); } public async ValueTask DisposeAsync() From 580257acb1298e51494b627ea6edf3461a085f4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 15 Nov 2024 20:23:14 +0800 Subject: [PATCH 26/85] Add LICENSE --- .../Internal/ClientConnectionContext.Tcp.cs | 4 ++++ .../ClientConnectionContext.WebSocket.cs | 15 +++++++-------- .../Internal/ClientConnectionContext.cs | 16 +++++++++++++--- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs index e7385f13b..b16d27bac 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs index 4cf7a8add..6bea4f91b 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Features; using System; @@ -147,14 +151,9 @@ private static void SetupClientWebSocket(ClientWebSocket clientWebSocket, MqttCl } - private class WebSocketStream : Stream + private class WebSocketStream(WebSocket webSocket) : Stream { - private readonly WebSocket _webSocket; - - public WebSocketStream(WebSocket webSocket) - { - _webSocket = webSocket; - } + private readonly WebSocket _webSocket = webSocket; public override bool CanRead => true; public override bool CanSeek => false; @@ -174,7 +173,7 @@ public override void Flush() { } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - return _webSocket.SendAsync(buffer, WebSocketMessageType.Binary, true, cancellationToken); + return _webSocket.SendAsync(buffer, WebSocketMessageType.Binary, false, cancellationToken); } public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs index 547c7fa9d..f578013a9 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; @@ -17,6 +21,8 @@ sealed partial class ClientConnectionContext : ConnectionContext private readonly Stream _stream; private readonly CancellationTokenSource _connectionCloseSource = new(); + private IDictionary? _items; + public override IDuplexPipe Transport { get; set; } public override CancellationToken ConnectionClosed @@ -25,11 +31,15 @@ public override CancellationToken ConnectionClosed set => throw new InvalidOperationException(); } - public override string ConnectionId { get; set; } = Guid.NewGuid().ToString(); + public override string ConnectionId { get; set; } = string.Empty; public override IFeatureCollection Features { get; } = new FeatureCollection(); - public override IDictionary Items { get; set; } = new Dictionary(); + public override IDictionary Items + { + get => _items ??= new Dictionary(); + set => _items = value; + } public ClientConnectionContext(Stream stream) { @@ -50,7 +60,7 @@ public override void Abort() _connectionCloseSource.Cancel(); } - + private class StreamTransport(Stream stream) : IDuplexPipe { public PipeReader Input { get; } = PipeReader.Create(stream, new StreamPipeReaderOptions(leaveOpen: true)); From 4f6b7ae645aabaaa324553404448e2d7cb09d9ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 15 Nov 2024 22:19:44 +0800 Subject: [PATCH 27/85] Add support for CreateLowLevelMqttClient --- Source/MQTTnet.AspnetCore/IMqttClientFactory.cs | 6 +++++- .../Internal/AspNetCoreMqttClientFactory.cs | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs index 41a7ce551..30472ce54 100644 --- a/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs +++ b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs @@ -2,10 +2,14 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.LowLevelClient; + namespace MQTTnet.AspNetCore { public interface IMqttClientFactory { IMqttClient CreateMqttClient(); + + ILowLevelMqttClient CreateLowLevelMqttClient(); } -} +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs index 68c8b0085..4c46755c6 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs @@ -4,6 +4,7 @@ using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; +using MQTTnet.LowLevelClient; namespace MQTTnet.AspNetCore { @@ -24,5 +25,10 @@ public IMqttClient CreateMqttClient() { return new MqttClient(_mqttClientAdapterFactory, _logger); } + + public ILowLevelMqttClient CreateLowLevelMqttClient() + { + return new LowLevelMqttClient(_mqttClientAdapterFactory, _logger); + } } } From b25159b915fd74dda207ca2adf14429bbac14e1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 11:42:58 +0800 Subject: [PATCH 28/85] AddMqttClient: Use AspNetCoreMqttClientAdapterFactory as the default IMqttClientAdapterFactory --- Samples/Server/Server_ASP_NET_Samples.cs | 2 +- .../MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs | 11 +++++++++++ .../MQTTnet.AspnetCore/ServiceCollectionExtensions.cs | 3 +-- ...MessageProcessingMqttConnectionContextBenchmark.cs | 2 +- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 8dd9b3238..73128a40d 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -24,7 +24,7 @@ public static Task Start_Server_With_WebSockets_Support() { var builder = WebApplication.CreateBuilder(); builder.Services.AddMqttServer(); - builder.Services.AddMqttClient().UseAspNetCoreMqttClientAdapterFactory(); + builder.Services.AddMqttClient(); builder.Services.AddHostedService(); builder.WebHost.UseKestrel(kestrel => diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs index 409d51a62..d41a20567 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -5,6 +5,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Adapter; +using MQTTnet.Implementations; using System; using System.Diagnostics.CodeAnalysis; @@ -12,6 +13,16 @@ namespace MQTTnet.AspNetCore { public static class MqttClientBuilderExtensions { + /// + /// Replace the implementation of IMqttClientAdapterFactory to MQTTnet.Implementations.MqttClientAdapterFactory + /// + /// + /// + public static IMqttClientBuilder UseMQTTnetMqttClientAdapterFactory(this IMqttClientBuilder builder) + { + return builder.UseMqttClientAdapterFactory(); + } + /// /// Replace the implementation of IMqttClientAdapterFactory to AspNetCoreMqttClientAdapterFactory /// diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index ae4e613b9..4d959c972 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -6,7 +6,6 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; -using MQTTnet.Implementations; using MQTTnet.Server; namespace MQTTnet.AspNetCore; @@ -39,7 +38,7 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) /// public static IMqttClientBuilder AddMqttClient(this IServiceCollection services) { - services.TryAddSingleton(); + services.TryAddSingleton(); services.TryAddSingleton(); return services.AddMqtt(); } diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs index f76bd4058..5da248e20 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs @@ -27,7 +27,7 @@ public void Setup() .ConfigureServices(services => { services.AddMqttServer(); - services.AddMqttClient().UseAspNetCoreMqttClientAdapterFactory(); + services.AddMqttClient(); }) .Build(); From ecb404fb358b0375a7bcbd2753304a0e0aeed51b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 12:56:42 +0800 Subject: [PATCH 29/85] Check that UseMqtt() and MapMqtt() are used. --- .../ConnectionBuilderExtensions.cs | 2 ++ .../EndpointRouteBuilderExtensions.cs | 2 ++ .../Internal/AspNetCoreMqttHostedServer.cs | 4 ++-- .../Internal/AspNetCoreMqttServer.cs | 22 ++++++++++++++++++- .../Internal/MqttConnectionHandler.cs | 4 ++++ 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index 0060c0f63..914a8a68b 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.DependencyInjection; namespace MQTTnet.AspNetCore { @@ -15,6 +16,7 @@ public static class ConnectionBuilderExtensions /// public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder) { + builder.ApplicationServices.GetRequiredService().UseFlag = true; return builder.UseConnectionHandler(); } } diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs index d7fe5a2e2..c6535a72c 100644 --- a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -5,6 +5,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; using System; using System.Collections.Generic; using System.Linq; @@ -39,6 +40,7 @@ static string SelectSubProtocol(IList requestedSubProtocolValues) /// public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, string pattern, Action options) { + endpoints.ServiceProvider.GetRequiredService().MapFlag = true; return endpoints.MapConnectionHandler(pattern, options); } } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs index 8877b037b..519721a4f 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -36,12 +36,12 @@ void OnApplicationStarted() protected override async Task ExecuteAsync(CancellationToken stoppingToken) { await _applicationStartedTask.WaitAsync(stoppingToken); - await _aspNetCoreMqttServer.StartAsync(); + await _aspNetCoreMqttServer.StartAsync(stoppingToken); } public override Task StopAsync(CancellationToken cancellationToken) { - return _aspNetCoreMqttServer.StopAsync(); + return _aspNetCoreMqttServer.StopAsync(cancellationToken); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs index 66f289b2d..4469105b4 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs @@ -4,26 +4,46 @@ using Microsoft.Extensions.Options; using MQTTnet.Diagnostics.Logger; +using MQTTnet.Exceptions; using MQTTnet.Server; using System.Collections.Generic; +using System.Linq; +using System.Threading; using System.Threading.Tasks; namespace MQTTnet.AspNetCore; sealed class AspNetCoreMqttServer : MqttServer { + private readonly MqttConnectionHandler _connectionHandler; private readonly IOptions _stopOptions; + private readonly IEnumerable _adapters; public AspNetCoreMqttServer( + MqttConnectionHandler connectionHandler, IOptions serverOptions, IOptions stopOptions, IEnumerable adapters, IMqttNetLogger logger) : base(serverOptions.Value.Build(), adapters, logger) { + _connectionHandler = connectionHandler; _stopOptions = stopOptions; + _adapters = adapters; } - public Task StopAsync() + public Task StartAsync(CancellationToken cancellationToken) + { + if (!_connectionHandler.UseFlag && + !_connectionHandler.MapFlag && + _adapters.All(item => item.GetType() == typeof(AspNetCoreMqttServerAdapter))) + { + throw new MqttConfigurationException("UseMqtt() or MapMqtt() must be called in at least one place"); + } + + return base.StartAsync(); + } + + public Task StopAsync(CancellationToken cancellationToken) { return base.StopAsync(_stopOptions.Value.Build()); } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 9747e6d22..bebde5c5f 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -19,6 +19,10 @@ sealed class MqttConnectionHandler : ConnectionHandler readonly IMqttNetLogger _logger; readonly MqttServerOptions _serverOptions; + public bool UseFlag { get; set; } + + public bool MapFlag { get; set; } + public Func? ClientHandler { get; set; } public MqttConnectionHandler( From 858a91421bc60da8616e75f72dd98b1158f1459e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 13:45:08 +0800 Subject: [PATCH 30/85] AspNetCoreMqttServerAdapter: Logging when MqttServerOptions are ignored --- .../Internal/AspNetCoreMqttServerAdapter.cs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs index b095d549a..edce7ab78 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs @@ -13,7 +13,6 @@ namespace MQTTnet.AspNetCore; sealed class AspNetCoreMqttServerAdapter : IMqttServerAdapter { readonly MqttConnectionHandler _connectionHandler; - public Func? ClientHandler { get => _connectionHandler.ClientHandler; @@ -27,6 +26,18 @@ public AspNetCoreMqttServerAdapter(MqttConnectionHandler connectionHandler) public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) { + if (options.DefaultEndpointOptions.IsEnabled) + { + var message = "DefaultEndpoint is ignored because the listener is implemented by the Asp.Net Core Server."; + logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); + } + + if (options.TlsEndpointOptions.IsEnabled) + { + var message = "EncryptedEndpoint is ignored because the the listener and TLS middleware are implemented by Asp.NetCore's Server."; + logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); + } + return Task.CompletedTask; } From fedb631007922cc1700a9c2f5c00dba774fad4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 14:13:07 +0800 Subject: [PATCH 31/85] Add IMqttServerBuilder.AddMqttServerAdapter() extensions --- .../MqttServerBuilderExtensions.cs | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index 487a8ce3a..8b0f732a2 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -7,11 +7,12 @@ using MQTTnet.Server; using MQTTnet.Server.Internal.Adapter; using System; +using System.Diagnostics.CodeAnalysis; namespace MQTTnet.AspNetCore { public static class MqttServerBuilderExtensions - { + { /// /// Configure MqttServerOptionsBuilder /// @@ -39,12 +40,33 @@ public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder /// /// Add MqttTcpServerAdapter to MqttServer /// - /// We recommend using ListenOptions.UseMqtt() instead of using MqttTcpServerAdapter in an AspNetCore environment /// /// public static IMqttServerBuilder AddMqttTcpServerAdapter(this IMqttServerBuilder builder) { - builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton()); + return builder.AddMqttServerAdapter(); + } + + /// + /// Add AspNetCoreMqttServerAdapter to MqttServer + /// + /// + /// + public static IMqttServerBuilder AddAspNetCoreMqttServerAdapter(this IMqttServerBuilder builder) + { + return builder.AddMqttServerAdapter(); + } + + /// + /// Add an IMqttServerAdapter to MqttServer + /// + /// + /// + /// + public static IMqttServerBuilder AddMqttServerAdapter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TMqttServerAdapter>(this IMqttServerBuilder builder) + where TMqttServerAdapter : class, IMqttServerAdapter + { + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton()); return builder; } } From 67239cd3e1f50c82b275864fa2d52bd7e2a847e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 15:31:37 +0800 Subject: [PATCH 32/85] Register MqttServerOptions and MqttServerStopOptions as services --- .../Internal/AspNetCoreMqttOptionsBuilder.cs | 33 +++++++++++++++++++ .../Internal/AspNetCoreMqttServer.cs | 11 +++---- .../Internal/MqttConnectionHandler.cs | 5 ++- .../ServiceCollectionExtensions.cs | 3 ++ 4 files changed, 43 insertions(+), 9 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs new file mode 100644 index 000000000..5122529b0 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using MQTTnet.Server; + +namespace MQTTnet.AspNetCore +{ + sealed class AspNetCoreMqttOptionsBuilder + { + private readonly MqttServerOptionsBuilder _serverOptionsBuilder; + private readonly MqttServerStopOptionsBuilder _stopOptionsBuilder; + + public AspNetCoreMqttOptionsBuilder( + IOptions serverOptionsBuilderOptions, + IOptions stopOptionsBuilderOptions) + { + _serverOptionsBuilder = serverOptionsBuilderOptions.Value; + _stopOptionsBuilder = stopOptionsBuilderOptions.Value; + } + + public MqttServerOptions BuildServerOptions() + { + return _serverOptionsBuilder.Build(); + } + + public MqttServerStopOptions BuildServerStopOptions() + { + return _stopOptionsBuilder.Build(); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs index 4469105b4..41fff646e 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.Extensions.Options; using MQTTnet.Diagnostics.Logger; using MQTTnet.Exceptions; using MQTTnet.Server; @@ -16,15 +15,15 @@ namespace MQTTnet.AspNetCore; sealed class AspNetCoreMqttServer : MqttServer { private readonly MqttConnectionHandler _connectionHandler; - private readonly IOptions _stopOptions; + private readonly MqttServerStopOptions _stopOptions; private readonly IEnumerable _adapters; public AspNetCoreMqttServer( MqttConnectionHandler connectionHandler, - IOptions serverOptions, - IOptions stopOptions, + MqttServerOptions serverOptions, + MqttServerStopOptions stopOptions, IEnumerable adapters, - IMqttNetLogger logger) : base(serverOptions.Value.Build(), adapters, logger) + IMqttNetLogger logger) : base(serverOptions, adapters, logger) { _connectionHandler = connectionHandler; _stopOptions = stopOptions; @@ -45,6 +44,6 @@ public Task StartAsync(CancellationToken cancellationToken) public Task StopAsync(CancellationToken cancellationToken) { - return base.StopAsync(_stopOptions.Value.Build()); + return base.StopAsync(_stopOptions); } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index bebde5c5f..90d7f8dd8 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -4,7 +4,6 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; -using Microsoft.Extensions.Options; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; @@ -27,10 +26,10 @@ sealed class MqttConnectionHandler : ConnectionHandler public MqttConnectionHandler( IMqttNetLogger logger, - IOptions serverOptions) + MqttServerOptions serverOptions) { _logger = logger; - _serverOptions = serverOptions.Value.Build(); + _serverOptions = serverOptions; } public override async Task OnConnectedAsync(ConnectionContext connection) diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 4d959c972..b131337e7 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -22,6 +22,9 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) services.AddOptions(); services.AddConnections(); services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(s => s.GetRequiredService().BuildServerOptions()); + services.TryAddSingleton(s => s.GetRequiredService().BuildServerStopOptions()); services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddSingleton(); From e03b9b1bac5ec30773407c6f57bfe97231179713 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 15:54:02 +0800 Subject: [PATCH 33/85] IMqttBuilder: Add IMqttBuilder.UseAspNetCoreMqttNetLogger() extension --- .../AspNetCoreMqttNetLoggerOptions.cs | 29 +++++++++++++++++++ .../Internal/AspNetCoreMqttNetLogger.cs | 26 +++++++---------- .../MqttBuilderExtensions.cs | 26 +++++++++++++++-- 3 files changed, 63 insertions(+), 18 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/AspNetCoreMqttNetLoggerOptions.cs diff --git a/Source/MQTTnet.AspnetCore/AspNetCoreMqttNetLoggerOptions.cs b/Source/MQTTnet.AspnetCore/AspNetCoreMqttNetLoggerOptions.cs new file mode 100644 index 000000000..273463f4d --- /dev/null +++ b/Source/MQTTnet.AspnetCore/AspNetCoreMqttNetLoggerOptions.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Logging; +using MQTTnet.Diagnostics.Logger; +using System; + +namespace MQTTnet.AspNetCore +{ + public sealed class AspNetCoreMqttNetLoggerOptions + { + public string? CategoryNamePrefix { get; set; } = "MQTTnet.AspNetCore."; + + public Func LogLevelConverter { get; set; } = ConvertLogLevel; + + private static LogLevel ConvertLogLevel(MqttNetLogLevel level) + { + return level switch + { + MqttNetLogLevel.Verbose => LogLevel.Trace, + MqttNetLogLevel.Info => LogLevel.Information, + MqttNetLogLevel.Warning => LogLevel.Warning, + MqttNetLogLevel.Error => LogLevel.Error, + _ => LogLevel.None + }; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs index 23a80bad9..bf7881bed 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using MQTTnet.Diagnostics.Logger; using System; @@ -11,31 +12,24 @@ namespace MQTTnet.AspNetCore sealed class AspNetCoreMqttNetLogger : IMqttNetLogger { private readonly ILoggerFactory _loggerFactory; - private const string categoryNamePrefix = "MQTTnet.AspNetCore."; + private readonly AspNetCoreMqttNetLoggerOptions _loggerOptions; public bool IsEnabled => true; - public AspNetCoreMqttNetLogger(ILoggerFactory loggerFactory) + public AspNetCoreMqttNetLogger( + ILoggerFactory loggerFactory, + IOptions loggerOptions) { _loggerFactory = loggerFactory; + _loggerOptions = loggerOptions.Value; } public void Publish(MqttNetLogLevel logLevel, string? source, string? message, object[]? parameters, Exception? exception) { - var logger = _loggerFactory.CreateLogger($"{categoryNamePrefix}{source}"); - logger.Log(ConvertLogLevel(logLevel), exception, message, parameters ?? []); - } - - private static LogLevel ConvertLogLevel(MqttNetLogLevel? level) - { - return level switch - { - MqttNetLogLevel.Verbose => LogLevel.Trace, - MqttNetLogLevel.Info => LogLevel.Information, - MqttNetLogLevel.Warning => LogLevel.Warning, - MqttNetLogLevel.Error => LogLevel.Error, - _ => LogLevel.None - }; + var categoryName = $"{_loggerOptions.CategoryNamePrefix}{source}"; + var logger = _loggerFactory.CreateLogger(categoryName); + var level = _loggerOptions.LogLevelConverter(logLevel); + logger.Log(level, exception, message, parameters ?? []); } } } diff --git a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs index 7ed88a2c2..6f5b48a35 100644 --- a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs @@ -13,11 +13,33 @@ namespace MQTTnet.AspNetCore public static class MqttBuilderExtensions { /// - /// Disable logging + /// Use AspNetCoreMqttNetLogger as IMqttNetLogger /// /// + /// /// - public static IMqttBuilder UseNullLogger(this IMqttBuilder builder) + public static IMqttBuilder UseAspNetCoreMqttNetLogger(this IMqttBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder.UseAspNetCoreMqttNetLogger(); + } + + /// + /// Use AspNetCoreMqttNetLogger as IMqttNetLogger + /// + /// + /// + public static IMqttBuilder UseAspNetCoreMqttNetLogger(this IMqttBuilder builder) + { + return builder.UseLogger(); + } + + /// + /// Use MqttNetNullLogger as IMqttNetLogger + /// + /// + /// + public static IMqttBuilder UseMqttNetNullLogger(this IMqttBuilder builder) { return builder.UseLogger(); } From 014a50b42b0bbc9c162a6a48de4ee08a353a2bef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 16:21:52 +0800 Subject: [PATCH 34/85] Remove some dead code. --- Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs index 6f5b48a35..7cca5696b 100644 --- a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs @@ -55,11 +55,6 @@ public static IMqttBuilder UseMqttNetNullLogger(this IMqttBuilder builder) { builder.Services.Replace(ServiceDescriptor.Singleton()); return builder; - } - - private class MqttBuilder(IServiceCollection services) : IMqttBuilder - { - public IServiceCollection Services { get; } = services; - } + } } } From ad6287779132c2ac463fd9f44dedbb057c1eae6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 16 Nov 2024 19:36:13 +0800 Subject: [PATCH 35/85] Apply the properties of MqttClientTcpOptions to the Socket --- .../Internal/ClientConnectionContext.Tcp.cs | 53 ++++++++++++++++--- .../ClientConnectionContext.WebSocket.cs | 7 ++- .../Internal/ClientConnectionContext.cs | 2 +- .../Internal/MqttClientChannelAdapter.cs | 1 - 4 files changed, 50 insertions(+), 13 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs index b16d27bac..7027f2f49 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using System; @@ -13,13 +12,53 @@ using System.Threading; using System.Threading.Tasks; -namespace MQTTnet.AspNetCore.Internal +namespace MQTTnet.AspNetCore { - partial class ClientConnectionContext : ConnectionContext - { - public static async Task CreateAsync(MqttClientTcpOptions options, CancellationToken cancellationToken) + partial class ClientConnectionContext + { + public static async Task CreateAsync(MqttClientTcpOptions options, CancellationToken cancellationToken) { - var socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + Socket socket; + if (options.RemoteEndpoint is UnixDomainSocketEndPoint) + { + socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + } + else if (options.AddressFamily == AddressFamily.Unspecified) + { + socket = new Socket(SocketType.Stream, options.ProtocolType); + } + else + { + socket = new Socket(options.AddressFamily, SocketType.Stream, options.ProtocolType); + } + + if (options.LocalEndpoint != null) + { + socket.Bind(options.LocalEndpoint); + } + + socket.ReceiveBufferSize = options.BufferSize; + socket.SendBufferSize = options.BufferSize; + + if (options.ProtocolType == ProtocolType.Tcp && options.RemoteEndpoint is not UnixDomainSocketEndPoint) + { + // Other protocol types do not support the Nagle algorithm. + socket.NoDelay = options.NoDelay; + } + + if (options.LingerState != null) + { + socket.LingerState = options.LingerState; + } + + if (options.DualMode.HasValue) + { + // It is important to avoid setting the flag if no specific value is set by the user + // because on IPv4 only networks the setter will always throw an exception. Regardless + // of the actual value. + socket.DualMode = options.DualMode.Value; + } + try { await socket.ConnectAsync(options.RemoteEndpoint, cancellationToken).ConfigureAwait(false); @@ -146,6 +185,6 @@ bool InternalUserCertificateValidationCallback(object sender, X509Certificate? x { return options.TlsOptions.ClientCertificatesProvider?.GetCertificates(); } - } + } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs index 6bea4f91b..a1d06ecdc 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Features; using System; using System.IO; @@ -11,11 +10,11 @@ using System.Threading; using System.Threading.Tasks; -namespace MQTTnet.AspNetCore.Internal +namespace MQTTnet.AspNetCore { - partial class ClientConnectionContext : ConnectionContext + partial class ClientConnectionContext { - public static async Task CreateAsync(MqttClientWebSocketOptions options, CancellationToken cancellationToken) + public static async Task CreateAsync(MqttClientWebSocketOptions options, CancellationToken cancellationToken) { var uri = new UriBuilder(new Uri(options.Uri, UriKind.Absolute)) { diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs index f578013a9..c4a24333a 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -14,7 +14,7 @@ using System.Threading; using System.Threading.Tasks; -namespace MQTTnet.AspNetCore.Internal +namespace MQTTnet.AspNetCore { sealed partial class ClientConnectionContext : ConnectionContext { diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index ed1dc1f6a..e66b1aa79 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -4,7 +4,6 @@ using Microsoft.AspNetCore.Connections; using MQTTnet.Adapter; -using MQTTnet.AspNetCore.Internal; using MQTTnet.Formatter; using MQTTnet.Packets; using System; From c49beb0b443b94ff237478c8c3b7aa6e7a941937 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sun, 17 Nov 2024 22:50:00 +0800 Subject: [PATCH 36/85] TlsConnectionFeature supports passing in ClientCertificate --- .../Internal/ClientConnectionContext.Tcp.cs | 3 ++- .../Internal/ClientConnectionContext.WebSocket.cs | 2 +- .../MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs | 7 ++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs index 7027f2f49..2a67dcb08 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs @@ -142,8 +142,9 @@ public static async Task CreateAsync(MqttClientTcpOptio LocalEndPoint = socket.LocalEndPoint, RemoteEndPoint = socket.RemoteEndPoint, }; - connection.Features.Set(TlsConnectionFeature.Instance); + connection.Features.Set(new ConnectionSocketFeature(socket)); + connection.Features.Set(new TlsConnectionFeature(sslStream.LocalCertificate)); return connection; diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs index a1d06ecdc..458d0d09f 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -43,7 +43,7 @@ public static async Task CreateAsync(MqttClientWebSocke if (uri.Scheme == Uri.UriSchemeWss) { - connection.Features.Set(TlsConnectionFeature.Instance); + connection.Features.Set(TlsConnectionFeature.Default); } return connection; } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs index c4a24333a..b6b34b9f3 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -70,7 +70,7 @@ private class StreamTransport(Stream stream) : IDuplexPipe private class TlsConnectionFeature : ITlsConnectionFeature { - public static readonly TlsConnectionFeature Instance = new(); + public static readonly TlsConnectionFeature Default = new(null); public X509Certificate2? ClientCertificate { get; set; } @@ -78,6 +78,11 @@ private class TlsConnectionFeature : ITlsConnectionFeature { return Task.FromResult(ClientCertificate); } + + public TlsConnectionFeature(X509Certificate? clientCertificate) + { + ClientCertificate = clientCertificate as X509Certificate2; + } } private class ConnectionSocketFeature(Socket socket) : IConnectionSocketFeature From 4c55368a3ac1c7860460ba1d00c322503b64f365 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sun, 17 Nov 2024 23:46:58 +0800 Subject: [PATCH 37/85] Add support for MqttPacketInspector --- .../AspNetCoreMqttClientAdapterFactory.cs | 2 +- .../Internal/MqttChannel.cs | 23 ++++++++++++++++--- .../Internal/MqttClientChannelAdapter.cs | 9 ++++++-- .../Internal/MqttConnectionHandler.cs | 3 ++- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index 702f85aee..2f5041144 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -16,7 +16,7 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa ArgumentNullException.ThrowIfNull(nameof(options)); var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax); var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter); - return new MqttClientChannelAdapter(formatter, options.ChannelOptions); + return new MqttClientChannelAdapter(formatter, options.ChannelOptions, packetInspector); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index 256438b84..d23f4839c 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -5,6 +5,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Features; +using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Formatter; using MQTTnet.Internal; @@ -24,6 +25,7 @@ class MqttChannel : IDisposable readonly AsyncLock _writerLock = new(); readonly PipeReader _input; readonly PipeWriter _output; + readonly MqttPacketInspector? _packetInspector; public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } @@ -38,12 +40,16 @@ class MqttChannel : IDisposable public bool IsSecureConnection { get; } - public MqttChannel(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) + public MqttChannel( + MqttPacketFormatterAdapter packetFormatterAdapter, + ConnectionContext connection, + MqttPacketInspector? packetInspector = null) { + PacketFormatterAdapter = packetFormatterAdapter; + _packetInspector = packetInspector; + var httpContextFeature = connection.Features.Get(); var tlsConnectionFeature = connection.Features.Get(); - - PacketFormatterAdapter = packetFormatterAdapter; Endpoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint); IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature); ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature); @@ -96,6 +102,8 @@ public virtual void Dispose() { try { + _packetInspector?.BeginReceivePacket(); + while (!cancellationToken.IsCancellationRequested) { ReadResult readResult; @@ -121,6 +129,11 @@ public virtual void Dispose() if (PacketFormatterAdapter.TryDecode(buffer, out var packet, out consumed, out observed, out var received)) { BytesReceived += received; + + if (_packetInspector != null) + { + await _packetInspector.EndReceivePacket().ConfigureAwait(false); + } return packet; } } @@ -164,6 +177,10 @@ public async Task SendPacketAsync(MqttPacket packet, CancellationToken cancellat try { var buffer = PacketFormatterAdapter.Encode(packet); + if (_packetInspector != null) + { + await _packetInspector.BeginSendPacket(buffer).ConfigureAwait(false); + } if (buffer.Payload.Length == 0) { diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index e66b1aa79..431ab8ae6 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -21,11 +21,16 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable private MqttChannel? _channel; private readonly MqttPacketFormatterAdapter _packetFormatterAdapter; private readonly IMqttClientChannelOptions _channelOptions; + private readonly MqttPacketInspector? _packetInspector; - public MqttClientChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, IMqttClientChannelOptions channelOptions) + public MqttClientChannelAdapter( + MqttPacketFormatterAdapter packetFormatterAdapter, + IMqttClientChannelOptions channelOptions, + MqttPacketInspector? packetInspector) { _packetFormatterAdapter = packetFormatterAdapter; _channelOptions = channelOptions; + _packetInspector = packetInspector; } public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter; @@ -49,7 +54,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken) MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken), _ => throw new NotSupportedException(), }; - _channel = new MqttChannel(_packetFormatterAdapter, _connection); + _channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector); } public async Task DisconnectAsync(CancellationToken cancellationToken) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 90d7f8dd8..0c3d49097 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -49,7 +49,8 @@ public override async Task OnConnectedAsync(ConnectionContext connection) transferFormatFeature.ActiveFormat = TransferFormat.Binary; } - var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); + var bufferWriter = new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax); + var formatter = new MqttPacketFormatterAdapter(bufferWriter); using var adapter = new MqttServerChannelAdapter(formatter, connection); await clientHandler(adapter).ConfigureAwait(false); } From 34a4db249eee0d97031fcab6d62ee2e3e8d8baad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sun, 17 Nov 2024 23:58:26 +0800 Subject: [PATCH 38/85] Add route syntax for pattern parameter. --- Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs index c6535a72c..faf64ad72 100644 --- a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -8,6 +8,7 @@ using Microsoft.Extensions.DependencyInjection; using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; namespace MQTTnet.AspNetCore @@ -20,7 +21,7 @@ public static class EndpointRouteBuilderExtensions /// /// /// - public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, string pattern) + public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern) { return endpoints.MapMqtt(pattern, options => options.WebSockets.SubProtocolSelector = SelectSubProtocol); @@ -38,7 +39,7 @@ static string SelectSubProtocol(IList requestedSubProtocolValues) /// /// /// - public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, string pattern, Action options) + public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern, Action options) { endpoints.ServiceProvider.GetRequiredService().MapFlag = true; return endpoints.MapConnectionHandler(pattern, options); From 6dc18e1b3ca22b79d438f4652c8f6bc8bdb4c6a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 02:04:34 +0800 Subject: [PATCH 39/85] Add KestrelServerOptions.ListenMqtt() extensions. --- Samples/Server/Server_ASP_NET_Samples.cs | 13 +-- Source/MQTTnet.AspTestApp/Program.cs | 12 +-- Source/MQTTnet.AspTestApp/appsettings.json | 12 +-- .../Internal/AspNetCoreMqttServer.cs | 5 +- .../Internal/AspNetCoreMqttServerAdapter.cs | 21 +++-- .../Internal/MqttConnectionHandler.cs | 2 + .../KestrelServerOptionsExtensions.cs | 85 +++++++++++++++++++ .../MqttServerBuilderExtensions.cs | 23 +---- 8 files changed, 118 insertions(+), 55 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 73128a40d..bfe222530 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -23,17 +23,18 @@ public static class Server_ASP_NET_Samples public static Task Start_Server_With_WebSockets_Support() { var builder = WebApplication.CreateBuilder(); - builder.Services.AddMqttServer(); + builder.Services.AddMqttServer().ConfigureMqttServer(s => s.WithDefaultEndpoint().WithEncryptedEndpoint()); builder.Services.AddMqttClient(); builder.Services.AddHostedService(); builder.WebHost.UseKestrel(kestrel => { - // mqtt over tcp - kestrel.ListenAnyIP(1883, l => l.UseMqtt()); + // Need ConfigureMqttServer(s => s.WithDefaultEndpoint().WithEncryptedEndpoint()) + kestrel.ListenMqtt(); - // mqtt over tls over tcp - kestrel.ListenLocalhost(1884, l => l.UseHttps().UseMqtt()); + // We can also manually listen to a specific port without ConfigureMqttServer() + // kestrel.ListenAnyIP(1883, l => l.UseMqtt()); // mqtt over tcp + // kestrel.ListenAnyIP(8883, l => l.UseHttps().UseMqtt()); // mqtt over tls over tcp // This will allow MQTT connections based on HTTP WebSockets with URI "localhost:5000/mqtt" // See code below for URI configuration. @@ -86,7 +87,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { await Task.Delay(1000); using var client = _mqttClientFactory.CreateMqttClient(); - var options = new MqttClientOptionsBuilder().WithConnectionUri("ws://localhost:5000/mqtt").Build(); + var options = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://127.0.0.1:1883").Build(); await client.ConnectAsync(options, stoppingToken); await client.DisconnectAsync(); } diff --git a/Source/MQTTnet.AspTestApp/Program.cs b/Source/MQTTnet.AspTestApp/Program.cs index ba64ccf79..df8d32184 100644 --- a/Source/MQTTnet.AspTestApp/Program.cs +++ b/Source/MQTTnet.AspTestApp/Program.cs @@ -11,13 +11,15 @@ builder.Services.AddRazorPages(); // Setup MQTT stuff. -builder.Services.AddMqttServer(); +builder.Services.AddMqttServer().ConfigureMqttServer(s => s.WithDefaultEndpoint()); -// UseMqttEndPoint -builder.WebHost.ConfigureKestrel((context, serverOptions) => +// ListenMqtt +builder.WebHost.UseKestrel(kestrel => { - var kestrelSection = context.Configuration.GetSection("Kestrel"); - serverOptions.Configure(kestrelSection).Endpoint("Mqtt", mqtt => mqtt.ListenOptions.UseMqtt()); + kestrel.ListenMqtt(); + + // mqtt over WebSocket + kestrel.ListenAnyIP(5000); // Default HTTP pipeline }); var app = builder.Build(); diff --git a/Source/MQTTnet.AspTestApp/appsettings.json b/Source/MQTTnet.AspTestApp/appsettings.json index 9d11d79ad..0f22ea11d 100644 --- a/Source/MQTTnet.AspTestApp/appsettings.json +++ b/Source/MQTTnet.AspTestApp/appsettings.json @@ -1,14 +1,4 @@ -{ - "Kestrel": { - "Endpoints": { - "Http": { - "Url": "http://localhost:5000" - }, - "Mqtt": { - "Url": "http://localhost:1883" - } - } - }, +{ "Logging": { "LogLevel": { "Default": "Information", diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs index 41fff646e..e2966b2f0 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs @@ -32,11 +32,12 @@ public AspNetCoreMqttServer( public Task StartAsync(CancellationToken cancellationToken) { - if (!_connectionHandler.UseFlag && + if (!_connectionHandler.ListenFlag && + !_connectionHandler.UseFlag && !_connectionHandler.MapFlag && _adapters.All(item => item.GetType() == typeof(AspNetCoreMqttServerAdapter))) { - throw new MqttConfigurationException("UseMqtt() or MapMqtt() must be called in at least one place"); + throw new MqttConfigurationException("ListenMqtt() or UseMqtt() or MapMqtt() must be called in at least one place"); } return base.StartAsync(); diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs index edce7ab78..f4e6e167e 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs @@ -26,16 +26,19 @@ public AspNetCoreMqttServerAdapter(MqttConnectionHandler connectionHandler) public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) { - if (options.DefaultEndpointOptions.IsEnabled) + if (!_connectionHandler.ListenFlag) { - var message = "DefaultEndpoint is ignored because the listener is implemented by the Asp.Net Core Server."; - logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); - } - - if (options.TlsEndpointOptions.IsEnabled) - { - var message = "EncryptedEndpoint is ignored because the the listener and TLS middleware are implemented by Asp.NetCore's Server."; - logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); + if (options.DefaultEndpointOptions.IsEnabled) + { + var message = "DefaultEndpointOptions has been ignored because the user called UseMqtt() on the specified listener."; + logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); + } + + if (options.TlsEndpointOptions.IsEnabled) + { + var message = "TlsEndpointOptions has been ignored because the user called UseMqtt() on the specified listener."; + logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); + } } return Task.CompletedTask; diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 0c3d49097..b52caae63 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -22,6 +22,8 @@ sealed class MqttConnectionHandler : ConnectionHandler public bool MapFlag { get; set; } + public bool ListenFlag { get; set; } + public Func? ClientHandler { get; set; } public MqttConnectionHandler( diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs new file mode 100644 index 000000000..b8f49f733 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -0,0 +1,85 @@ +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Exceptions; +using MQTTnet.Server; +using System; +using System.Net; +using System.Security.Cryptography.X509Certificates; + +namespace MQTTnet.AspNetCore +{ + public static class KestrelServerOptionsExtensions + { + /// + /// Listen all endponts in MqttServerOptions + /// • The properties of the DefaultEndpointOptions will be ignored except for the BoundInterNetworkAddress and port + /// • The properties of the TlsEndpointOptions will be ignored except for the BoundInterNetworkAddress and port + /// + /// + /// + /// + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel) + { + return kestrel.ListenMqtt(tls => { }); + } + + /// + /// Listen all endponts in MqttServerOptions + /// • The properties of the DefaultEndpointOptions will be ignored except for the BoundInterNetworkAddress and port + /// • The properties of the TlsEndpointOptions will be ignored except for the BoundInterNetworkAddress and port + /// + /// + /// + /// + /// + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, X509Certificate2 serverCertificate) + { + return kestrel.ListenMqtt(tls => tls.ServerCertificate = serverCertificate); + } + + /// + /// Listen all endponts in MqttServerOptions + /// • The properties of the DefaultEndpointOptions will be ignored except for the BoundInterNetworkAddress and port + /// • The properties of the TlsEndpointOptions will be ignored except for the BoundInterNetworkAddress and port + /// + /// + /// + /// + /// + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, Action tlsConfigure) + { + var serverOptions = kestrel.ApplicationServices.GetRequiredService(); + var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); + + if (serverOptions.DefaultEndpointOptions.IsEnabled) + { + var endpoint = serverOptions.DefaultEndpointOptions; + kestrel.Listen(endpoint.BoundInterNetworkV6Address, endpoint.Port, o => o.UseMqtt()); + if (!IPAddress.IPv6Any.Equals(endpoint.BoundInterNetworkV6Address)) + { + kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, o => o.UseMqtt()); + } + connectionHandler.ListenFlag = true; + } + + if (serverOptions.TlsEndpointOptions.IsEnabled) + { + var endpoint = serverOptions.TlsEndpointOptions; + kestrel.Listen(endpoint.BoundInterNetworkV6Address, endpoint.Port, o => o.UseHttps(tlsConfigure).UseMqtt()); + if (!IPAddress.IPv6Any.Equals(endpoint.BoundInterNetworkV6Address)) + { + kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, o => o.UseHttps(tlsConfigure).UseMqtt()); + } + connectionHandler.ListenFlag = true; + } + + if (!connectionHandler.ListenFlag) + { + throw new MqttConfigurationException("None of the MqttServerOptions Endpoints are enabled."); + } + return kestrel; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index 8b0f732a2..813313ad8 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -5,7 +5,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Server; -using MQTTnet.Server.Internal.Adapter; using System; using System.Diagnostics.CodeAnalysis; @@ -36,27 +35,7 @@ public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder builder.Services.Configure(configure); return builder; } - - /// - /// Add MqttTcpServerAdapter to MqttServer - /// - /// - /// - public static IMqttServerBuilder AddMqttTcpServerAdapter(this IMqttServerBuilder builder) - { - return builder.AddMqttServerAdapter(); - } - - /// - /// Add AspNetCoreMqttServerAdapter to MqttServer - /// - /// - /// - public static IMqttServerBuilder AddAspNetCoreMqttServerAdapter(this IMqttServerBuilder builder) - { - return builder.AddMqttServerAdapter(); - } - + /// /// Add an IMqttServerAdapter to MqttServer /// From 67dfda4e3a05108d2d3d7604222fc5b9242b48a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 02:27:00 +0800 Subject: [PATCH 40/85] Optimize the implementation of ListenMqtt. --- .../KestrelServerOptionsExtensions.cs | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index b8f49f733..179576e7f 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -53,33 +53,38 @@ public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, var serverOptions = kestrel.ApplicationServices.GetRequiredService(); var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); - if (serverOptions.DefaultEndpointOptions.IsEnabled) + Listen(serverOptions.DefaultEndpointOptions, useTls: false); + Listen(serverOptions.TlsEndpointOptions, useTls: true); + + return connectionHandler.ListenFlag + ? kestrel + : throw new MqttConfigurationException("None of the MqttServerOptions Endpoints are enabled."); + + void Listen(MqttServerTcpEndpointBaseOptions endpoint, bool useTls) { - var endpoint = serverOptions.DefaultEndpointOptions; - kestrel.Listen(endpoint.BoundInterNetworkV6Address, endpoint.Port, o => o.UseMqtt()); - if (!IPAddress.IPv6Any.Equals(endpoint.BoundInterNetworkV6Address)) + if (!endpoint.IsEnabled) { - kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, o => o.UseMqtt()); + return; } - connectionHandler.ListenFlag = true; - } - if (serverOptions.TlsEndpointOptions.IsEnabled) - { - var endpoint = serverOptions.TlsEndpointOptions; - kestrel.Listen(endpoint.BoundInterNetworkV6Address, endpoint.Port, o => o.UseHttps(tlsConfigure).UseMqtt()); + // No need to listen any IPv4 when has IPv6Any if (!IPAddress.IPv6Any.Equals(endpoint.BoundInterNetworkV6Address)) { - kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, o => o.UseHttps(tlsConfigure).UseMqtt()); + kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, UseMiddleware); } + kestrel.Listen(endpoint.BoundInterNetworkV6Address, endpoint.Port, UseMiddleware); connectionHandler.ListenFlag = true; - } - if (!connectionHandler.ListenFlag) - { - throw new MqttConfigurationException("None of the MqttServerOptions Endpoints are enabled."); + + void UseMiddleware(ListenOptions listenOptions) + { + if (useTls) + { + listenOptions.UseHttps(tlsConfigure); + } + listenOptions.UseMqtt(); + } } - return kestrel; } } -} +} \ No newline at end of file From 5813a138f779ba28389d81039937bba49c7f8ab9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 03:33:58 +0800 Subject: [PATCH 41/85] Adapt MqttServerTlsTcpEndpointOptions to HttpsConnectionAdapterOptions --- .../KestrelServerOptionsExtensions.cs | 54 ++++++++++++++----- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 179576e7f..a47ae0326 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -1,3 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Https; @@ -14,53 +18,47 @@ public static class KestrelServerOptionsExtensions { /// /// Listen all endponts in MqttServerOptions - /// • The properties of the DefaultEndpointOptions will be ignored except for the BoundInterNetworkAddress and port - /// • The properties of the TlsEndpointOptions will be ignored except for the BoundInterNetworkAddress and port /// /// /// /// public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel) { - return kestrel.ListenMqtt(tls => { }); + return kestrel.ListenMqtt(default(Action)); } /// /// Listen all endponts in MqttServerOptions - /// • The properties of the DefaultEndpointOptions will be ignored except for the BoundInterNetworkAddress and port - /// • The properties of the TlsEndpointOptions will be ignored except for the BoundInterNetworkAddress and port /// /// /// /// /// - public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, X509Certificate2 serverCertificate) + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, X509Certificate2? serverCertificate) { return kestrel.ListenMqtt(tls => tls.ServerCertificate = serverCertificate); } /// /// Listen all endponts in MqttServerOptions - /// • The properties of the DefaultEndpointOptions will be ignored except for the BoundInterNetworkAddress and port - /// • The properties of the TlsEndpointOptions will be ignored except for the BoundInterNetworkAddress and port /// /// /// /// /// - public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, Action tlsConfigure) + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, Action? tlsConfigure) { var serverOptions = kestrel.ApplicationServices.GetRequiredService(); var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); - Listen(serverOptions.DefaultEndpointOptions, useTls: false); - Listen(serverOptions.TlsEndpointOptions, useTls: true); + Listen(serverOptions.DefaultEndpointOptions); + Listen(serverOptions.TlsEndpointOptions); return connectionHandler.ListenFlag ? kestrel : throw new MqttConfigurationException("None of the MqttServerOptions Endpoints are enabled."); - void Listen(MqttServerTcpEndpointBaseOptions endpoint, bool useTls) + void Listen(MqttServerTcpEndpointBaseOptions endpoint) { if (!endpoint.IsEnabled) { @@ -78,13 +76,41 @@ void Listen(MqttServerTcpEndpointBaseOptions endpoint, bool useTls) void UseMiddleware(ListenOptions listenOptions) { - if (useTls) + if (endpoint is MqttServerTlsTcpEndpointOptions tlsEndPoint) { - listenOptions.UseHttps(tlsConfigure); + var httpsOptions = CreateHttpsOptions(tlsEndPoint); + tlsConfigure?.Invoke(httpsOptions); + listenOptions.UseHttps(httpsOptions); } listenOptions.UseMqtt(); } } } + + private static HttpsConnectionAdapterOptions CreateHttpsOptions(MqttServerTlsTcpEndpointOptions tlsEndPoint) + { + var options = new HttpsConnectionAdapterOptions + { + SslProtocols = tlsEndPoint.SslProtocol, + CheckCertificateRevocation = tlsEndPoint.CheckCertificateRevocation, + }; + + if (tlsEndPoint.ClientCertificateRequired) + { + options.ClientCertificateMode = ClientCertificateMode.RequireCertificate; + } + + if (tlsEndPoint.CertificateProvider != null) + { + options.ServerCertificateSelector = (context, host) => tlsEndPoint.CertificateProvider.GetCertificate(); + } + + if (tlsEndPoint.RemoteCertificateValidationCallback != null) + { + options.ClientCertificateValidation = (cert, chain, errors) => tlsEndPoint.RemoteCertificateValidationCallback(tlsEndPoint, cert, chain, errors); + } + + return options; + } } } \ No newline at end of file From 7628614540f6a0854804ec0ffd836c0307ddd7d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 03:46:16 +0800 Subject: [PATCH 42/85] Compatible with the default server certificate. --- .../KestrelServerOptionsExtensions.cs | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index a47ae0326..181a3afa4 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -78,39 +78,36 @@ void UseMiddleware(ListenOptions listenOptions) { if (endpoint is MqttServerTlsTcpEndpointOptions tlsEndPoint) { - var httpsOptions = CreateHttpsOptions(tlsEndPoint); - tlsConfigure?.Invoke(httpsOptions); - listenOptions.UseHttps(httpsOptions); + listenOptions.UseHttps(httpsOptions => + { + tlsEndPoint.AdaptTo(httpsOptions); + tlsConfigure?.Invoke(httpsOptions); + }); } listenOptions.UseMqtt(); } } } - private static HttpsConnectionAdapterOptions CreateHttpsOptions(MqttServerTlsTcpEndpointOptions tlsEndPoint) + private static void AdaptTo(this MqttServerTlsTcpEndpointOptions tlsEndPoint, HttpsConnectionAdapterOptions httpsOptions) { - var options = new HttpsConnectionAdapterOptions - { - SslProtocols = tlsEndPoint.SslProtocol, - CheckCertificateRevocation = tlsEndPoint.CheckCertificateRevocation, - }; + httpsOptions.SslProtocols = tlsEndPoint.SslProtocol; + httpsOptions.CheckCertificateRevocation = tlsEndPoint.CheckCertificateRevocation; if (tlsEndPoint.ClientCertificateRequired) { - options.ClientCertificateMode = ClientCertificateMode.RequireCertificate; + httpsOptions.ClientCertificateMode = ClientCertificateMode.RequireCertificate; } if (tlsEndPoint.CertificateProvider != null) { - options.ServerCertificateSelector = (context, host) => tlsEndPoint.CertificateProvider.GetCertificate(); + httpsOptions.ServerCertificateSelector = (context, host) => tlsEndPoint.CertificateProvider.GetCertificate(); } if (tlsEndPoint.RemoteCertificateValidationCallback != null) { - options.ClientCertificateValidation = (cert, chain, errors) => tlsEndPoint.RemoteCertificateValidationCallback(tlsEndPoint, cert, chain, errors); + httpsOptions.ClientCertificateValidation = (cert, chain, errors) => tlsEndPoint.RemoteCertificateValidationCallback(tlsEndPoint, cert, chain, errors); } - - return options; } } } \ No newline at end of file From 7d8562aa66c92af535504197eebf0429394d25c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 11:18:23 +0800 Subject: [PATCH 43/85] Supports both MQTT and MQTT over WebSocket on a ConnectionContext. --- Samples/Server/Server_ASP_NET_Samples.cs | 21 +++++--- Source/MQTTnet.AspTestApp/Program.cs | 5 +- .../ConnectionBuilderExtensions.cs | 16 +++++- .../ClientConnectionContext.WebSocket.cs | 32 ++++++----- .../Internal/MqttConnectionMiddleware.cs | 54 +++++++++++++++++++ .../KestrelServerOptionsExtensions.cs | 17 +++--- Source/MQTTnet.AspnetCore/MqttProtocols.cs | 19 +++++++ .../ServiceCollectionExtensions.cs | 14 +++++ 8 files changed, 140 insertions(+), 38 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs create mode 100644 Source/MQTTnet.AspnetCore/MqttProtocols.cs diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index bfe222530..b8250ab12 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -23,22 +23,18 @@ public static class Server_ASP_NET_Samples public static Task Start_Server_With_WebSockets_Support() { var builder = WebApplication.CreateBuilder(); - builder.Services.AddMqttServer().ConfigureMqttServer(s => s.WithDefaultEndpoint().WithEncryptedEndpoint()); + builder.Services.AddMqttServer(s => s.WithDefaultEndpoint().WithEncryptedEndpoint()); builder.Services.AddMqttClient(); builder.Services.AddHostedService(); builder.WebHost.UseKestrel(kestrel => { - // Need ConfigureMqttServer(s => s.WithDefaultEndpoint().WithEncryptedEndpoint()) + // Need ConfigureMqttServer(s => ...) to enable the endpoints kestrel.ListenMqtt(); // We can also manually listen to a specific port without ConfigureMqttServer() // kestrel.ListenAnyIP(1883, l => l.UseMqtt()); // mqtt over tcp - // kestrel.ListenAnyIP(8883, l => l.UseHttps().UseMqtt()); // mqtt over tls over tcp - - // This will allow MQTT connections based on HTTP WebSockets with URI "localhost:5000/mqtt" - // See code below for URI configuration. - kestrel.ListenAnyIP(5000); // Default HTTP pipeline + // kestrel.ListenLocalhost(8883, l => l.UseHttps().UseMqtt()); // mqtt over tls over tcp }); var app = builder.Build(); @@ -87,7 +83,16 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { await Task.Delay(1000); using var client = _mqttClientFactory.CreateMqttClient(); - var options = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://127.0.0.1:1883").Build(); + + // var mqttUri = "mqtt://localhost:1883"; + // var mqttsUri = "mqtt://localhost:8883"; + // var wsMqttUri = "ws://localhost:1883/mqtt"; + var wssMqttUri = "wss://localhost:8883/mqtt"; + + var options = new MqttClientOptionsBuilder() + .WithConnectionUri(wssMqttUri) + .Build(); + await client.ConnectAsync(options, stoppingToken); await client.DisconnectAsync(); } diff --git a/Source/MQTTnet.AspTestApp/Program.cs b/Source/MQTTnet.AspTestApp/Program.cs index df8d32184..6d4de4df7 100644 --- a/Source/MQTTnet.AspTestApp/Program.cs +++ b/Source/MQTTnet.AspTestApp/Program.cs @@ -11,15 +11,12 @@ builder.Services.AddRazorPages(); // Setup MQTT stuff. -builder.Services.AddMqttServer().ConfigureMqttServer(s => s.WithDefaultEndpoint()); +builder.Services.AddMqttServer(s => s.WithDefaultEndpoint().WithDefaultEndpointPort(5000)); // ListenMqtt builder.WebHost.UseKestrel(kestrel => { kestrel.ListenMqtt(); - - // mqtt over WebSocket - kestrel.ListenAnyIP(5000); // Default HTTP pipeline }); var app = builder.Build(); diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index 914a8a68b..d9a11da6d 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; +using System; namespace MQTTnet.AspNetCore { @@ -13,11 +14,22 @@ public static class ConnectionBuilderExtensions /// Treat the obtained connection as an mqtt connection /// /// + /// /// - public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder) + public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndHttp) { builder.ApplicationServices.GetRequiredService().UseFlag = true; - return builder.UseConnectionHandler(); + if (protocols == MqttProtocols.Mqtt) + { + return builder.UseConnectionHandler(); + } + else if (protocols == MqttProtocols.MqttAndHttp) + { + var middleware = builder.ApplicationServices.GetRequiredService(); + return builder.Use(next => context => middleware.InvokeAsync(next, context)); + } + + throw new NotSupportedException(protocols.ToString()); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs index 458d0d09f..5aea0e324 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -16,15 +16,13 @@ partial class ClientConnectionContext { public static async Task CreateAsync(MqttClientWebSocketOptions options, CancellationToken cancellationToken) { - var uri = new UriBuilder(new Uri(options.Uri, UriKind.Absolute)) - { - Scheme = options.TlsOptions?.UseTls == true ? Uri.UriSchemeWss : Uri.UriSchemeWs - }.Uri; - var clientWebSocket = new ClientWebSocket(); + var uri = new Uri(options.Uri, UriKind.Absolute); + var useTls = options.TlsOptions?.UseTls == true || uri.Scheme == Uri.UriSchemeWss; + try { - SetupClientWebSocket(clientWebSocket, options); + SetupClientWebSocket(clientWebSocket.Options, options, useTls); await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); } catch @@ -48,18 +46,18 @@ public static async Task CreateAsync(MqttClientWebSocke return connection; } - private static void SetupClientWebSocket(ClientWebSocket clientWebSocket, MqttClientWebSocketOptions options) + private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions, MqttClientWebSocketOptions options, bool useTls) { if (options.ProxyOptions != null) { - clientWebSocket.Options.Proxy = CreateProxy(options); + webSocketOptions.Proxy = CreateProxy(options); } if (options.RequestHeaders != null) { foreach (var requestHeader in options.RequestHeaders) { - clientWebSocket.Options.SetRequestHeader(requestHeader.Key, requestHeader.Value); + webSocketOptions.SetRequestHeader(requestHeader.Key, requestHeader.Value); } } @@ -67,21 +65,21 @@ private static void SetupClientWebSocket(ClientWebSocket clientWebSocket, MqttCl { foreach (var subProtocol in options.SubProtocols) { - clientWebSocket.Options.AddSubProtocol(subProtocol); + webSocketOptions.AddSubProtocol(subProtocol); } } if (options.CookieContainer != null) { - clientWebSocket.Options.Cookies = options.CookieContainer; + webSocketOptions.Cookies = options.CookieContainer; } - if (options.TlsOptions?.UseTls == true) + if (useTls) { var certificates = options.TlsOptions?.ClientCertificatesProvider?.GetCertificates(); if (certificates?.Count > 0) { - clientWebSocket.Options.ClientCertificates = certificates; + webSocketOptions.ClientCertificates = certificates; } } @@ -89,23 +87,23 @@ private static void SetupClientWebSocket(ClientWebSocket clientWebSocket, MqttCl // and will throw a _PlatformNotSupported_ (i.e. WASM) exception when being used regardless of the actual value. if (options.UseDefaultCredentials) { - clientWebSocket.Options.UseDefaultCredentials = options.UseDefaultCredentials; + webSocketOptions.UseDefaultCredentials = options.UseDefaultCredentials; } if (options.KeepAliveInterval != WebSocket.DefaultKeepAliveInterval) { - clientWebSocket.Options.KeepAliveInterval = options.KeepAliveInterval; + webSocketOptions.KeepAliveInterval = options.KeepAliveInterval; } if (options.Credentials != null) { - clientWebSocket.Options.Credentials = options.Credentials; + webSocketOptions.Credentials = options.Credentials; } var certificateValidationHandler = options.TlsOptions?.CertificateValidationHandler; if (certificateValidationHandler != null) { - clientWebSocket.Options.RemoteCertificateValidationCallback = (_, certificate, chain, sslPolicyErrors) => + webSocketOptions.RemoteCertificateValidationCallback = (_, certificate, chain, sslPolicyErrors) => { // TODO: Find a way to add client options to same callback. Problem is that they have a different type. var context = new MqttClientCertificateValidationEventArgs(certificate, chain, sslPolicyErrors, options); diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs new file mode 100644 index 000000000..77a9dd870 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using System; +using System.IO.Pipelines; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +/// +/// Middleware that allows connections to be either HTTP or MQTT +/// +sealed class MqttConnectionMiddleware +{ + private static readonly byte[] _mqtt = "MQTT"u8.ToArray(); + private static readonly byte[] _MQIsdp = "MQIsdp"u8.ToArray(); + private readonly MqttConnectionHandler _connectionHandler; + + public MqttConnectionMiddleware(MqttConnectionHandler connectionHandler) + { + _connectionHandler = connectionHandler; + } + + public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connection) + { + var input = connection.Transport.Input; + var readResult = await input.ReadAsync(); + var isMqtt = IsMqttRequest(readResult); + input.AdvanceTo(readResult.Buffer.Start); + + if (isMqtt) + { + await _connectionHandler.OnConnectedAsync(connection); + } + else + { + await next(connection); + } + } + + private static bool IsMqttRequest(ReadResult readResult) + { + var span = readResult.Buffer.FirstSpan; + if (span.Length > 4) + { + span = span[4..]; + return span.StartsWith(_mqtt) || span.StartsWith(_MQIsdp); + } + + return false; + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 181a3afa4..3d905e7b0 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -20,36 +20,39 @@ public static class KestrelServerOptionsExtensions /// Listen all endponts in MqttServerOptions /// /// + /// /// /// - public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel) + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols = MqttProtocols.MqttAndHttp) { - return kestrel.ListenMqtt(default(Action)); + return kestrel.ListenMqtt(protocols, default(Action)); } /// /// Listen all endponts in MqttServerOptions /// /// + /// /// /// /// - public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, X509Certificate2? serverCertificate) + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols, X509Certificate2? serverCertificate) { - return kestrel.ListenMqtt(tls => tls.ServerCertificate = serverCertificate); + return kestrel.ListenMqtt(protocols, tls => tls.ServerCertificate = serverCertificate); } /// /// Listen all endponts in MqttServerOptions /// /// + /// /// /// /// - public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, Action? tlsConfigure) + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols, Action? tlsConfigure) { - var serverOptions = kestrel.ApplicationServices.GetRequiredService(); var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); + var serverOptions = kestrel.ApplicationServices.GetRequiredService(); Listen(serverOptions.DefaultEndpointOptions); Listen(serverOptions.TlsEndpointOptions); @@ -84,7 +87,7 @@ void UseMiddleware(ListenOptions listenOptions) tlsConfigure?.Invoke(httpsOptions); }); } - listenOptions.UseMqtt(); + listenOptions.UseMqtt(protocols); } } } diff --git a/Source/MQTTnet.AspnetCore/MqttProtocols.cs b/Source/MQTTnet.AspnetCore/MqttProtocols.cs new file mode 100644 index 000000000..55d2648d4 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttProtocols.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace MQTTnet.AspNetCore +{ + public enum MqttProtocols + { + /// + /// Only support Mqtt connection + /// + Mqtt, + + /// + /// Support both Mqtt and Mqtt over WebSocket connection + /// + MqttAndHttp + } +} diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index b131337e7..ed5dec25a 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -7,11 +7,24 @@ using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Server; +using System; namespace MQTTnet.AspNetCore; public static class ServiceCollectionExtensions { + /// + /// Register MqttServer as a singleton service + /// + /// + /// + /// + public static IMqttServerBuilder AddMqttServer(this IServiceCollection services, Action configure) + { + services.Configure(configure); + return services.AddMqttServer(); + } + /// /// Register MqttServer as a singleton service /// @@ -22,6 +35,7 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) services.AddOptions(); services.AddConnections(); services.TryAddSingleton(); + services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(s => s.GetRequiredService().BuildServerOptions()); services.TryAddSingleton(s => s.GetRequiredService().BuildServerStopOptions()); From 1bd594991328230e92743c176aee5f66f9a43e1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 11:48:21 +0800 Subject: [PATCH 44/85] MqttProtocols adds WebSocket item. --- Samples/Server/Server_ASP_NET_Samples.cs | 2 +- Source/MQTTnet.AspTestApp/Program.cs | 2 +- .../MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs | 8 ++++++-- .../EndpointRouteBuilderExtensions.cs | 4 ++-- .../KestrelServerOptionsExtensions.cs | 2 +- Source/MQTTnet.AspnetCore/MqttProtocols.cs | 11 ++++++++--- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index b8250ab12..d704dc1eb 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -85,7 +85,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) using var client = _mqttClientFactory.CreateMqttClient(); // var mqttUri = "mqtt://localhost:1883"; - // var mqttsUri = "mqtt://localhost:8883"; + // var mqttsUri = "mqtts://localhost:8883"; // var wsMqttUri = "ws://localhost:1883/mqtt"; var wssMqttUri = "wss://localhost:8883/mqtt"; diff --git a/Source/MQTTnet.AspTestApp/Program.cs b/Source/MQTTnet.AspTestApp/Program.cs index 6d4de4df7..0b2b48571 100644 --- a/Source/MQTTnet.AspTestApp/Program.cs +++ b/Source/MQTTnet.AspTestApp/Program.cs @@ -16,7 +16,7 @@ // ListenMqtt builder.WebHost.UseKestrel(kestrel => { - kestrel.ListenMqtt(); + kestrel.ListenMqtt(MqttProtocols.WebSocket); }); var app = builder.Build(); diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index d9a11da6d..a9bf0c589 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -16,14 +16,18 @@ public static class ConnectionBuilderExtensions /// /// /// - public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndHttp) + public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket) { builder.ApplicationServices.GetRequiredService().UseFlag = true; if (protocols == MqttProtocols.Mqtt) { return builder.UseConnectionHandler(); } - else if (protocols == MqttProtocols.MqttAndHttp) + else if (protocols == MqttProtocols.WebSocket) + { + return builder; + } + else if (protocols == MqttProtocols.MqttAndWebSocket) { var middleware = builder.ApplicationServices.GetRequiredService(); return builder.Use(next => context => middleware.InvokeAsync(next, context)); diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs index faf64ad72..511610907 100644 --- a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -16,7 +16,7 @@ namespace MQTTnet.AspNetCore public static class EndpointRouteBuilderExtensions { /// - /// Treat the obtained WebSocket as an mqtt connection + /// Specify the matching path for mqtt-over-websocket /// /// /// @@ -33,7 +33,7 @@ static string SelectSubProtocol(IList requestedSubProtocolValues) } /// - /// Treat the obtained WebSocket as an mqtt connection + /// Specify the matching path for mqtt-over-websocket /// /// /// diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 3d905e7b0..15868846d 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -23,7 +23,7 @@ public static class KestrelServerOptionsExtensions /// /// /// - public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols = MqttProtocols.MqttAndHttp) + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket) { return kestrel.ListenMqtt(protocols, default(Action)); } diff --git a/Source/MQTTnet.AspnetCore/MqttProtocols.cs b/Source/MQTTnet.AspnetCore/MqttProtocols.cs index 55d2648d4..f1701445c 100644 --- a/Source/MQTTnet.AspnetCore/MqttProtocols.cs +++ b/Source/MQTTnet.AspnetCore/MqttProtocols.cs @@ -7,13 +7,18 @@ namespace MQTTnet.AspNetCore public enum MqttProtocols { /// - /// Only support Mqtt connection + /// Only support Mqtt /// Mqtt, /// - /// Support both Mqtt and Mqtt over WebSocket connection + /// Only support Mqtt-over-WebSocket /// - MqttAndHttp + WebSocket, + + /// + /// Support both Mqtt and Mqtt-over-WebSocket + /// + MqttAndWebSocket } } From c4d9a2246af1359aa4a2475a11a9ad5507bc46f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 12:30:59 +0800 Subject: [PATCH 45/85] Make sure services.AddMqttServer() has been called before operating Mqtt. --- Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs | 6 +++++- Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs | 4 ++++ Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs | 3 +++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index a9bf0c589..5f423076b 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Server; using System; namespace MQTTnet.AspNetCore @@ -11,13 +12,16 @@ namespace MQTTnet.AspNetCore public static class ConnectionBuilderExtensions { /// - /// Treat the obtained connection as an mqtt connection + /// Handle the connection using the specified MQTT protocols /// /// /// /// public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket) { + // check services.AddMqttServer() + builder.ApplicationServices.GetRequiredService(); + builder.ApplicationServices.GetRequiredService().UseFlag = true; if (protocols == MqttProtocols.Mqtt) { diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs index 511610907..dc1c73098 100644 --- a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Server; using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; @@ -41,6 +42,9 @@ static string SelectSubProtocol(IList requestedSubProtocolValues) /// public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern, Action options) { + // check services.AddMqttServer() + endpoints.ServiceProvider.GetRequiredService(); + endpoints.ServiceProvider.GetRequiredService().MapFlag = true; return endpoints.MapConnectionHandler(pattern, options); } diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 15868846d..2c8a4988e 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -51,6 +51,9 @@ public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, /// public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols, Action? tlsConfigure) { + // check services.AddMqttServer() + kestrel.ApplicationServices.GetRequiredService(); + var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); var serverOptions = kestrel.ApplicationServices.GetRequiredService(); From 62313c8f5d20250dda718540415c3f39350499e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 13:44:04 +0800 Subject: [PATCH 46/85] Simplify MqttConnectionMiddleware --- .../ConnectionBuilderExtensions.cs | 18 ++-------- .../Internal/MqttConnectionMiddleware.cs | 33 ++++++++++++------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index 5f423076b..0dd7dc28b 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -5,7 +5,6 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using MQTTnet.Server; -using System; namespace MQTTnet.AspNetCore { @@ -21,23 +20,10 @@ public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttPr { // check services.AddMqttServer() builder.ApplicationServices.GetRequiredService(); - builder.ApplicationServices.GetRequiredService().UseFlag = true; - if (protocols == MqttProtocols.Mqtt) - { - return builder.UseConnectionHandler(); - } - else if (protocols == MqttProtocols.WebSocket) - { - return builder; - } - else if (protocols == MqttProtocols.MqttAndWebSocket) - { - var middleware = builder.ApplicationServices.GetRequiredService(); - return builder.Use(next => context => middleware.InvokeAsync(next, context)); - } - throw new NotSupportedException(protocols.ToString()); + var middleware = builder.ApplicationServices.GetRequiredService(); + return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols)); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs index 77a9dd870..ffbb5848c 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs @@ -10,34 +10,43 @@ namespace MQTTnet.AspNetCore; /// -/// Middleware that allows connections to be either HTTP or MQTT +/// Middleware that connection using the specified MQTT protocols /// sealed class MqttConnectionMiddleware { private static readonly byte[] _mqtt = "MQTT"u8.ToArray(); - private static readonly byte[] _MQIsdp = "MQIsdp"u8.ToArray(); + private static readonly byte[] _mqisdp = "MQIsdp"u8.ToArray(); private readonly MqttConnectionHandler _connectionHandler; public MqttConnectionMiddleware(MqttConnectionHandler connectionHandler) - { + { _connectionHandler = connectionHandler; } - public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connection) + public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connection, MqttProtocols protocols) { - var input = connection.Transport.Input; - var readResult = await input.ReadAsync(); - var isMqtt = IsMqttRequest(readResult); - input.AdvanceTo(readResult.Buffer.Start); + if (protocols == MqttProtocols.MqttAndWebSocket) + { + var input = connection.Transport.Input; + var readResult = await input.ReadAsync(); + var isMqtt = IsMqttRequest(readResult); + input.AdvanceTo(readResult.Buffer.Start); + + protocols = isMqtt ? MqttProtocols.Mqtt : MqttProtocols.WebSocket; + } - if (isMqtt) + if (protocols == MqttProtocols.Mqtt) { await _connectionHandler.OnConnectedAsync(connection); } - else + else if (protocols == MqttProtocols.WebSocket) { await next(connection); } + else + { + throw new NotSupportedException(protocols.ToString()); + } } private static bool IsMqttRequest(ReadResult readResult) @@ -45,8 +54,8 @@ private static bool IsMqttRequest(ReadResult readResult) var span = readResult.Buffer.FirstSpan; if (span.Length > 4) { - span = span[4..]; - return span.StartsWith(_mqtt) || span.StartsWith(_MQIsdp); + var protocol = span[4..]; + return protocol.StartsWith(_mqtt) || protocol.StartsWith(_mqisdp); } return false; From e087acb2267f458da37baee6031cb4a130269955 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 15:49:44 +0800 Subject: [PATCH 47/85] Improve the compatibility of wss connections. --- .../ClientConnectionContext.WebSocket.cs | 27 ++++++++++++------- .../MQTTnet.AspNetCore.csproj | 6 ----- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs index 5aea0e324..4ac414ebd 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Http.Features; +using MQTTnet.Exceptions; using System; using System.IO; using System.Net; @@ -16,13 +17,21 @@ partial class ClientConnectionContext { public static async Task CreateAsync(MqttClientWebSocketOptions options, CancellationToken cancellationToken) { - var clientWebSocket = new ClientWebSocket(); var uri = new Uri(options.Uri, UriKind.Absolute); - var useTls = options.TlsOptions?.UseTls == true || uri.Scheme == Uri.UriSchemeWss; + if (uri.Scheme != Uri.UriSchemeWs && uri.Scheme != Uri.UriSchemeWss) + { + throw new MqttConfigurationException("The scheme of the WebSocket Uri must be ws or wss."); + } + // Patching TlsOptions + options.TlsOptions ??= new MqttClientTlsOptions(); + // Scheme decides whether to use TLS + options.TlsOptions.UseTls = uri.Scheme == Uri.UriSchemeWss; + + var clientWebSocket = new ClientWebSocket(); try { - SetupClientWebSocket(clientWebSocket.Options, options, useTls); + SetupClientWebSocket(clientWebSocket.Options, options); await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); } catch @@ -46,7 +55,7 @@ public static async Task CreateAsync(MqttClientWebSocke return connection; } - private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions, MqttClientWebSocketOptions options, bool useTls) + private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions, MqttClientWebSocketOptions options) { if (options.ProxyOptions != null) { @@ -74,9 +83,9 @@ private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions webSocketOptions.Cookies = options.CookieContainer; } - if (useTls) + if (options.TlsOptions.UseTls) { - var certificates = options.TlsOptions?.ClientCertificatesProvider?.GetCertificates(); + var certificates = options.TlsOptions.ClientCertificatesProvider?.GetCertificates(); if (certificates?.Count > 0) { webSocketOptions.ClientCertificates = certificates; @@ -100,7 +109,7 @@ private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions webSocketOptions.Credentials = options.Credentials; } - var certificateValidationHandler = options.TlsOptions?.CertificateValidationHandler; + var certificateValidationHandler = options.TlsOptions.CertificateValidationHandler; if (certificateValidationHandler != null) { webSocketOptions.RemoteCertificateValidationCallback = (_, certificate, chain, sslPolicyErrors) => @@ -110,7 +119,7 @@ private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions return certificateValidationHandler(context); }; - var certificateSelectionHandler = options.TlsOptions?.CertificateSelectionHandler; + var certificateSelectionHandler = options.TlsOptions.CertificateSelectionHandler; if (certificateSelectionHandler != null) { throw new NotSupportedException("Remote certificate selection callback is not supported for WebSocket connections."); @@ -148,7 +157,7 @@ private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions } - private class WebSocketStream(WebSocket webSocket) : Stream + private sealed class WebSocketStream(WebSocket webSocket) : Stream { private readonly WebSocket _webSocket = webSocket; diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 537e5e258..8864320ac 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -1,4 +1,3 @@ - @@ -47,7 +46,6 @@ True \ - @@ -55,12 +53,8 @@ - - - - From db4614c5250bedf07c7f14cc1c1819298c38f35a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 19:22:15 +0800 Subject: [PATCH 48/85] Update UnixSocket sample. --- Samples/Server/Server_ASP_NET_Samples.cs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index d704dc1eb..5ee50b357 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -20,8 +20,12 @@ namespace MQTTnet.Samples.Server; public static class Server_ASP_NET_Samples { + static readonly string unixSocketPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "mqtt.socks"); + public static Task Start_Server_With_WebSockets_Support() { + File.Delete(unixSocketPath); + var builder = WebApplication.CreateBuilder(); builder.Services.AddMqttServer(s => s.WithDefaultEndpoint().WithEncryptedEndpoint()); builder.Services.AddMqttClient(); @@ -33,8 +37,9 @@ public static Task Start_Server_With_WebSockets_Support() kestrel.ListenMqtt(); // We can also manually listen to a specific port without ConfigureMqttServer() + kestrel.ListenUnixSocket(unixSocketPath, l => l.UseMqtt()); // kestrel.ListenAnyIP(1883, l => l.UseMqtt()); // mqtt over tcp - // kestrel.ListenLocalhost(8883, l => l.UseHttps().UseMqtt()); // mqtt over tls over tcp + // kestrel.ListenAnyIP(8883, l => l.UseHttps().UseMqtt()); // mqtt over tls over tcp }); var app = builder.Build(); @@ -90,6 +95,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) var wssMqttUri = "wss://localhost:8883/mqtt"; var options = new MqttClientOptionsBuilder() + //.WithEndPoint(new UnixDomainSocketEndPoint(unixSocketPath)) .WithConnectionUri(wssMqttUri) .Build(); From acf8ed77b758b3811c29dcf79d4bba57dc547825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 18 Nov 2024 21:01:17 +0800 Subject: [PATCH 49/85] Update benchmark --- ...rocessingMqttConnectionContextBenchmark.cs | 73 ++++++++++++------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs index 5da248e20..66aa12581 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs @@ -4,10 +4,13 @@ using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Jobs; -using Microsoft.AspNetCore; +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; using MQTTnet.AspNetCore; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Server.Internal.Adapter; +using System.Threading.Tasks; namespace MQTTnet.Benchmarks { @@ -15,53 +18,67 @@ namespace MQTTnet.Benchmarks [MemoryDiagnoser] public class MessageProcessingMqttConnectionContextBenchmark : BaseBenchmark { - IWebHost _host; - IMqttClient _mqttClient; + WebApplication _app; + IMqttClient _aspNetCoreMqttClient; + IMqttClient _mqttNetMqttClient; MqttApplicationMessage _message; - [GlobalSetup] - public void Setup() - { - _host = WebHost.CreateDefaultBuilder() - .UseKestrel(o => o.ListenAnyIP(1883, l => l.UseMqtt())) - .ConfigureServices(services => - { - services.AddMqttServer(); - services.AddMqttClient(); - }) - .Build(); + [Params(1 * 1024, 8 * 1024, 64 * 1024)] + public int PayloadSize { get; set; } - var factory = _host.Services.GetRequiredService(); - _mqttClient = factory.CreateMqttClient(); - _host.StartAsync().GetAwaiter().GetResult(); + [GlobalSetup] + public async Task Setup() + { + var builder = WebApplication.CreateBuilder(); - var clientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("localhost").Build(); + builder.Services.AddMqttServer(s => s.WithDefaultEndpoint()).AddMqttServerAdapter().UseMqttNetNullLogger(); + builder.Services.AddMqttClient(); + builder.WebHost.UseKestrel(o => + { + o.ListenAnyIP(1884, l => l.UseMqtt(MqttProtocols.Mqtt)); + }); - _mqttClient.ConnectAsync(clientOptions).GetAwaiter().GetResult(); + _app = builder.Build(); + await _app.StartAsync(); _message = new MqttApplicationMessageBuilder() .WithTopic("A") + .WithPayload(new byte[PayloadSize]) .Build(); + + _aspNetCoreMqttClient = _app.Services.GetRequiredService().CreateMqttClient(); + var clientOptions = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://localhost:1884").Build(); + await _aspNetCoreMqttClient.ConnectAsync(clientOptions); + + clientOptions = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://localhost:1883").Build(); + _mqttNetMqttClient = new MqttClientFactory().CreateMqttClient(MqttNetNullLogger.Instance); + await _mqttNetMqttClient.ConnectAsync(clientOptions); } [GlobalCleanup] - public void Cleanup() + public async Task Cleanup() { - _mqttClient.DisconnectAsync().GetAwaiter().GetResult(); - _mqttClient.Dispose(); + await _aspNetCoreMqttClient.DisconnectAsync(); + _aspNetCoreMqttClient.Dispose(); + await _app.StopAsync(); + } - _host.StopAsync().GetAwaiter().GetResult(); - _host.Dispose(); + [Benchmark(Baseline = true)] + public async Task AspNetCore_Send_1000_Messages() + { + for (var i = 0; i < 1000; i++) + { + await _aspNetCoreMqttClient.PublishAsync(_message); + } } [Benchmark] - public void Send_10000_Messages() + public async Task MQTTnet_Send_1000_Messages() { - for (var i = 0; i < 10000; i++) + for (var i = 0; i < 1000; i++) { - _mqttClient.PublishAsync(_message).GetAwaiter().GetResult(); + await _mqttNetMqttClient.PublishAsync(_message); } } } From 457dcc1ceda975e27f886cd851457842ff18d9c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Tue, 19 Nov 2024 09:57:01 +0800 Subject: [PATCH 50/85] Enhanced IOptions of MQTT ServiceOptions --- .../Internal/AspNetCoreMqttOptionsBuilder.cs | 33 ------------ .../Internal/MqttOptionsFactory.cs | 53 +++++++++++++++++++ .../Internal/MqttServerOptionsFactory.cs | 30 +++++++++++ .../Internal/MqttServerStopOptionsFactory.cs | 26 +++++++++ .../MqttServerBuilderExtensions.cs | 40 +++++++++++--- .../ServiceCollectionExtensions.cs | 7 +-- 6 files changed, 146 insertions(+), 43 deletions(-) delete mode 100644 Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs create mode 100644 Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs create mode 100644 Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs create mode 100644 Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs deleted file mode 100644 index 5122529b0..000000000 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttOptionsBuilder.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.Extensions.Options; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore -{ - sealed class AspNetCoreMqttOptionsBuilder - { - private readonly MqttServerOptionsBuilder _serverOptionsBuilder; - private readonly MqttServerStopOptionsBuilder _stopOptionsBuilder; - - public AspNetCoreMqttOptionsBuilder( - IOptions serverOptionsBuilderOptions, - IOptions stopOptionsBuilderOptions) - { - _serverOptionsBuilder = serverOptionsBuilderOptions.Value; - _stopOptionsBuilder = stopOptionsBuilderOptions.Value; - } - - public MqttServerOptions BuildServerOptions() - { - return _serverOptionsBuilder.Build(); - } - - public MqttServerStopOptions BuildServerStopOptions() - { - return _stopOptionsBuilder.Build(); - } - } -} diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs new file mode 100644 index 000000000..b7bdc292f --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using System.Collections.Generic; + +namespace MQTTnet.AspNetCore +{ + abstract class MqttOptionsFactory + where TOptionsBuilder : class + where TOptions : class + { + private readonly IEnumerable> _setups; + private readonly IEnumerable> _postConfigures; + protected TOptionsBuilder OptionsBuilder { get; } + + public MqttOptionsFactory( + IOptions optionsBuilderOptions, + IEnumerable> setups, + IEnumerable> postConfigures) + { + OptionsBuilder = optionsBuilderOptions.Value; + _setups = setups; + _postConfigures = postConfigures; + } + + public TOptions Build() + { + var options = CreateOptions(); + var name = Options.DefaultName; + + foreach (var setup in _setups) + { + if (setup is IConfigureNamedOptions namedSetup) + { + namedSetup.Configure(name, options); + } + else if (name == Options.DefaultName) + { + setup.Configure(options); + } + } + foreach (var post in _postConfigures) + { + post.PostConfigure(name, options); + } + return options; + } + + protected abstract TOptions CreateOptions(); + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs new file mode 100644 index 000000000..1d0fee3d1 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using MQTTnet.Server; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; + +namespace MQTTnet.AspNetCore +{ + sealed class MqttServerOptionsFactory : MqttOptionsFactory + { + public MqttServerOptionsFactory( + IOptions optionsBuilderOptions, + IEnumerable> setups, + IEnumerable> postConfigures) + : base(optionsBuilderOptions, setups, postConfigures) + { + } + + protected override MqttServerOptions CreateOptions() + { + return base.OptionsBuilder.Build(); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs new file mode 100644 index 000000000..e26d3615e --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using MQTTnet.Server; +using System.Collections.Generic; + +namespace MQTTnet.AspNetCore +{ + sealed class MqttServerStopOptionsFactory : MqttOptionsFactory + { + public MqttServerStopOptionsFactory( + IOptions optionsBuilderOptions, + IEnumerable> setups, + IEnumerable> postConfigures) + : base(optionsBuilderOptions, setups, postConfigures) + { + } + + protected override MqttServerStopOptions CreateOptions() + { + return OptionsBuilder.Build(); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index 813313ad8..869fe19e1 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -16,11 +16,24 @@ public static class MqttServerBuilderExtensions /// Configure MqttServerOptionsBuilder /// /// - /// + /// /// - public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder builder, Action configure) + public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder builder, Action builderConfigure) { - builder.Services.Configure(configure); + builder.Services.Configure(builderConfigure); + return builder; + } + + /// + /// Configure MqttServerOptionsBuilder and MqttServerOptions + /// + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder builder, Action builderConfigure, Action optionsConfigure) + { + builder.Services.Configure(builderConfigure).Configure(optionsConfigure); return builder; } @@ -28,14 +41,27 @@ public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder bui /// Configure MqttServerStopOptionsBuilder /// /// - /// + /// /// - public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder builder, Action configure) + public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder builder, Action builderConfigure) { - builder.Services.Configure(configure); + builder.Services.Configure(builderConfigure); return builder; } - + + /// + /// Configure MqttServerStopOptionsBuilder and MqttServerStopOptions + /// + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder builder, Action builderConfigure, Action optionsConfigure) + { + builder.Services.Configure(builderConfigure).Configure(optionsConfigure); + return builder; + } + /// /// Add an IMqttServerAdapter to MqttServer /// diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index ed5dec25a..987d4deb7 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -36,9 +36,10 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) services.AddConnections(); services.TryAddSingleton(); services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddSingleton(s => s.GetRequiredService().BuildServerOptions()); - services.TryAddSingleton(s => s.GetRequiredService().BuildServerStopOptions()); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(s => s.GetRequiredService().Build()); + services.TryAddSingleton(s => s.GetRequiredService().Build()); services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddSingleton(); From 358c2c04fc8732c01ac4a51ddaf21adc0ed00c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Tue, 19 Nov 2024 10:16:06 +0800 Subject: [PATCH 51/85] add IMqttServerBuilder.ConfigureMqttSocketTransport extension. --- .../MqttServerBuilderExtensions.cs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index 869fe19e1..f8cf714a1 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using MQTTnet.Server; @@ -62,6 +63,18 @@ public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder return builder; } + /// + /// Configure the socket of mqtt listener + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttSocketTransport(this IMqttServerBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder; + } + /// /// Add an IMqttServerAdapter to MqttServer /// From 02778aa81846aca5ca494933a0e85791546b316c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Tue, 19 Nov 2024 10:37:12 +0800 Subject: [PATCH 52/85] Simplifying MqttOptionsFactory --- .../Internal/MqttOptionsFactory.cs | 15 ++++++--------- .../Internal/MqttServerOptionsFactory.cs | 13 ++----------- .../Internal/MqttServerStopOptionsFactory.cs | 9 ++------- 3 files changed, 10 insertions(+), 27 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs index b7bdc292f..cad7822ce 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs @@ -3,31 +3,30 @@ // See the LICENSE file in the project root for more information. using Microsoft.Extensions.Options; +using System; using System.Collections.Generic; namespace MQTTnet.AspNetCore { - abstract class MqttOptionsFactory - where TOptionsBuilder : class - where TOptions : class + class MqttOptionsFactory where TOptions : class { + private readonly Func _defaultOptionsFactory; private readonly IEnumerable> _setups; private readonly IEnumerable> _postConfigures; - protected TOptionsBuilder OptionsBuilder { get; } public MqttOptionsFactory( - IOptions optionsBuilderOptions, + Func defaultOptionsFactory, IEnumerable> setups, IEnumerable> postConfigures) { - OptionsBuilder = optionsBuilderOptions.Value; + _defaultOptionsFactory = defaultOptionsFactory; _setups = setups; _postConfigures = postConfigures; } public TOptions Build() { - var options = CreateOptions(); + var options = _defaultOptionsFactory(); var name = Options.DefaultName; foreach (var setup in _setups) @@ -47,7 +46,5 @@ public TOptions Build() } return options; } - - protected abstract TOptions CreateOptions(); } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs index 1d0fee3d1..452d29e3c 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs @@ -4,27 +4,18 @@ using Microsoft.Extensions.Options; using MQTTnet.Server; -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - using System.Collections.Generic; namespace MQTTnet.AspNetCore { - sealed class MqttServerOptionsFactory : MqttOptionsFactory + sealed class MqttServerOptionsFactory : MqttOptionsFactory { public MqttServerOptionsFactory( IOptions optionsBuilderOptions, IEnumerable> setups, IEnumerable> postConfigures) - : base(optionsBuilderOptions, setups, postConfigures) - { - } - - protected override MqttServerOptions CreateOptions() + : base(optionsBuilderOptions.Value.Build, setups, postConfigures) { - return base.OptionsBuilder.Build(); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs index e26d3615e..b49570e85 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs @@ -8,19 +8,14 @@ namespace MQTTnet.AspNetCore { - sealed class MqttServerStopOptionsFactory : MqttOptionsFactory + sealed class MqttServerStopOptionsFactory : MqttOptionsFactory { public MqttServerStopOptionsFactory( IOptions optionsBuilderOptions, IEnumerable> setups, IEnumerable> postConfigures) - : base(optionsBuilderOptions, setups, postConfigures) + : base(optionsBuilderOptions.Value.Build, setups, postConfigures) { } - - protected override MqttServerStopOptions CreateOptions() - { - return OptionsBuilder.Build(); - } } } From b46b3571a78e559b6d77a5f4e76c5e2afc0c56f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 20 Nov 2024 00:34:01 +0800 Subject: [PATCH 53/85] Optimizing MqttChannel.SendPacketAsync --- .../Internal/MqttChannel.cs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index d23f4839c..3ed995792 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -26,6 +26,7 @@ class MqttChannel : IDisposable readonly PipeReader _input; readonly PipeWriter _output; readonly MqttPacketInspector? _packetInspector; + readonly bool _serverModeWebSocket; public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } @@ -53,11 +54,18 @@ public MqttChannel( Endpoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint); IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature); ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature); + _serverModeWebSocket = IsServerModeWebSocket(httpContextFeature); _input = connection.Transport.Input; _output = connection.Transport.Output; } + private static bool IsServerModeWebSocket(IHttpContextFeature? _httpContextFeature) + { + return _httpContextFeature != null && _httpContextFeature.HttpContext != null && _httpContextFeature.HttpContext.WebSockets.IsWebSocketRequest; + } + + private static string? GetRemoteEndPoint(IHttpContextFeature? _httpContextFeature, EndPoint? remoteEndPoint) { if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) @@ -188,11 +196,20 @@ public async Task SendPacketAsync(MqttPacket packet, CancellationToken cancellat // https://github.com/dotnet/runtime/blob/e31ddfdc4f574b26231233dc10c9a9c402f40590/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeWriter.cs#L279 await _output.WriteAsync(buffer.Packet, cancellationToken).ConfigureAwait(false); } - else + else if (_serverModeWebSocket) // server channel, and client is MQTT over WebSocket { + // Make sure the MQTT packet is in a WebSocket frame to be compatible with JavaScript WebSocket WritePacketBuffer(_output, buffer); await _output.FlushAsync(cancellationToken).ConfigureAwait(false); } + else + { + await _output.WriteAsync(buffer.Packet, cancellationToken).ConfigureAwait(false); + foreach (var block in buffer.Payload) + { + await _output.WriteAsync(block, cancellationToken).ConfigureAwait(false); + } + } BytesSent += buffer.Length; } From 38eb1c388e4c8449c5fa55c4b8d5110fd51c9e18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 20 Nov 2024 09:57:19 +0800 Subject: [PATCH 54/85] Check buffer IsEmpty. --- .../Internal/MqttConnectionMiddleware.cs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs index ffbb5848c..4a57d6ab2 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs @@ -4,7 +4,7 @@ using Microsoft.AspNetCore.Connections; using System; -using System.IO.Pipelines; +using System.Buffers; using System.Threading.Tasks; namespace MQTTnet.AspNetCore; @@ -29,7 +29,7 @@ public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connect { var input = connection.Transport.Input; var readResult = await input.ReadAsync(); - var isMqtt = IsMqttRequest(readResult); + var isMqtt = IsMqttRequest(readResult.Buffer); input.AdvanceTo(readResult.Buffer.Start); protocols = isMqtt ? MqttProtocols.Mqtt : MqttProtocols.WebSocket; @@ -49,13 +49,16 @@ public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connect } } - private static bool IsMqttRequest(ReadResult readResult) + private static bool IsMqttRequest(ReadOnlySequence buffer) { - var span = readResult.Buffer.FirstSpan; - if (span.Length > 4) + if (!buffer.IsEmpty) { - var protocol = span[4..]; - return protocol.StartsWith(_mqtt) || protocol.StartsWith(_mqisdp); + var span = buffer.FirstSpan; + if (span.Length > 4) + { + var protocol = span[4..]; + return protocol.StartsWith(_mqtt) || protocol.StartsWith(_mqisdp); + } } return false; From ef838f8167389e2e85d0c17f3727a98d5778c2b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 30 Nov 2024 22:47:40 +0800 Subject: [PATCH 55/85] Register MqttClientFactory as a service. --- .../ApplicationBuilderExtensions.cs | 4 ++-- .../MQTTnet.AspnetCore/IMqttClientBuilder.cs | 2 +- .../MQTTnet.AspnetCore/IMqttServerBuilder.cs | 4 +++- .../Internal/AspNetCoreMqttClientFactory.cs | 22 +++---------------- .../KestrelServerOptionsExtensions.cs | 6 ++--- .../MqttBuilderExtensions.cs | 6 ++--- .../MqttClientBuilderExtensions.cs | 6 ++--- .../MqttServerBuilderExtensions.cs | 10 ++++----- .../ServiceCollectionExtensions.cs | 10 +++++---- 9 files changed, 29 insertions(+), 41 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index 4ab0eb653..e01a6556c 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -13,9 +13,9 @@ namespace MQTTnet.AspNetCore; public static class ApplicationBuilderExtensions { /// - /// Get and use MqttServer + /// Get and use /// - /// Also, you can inject MqttServer into your service + /// Also, you can inject into your service /// /// /// diff --git a/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs index f7fcda33c..575ce7c61 100644 --- a/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs +++ b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs @@ -5,7 +5,7 @@ namespace MQTTnet.AspNetCore { /// - /// Builder of IMqttClientFactory + /// Builder of /// public interface IMqttClientBuilder: IMqttBuilder { diff --git a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs index 1b6057bf5..28c71acab 100644 --- a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs +++ b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs @@ -2,10 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Server; + namespace MQTTnet.AspNetCore { /// - /// Builder of MqttServer + /// Builder of /// public interface IMqttServerBuilder : IMqttBuilder { diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs index 4c46755c6..7d38129dc 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs @@ -4,31 +4,15 @@ using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; -using MQTTnet.LowLevelClient; namespace MQTTnet.AspNetCore { - sealed class AspNetCoreMqttClientFactory : IMqttClientFactory + sealed class AspNetCoreMqttClientFactory : MqttClientFactory, IMqttClientFactory { - private readonly IMqttClientAdapterFactory _mqttClientAdapterFactory; - private readonly IMqttNetLogger _logger; - public AspNetCoreMqttClientFactory( - IMqttClientAdapterFactory mqttClientAdapterFactory, - IMqttNetLogger logger) - { - _mqttClientAdapterFactory = mqttClientAdapterFactory; - _logger = logger; - } - - public IMqttClient CreateMqttClient() - { - return new MqttClient(_mqttClientAdapterFactory, _logger); - } - - public ILowLevelMqttClient CreateLowLevelMqttClient() + IMqttNetLogger logger, + IMqttClientAdapterFactory clientAdapterFactory) : base(logger, clientAdapterFactory) { - return new LowLevelMqttClient(_mqttClientAdapterFactory, _logger); } } } diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 2c8a4988e..1ca1aa8dd 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -17,7 +17,7 @@ namespace MQTTnet.AspNetCore public static class KestrelServerOptionsExtensions { /// - /// Listen all endponts in MqttServerOptions + /// Listen all endponts in /// /// /// @@ -29,7 +29,7 @@ public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, } /// - /// Listen all endponts in MqttServerOptions + /// Listen all endponts in /// /// /// @@ -42,7 +42,7 @@ public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, } /// - /// Listen all endponts in MqttServerOptions + /// Listen all endponts in /// /// /// diff --git a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs index 7cca5696b..d104e3562 100644 --- a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs @@ -13,7 +13,7 @@ namespace MQTTnet.AspNetCore public static class MqttBuilderExtensions { /// - /// Use AspNetCoreMqttNetLogger as IMqttNetLogger + /// Use as /// /// /// @@ -25,7 +25,7 @@ public static IMqttBuilder UseAspNetCoreMqttNetLogger(this IMqttBuilder builder, } /// - /// Use AspNetCoreMqttNetLogger as IMqttNetLogger + /// Use as /// /// /// @@ -35,7 +35,7 @@ public static IMqttBuilder UseAspNetCoreMqttNetLogger(this IMqttBuilder builder) } /// - /// Use MqttNetNullLogger as IMqttNetLogger + /// Use as /// /// /// diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs index d41a20567..1c79663c4 100644 --- a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -14,7 +14,7 @@ namespace MQTTnet.AspNetCore public static class MqttClientBuilderExtensions { /// - /// Replace the implementation of IMqttClientAdapterFactory to MQTTnet.Implementations.MqttClientAdapterFactory + /// Replace the implementation of to /// /// /// @@ -24,7 +24,7 @@ public static IMqttClientBuilder UseMQTTnetMqttClientAdapterFactory(this IMqttCl } /// - /// Replace the implementation of IMqttClientAdapterFactory to AspNetCoreMqttClientAdapterFactory + /// Replace the implementation of to /// /// /// @@ -34,7 +34,7 @@ public static IMqttClientBuilder UseAspNetCoreMqttClientAdapterFactory(this IMqt } /// - /// Replace the implementation of IMqttClientAdapterFactory to TMqttClientAdapterFactory + /// Replace the implementation of to TMqttClientAdapterFactory /// /// /// diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index f8cf714a1..82a0dd00b 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -14,7 +14,7 @@ namespace MQTTnet.AspNetCore public static class MqttServerBuilderExtensions { /// - /// Configure MqttServerOptionsBuilder + /// Configure /// /// /// @@ -26,7 +26,7 @@ public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder bui } /// - /// Configure MqttServerOptionsBuilder and MqttServerOptions + /// Configure and /// /// /// @@ -39,7 +39,7 @@ public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder bui } /// - /// Configure MqttServerStopOptionsBuilder + /// Configure /// /// /// @@ -51,7 +51,7 @@ public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder } /// - /// Configure MqttServerStopOptionsBuilder and MqttServerStopOptions + /// Configure and /// /// /// @@ -76,7 +76,7 @@ public static IMqttServerBuilder ConfigureMqttSocketTransport(this IMqttServerBu } /// - /// Add an IMqttServerAdapter to MqttServer + /// Add an to /// /// /// diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 987d4deb7..6d8c427e0 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -14,7 +14,7 @@ namespace MQTTnet.AspNetCore; public static class ServiceCollectionExtensions { /// - /// Register MqttServer as a singleton service + /// Register as a singleton service /// /// /// @@ -26,7 +26,7 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services, } /// - /// Register MqttServer as a singleton service + /// Register as a singleton service /// /// /// @@ -50,14 +50,16 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) } /// - /// Register IMqttClientFactory as a singleton service + /// Register and as singleton service /// /// /// public static IMqttClientBuilder AddMqttClient(this IServiceCollection services) { services.TryAddSingleton(); - services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(s => s.GetRequiredService()); + services.TryAddSingleton(s => s.GetRequiredService()); return services.AddMqtt(); } From 72b42d5f8a809fd9f701b0976318dec03f8b6497 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 2 Dec 2024 13:25:33 +0800 Subject: [PATCH 56/85] MqttOptionsFactory.Build() -> MqttOptionsFactory.CreateOptions() --- Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs | 2 +- Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs index cad7822ce..18d21c9a1 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs @@ -24,7 +24,7 @@ public MqttOptionsFactory( _postConfigures = postConfigures; } - public TOptions Build() + public TOptions CreateOptions() { var options = _defaultOptionsFactory(); var name = Options.DefaultName; diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 6d8c427e0..f1ec9c55f 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -38,8 +38,8 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); - services.TryAddSingleton(s => s.GetRequiredService().Build()); - services.TryAddSingleton(s => s.GetRequiredService().Build()); + services.TryAddSingleton(s => s.GetRequiredService().CreateOptions()); + services.TryAddSingleton(s => s.GetRequiredService().CreateOptions()); services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddSingleton(); From b40c8a18f98e090255a953f5f4b31ba68c7246f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 2 Dec 2024 16:03:48 +0800 Subject: [PATCH 57/85] Add MqttBufferWriterPool --- .../Internal/MqttBufferWriterPool.cs | 64 +++++++++++++++++++ .../Internal/MqttConnectionHandler.cs | 22 +++++-- .../ServiceCollectionExtensions.cs | 1 + 3 files changed, 80 insertions(+), 7 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs new file mode 100644 index 000000000..59baef5b8 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Formatter; +using MQTTnet.Server; +using System; +using System.Collections.Concurrent; + +namespace MQTTnet.AspNetCore +{ + sealed class MqttBufferWriterPool(MqttServerOptions serverOptions) + { + private readonly MqttServerOptions _serverOptions = serverOptions; + private readonly ConcurrentQueue _queue = new(); + + public RecyclableMqttBufferWriter Rent() + { + if (_queue.TryDequeue(out var bufferWriter)) + { + bufferWriter.Reset(); + } + else + { + var writer = new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax); + bufferWriter = new RecyclableMqttBufferWriter(writer); + } + return bufferWriter; + } + + public void Return(RecyclableMqttBufferWriter bufferWriter) + { + if (bufferWriter.CanRecycle) + { + _queue.Enqueue(bufferWriter); + } + } + + + public sealed class RecyclableMqttBufferWriter(MqttBufferWriter bufferWriter) + { + private long _tickCount = Environment.TickCount64; + private readonly MqttBufferWriter _bufferWriter = bufferWriter; + private static readonly TimeSpan _maxLifeTime = TimeSpan.FromMinutes(1d); + + /// + /// We only recycle the MqttBufferWriter created by channels that are frequently offline. + /// This ensures that the MqttBufferWriter cache hit rate is high and does not cause the problem of too many MqttBufferWriters being pooled when the number of channels is reduced. + /// + /// + public bool CanRecycle => TimeSpan.FromMilliseconds(Environment.TickCount64 - _tickCount) < _maxLifeTime; + + public void Reset() + { + _tickCount = Environment.TickCount64; + } + + public static implicit operator MqttBufferWriter(RecyclableMqttBufferWriter bufferWriterItem) + { + return bufferWriterItem._bufferWriter; + } + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index b52caae63..3f0033cb8 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -16,7 +16,7 @@ namespace MQTTnet.AspNetCore; sealed class MqttConnectionHandler : ConnectionHandler { readonly IMqttNetLogger _logger; - readonly MqttServerOptions _serverOptions; + readonly MqttBufferWriterPool _bufferWriterPool; public bool UseFlag { get; set; } @@ -28,10 +28,10 @@ sealed class MqttConnectionHandler : ConnectionHandler public MqttConnectionHandler( IMqttNetLogger logger, - MqttServerOptions serverOptions) + MqttBufferWriterPool bufferWriterPool) { _logger = logger; - _serverOptions = serverOptions; + _bufferWriterPool = bufferWriterPool; } public override async Task OnConnectedAsync(ConnectionContext connection) @@ -51,9 +51,17 @@ public override async Task OnConnectedAsync(ConnectionContext connection) transferFormatFeature.ActiveFormat = TransferFormat.Binary; } - var bufferWriter = new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax); - var formatter = new MqttPacketFormatterAdapter(bufferWriter); - using var adapter = new MqttServerChannelAdapter(formatter, connection); - await clientHandler(adapter).ConfigureAwait(false); + var bufferWriter = _bufferWriterPool.Rent(); + + try + { + var formatter = new MqttPacketFormatterAdapter(bufferWriter); + using var adapter = new MqttServerChannelAdapter(formatter, connection); + await clientHandler(adapter).ConfigureAwait(false); + } + finally + { + _bufferWriterPool.Return(bufferWriter); + } } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index f1ec9c55f..a028f1b65 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -34,6 +34,7 @@ public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) { services.AddOptions(); services.AddConnections(); + services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); From a2c014fa8b997461a2c4c4b398dce9d2495c00df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 2 Dec 2024 21:29:31 +0800 Subject: [PATCH 58/85] Adapt the RemoteEndPoint property. --- .../Internal/MqttChannel.cs | 20 ++++++++----------- .../Internal/MqttClientChannelAdapter.cs | 3 ++- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index e92b14e17..90a44bd29 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -36,11 +36,7 @@ class MqttChannel : IDisposable public X509Certificate2? ClientCertificate { get; } - // mqtt over websocket - var httpFeature = _connection.Features.Get(); - return httpFeature?.HttpContext?.Connection.ClientCertificate; - } - } + public EndPoint? RemoteEndPoint { get; private set; } public bool IsSecureConnection { get; } @@ -49,14 +45,13 @@ public MqttChannel( MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, MqttPacketInspector? packetInspector = null) - public EndPoint RemoteEndPoint { PacketFormatterAdapter = packetFormatterAdapter; _packetInspector = packetInspector; var httpContextFeature = connection.Features.Get(); var tlsConnectionFeature = connection.Features.Get(); - Endpoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint); + RemoteEndPoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint); IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature); ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature); _serverModeWebSocket = IsServerModeWebSocket(httpContextFeature); @@ -71,18 +66,19 @@ private static bool IsServerModeWebSocket(IHttpContextFeature? _httpContextFeatu } - private static string? GetRemoteEndPoint(IHttpContextFeature? _httpContextFeature, EndPoint? remoteEndPoint) + private static EndPoint? GetRemoteEndPoint(IHttpContextFeature? _httpContextFeature, EndPoint? remoteEndPoint) { if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) { var httpConnection = _httpContextFeature.HttpContext.Connection; var remoteAddress = httpConnection.RemoteIpAddress; - return remoteAddress == null ? null : $"{remoteAddress}:{httpConnection.RemotePort}"; + if (remoteAddress != null) + { + return new IPEndPoint(remoteAddress, httpConnection.RemotePort); + } } - return remoteEndPoint is DnsEndPoint dnsEndPoint - ? $"{dnsEndPoint.Host}:{dnsEndPoint.Port}" - : remoteEndPoint?.ToString(); + return null; } private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index 431ab8ae6..70cf0c132 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -7,6 +7,7 @@ using MQTTnet.Formatter; using MQTTnet.Packets; using System; +using System.Net; using System.Runtime.CompilerServices; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -41,7 +42,7 @@ public MqttClientChannelAdapter( public X509Certificate2? ClientCertificate => GetChannel().ClientCertificate; - public string? Endpoint => GetChannel().Endpoint; + public EndPoint? RemoteEndPoint => GetChannel().RemoteEndPoint; public bool IsSecureConnection => GetChannel().IsSecureConnection; From dad4faa025a70b0bf6e17dedd605d58974c531d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Tue, 3 Dec 2024 09:09:16 +0800 Subject: [PATCH 59/85] Add MqttBufferWriterPoolOptions --- .../Internal/MqttBufferWriterPool.cs | 43 +++++++++++-------- .../MqttBufferWriterPoolOptions.cs | 18 ++++++++ .../MqttServerBuilderExtensions.cs | 13 ++++++ 3 files changed, 55 insertions(+), 19 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs index 59baef5b8..8cbb05667 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.Extensions.Options; using MQTTnet.Formatter; using MQTTnet.Server; using System; @@ -9,55 +10,59 @@ namespace MQTTnet.AspNetCore { - sealed class MqttBufferWriterPool(MqttServerOptions serverOptions) + sealed class MqttBufferWriterPool { - private readonly MqttServerOptions _serverOptions = serverOptions; - private readonly ConcurrentQueue _queue = new(); + private readonly MqttServerOptions _serverOptions; + private readonly IOptionsMonitor _poolOptions; + private readonly ConcurrentQueue _bufferWriterQueue = new(); - public RecyclableMqttBufferWriter Rent() + public MqttBufferWriterPool( + MqttServerOptions serverOptions, + IOptionsMonitor poolOptions) { - if (_queue.TryDequeue(out var bufferWriter)) + _serverOptions = serverOptions; + _poolOptions = poolOptions; + } + + public ResettableMqttBufferWriter Rent() + { + if (_bufferWriterQueue.TryDequeue(out var bufferWriter)) { bufferWriter.Reset(); } else { var writer = new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax); - bufferWriter = new RecyclableMqttBufferWriter(writer); + bufferWriter = new ResettableMqttBufferWriter(writer); } return bufferWriter; } - public void Return(RecyclableMqttBufferWriter bufferWriter) + public void Return(ResettableMqttBufferWriter bufferWriter) { - if (bufferWriter.CanRecycle) + var options = _poolOptions.CurrentValue; + if (options.Enable && bufferWriter.LifeTime < options.MaxLifeTime) { - _queue.Enqueue(bufferWriter); + _bufferWriterQueue.Enqueue(bufferWriter); } } - public sealed class RecyclableMqttBufferWriter(MqttBufferWriter bufferWriter) + public sealed class ResettableMqttBufferWriter(MqttBufferWriter bufferWriter) { private long _tickCount = Environment.TickCount64; private readonly MqttBufferWriter _bufferWriter = bufferWriter; - private static readonly TimeSpan _maxLifeTime = TimeSpan.FromMinutes(1d); - /// - /// We only recycle the MqttBufferWriter created by channels that are frequently offline. - /// This ensures that the MqttBufferWriter cache hit rate is high and does not cause the problem of too many MqttBufferWriters being pooled when the number of channels is reduced. - /// - /// - public bool CanRecycle => TimeSpan.FromMilliseconds(Environment.TickCount64 - _tickCount) < _maxLifeTime; + public TimeSpan LifeTime => TimeSpan.FromMilliseconds(Environment.TickCount64 - _tickCount); public void Reset() { _tickCount = Environment.TickCount64; } - public static implicit operator MqttBufferWriter(RecyclableMqttBufferWriter bufferWriterItem) + public static implicit operator MqttBufferWriter(ResettableMqttBufferWriter resettableMqttBufferWriter) { - return bufferWriterItem._bufferWriter; + return resettableMqttBufferWriter._bufferWriter; } } } diff --git a/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs new file mode 100644 index 000000000..a455d35de --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace MQTTnet.AspNetCore +{ + public sealed class MqttBufferWriterPoolOptions + { + public bool Enable { get; set; } = true; + + /// + /// When the lifecycle of the channel associated with MqttBufferWriter is less than this value, MqttBufferWriter is pooled. + /// + public TimeSpan MaxLifeTime { get; set; } = TimeSpan.FromMinutes(1d); + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs index 82a0dd00b..8dbce7773 100644 --- a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -5,6 +5,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; +using MQTTnet.Formatter; using MQTTnet.Server; using System; using System.Diagnostics.CodeAnalysis; @@ -63,6 +64,18 @@ public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder return builder; } + /// + /// Configure the pool of + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttBufferWriterPool(this IMqttServerBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder; + } + /// /// Configure the socket of mqtt listener /// From 6f60eaea08feadb0b7965d3fb0756ae251d67077 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Tue, 3 Dec 2024 15:57:32 +0800 Subject: [PATCH 60/85] Add more conditions to the pool of MqttBufferWriterPoolOptions. --- .../Internal/MqttBufferWriterPool.cs | 41 +++++++++++++++---- .../MqttBufferWriterPoolOptions.cs | 9 +++- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs index 8cbb05667..438cae44d 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs @@ -14,7 +14,7 @@ sealed class MqttBufferWriterPool { private readonly MqttServerOptions _serverOptions; private readonly IOptionsMonitor _poolOptions; - private readonly ConcurrentQueue _bufferWriterQueue = new(); + private readonly ConcurrentQueue _bufferWriterQueue = new(); public MqttBufferWriterPool( MqttServerOptions serverOptions, @@ -24,7 +24,7 @@ public MqttBufferWriterPool( _poolOptions = poolOptions; } - public ResettableMqttBufferWriter Rent() + public ChannelMqttBufferWriter Rent() { if (_bufferWriterQueue.TryDequeue(out var bufferWriter)) { @@ -33,26 +33,49 @@ public ResettableMqttBufferWriter Rent() else { var writer = new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax); - bufferWriter = new ResettableMqttBufferWriter(writer); + bufferWriter = new ChannelMqttBufferWriter(writer); } return bufferWriter; } - public void Return(ResettableMqttBufferWriter bufferWriter) + public void Return(ChannelMqttBufferWriter bufferWriter) { - var options = _poolOptions.CurrentValue; - if (options.Enable && bufferWriter.LifeTime < options.MaxLifeTime) + if (CanReturn(bufferWriter)) { _bufferWriterQueue.Enqueue(bufferWriter); } } + private bool CanReturn(ChannelMqttBufferWriter bufferWriter) + { + var options = _poolOptions.CurrentValue; + if (!options.Enable) + { + return false; + } + + if (bufferWriter.LifeTime < options.PoolingItemMaxLifeTime) + { + return true; + } + + if (options.PoolingLargeBufferSizeItem && + bufferWriter.BufferSize > _serverOptions.WriterBufferSize) + { + return true; + } + + return false; + } + + - public sealed class ResettableMqttBufferWriter(MqttBufferWriter bufferWriter) + public sealed class ChannelMqttBufferWriter(MqttBufferWriter bufferWriter) { private long _tickCount = Environment.TickCount64; private readonly MqttBufferWriter _bufferWriter = bufferWriter; + public int BufferSize => _bufferWriter.GetBuffer().Length; public TimeSpan LifeTime => TimeSpan.FromMilliseconds(Environment.TickCount64 - _tickCount); public void Reset() @@ -60,9 +83,9 @@ public void Reset() _tickCount = Environment.TickCount64; } - public static implicit operator MqttBufferWriter(ResettableMqttBufferWriter resettableMqttBufferWriter) + public static implicit operator MqttBufferWriter(ChannelMqttBufferWriter channelMqttBufferWriter) { - return resettableMqttBufferWriter._bufferWriter; + return channelMqttBufferWriter._bufferWriter; } } } diff --git a/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs index a455d35de..8f725d943 100644 --- a/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs +++ b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs @@ -11,8 +11,13 @@ public sealed class MqttBufferWriterPoolOptions public bool Enable { get; set; } = true; /// - /// When the lifecycle of the channel associated with MqttBufferWriter is less than this value, MqttBufferWriter is pooled. + /// When the life time of the MqttBufferWriter is less than this value, MqttBufferWriter is pooled. /// - public TimeSpan MaxLifeTime { get; set; } = TimeSpan.FromMinutes(1d); + public TimeSpan PoolingItemMaxLifeTime { get; set; } = TimeSpan.FromMinutes(1d); + + /// + /// Whether to pool MqttBufferWriter with BufferSize greater than the default buffer size. + /// + public bool PoolingLargeBufferSizeItem { get; set; } = true; } } From 54f0b4dd758e6ab6dbd1acda99e06d4d8c9753ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 4 Dec 2024 08:56:50 +0800 Subject: [PATCH 61/85] MqttBufferWriterPool: Implementing the IReadOnlyCollection interface. --- .../Internal/MqttBufferWriterPool.cs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs index 438cae44d..398297d45 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs @@ -6,16 +6,22 @@ using MQTTnet.Formatter; using MQTTnet.Server; using System; +using System.Collections; using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; namespace MQTTnet.AspNetCore { - sealed class MqttBufferWriterPool + [DebuggerDisplay("Count = {Count}")] + sealed class MqttBufferWriterPool : IReadOnlyCollection { private readonly MqttServerOptions _serverOptions; private readonly IOptionsMonitor _poolOptions; private readonly ConcurrentQueue _bufferWriterQueue = new(); + public int Count => _bufferWriterQueue.Count; + public MqttBufferWriterPool( MqttServerOptions serverOptions, IOptionsMonitor poolOptions) @@ -68,8 +74,18 @@ private bool CanReturn(ChannelMqttBufferWriter bufferWriter) return false; } + public IEnumerator GetEnumerator() + { + return _bufferWriterQueue.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _bufferWriterQueue.GetEnumerator(); + } + [DebuggerDisplay("BufferSize = {BufferSize}, LifeTime = {LifeTime}")] public sealed class ChannelMqttBufferWriter(MqttBufferWriter bufferWriter) { private long _tickCount = Environment.TickCount64; From ee4de820aff8861acaefbe3860f4124f00919394 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 4 Dec 2024 11:15:43 +0800 Subject: [PATCH 62/85] ConfigureAwait(false) --- .../Internal/AspNetCoreMqttHostedServer.cs | 4 ++-- .../Internal/ClientConnectionContext.Tcp.cs | 2 +- .../Internal/ClientConnectionContext.cs | 2 +- Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs | 4 ++-- .../Internal/MqttClientChannelAdapter.cs | 14 +++++++------- .../Internal/MqttConnectionMiddleware.cs | 4 ++-- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs index 519721a4f..58b3bc387 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -35,8 +35,8 @@ void OnApplicationStarted() protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - await _applicationStartedTask.WaitAsync(stoppingToken); - await _aspNetCoreMqttServer.StartAsync(stoppingToken); + await _applicationStartedTask.WaitAsync(stoppingToken).ConfigureAwait(false); + await _aspNetCoreMqttServer.StartAsync(stoppingToken).ConfigureAwait(false) ; } public override Task StopAsync(CancellationToken cancellationToken) diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs index 2a67dcb08..4c03d0c69 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs @@ -133,7 +133,7 @@ public static async Task CreateAsync(MqttClientTcpOptio } catch (Exception) { - await sslStream.DisposeAsync(); + await sslStream.DisposeAsync().ConfigureAwait(false); throw; } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs index b6b34b9f3..1e6b63d8b 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -49,7 +49,7 @@ public ClientConnectionContext(Stream stream) public override async ValueTask DisposeAsync() { - await _stream.DisposeAsync(); + await _stream.DisposeAsync().ConfigureAwait(false); _connectionCloseSource.Cancel(); _connectionCloseSource.Dispose(); } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index 90a44bd29..e603e4d56 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -98,8 +98,8 @@ private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, IT public async Task DisconnectAsync() { - await _input.CompleteAsync(); - await _output.CompleteAsync(); + await _input.CompleteAsync().ConfigureAwait(false); + await _output.CompleteAsync().ConfigureAwait(false); } public virtual void Dispose() diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index 70cf0c132..e817eb4bb 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -51,16 +51,16 @@ public async Task ConnectAsync(CancellationToken cancellationToken) { _connection = _channelOptions switch { - MqttClientTcpOptions tcpOptions => await ClientConnectionContext.CreateAsync(tcpOptions, cancellationToken), - MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken), + MqttClientTcpOptions tcpOptions => await ClientConnectionContext.CreateAsync(tcpOptions, cancellationToken).ConfigureAwait(false), + MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false), _ => throw new NotSupportedException(), }; _channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector); } - public async Task DisconnectAsync(CancellationToken cancellationToken) + public Task DisconnectAsync(CancellationToken cancellationToken) { - await GetChannel().DisconnectAsync(); + return GetChannel().DisconnectAsync(); } public async ValueTask DisposeAsync() @@ -74,19 +74,19 @@ public async ValueTask DisposeAsync() if (_channel != null) { - await _channel.DisconnectAsync(); + await _channel.DisconnectAsync().ConfigureAwait(false); _channel.Dispose(); } if (_connection != null) { - await _connection.DisposeAsync(); + await _connection.DisposeAsync().ConfigureAwait(false); } } public void Dispose() { - DisposeAsync().GetAwaiter().GetResult(); + DisposeAsync().ConfigureAwait(false).GetAwaiter().GetResult(); } public Task ReceivePacketAsync(CancellationToken cancellationToken) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs index 4a57d6ab2..2ed66d3ef 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs @@ -37,11 +37,11 @@ public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connect if (protocols == MqttProtocols.Mqtt) { - await _connectionHandler.OnConnectedAsync(connection); + await _connectionHandler.OnConnectedAsync(connection).ConfigureAwait(false); } else if (protocols == MqttProtocols.WebSocket) { - await next(connection); + await next(connection).ConfigureAwait(false); } else { From b708344f7d3cb4661bf1525a5991167b215072e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 4 Dec 2024 13:44:09 +0800 Subject: [PATCH 63/85] MqttChannel: adapt AllowPacketFragmentation option. --- .../AspNetCoreMqttClientAdapterFactory.cs | 2 +- .../Internal/MqttChannel.cs | 36 +++++++++++-------- .../Internal/MqttClientChannelAdapter.cs | 7 ++-- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index 2f5041144..1cd50bec3 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -16,7 +16,7 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa ArgumentNullException.ThrowIfNull(nameof(options)); var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax); var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter); - return new MqttClientChannelAdapter(formatter, options.ChannelOptions, packetInspector); + return new MqttClientChannelAdapter(formatter, options.ChannelOptions, packetInspector, options.AllowPacketFragmentation); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index e603e4d56..b3d0b360b 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -26,7 +26,7 @@ class MqttChannel : IDisposable readonly PipeReader _input; readonly PipeWriter _output; readonly MqttPacketInspector? _packetInspector; - readonly bool _serverModeWebSocket; + readonly bool _allowPacketFragmentation; public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } @@ -44,7 +44,8 @@ class MqttChannel : IDisposable public MqttChannel( MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, - MqttPacketInspector? packetInspector = null) + MqttPacketInspector? packetInspector = null, + bool? allowPacketFragmentation = null) { PacketFormatterAdapter = packetFormatterAdapter; _packetInspector = packetInspector; @@ -54,15 +55,22 @@ public MqttChannel( RemoteEndPoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint); IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature); ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature); - _serverModeWebSocket = IsServerModeWebSocket(httpContextFeature); _input = connection.Transport.Input; _output = connection.Transport.Output; + + _allowPacketFragmentation = allowPacketFragmentation == null + ? AllowPacketFragmentation(httpContextFeature) + : allowPacketFragmentation.Value; } - private static bool IsServerModeWebSocket(IHttpContextFeature? _httpContextFeature) + private static bool AllowPacketFragmentation(IHttpContextFeature? _httpContextFeature) { - return _httpContextFeature != null && _httpContextFeature.HttpContext != null && _httpContextFeature.HttpContext.WebSockets.IsWebSocketRequest; + var serverModeWebSocket = _httpContextFeature != null && + _httpContextFeature.HttpContext != null && + _httpContextFeature.HttpContext.WebSockets.IsWebSocketRequest; + + return !serverModeWebSocket; } @@ -197,20 +205,20 @@ public async Task SendPacketAsync(MqttPacket packet, CancellationToken cancellat // https://github.com/dotnet/runtime/blob/e31ddfdc4f574b26231233dc10c9a9c402f40590/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeWriter.cs#L279 await _output.WriteAsync(buffer.Packet, cancellationToken).ConfigureAwait(false); } - else if (_serverModeWebSocket) // server channel, and client is MQTT over WebSocket - { - // Make sure the MQTT packet is in a WebSocket frame to be compatible with JavaScript WebSocket - WritePacketBuffer(_output, buffer); - await _output.FlushAsync(cancellationToken).ConfigureAwait(false); - } - else + else if (_allowPacketFragmentation) { await _output.WriteAsync(buffer.Packet, cancellationToken).ConfigureAwait(false); - foreach (var block in buffer.Payload) + foreach (var memory in buffer.Payload) { - await _output.WriteAsync(block, cancellationToken).ConfigureAwait(false); + await _output.WriteAsync(memory, cancellationToken).ConfigureAwait(false); } } + else + { + // Make sure the MQTT packet is in a WebSocket frame to be compatible with JavaScript WebSocket + WritePacketBuffer(_output, buffer); + await _output.FlushAsync(cancellationToken).ConfigureAwait(false); + } BytesSent += buffer.Length; } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index e817eb4bb..671043d8f 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -23,15 +23,18 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable private readonly MqttPacketFormatterAdapter _packetFormatterAdapter; private readonly IMqttClientChannelOptions _channelOptions; private readonly MqttPacketInspector? _packetInspector; + private readonly bool? _allowPacketFragmentation; public MqttClientChannelAdapter( MqttPacketFormatterAdapter packetFormatterAdapter, IMqttClientChannelOptions channelOptions, - MqttPacketInspector? packetInspector) + MqttPacketInspector? packetInspector, + bool? allowPacketFragmentation) { _packetFormatterAdapter = packetFormatterAdapter; _channelOptions = channelOptions; _packetInspector = packetInspector; + _allowPacketFragmentation = allowPacketFragmentation; } public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter; @@ -55,7 +58,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken) MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false), _ => throw new NotSupportedException(), }; - _channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector); + _channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector, _allowPacketFragmentation); } public Task DisconnectAsync(CancellationToken cancellationToken) From 13e819052b8d20b39f3154b18d43388f96812ddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 4 Dec 2024 13:59:19 +0800 Subject: [PATCH 64/85] Fixed the issue that GetRemoteEndPoint did not use the remoteEndPoint parameter. --- Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index b3d0b360b..b709838c0 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -86,7 +86,7 @@ private static bool AllowPacketFragmentation(IHttpContextFeature? _httpContextFe } } - return null; + return remoteEndPoint; } private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature) From 4bfd77dc26196df0d425a72d8855a02f4db129ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 4 Dec 2024 15:04:48 +0800 Subject: [PATCH 65/85] Add some unit tests. --- .../ASP/MqttBufferWriterPoolTest.cs | 49 +++++++++++ Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs | 31 +++++++ .../ASP/MqttClientBuilderTest.cs | 50 ++++++++++++ ...ttPacketFormatterAdapterExtensionsTest.cs} | 2 +- .../ASP/MqttServerBuilderTest.cs | 81 +++++++++++++++++++ 5 files changed, 212 insertions(+), 1 deletion(-) create mode 100644 Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs create mode 100644 Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs create mode 100644 Source/MQTTnet.Tests/ASP/MqttClientBuilderTest.cs rename Source/MQTTnet.Tests/ASP/{ReaderExtensionsTest.cs => MqttPacketFormatterAdapterExtensionsTest.cs} (96%) create mode 100644 Source/MQTTnet.Tests/ASP/MqttServerBuilderTest.cs diff --git a/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs b/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs new file mode 100644 index 000000000..f4491be13 --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs @@ -0,0 +1,49 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Formatter; +using MQTTnet.Server; +using System; +using System.Threading.Tasks; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttBufferWriterPoolTest + { + [TestMethod] + public async Task RentReturnTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().ConfigureMqttBufferWriterPool(p => + { + p.PoolingItemMaxLifeTime = TimeSpan.FromSeconds(1d); + }); + + var s = services.BuildServiceProvider(); + var pool = s.GetRequiredService(); + var options = s.GetRequiredService(); + + var bufferWriter = pool.Rent(); + Assert.AreEqual(0, pool.Count); + + pool.Return(bufferWriter); + Assert.AreEqual(1, pool.Count); + + bufferWriter = pool.Rent(); + Assert.AreEqual(0, pool.Count); + + await Task.Delay(TimeSpan.FromSeconds(2d)); + + pool.Return(bufferWriter); + Assert.AreEqual(0, pool.Count); + + MqttBufferWriter writer = bufferWriter; + writer.Seek(options.WriterBufferSize + 1); + Assert.IsTrue(bufferWriter.BufferSize > options.WriterBufferSize); + + pool.Return(bufferWriter); + Assert.AreEqual(1, pool.Count); + } + } +} diff --git a/Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs b/Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs new file mode 100644 index 000000000..a8e4c00a5 --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs @@ -0,0 +1,31 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Diagnostics.Logger; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttBuilderTest + { + [TestMethod] + public void UseMqttNetNullLoggerTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().UseMqttNetNullLogger(); + var s = services.BuildServiceProvider(); + var logger = s.GetRequiredService(); + Assert.IsInstanceOfType(logger); + } + + [TestMethod] + public void UseAspNetCoreMqttNetLoggerTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().UseAspNetCoreMqttNetLogger(); + var s = services.BuildServiceProvider(); + var logger = s.GetRequiredService(); + Assert.IsInstanceOfType(logger); + } + } +} diff --git a/Source/MQTTnet.Tests/ASP/MqttClientBuilderTest.cs b/Source/MQTTnet.Tests/ASP/MqttClientBuilderTest.cs new file mode 100644 index 000000000..5258acd3e --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttClientBuilderTest.cs @@ -0,0 +1,50 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Adapter; +using MQTTnet.AspNetCore; +using MQTTnet.Implementations; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttClientBuilderTest + { + [TestMethod] + public void AddMqttClientTest() + { + var services = new ServiceCollection(); + services.AddMqttClient(); + var s = services.BuildServiceProvider(); + + var mqttClientFactory1 = s.GetRequiredService(); + var mqttClientFactory2 = s.GetRequiredService(); + Assert.IsTrue(ReferenceEquals(mqttClientFactory2, mqttClientFactory2)); + + Assert.IsInstanceOfType(mqttClientFactory1); + Assert.IsInstanceOfType(mqttClientFactory1); + } + + [TestMethod] + public void UseMQTTnetMqttClientAdapterFactoryTest() + { + var services = new ServiceCollection(); + services.AddMqttClient().UseMQTTnetMqttClientAdapterFactory(); + var s = services.BuildServiceProvider(); + var adapterFactory = s.GetRequiredService(); + + Assert.IsInstanceOfType(adapterFactory); + } + + + [TestMethod] + public void UseAspNetCoreMqttClientAdapterFactoryTest() + { + var services = new ServiceCollection(); + services.AddMqttClient().UseAspNetCoreMqttClientAdapterFactory(); + var s = services.BuildServiceProvider(); + var adapterFactory = s.GetRequiredService(); + + Assert.IsInstanceOfType(adapterFactory); + } + } +} diff --git a/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs b/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs similarity index 96% rename from Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs rename to Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs index 6c9cac8f8..7517d559d 100644 --- a/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Tests.ASP; [TestClass] -public sealed class ReaderExtensionsTest +public sealed class MqttPacketFormatterAdapterExtensionsTest { [TestMethod] public void TestTryDeserialize() diff --git a/Source/MQTTnet.Tests/ASP/MqttServerBuilderTest.cs b/Source/MQTTnet.Tests/ASP/MqttServerBuilderTest.cs new file mode 100644 index 000000000..cd162bab9 --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttServerBuilderTest.cs @@ -0,0 +1,81 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Server; +using MQTTnet.Server.Internal.Adapter; +using System.Collections.Generic; +using System.Linq; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttServerBuilderTest + { + [TestMethod] + public void AddMqttServerTest() + { + var services = new ServiceCollection(); + services.AddMqttServer(); + var s = services.BuildServiceProvider(); + + var mqttServer1 = s.GetRequiredService(); + var mqttServer2 = s.GetRequiredService(); + Assert.IsInstanceOfType(mqttServer1); + Assert.AreEqual(mqttServer1, mqttServer2); + } + + [TestMethod] + public void ConfigureMqttServerTest() + { + const int TcpKeepAliveTime1 = 19; + const int TcpKeepAliveTime2 = 20; + + var services = new ServiceCollection(); + services.AddMqttServer().ConfigureMqttServer( + b => b.WithTcpKeepAliveTime(TcpKeepAliveTime1), + o => + { + Assert.AreEqual(TcpKeepAliveTime1, o.DefaultEndpointOptions.TcpKeepAliveTime); + o.DefaultEndpointOptions.TcpKeepAliveTime = TcpKeepAliveTime2; + }); + + var s = services.BuildServiceProvider(); + var options = s.GetRequiredService(); + Assert.AreEqual(TcpKeepAliveTime2, options.DefaultEndpointOptions.TcpKeepAliveTime); + } + + [TestMethod] + public void ConfigureMqttServerStopTest() + { + const string ReasonString1 = "ReasonString1"; + const string ReasonString2 = "ReasonString2"; + + var services = new ServiceCollection(); + services.AddMqttServer().ConfigureMqttServerStop( + b => b.WithDefaultClientDisconnectOptions(c => c.WithReasonString(ReasonString1)), + o => + { + Assert.AreEqual(ReasonString1, o.DefaultClientDisconnectOptions.ReasonString); + o.DefaultClientDisconnectOptions.ReasonString = ReasonString2; + }); + + var s = services.BuildServiceProvider(); + var options = s.GetRequiredService(); + Assert.AreEqual(ReasonString2, options.DefaultClientDisconnectOptions.ReasonString); + } + + [TestMethod] + public void AddMqttServerAdapterTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().AddMqttServerAdapter(); + services.AddMqttServer().AddMqttServerAdapter(); + + var s = services.BuildServiceProvider(); + var adapters = s.GetRequiredService>().ToArray(); + Assert.AreEqual(2, adapters.Length); + Assert.IsInstanceOfType(adapters[0]); + Assert.IsInstanceOfType(adapters[1]); + } + } +} From 9d9dd4468690204ff78483153ac0220f590db3b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Wed, 4 Dec 2024 15:20:24 +0800 Subject: [PATCH 66/85] Add more unit test. --- MQTTnet.sln | 4 +--- .../Internal/MqttConnectionMiddleware.cs | 2 +- .../ASP/MqttConnectionMiddlewareTest.cs | 24 +++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) create mode 100644 Source/MQTTnet.Tests/ASP/MqttConnectionMiddlewareTest.cs diff --git a/MQTTnet.sln b/MQTTnet.sln index 15a0d2323..984482f18 100644 --- a/MQTTnet.sln +++ b/MQTTnet.sln @@ -7,11 +7,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet", "Source\MQTTnet\M EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{B3F60ECB-45BA-4C66-8903-8BB89CA67998}" ProjectSection(SolutionItems) = preProject + .github\workflows\ci.yml = .github\workflows\ci.yml CODE-OF-CONDUCT.md = CODE-OF-CONDUCT.md LICENSE = LICENSE README.md = README.md Source\ReleaseNotes.md = Source\ReleaseNotes.md - .github\workflows\ci.yml = .github\workflows\ci.yml EndProjectSection EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet.AspNetCore", "Source\MQTTnet.AspnetCore\MQTTnet.AspNetCore.csproj", "{F10C4060-F7EE-4A83-919F-FF723E72F94A}" @@ -85,6 +85,4 @@ Global GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {07536672-5CBC-4BE3-ACE0-708A431A7894} EndGlobalSection - GlobalSection(NestedProjects) = preSolution - EndGlobalSection EndGlobal diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs index 2ed66d3ef..69254cb7e 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs @@ -49,7 +49,7 @@ public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connect } } - private static bool IsMqttRequest(ReadOnlySequence buffer) + public static bool IsMqttRequest(ReadOnlySequence buffer) { if (!buffer.IsEmpty) { diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionMiddlewareTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionMiddlewareTest.cs new file mode 100644 index 000000000..a25faf88c --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionMiddlewareTest.cs @@ -0,0 +1,24 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using System; +using System.Buffers; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttConnectionMiddlewareTest + { + [TestMethod] + public void IsMqttRequestTest() + { + var mqttv31Request = Convert.FromHexString("102800044d51545404c0003c0008636c69656e7469640008757365726e616d650008706173736f777264"); + var mqttv50Request = Convert.FromHexString("102900044d51545405c0003c000008636c69656e7469640008757365726e616d650008706173736f777264"); + + var isMqttv31 = MqttConnectionMiddleware.IsMqttRequest(new ReadOnlySequence(mqttv31Request)); + var isMqttv50 = MqttConnectionMiddleware.IsMqttRequest(new ReadOnlySequence(mqttv50Request)); + + Assert.IsTrue(isMqttv31); + Assert.IsTrue(isMqttv50); + } + } +} From 8b7c41109a6b8791782917dc0aad3434c0d439bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 5 Dec 2024 00:07:31 +0800 Subject: [PATCH 67/85] MqttBufferWriterPoolOptions: Renaming properties. --- .../Internal/MqttBufferWriterPool.cs | 16 ++++++---------- .../MqttBufferWriterPoolOptions.cs | 11 +++++------ .../ASP/MqttBufferWriterPoolTest.cs | 8 ++++---- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs index 398297d45..c9cf437f5 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs @@ -44,29 +44,25 @@ public ChannelMqttBufferWriter Rent() return bufferWriter; } - public void Return(ChannelMqttBufferWriter bufferWriter) + public bool Return(ChannelMqttBufferWriter bufferWriter) { if (CanReturn(bufferWriter)) { _bufferWriterQueue.Enqueue(bufferWriter); + return true; } + return false; } private bool CanReturn(ChannelMqttBufferWriter bufferWriter) { var options = _poolOptions.CurrentValue; - if (!options.Enable) - { - return false; - } - - if (bufferWriter.LifeTime < options.PoolingItemMaxLifeTime) + if (bufferWriter.Lifetime < options.MaxLifetime) { return true; } - if (options.PoolingLargeBufferSizeItem && - bufferWriter.BufferSize > _serverOptions.WriterBufferSize) + if (options.LargeBufferSizeEnabled && bufferWriter.BufferSize > _serverOptions.WriterBufferSize) { return true; } @@ -92,7 +88,7 @@ public sealed class ChannelMqttBufferWriter(MqttBufferWriter bufferWriter) private readonly MqttBufferWriter _bufferWriter = bufferWriter; public int BufferSize => _bufferWriter.GetBuffer().Length; - public TimeSpan LifeTime => TimeSpan.FromMilliseconds(Environment.TickCount64 - _tickCount); + public TimeSpan Lifetime => TimeSpan.FromMilliseconds(Environment.TickCount64 - _tickCount); public void Reset() { diff --git a/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs index 8f725d943..9927404a6 100644 --- a/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs +++ b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs @@ -2,22 +2,21 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Formatter; using System; namespace MQTTnet.AspNetCore { public sealed class MqttBufferWriterPoolOptions { - public bool Enable { get; set; } = true; - /// - /// When the life time of the MqttBufferWriter is less than this value, MqttBufferWriter is pooled. + /// When the lifetime of the is less than this value, is pooled. /// - public TimeSpan PoolingItemMaxLifeTime { get; set; } = TimeSpan.FromMinutes(1d); + public TimeSpan MaxLifetime { get; set; } = TimeSpan.FromMinutes(1d); /// - /// Whether to pool MqttBufferWriter with BufferSize greater than the default buffer size. + /// Whether to pool with BufferSize greater than the default buffer size. /// - public bool PoolingLargeBufferSizeItem { get; set; } = true; + public bool LargeBufferSizeEnabled { get; set; } = true; } } diff --git a/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs b/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs index f4491be13..8adb9e771 100644 --- a/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs @@ -17,7 +17,7 @@ public async Task RentReturnTest() var services = new ServiceCollection(); services.AddMqttServer().ConfigureMqttBufferWriterPool(p => { - p.PoolingItemMaxLifeTime = TimeSpan.FromSeconds(1d); + p.MaxLifetime = TimeSpan.FromSeconds(1d); }); var s = services.BuildServiceProvider(); @@ -27,7 +27,7 @@ public async Task RentReturnTest() var bufferWriter = pool.Rent(); Assert.AreEqual(0, pool.Count); - pool.Return(bufferWriter); + Assert.IsTrue(pool.Return(bufferWriter)); Assert.AreEqual(1, pool.Count); bufferWriter = pool.Rent(); @@ -35,14 +35,14 @@ public async Task RentReturnTest() await Task.Delay(TimeSpan.FromSeconds(2d)); - pool.Return(bufferWriter); + Assert.IsFalse(pool.Return(bufferWriter)); Assert.AreEqual(0, pool.Count); MqttBufferWriter writer = bufferWriter; writer.Seek(options.WriterBufferSize + 1); Assert.IsTrue(bufferWriter.BufferSize > options.WriterBufferSize); - pool.Return(bufferWriter); + Assert.IsTrue(pool.Return(bufferWriter)); Assert.AreEqual(1, pool.Count); } } From 8a3624da7ae5346c28b3e3c2ec994070aa9387d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Thu, 5 Dec 2024 21:34:53 +0800 Subject: [PATCH 68/85] Rename and update benchmark. --- .../MessageProcessingAspNetCoreBenchmark.cs | 57 +++++++++++++ ...rocessingMqttConnectionContextBenchmark.cs | 85 ------------------- 2 files changed, 57 insertions(+), 85 deletions(-) create mode 100644 Source/MQTTnet.Benchmarks/MessageProcessingAspNetCoreBenchmark.cs delete mode 100644 Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingAspNetCoreBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingAspNetCoreBenchmark.cs new file mode 100644 index 000000000..6d8e85d4d --- /dev/null +++ b/Source/MQTTnet.Benchmarks/MessageProcessingAspNetCoreBenchmark.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using MQTTnet.AspNetCore; +using System.Threading.Tasks; + +namespace MQTTnet.Benchmarks; + +[SimpleJob(RuntimeMoniker.Net80)] +[RPlotExporter] +[RankColumn] +[MemoryDiagnoser] +public class MessageProcessingAspNetCoreBenchmark : BaseBenchmark +{ + IMqttClient _mqttClient; + string _payload = string.Empty; + + [Params(1 * 1024, 4 * 1024, 8 * 1024)] + public int PayloadSize { get; set; } + + [Benchmark] + public async Task Send_1000_Messages_AspNetCore() + { + for (var i = 0; i < 1000; i++) + { + await _mqttClient.PublishStringAsync("A", _payload); + } + } + + [GlobalSetup] + public async Task Setup() + { + var builder = WebApplication.CreateBuilder(); + + builder.Services.AddMqttServer(s => s.WithDefaultEndpoint()); + builder.Services.AddMqttClient(); + builder.WebHost.UseKestrel(k => k.ListenMqtt()); + + var app = builder.Build(); + await app.StartAsync(); + + _mqttClient = app.Services.GetRequiredService().CreateMqttClient(); + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost") + .Build(); + + await _mqttClient.ConnectAsync(clientOptions); + + _payload = string.Empty.PadLeft(PayloadSize, '0'); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs deleted file mode 100644 index 66aa12581..000000000 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ /dev/null @@ -1,85 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Jobs; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Hosting; -using Microsoft.Extensions.DependencyInjection; -using MQTTnet.AspNetCore; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Server.Internal.Adapter; -using System.Threading.Tasks; - -namespace MQTTnet.Benchmarks -{ - [SimpleJob(RuntimeMoniker.Net80)] - [MemoryDiagnoser] - public class MessageProcessingMqttConnectionContextBenchmark : BaseBenchmark - { - WebApplication _app; - IMqttClient _aspNetCoreMqttClient; - IMqttClient _mqttNetMqttClient; - MqttApplicationMessage _message; - - [Params(1 * 1024, 8 * 1024, 64 * 1024)] - public int PayloadSize { get; set; } - - - [GlobalSetup] - public async Task Setup() - { - var builder = WebApplication.CreateBuilder(); - - builder.Services.AddMqttServer(s => s.WithDefaultEndpoint()).AddMqttServerAdapter().UseMqttNetNullLogger(); - builder.Services.AddMqttClient(); - builder.WebHost.UseKestrel(o => - { - o.ListenAnyIP(1884, l => l.UseMqtt(MqttProtocols.Mqtt)); - }); - - _app = builder.Build(); - await _app.StartAsync(); - - _message = new MqttApplicationMessageBuilder() - .WithTopic("A") - .WithPayload(new byte[PayloadSize]) - .Build(); - - _aspNetCoreMqttClient = _app.Services.GetRequiredService().CreateMqttClient(); - var clientOptions = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://localhost:1884").Build(); - await _aspNetCoreMqttClient.ConnectAsync(clientOptions); - - clientOptions = new MqttClientOptionsBuilder().WithConnectionUri("mqtt://localhost:1883").Build(); - _mqttNetMqttClient = new MqttClientFactory().CreateMqttClient(MqttNetNullLogger.Instance); - await _mqttNetMqttClient.ConnectAsync(clientOptions); - } - - [GlobalCleanup] - public async Task Cleanup() - { - await _aspNetCoreMqttClient.DisconnectAsync(); - _aspNetCoreMqttClient.Dispose(); - await _app.StopAsync(); - } - - [Benchmark(Baseline = true)] - public async Task AspNetCore_Send_1000_Messages() - { - for (var i = 0; i < 1000; i++) - { - await _aspNetCoreMqttClient.PublishAsync(_message); - } - } - - [Benchmark] - public async Task MQTTnet_Send_1000_Messages() - { - for (var i = 0; i < 1000; i++) - { - await _mqttNetMqttClient.PublishAsync(_message); - } - } - } -} From 06f03cb8b4a9a91e527e35cf6c23ce39a7b13fbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 6 Dec 2024 14:23:39 +0800 Subject: [PATCH 69/85] Add AspNetCoreTestEnvironment to test environments. --- .../Internal/AspNetCoreMqttNetLogger.cs | 14 +- .../Internal/MqttChannel.cs | 2 +- .../MqttPacketFormatterAdapterExtensions.cs | 7 + .../ReaderExtensionsBenchmark.cs | 2 +- ...qttPacketFormatterAdapterExtensionsTest.cs | 8 +- Source/MQTTnet.Tests/BaseTestClass.cs | 24 ++- .../LowLevelMqttClient_Tests.cs | 15 +- .../MqttClient/MqttClient_Connection_Tests.cs | 15 +- .../Clients/MqttClient/MqttClient_Tests.cs | 78 +++++--- .../Diagnostics/PacketInspection_Tests.cs | 3 +- Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs | 21 ++- Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs | 45 +++-- Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs | 31 ++-- .../Mockups/AspNetCoreTestEnvironment.cs | 168 ++++++++++++++++++ .../MQTTnet.Tests/Mockups/TestEnvironment.cs | 23 ++- .../Mockups/TestEnvironmentCollection.cs | 39 ++++ Source/MQTTnet.Tests/RoundtripTime_Tests.cs | 6 +- .../Server/Assigned_Client_ID_Tests.cs | 3 +- .../MQTTnet.Tests/Server/Connection_Tests.cs | 6 +- .../Server/Cross_Version_Tests.cs | 12 +- Source/MQTTnet.Tests/Server/Events_Tests.cs | 72 ++++---- Source/MQTTnet.Tests/Server/General.cs | 89 ++++++---- .../MQTTnet.Tests/Server/Injection_Tests.cs | 9 +- .../MQTTnet.Tests/Server/Keep_Alive_Tests.cs | 3 +- Source/MQTTnet.Tests/Server/Load_Tests.cs | 9 +- Source/MQTTnet.Tests/Server/No_Local_Tests.cs | 3 +- .../MQTTnet.Tests/Server/Publishing_Tests.cs | 15 +- Source/MQTTnet.Tests/Server/QoS_Tests.cs | 12 +- .../Server/Retain_As_Published_Tests.cs | 3 +- .../Server/Retain_Handling_Tests.cs | 3 +- .../Server/Retained_Messages_Tests.cs | 30 ++-- Source/MQTTnet.Tests/Server/Security_Tests.cs | 15 +- .../Server/Server_Reference_Tests.cs | 3 +- Source/MQTTnet.Tests/Server/Session_Tests.cs | 58 +++--- .../Server/Shared_Subscriptions_Tests.cs | 6 +- Source/MQTTnet.Tests/Server/Status_Tests.cs | 18 +- .../MQTTnet.Tests/Server/Subscribe_Tests.cs | 36 ++-- .../Server/Subscription_Identifier_Tests.cs | 9 +- Source/MQTTnet.Tests/Server/Tls_Tests.cs | 2 +- .../MQTTnet.Tests/Server/Topic_Alias_Tests.cs | 6 +- .../MQTTnet.Tests/Server/Unsubscribe_Tests.cs | 7 +- .../Server/User_Properties_Tests.cs | 7 +- .../Wildcard_Subscription_Available_Tests.cs | 6 +- Source/MQTTnet.Tests/Server/Will_Tests.cs | 9 +- 44 files changed, 689 insertions(+), 263 deletions(-) create mode 100644 Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs create mode 100644 Source/MQTTnet.Tests/Mockups/TestEnvironmentCollection.cs diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs index bf7881bed..caa1c07ca 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs @@ -26,10 +26,16 @@ public AspNetCoreMqttNetLogger( public void Publish(MqttNetLogLevel logLevel, string? source, string? message, object[]? parameters, Exception? exception) { - var categoryName = $"{_loggerOptions.CategoryNamePrefix}{source}"; - var logger = _loggerFactory.CreateLogger(categoryName); - var level = _loggerOptions.LogLevelConverter(logLevel); - logger.Log(level, exception, message, parameters ?? []); + try + { + var categoryName = $"{_loggerOptions.CategoryNamePrefix}{source}"; + var logger = _loggerFactory.CreateLogger(categoryName); + var level = _loggerOptions.LogLevelConverter(logLevel); + logger.Log(level, exception, message, parameters ?? []); + } + catch (ObjectDisposedException) + { + } } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index b709838c0..7b93d5430 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -143,7 +143,7 @@ public virtual void Dispose() { if (!buffer.IsEmpty) { - if (PacketFormatterAdapter.TryDecode(buffer, out var packet, out consumed, out observed, out var received)) + if (PacketFormatterAdapter.TryDecode(buffer,_packetInspector, out var packet, out consumed, out observed, out var received)) { BytesReceived += received; diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs index 94d8e4dff..796bdc01f 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs @@ -18,6 +18,7 @@ static class MqttPacketFormatterAdapterExtensions public static bool TryDecode( this MqttPacketFormatterAdapter formatter, in ReadOnlySequence input, + MqttPacketInspector? packetInspector, [MaybeNullWhen(false)] out MqttPacket packet, out SequencePosition consumed, out SequencePosition observed, @@ -51,6 +52,12 @@ public static bool TryDecode( var bodySlice = copy.Slice(0, bodyLength); var bodySegment = GetArraySegment(ref bodySlice); + if (packetInspector != null) + { + packetInspector.FillReceiveBuffer(input.Slice(input.Start, headerLength).ToArray()); + packetInspector.FillReceiveBuffer(bodySegment.ToArray()); + } + var receivedMqttPacket = new ReceivedMqttPacket(fixedHeader, bodySegment, headerLength + bodyLength); if (formatter.ProtocolVersion == MqttProtocolVersion.Unknown) { diff --git a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs index 7b9cb19f3..3a65dfa99 100644 --- a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs @@ -116,7 +116,7 @@ public async Task After() { if (!buffer.IsEmpty) { - if (MqttPacketFormatterAdapterExtensions.TryDecode(mqttPacketFormatter, buffer, out var packet, out consumed, out observed, out var received)) + if (MqttPacketFormatterAdapterExtensions.TryDecode(mqttPacketFormatter, buffer, null, out var packet, out consumed, out observed, out var received)) { break; } diff --git a/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs b/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs index 7517d559d..7f810feb3 100644 --- a/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs @@ -28,19 +28,19 @@ public void TestTryDeserialize() var read = 0; part = sequence.Slice(sequence.Start, 0); // empty message should fail - var result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + var result = serializer.TryDecode(part,null, out _, out consumed, out observed, out read); Assert.IsFalse(result); part = sequence.Slice(sequence.Start, 1); // partial fixed header should fail - result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + result = serializer.TryDecode(part, null, out _, out consumed, out observed, out read); Assert.IsFalse(result); part = sequence.Slice(sequence.Start, 4); // partial body should fail - result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + result = serializer.TryDecode(part, null, out _, out consumed, out observed, out read); Assert.IsFalse(result); part = sequence; // complete msg should work - result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + result = serializer.TryDecode(part, null, out _, out consumed, out observed, out read); Assert.IsTrue(result); } } \ No newline at end of file diff --git a/Source/MQTTnet.Tests/BaseTestClass.cs b/Source/MQTTnet.Tests/BaseTestClass.cs index 8e5248e7f..cd7447636 100644 --- a/Source/MQTTnet.Tests/BaseTestClass.cs +++ b/Source/MQTTnet.Tests/BaseTestClass.cs @@ -2,21 +2,35 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Formatter; using MQTTnet.Tests.Mockups; +using System; +using System.Threading.Tasks; namespace MQTTnet.Tests { public abstract class BaseTestClass { public TestContext TestContext { get; set; } - - protected TestEnvironment CreateTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + + protected TestEnvironmentCollection CreateTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + { + var mqttnet = new TestEnvironment(TestContext, protocolVersion); + return new TestEnvironmentCollection(mqttnet); + } + + protected TestEnvironmentCollection CreateAspNetCoreTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + { + var aspnetcore = new AspNetCoreTestEnvironment(TestContext, protocolVersion); + return new TestEnvironmentCollection(aspnetcore); + } + + protected TestEnvironmentCollection CreateMixedTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) { - return new TestEnvironment(TestContext, protocolVersion); + var mqttnet = new TestEnvironment(TestContext, protocolVersion); + var aspnetcore = new AspNetCoreTestEnvironment(TestContext, protocolVersion); + return new TestEnvironmentCollection(mqttnet, aspnetcore); } protected Task LongTestDelay() diff --git a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs index 4c8a44ad8..14d01fcf6 100644 --- a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs @@ -22,7 +22,8 @@ public sealed class LowLevelMqttClient_Tests : BaseTestClass [TestMethod] public async Task Authenticate() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -43,7 +44,8 @@ public async Task Authenticate() [TestMethod] public async Task Connect_And_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -78,7 +80,8 @@ public async Task Connect_To_Wrong_Host() [TestMethod] public async Task Loose_Connection() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreServerLogErrors = true; @@ -116,7 +119,8 @@ public async Task Loose_Connection() [TestMethod] public async Task Maintain_IsConnected_Property() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreServerLogErrors = true; @@ -161,7 +165,8 @@ public async Task Maintain_IsConnected_Property() [TestMethod] public async Task Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs index 75b00d39d..f8173bca6 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs @@ -82,7 +82,8 @@ public async Task ConnectTimeout_Throws_Exception() [TestMethod] public async Task Disconnect_Clean() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -110,7 +111,8 @@ public async Task Disconnect_Clean() [TestMethod] public async Task Disconnect_Clean_With_Custom_Reason() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -138,7 +140,8 @@ public async Task Disconnect_Clean_With_Custom_Reason() [TestMethod] public async Task Disconnect_Clean_With_User_Properties() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -169,7 +172,8 @@ public async Task Disconnect_Clean_With_User_Properties() [TestMethod] public async Task No_Unobserved_Exception() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -201,7 +205,8 @@ public async Task No_Unobserved_Exception() [TestMethod] public async Task Return_Non_Success() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs index b8b1ba9e2..5c87a4621 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs @@ -36,7 +36,8 @@ public async Task Concurrent_Processing(MqttQualityOfServiceLevel qos) long concurrency = 0; var success = false; - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var publisher = await testEnvironment.ConnectClient(); @@ -74,7 +75,8 @@ async Task InvokeInternal() [TestMethod] public async Task Connect_Disconnect_Connect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -91,7 +93,8 @@ public async Task Connect_Disconnect_Connect() [ExpectedException(typeof(InvalidOperationException))] public async Task Connect_Multiple_Times_Should_Fail() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -136,7 +139,8 @@ public async Task Disconnect_Event_Contains_Exception() [TestMethod] public async Task Ensure_Queue_Drain() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectLowLevelClient(); @@ -177,7 +181,8 @@ await client.SendAsync( [TestMethod] public async Task Fire_Disconnected_Event_On_Server_Shutdown() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -200,7 +205,8 @@ public async Task Fire_Disconnected_Event_On_Server_Shutdown() [TestMethod] public async Task Frequent_Connects() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -268,7 +274,8 @@ public async Task Invalid_Connect_Throws_Exception() [TestMethod] public async Task No_Payload() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -304,7 +311,8 @@ await receiver.SubscribeAsync( [TestMethod] public async Task NoConnectedHandler_Connect_DoesNotThrowException() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -317,7 +325,8 @@ public async Task NoConnectedHandler_Connect_DoesNotThrowException() [TestMethod] public async Task NoDisconnectedHandler_Disconnect_DoesNotThrowException() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -332,7 +341,8 @@ public async Task NoDisconnectedHandler_Disconnect_DoesNotThrowException() [TestMethod] public async Task PacketIdentifier_In_Publish_Result() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -365,7 +375,8 @@ public async Task Preserve_Message_Order() // is an issue). const int MessagesCount = 50; - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -410,7 +421,8 @@ public async Task Preserve_Message_Order_With_Delayed_Acknowledgement() // is an issue). const int MessagesCount = 50; - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -454,7 +466,8 @@ Task Handler1(MqttApplicationMessageReceivedEventArgs eventArgs) [TestMethod] public async Task Publish_QoS_0_Over_Period_Exceeding_KeepAlive() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { const int KeepAlivePeriodSecs = 3; @@ -486,7 +499,8 @@ public async Task Publish_QoS_0_Over_Period_Exceeding_KeepAlive() [TestMethod] public async Task Publish_QoS_1_In_ApplicationMessageReceiveHandler() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -530,7 +544,8 @@ public async Task Publish_QoS_1_In_ApplicationMessageReceiveHandler() [TestMethod] public async Task Publish_With_Correct_Retain_Flag() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -563,7 +578,8 @@ public async Task Publish_With_Correct_Retain_Flag() [TestMethod] public async Task Reconnect() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -586,7 +602,8 @@ public async Task Reconnect() [TestMethod] public async Task Reconnect_From_Disconnected_Event() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -627,7 +644,8 @@ public async Task Reconnect_From_Disconnected_Event() [TestMethod] public async Task Reconnect_While_Server_Offline() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -665,7 +683,8 @@ public async Task Reconnect_While_Server_Offline() [TestMethod] public async Task Send_Manual_Ping() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -677,7 +696,8 @@ public async Task Send_Manual_Ping() [TestMethod] public async Task Send_Reply_For_Any_Received_Message() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -723,7 +743,8 @@ Task Handler2(MqttApplicationMessageReceivedEventArgs eventArgs) [TestMethod] public async Task Send_Reply_In_Message_Handler() { - using (var testEnvironment = new TestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client1 = await testEnvironment.ConnectClient(); @@ -770,7 +791,8 @@ public async Task Send_Reply_In_Message_Handler() [TestMethod] public async Task Send_Reply_In_Message_Handler_For_Same_Client() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -806,7 +828,8 @@ public async Task Send_Reply_In_Message_Handler_For_Same_Client() [TestMethod] public async Task Set_ClientWasConnected_On_ClientDisconnect() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -826,7 +849,8 @@ public async Task Set_ClientWasConnected_On_ClientDisconnect() [TestMethod] public async Task Set_ClientWasConnected_On_ServerDisconnect() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -847,7 +871,8 @@ public async Task Set_ClientWasConnected_On_ServerDisconnect() [TestMethod] public async Task Subscribe_In_Callback_Events() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -887,7 +912,8 @@ public async Task Subscribe_In_Callback_Events() [TestMethod] public async Task Subscribe_With_QoS2() { - using (var testEnvironment = new TestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client1 = await testEnvironment.ConnectClient(o => o.WithProtocolVersion(MqttProtocolVersion.V500)); diff --git a/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs b/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs index 7db513f58..71e17c68e 100644 --- a/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs +++ b/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs @@ -18,7 +18,8 @@ public sealed class PacketInspection_Tests : BaseTestClass [TestMethod] public async Task Inspect_Client_Packets() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs index 02dddbfa8..209b51ae1 100644 --- a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs +++ b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs @@ -22,7 +22,8 @@ public sealed class Rpc_Tests : BaseTestClass [TestMethod] public async Task Execute_Success_MQTT_V5_Mixed_Clients() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(); @@ -54,7 +55,8 @@ public async Task Execute_Success_Parameters_Propagated_Correctly() { TestParametersTopicGenerationStrategy.ExpectedParamName, "123" } }; - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -132,7 +134,8 @@ public Task Execute_Success_With_QoS_2_MQTT_V5_Use_ResponseTopic() [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_Timeout() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -147,7 +150,8 @@ public async Task Execute_Timeout() [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_Timeout_MQTT_V5_Mixed_Clients() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(); @@ -172,7 +176,8 @@ public async Task Execute_Timeout_MQTT_V5_Mixed_Clients() [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_With_Custom_Topic_Names() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -196,7 +201,8 @@ public void Use_Factory() async Task Execute_Success(MqttQualityOfServiceLevel qosLevel, MqttProtocolVersion protocolVersion) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(protocolVersion)); @@ -217,7 +223,8 @@ async Task Execute_Success(MqttQualityOfServiceLevel qosLevel, MqttProtocolVersi async Task Execute_Success_MQTT_V5(MqttQualityOfServiceLevel qosLevel) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); diff --git a/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs index 7b7570a99..d2d2fdf93 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs @@ -21,7 +21,8 @@ public sealed class Client_Tests : BaseTestClass [TestMethod] public async Task Connect_With_New_Mqtt_Features() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -66,7 +67,8 @@ await client.PublishAsync(new MqttApplicationMessageBuilder() [TestMethod] public async Task Connect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); await testEnvironment.ConnectClient(o => o.WithProtocolVersion(MqttProtocolVersion.V500).Build()); @@ -76,7 +78,8 @@ public async Task Connect() [TestMethod] public async Task Connect_And_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -88,7 +91,8 @@ public async Task Connect_And_Disconnect() [TestMethod] public async Task Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -114,7 +118,8 @@ public async Task Subscribe() [TestMethod] public async Task Unsubscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -133,7 +138,8 @@ public async Task Unsubscribe() public async Task Publish_QoS_0_LargeBuffer() { using var recyclableMemoryStream = GetLargePayload(); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -149,7 +155,8 @@ public async Task Publish_QoS_0_LargeBuffer() public async Task Publish_QoS_1_LargeBuffer() { using var recyclableMemoryStream = GetLargePayload(); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -165,7 +172,8 @@ public async Task Publish_QoS_1_LargeBuffer() public async Task Publish_QoS_2_LargeBuffer() { using var recyclableMemoryStream = GetLargePayload(); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -180,7 +188,8 @@ public async Task Publish_QoS_2_LargeBuffer() [TestMethod] public async Task Publish_QoS_0() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -195,7 +204,8 @@ public async Task Publish_QoS_0() [TestMethod] public async Task Publish_QoS_1() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -210,7 +220,8 @@ public async Task Publish_QoS_1() [TestMethod] public async Task Publish_QoS_2() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -226,7 +237,8 @@ public async Task Publish_QoS_2() public async Task Publish_With_RecyclableMemoryStream() { var memoryManager = new RecyclableMemoryStreamManager(options: new RecyclableMemoryStreamManager.Options { BlockSize = 4096 }); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -264,7 +276,8 @@ public async Task Publish_With_RecyclableMemoryStream() [TestMethod] public async Task Publish_With_Properties() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -293,7 +306,8 @@ public async Task Publish_With_Properties() [TestMethod] public async Task Subscribe_And_Publish() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -321,7 +335,8 @@ public async Task Subscribe_And_Publish() [TestMethod] public async Task Publish_And_Receive_New_Properties() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs index 067fed791..b6edba296 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs @@ -19,7 +19,8 @@ public sealed class Server_Tests : BaseTestClass [TestMethod] public async Task Will_Message_Send() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -28,9 +29,11 @@ public async Task Will_Message_Send() var c1 = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); var receivedMessagesCount = 0; + var taskSource = new TaskCompletionSource(); c1.ApplicationMessageReceivedAsync += e => { Interlocked.Increment(ref receivedMessagesCount); + taskSource.TrySetResult(); return CompletedTask.Instance; }; @@ -39,7 +42,7 @@ public async Task Will_Message_Send() var c2 = await testEnvironment.ConnectClient(clientOptions); c2.Dispose(); // Dispose will not send a DISCONNECT packet first so the will message must be sent. - await LongTestDelay(); + await taskSource.Task; Assert.AreEqual(1, receivedMessagesCount); } @@ -48,7 +51,8 @@ public async Task Will_Message_Send() [TestMethod] public async Task Validate_IsSessionPresent() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -85,7 +89,8 @@ public async Task Validate_IsSessionPresent() [TestMethod] public async Task Connect_with_Undefined_SessionExpiryInterval() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -125,7 +130,8 @@ public async Task Connect_with_Undefined_SessionExpiryInterval() [TestMethod] public async Task Reconnect_with_different_SessionExpiryInterval() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -177,17 +183,18 @@ public async Task Reconnect_with_different_SessionExpiryInterval() [TestMethod] public async Task Disconnect_with_Reason() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { - var disconnectReason = MqttClientDisconnectReason.UnspecifiedError; + var disconnectReasonTaskSource = new TaskCompletionSource(); - string testClientId = null; + var testClientIdTaskSource = new TaskCompletionSource(); await testEnvironment.StartServer(); testEnvironment.Server.ClientConnectedAsync += e => { - testClientId = e.ClientId; + testClientIdTaskSource.TrySetResult(e.ClientId); return CompletedTask.Instance; }; @@ -195,7 +202,7 @@ public async Task Disconnect_with_Reason() client.DisconnectedAsync += e => { - disconnectReason = e.Reason; + disconnectReasonTaskSource.TrySetResult(e.Reason); return CompletedTask.Instance; }; @@ -203,14 +210,14 @@ public async Task Disconnect_with_Reason() // Test client should be connected now + var testClientId = await testClientIdTaskSource.Task; Assert.IsTrue(testClientId != null); // Have the server disconnect the client with AdministrativeAction reason await testEnvironment.Server.DisconnectClientAsync(testClientId, MqttDisconnectReasonCode.AdministrativeAction); - await LongTestDelay(); - + var disconnectReason = await disconnectReasonTaskSource.Task; // The reason should be returned to the client in the DISCONNECT packet Assert.AreEqual(MqttClientDisconnectReason.AdministrativeAction, disconnectReason); diff --git a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs new file mode 100644 index 000000000..85d434023 --- /dev/null +++ b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs @@ -0,0 +1,168 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Formatter; +using MQTTnet.Internal; +using MQTTnet.Server; +using System; +using System.Linq; +using System.Net.NetworkInformation; +using System.Threading.Tasks; + +namespace MQTTnet.Tests.Mockups +{ + public sealed class AspNetCoreTestEnvironment : TestEnvironment + { + private WebApplication _app; + + public AspNetCoreTestEnvironment() + : this(null) + { + } + + public AspNetCoreTestEnvironment(TestContext testContext, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + : base(testContext, protocolVersion) + { + } + + public override IMqttClient CreateClient() + { + var services = new ServiceCollection(); + var clientBuilder = services.AddMqttClient(); + if (EnableLogger) + { + clientBuilder.UseAspNetCoreMqttNetLogger(); + } + else + { + clientBuilder.UseMqttNetNullLogger(); + } + + var s = services.BuildServiceProvider(); + var client = s.GetRequiredService().CreateMqttClient(); + + client.ConnectingAsync += e => + { + if (TestContext != null) + { + var clientOptions = e.ClientOptions; + var existingClientId = clientOptions.ClientId; + if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) + { + clientOptions.ClientId = TestContext.TestName + "_" + existingClientId; + } + } + + return CompletedTask.Instance; + }; + + lock (_clients) + { + _clients.Add(client); + } + + return client; + } + + public override MqttServer CreateServer(MqttServerOptions options) + { + throw new NotSupportedException("Can not create MqttServer in AspNetCoreTestEnvironment."); + } + + public override async Task StartServer(Action configure) + { + if (Server != null) + { + throw new InvalidOperationException("Server already started."); + } + + var appBuilder = WebApplication.CreateBuilder(); + appBuilder.Services.AddMqttServer(optionsBuilder => + { + optionsBuilder.WithDefaultEndpoint(); + optionsBuilder.WithDefaultEndpointPort(ServerPort); + optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); + }).ConfigureMqttServer(configure, o => + { + if (o.DefaultEndpointOptions.Port == 0) + { + o.DefaultEndpointOptions.Port = GetServerPort(); + } + }); + + appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); + appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); + + _app = appBuilder.Build(); + + // The OS has chosen the port to we have to properly expose it to the tests. + ServerPort = _app.Services.GetRequiredService().DefaultEndpointOptions.Port; + + await _app.StartAsync(); + Server = _app.Services.GetRequiredService(); + return Server; + } + + public override async Task StartServer(MqttServerOptionsBuilder optionsBuilder) + { + if (Server != null) + { + throw new InvalidOperationException("Server already started."); + } + + if (ServerPort == 0) + { + ServerPort = GetServerPort(); + } + + optionsBuilder.WithDefaultEndpoint(); + optionsBuilder.WithDefaultEndpointPort(ServerPort); + optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); + + var options = optionsBuilder.Build(); + + var appBuilder = WebApplication.CreateBuilder(); + appBuilder.Services.AddMqttServer(); + appBuilder.Services.AddSingleton(options); + + appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); + appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); + + _app = appBuilder.Build(); + await _app.StartAsync(); + Server = _app.Services.GetRequiredService(); + return Server; + } + + public override void Dispose() + { + if (_app != null) + { + _app.StopAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + _app.DisposeAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + _app = null; + } + base.Dispose(); + } + + private static int GetServerPort() + { + var listeners = IPGlobalProperties.GetIPGlobalProperties().GetActiveTcpListeners(); + var portSet = listeners.Select(i => i.Port).ToHashSet(); + + var port = 1883; + while (!portSet.Add(port)) + { + port += 1; + } + return port; + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs index 4f1391f15..daedba802 100644 --- a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs @@ -19,10 +19,11 @@ namespace MQTTnet.Tests.Mockups { - public sealed class TestEnvironment : IDisposable + public class TestEnvironment : IDisposable { + bool _disposed = false; readonly List _clientErrors = new(); - readonly List _clients = new(); + protected readonly List _clients = new(); readonly List _exceptions = new(); readonly List _lowLevelClients = new(); readonly MqttProtocolVersion _protocolVersion; @@ -87,7 +88,7 @@ public TestEnvironment(TestContext testContext, MqttProtocolVersion protocolVers public bool IgnoreServerLogErrors { get; set; } - public MqttServer Server { get; private set; } + public MqttServer Server { get; protected set; } public MqttNetEventLogger ServerLogger { get; } = new("server"); @@ -195,7 +196,7 @@ public TestApplicationMessageReceivedHandler CreateApplicationMessageHandler(IMq return new TestApplicationMessageReceivedHandler(mqttClient); } - public IMqttClient CreateClient() + public virtual IMqttClient CreateClient() { var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; @@ -249,7 +250,7 @@ public ILowLevelMqttClient CreateLowLevelClient() return client; } - public MqttServer CreateServer(MqttServerOptions options) + public virtual MqttServer CreateServer(MqttServerOptions options) { if (Server != null) { @@ -278,8 +279,14 @@ public MqttServer CreateServer(MqttServerOptions options) return Server; } - public void Dispose() + public virtual void Dispose() { + if (_disposed) + { + return; + } + _disposed = true; + try { lock (_clients) @@ -350,7 +357,7 @@ public Task StartServer() return StartServer(ServerFactory.CreateServerOptionsBuilder()); } - public async Task StartServer(MqttServerOptionsBuilder optionsBuilder) + public virtual async Task StartServer(MqttServerOptionsBuilder optionsBuilder) { optionsBuilder.WithDefaultEndpoint(); optionsBuilder.WithDefaultEndpointPort(ServerPort); @@ -365,7 +372,7 @@ public async Task StartServer(MqttServerOptionsBuilder optionsBuilde return server; } - public async Task StartServer(Action configure) + public virtual async Task StartServer(Action configure) { var optionsBuilder = ServerFactory.CreateServerOptionsBuilder(); diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironmentCollection.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironmentCollection.cs new file mode 100644 index 000000000..2fc57593a --- /dev/null +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironmentCollection.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections; +using System.Collections.Generic; + +namespace MQTTnet.Tests.Mockups +{ + public class TestEnvironmentCollection : IReadOnlyCollection, IDisposable + { + private readonly TestEnvironment[] _testEnvironments; + + public int Count => _testEnvironments.Length; + + public TestEnvironmentCollection(params TestEnvironment[] testEnvironments) + { + _testEnvironments = testEnvironments; + } + + public IEnumerator GetEnumerator() + { + foreach (var environment in _testEnvironments) + { + yield return environment; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public void Dispose() + { + foreach (var environment in _testEnvironments) + { + environment.Dispose(); + } + } + } +} diff --git a/Source/MQTTnet.Tests/RoundtripTime_Tests.cs b/Source/MQTTnet.Tests/RoundtripTime_Tests.cs index b1359bc3d..b8b0fe25c 100644 --- a/Source/MQTTnet.Tests/RoundtripTime_Tests.cs +++ b/Source/MQTTnet.Tests/RoundtripTime_Tests.cs @@ -13,14 +13,14 @@ namespace MQTTnet.Tests { [TestClass] - public class RoundtripTime_Tests + public class RoundtripTime_Tests : BaseTestClass { - public TestContext TestContext { get; set; } [TestMethod] public async Task Round_Trip_Time() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs b/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs index 69ed38875..14f8b00ac 100644 --- a/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs @@ -28,7 +28,8 @@ public Task Connect_With_Client_Id() async Task Connect_With_Client_Id(string expectedClientId, string expectedReturnedClientId, string usedClientId, string assignedClientId) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { string serverConnectedClientId = null; string serverDisconnectedClientId = null; diff --git a/Source/MQTTnet.Tests/Server/Connection_Tests.cs b/Source/MQTTnet.Tests/Server/Connection_Tests.cs index ef2a482c4..26d180269 100644 --- a/Source/MQTTnet.Tests/Server/Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Connection_Tests.cs @@ -20,7 +20,8 @@ public sealed class Connection_Tests : BaseTestClass [TestMethod] public async Task Close_Idle_Connection_On_Connect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); @@ -49,7 +50,8 @@ public async Task Close_Idle_Connection_On_Connect() [TestMethod] public async Task Send_Garbage() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); diff --git a/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs b/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs index 83a30ecfe..f3b946d7f 100644 --- a/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs @@ -12,7 +12,8 @@ public sealed class Cross_Version_Tests : BaseTestClass [TestMethod] public async Task Send_V311_Receive_V500() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -36,7 +37,8 @@ public async Task Send_V311_Receive_V500() [TestMethod] public async Task Send_V500_Receive_V311() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -48,9 +50,9 @@ public async Task Send_V500_Receive_V311() var applicationMessage = new MqttApplicationMessageBuilder().WithTopic("My/Message") .WithPayload("My_Payload") - .WithUserProperty("A", "B") - .WithResponseTopic("Response") - .WithCorrelationData(Encoding.UTF8.GetBytes("Correlation")) + //.WithUserProperty("A", "B") + //.WithResponseTopic("Response") + //.WithCorrelationData(Encoding.UTF8.GetBytes("Correlation")) .Build(); await sender.PublishAsync(applicationMessage); diff --git a/Source/MQTTnet.Tests/Server/Events_Tests.cs b/Source/MQTTnet.Tests/Server/Events_Tests.cs index 415a2a2d2..5e726979a 100644 --- a/Source/MQTTnet.Tests/Server/Events_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Events_Tests.cs @@ -9,6 +9,7 @@ using MQTTnet.Internal; using MQTTnet.Protocol; using MQTTnet.Server; +using MQTTnet.Tests.Mockups; namespace MQTTnet.Tests.Server { @@ -18,20 +19,21 @@ public sealed class Events_Tests : BaseTestClass [TestMethod] public async Task Fire_Client_Connected_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - - ClientConnectedEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientConnectedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task; Assert.IsNotNull(eventArgs); @@ -45,21 +47,23 @@ public async Task Fire_Client_Connected_Event() [TestMethod] public async Task Fire_Client_Disconnected_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - ClientDisconnectedEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientDisconnectedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); await client.DisconnectAsync(); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task; Assert.IsNotNull(eventArgs); @@ -72,21 +76,23 @@ public async Task Fire_Client_Disconnected_Event() [TestMethod] public async Task Fire_Client_Subscribed_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - ClientSubscribedTopicEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientSubscribedTopicAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(); await client.SubscribeAsync("The/Topic", MqttQualityOfServiceLevel.AtLeastOnce); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task; Assert.IsNotNull(eventArgs); @@ -99,21 +105,23 @@ public async Task Fire_Client_Subscribed_Event() [TestMethod] public async Task Fire_Client_Unsubscribed_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - ClientUnsubscribedTopicEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientUnsubscribedTopicAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(); await client.UnsubscribeAsync("The/Topic"); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task; Assert.IsNotNull(eventArgs); @@ -125,21 +133,23 @@ public async Task Fire_Client_Unsubscribed_Event() [TestMethod] public async Task Fire_Application_Message_Received_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - InterceptingPublishEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.InterceptingPublishAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(); await client.PublishStringAsync("The_Topic", "The_Payload"); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task; Assert.IsNotNull(eventArgs); @@ -152,20 +162,21 @@ public async Task Fire_Application_Message_Received_Event() [TestMethod] public async Task Fire_Started_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = testEnvironment.CreateServer(new MqttServerOptions()); - EventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); server.StartedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; await server.StartAsync(); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task; Assert.IsNotNull(eventArgs); } @@ -174,20 +185,21 @@ public async Task Fire_Started_Event() [TestMethod] public async Task Fire_Stopped_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - - EventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.StoppedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; await server.StopAsync(); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task; Assert.IsNotNull(eventArgs); } diff --git a/Source/MQTTnet.Tests/Server/General.cs b/Source/MQTTnet.Tests/Server/General.cs index 45cff1983..3228f0a6f 100644 --- a/Source/MQTTnet.Tests/Server/General.cs +++ b/Source/MQTTnet.Tests/Server/General.cs @@ -25,7 +25,8 @@ public sealed class General_Tests : BaseTestClass [TestMethod] public async Task Client_Disconnect_Without_Errors() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { bool clientWasConnected; @@ -54,7 +55,8 @@ public async Task Client_Disconnect_Without_Errors() [TestMethod] public async Task Collect_Messages_In_Disconnected_Session() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithPersistentSessions()); @@ -88,7 +90,8 @@ public async Task Collect_Messages_In_Disconnected_Session() [TestMethod] public async Task Deny_Connection() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -110,7 +113,8 @@ public async Task Deny_Connection() [TestMethod] public async Task Do_Not_Send_Retained_Messages_For_Denied_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -162,7 +166,8 @@ public async Task Do_Not_Send_Retained_Messages_For_Denied_Subscription() [TestMethod] public async Task Handle_Clean_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder()); @@ -204,7 +209,8 @@ public async Task Handle_Lots_Of_Parallel_Retained_Messages() { const int clientCount = 50; - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -268,7 +274,8 @@ await client.PublishAsync( [TestMethod] public async Task Intercept_Application_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -299,7 +306,8 @@ public async Task Intercept_Application_Message() [TestMethod] public async Task Intercept_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); server.InterceptingPublishAsync += e => @@ -331,7 +339,8 @@ public async Task Intercept_Message() [TestMethod] public async Task Intercept_Undelivered() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var undeliverd = string.Empty; @@ -357,7 +366,8 @@ public async Task Intercept_Undelivered() [TestMethod] public async Task No_Messages_If_No_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -392,7 +402,8 @@ public async Task No_Messages_If_No_Subscription() [TestMethod] public async Task Persist_Retained_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { List savedRetainedMessages = null; @@ -416,7 +427,8 @@ public async Task Persist_Retained_Message() [TestMethod] public async Task Publish_After_Client_Connects() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); server.ClientConnectedAsync += async e => @@ -477,7 +489,8 @@ public async Task Publish_Exactly_Once_0x02() [TestMethod] public async Task Publish_From_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -508,10 +521,10 @@ await server.InjectApplicationMessage( [TestMethod] public async Task Publish_Multiple_Clients() { - var receivedMessagesCount = 0; - - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { + var receivedMessagesCount = 0; await testEnvironment.StartServer(); var c1 = await testEnvironment.ConnectClient(); @@ -549,7 +562,8 @@ public async Task Publish_Multiple_Clients() [TestMethod] public async Task Remove_Session() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder()); @@ -568,7 +582,8 @@ public async Task Remove_Session() [TestMethod] public async Task Same_Client_Id_Connect_Disconnect_Event_Order() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -655,7 +670,8 @@ public async Task Same_Client_Id_Connect_Disconnect_Event_Order() [TestMethod] public async Task Same_Client_Id_Refuse_Connection() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -754,7 +770,8 @@ public async Task Same_Client_Id_Refuse_Connection() [TestMethod] public async Task Send_Long_Body() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { const int PayloadSizeInMB = 30; const int CharCount = PayloadSizeInMB * 1024 * 1024; @@ -774,15 +791,16 @@ public async Task Send_Long_Body() } } - byte[] receivedBody = null; + TaskCompletionSource> receivedBodyTaskSource = new(); await testEnvironment.StartServer(); var client1 = await testEnvironment.ConnectClient(); client1.ApplicationMessageReceivedAsync += e => { - receivedBody = e.ApplicationMessage.Payload.ToArray(); - return CompletedTask.Instance; + var payload = e.ApplicationMessage.Payload; + receivedBodyTaskSource.TrySetResult(payload); + return Task.CompletedTask; }; await client1.SubscribeAsync("string"); @@ -790,16 +808,17 @@ public async Task Send_Long_Body() var client2 = await testEnvironment.ConnectClient(); await client2.PublishBinaryAsync("string", longBody); - await Task.Delay(TimeSpan.FromSeconds(5)); + var receivedBody = await receivedBodyTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); - Assert.IsTrue(longBody.SequenceEqual(receivedBody ?? new byte[0])); + Assert.IsTrue(MqttMemoryHelper.SequenceEqual(receivedBody, new ReadOnlySequence(longBody))); } } [TestMethod] public async Task Set_Subscription_At_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -837,7 +856,8 @@ public async Task Set_Subscription_At_Server() [TestMethod] public async Task Shutdown_Disconnects_Clients_Gracefully() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder()); @@ -863,7 +883,8 @@ public async Task Shutdown_Disconnects_Clients_Gracefully() [TestMethod] public async Task Stop_And_Restart() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -895,7 +916,8 @@ public async Task Stop_And_Restart() [DataRow(null, null)] public async Task Use_Admissible_Credentials(string username, string password) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -913,7 +935,8 @@ public async Task Use_Admissible_Credentials(string username, string password) [TestMethod] public async Task Use_Empty_Client_ID() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client2 = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithClientId("b").WithCleanSession(false)); @@ -936,7 +959,8 @@ public async Task Use_Empty_Client_ID() [TestMethod] public async Task Disconnect_Client_with_Reason() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var disconnectPacketReceived = false; @@ -1007,7 +1031,8 @@ async Task TestPublishAsync( MqttQualityOfServiceLevel filterQualityOfServiceLevel, int expectedReceivedMessagesCount) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Injection_Tests.cs b/Source/MQTTnet.Tests/Server/Injection_Tests.cs index cefbc34dd..9c6db90d0 100644 --- a/Source/MQTTnet.Tests/Server/Injection_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Injection_Tests.cs @@ -11,7 +11,8 @@ public sealed class Injection_Tests : BaseTestClass [TestMethod] public async Task Inject_Application_Message_At_Session_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var receiver1 = await testEnvironment.ConnectClient(); @@ -40,7 +41,8 @@ public async Task Inject_Application_Message_At_Session_Level() [TestMethod] public async Task Inject_ApplicationMessage_At_Server_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -64,7 +66,8 @@ public async Task Inject_ApplicationMessage_At_Server_Level() [TestMethod] public async Task Intercept_Injected_Application_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs index 8bcd8b0a8..a0a1216fc 100644 --- a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs @@ -18,7 +18,8 @@ public sealed class KeepAlive_Tests : BaseTestClass [TestMethod] public async Task Disconnect_Client_DueTo_KeepAlive() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Load_Tests.cs b/Source/MQTTnet.Tests/Server/Load_Tests.cs index 466e1cfa0..019fa21cb 100644 --- a/Source/MQTTnet.Tests/Server/Load_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Load_Tests.cs @@ -14,7 +14,8 @@ public sealed class Load_Tests : BaseTestClass [TestMethod] public async Task Handle_100_000_Messages_In_Receiving_Client() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -64,7 +65,8 @@ await client.PublishAsync(message) [TestMethod] public async Task Handle_100_000_Messages_In_Low_Level_Client() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -127,7 +129,8 @@ await client.SendAsync(publishPacket, CancellationToken.None) [TestMethod] public async Task Handle_100_000_Messages_In_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/No_Local_Tests.cs b/Source/MQTTnet.Tests/Server/No_Local_Tests.cs index 5a8627d2f..208d3979c 100644 --- a/Source/MQTTnet.Tests/Server/No_Local_Tests.cs +++ b/Source/MQTTnet.Tests/Server/No_Local_Tests.cs @@ -27,7 +27,8 @@ async Task ExecuteTest( bool noLocal, int expectedCountAfterPublish) { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Publishing_Tests.cs b/Source/MQTTnet.Tests/Server/Publishing_Tests.cs index e829e2573..469b03fd4 100644 --- a/Source/MQTTnet.Tests/Server/Publishing_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Publishing_Tests.cs @@ -19,7 +19,8 @@ public sealed class Publishing_Tests : BaseTestClass [ExpectedException(typeof(MqttClientDisconnectedException))] public async Task Disconnect_While_Publishing() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -34,7 +35,8 @@ public async Task Disconnect_While_Publishing() [TestMethod] public async Task Return_NoMatchingSubscribers_When_Not_Subscribed() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -55,7 +57,8 @@ public async Task Return_NoMatchingSubscribers_When_Not_Subscribed() [TestMethod] public async Task Return_Success_When_Subscribed() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -76,7 +79,8 @@ public async Task Return_Success_When_Subscribed() [TestMethod] public async Task Intercept_Client_Enqueue() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -110,7 +114,8 @@ public async Task Intercept_Client_Enqueue() [TestMethod] public async Task Intercept_Client_Enqueue_Multiple_Clients_Subscribed_Messages_Are_Filtered() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/QoS_Tests.cs b/Source/MQTTnet.Tests/Server/QoS_Tests.cs index b48a1a99c..b52c159d8 100644 --- a/Source/MQTTnet.Tests/Server/QoS_Tests.cs +++ b/Source/MQTTnet.Tests/Server/QoS_Tests.cs @@ -17,7 +17,8 @@ public sealed class QoS_Tests : BaseTestClass [TestMethod] public async Task Fire_Event_On_Client_Acknowledges_QoS_0() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -44,7 +45,8 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_0() [TestMethod] public async Task Fire_Event_On_Client_Acknowledges_QoS_1() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -75,7 +77,8 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_1() [TestMethod] public async Task Fire_Event_On_Client_Acknowledges_QoS_2() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -114,7 +117,8 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_2() [TestMethod] public async Task Preserve_Message_Order_For_Queued_Messages() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(o => o.WithPersistentSessions()); diff --git a/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs b/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs index 43e706247..4dcc495fa 100644 --- a/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs @@ -25,7 +25,8 @@ public Task Subscribe_Without_Retain_As_Published() async Task ExecuteTest(bool retainAsPublished) { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs b/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs index aa4b7f227..78d1aa4e6 100644 --- a/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs @@ -36,7 +36,8 @@ async Task ExecuteTest( int expectedCountAfterSecondPublish, int expectedCountAfterSecondSubscribe) { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs index fbbe500d8..583d72643 100644 --- a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs @@ -17,7 +17,8 @@ public sealed class Retained_Messages_Tests : BaseTestClass [TestMethod] public async Task Clear_Retained_Message_With_Empty_Payload() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -42,7 +43,8 @@ public async Task Clear_Retained_Message_With_Empty_Payload() [TestMethod] public async Task Clear_Retained_Message_With_Null_Payload() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -67,7 +69,8 @@ public async Task Clear_Retained_Message_With_Null_Payload() [TestMethod] public async Task Downgrade_QoS_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -97,7 +100,8 @@ await c1.PublishAsync( [TestMethod] public async Task No_Upgrade_QoS_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -127,7 +131,8 @@ await c1.PublishAsync( [TestMethod] public async Task Receive_No_Retained_Message_After_Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -148,7 +153,8 @@ public async Task Receive_No_Retained_Message_After_Subscribe() [TestMethod] public async Task Receive_Retained_Message_After_Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -171,7 +177,8 @@ public async Task Receive_Retained_Message_After_Subscribe() [TestMethod] public async Task Receive_Retained_Messages_From_Higher_Qos_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -202,7 +209,8 @@ await c1.PublishAsync( [TestMethod] public async Task Retained_Messages_Flow() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); @@ -234,7 +242,8 @@ public async Task Retained_Messages_Flow() [TestMethod] public async Task Server_Reports_Retained_Messages_Supported_V3() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -253,7 +262,8 @@ public async Task Server_Reports_Retained_Messages_Supported_V3() [TestMethod] public async Task Server_Reports_Retained_Messages_Supported_V5() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Security_Tests.cs b/Source/MQTTnet.Tests/Server/Security_Tests.cs index 404314c43..5717dadc3 100644 --- a/Source/MQTTnet.Tests/Server/Security_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Security_Tests.cs @@ -18,7 +18,8 @@ public sealed class Security_Tests : BaseTestClass [TestMethod] public async Task Do_Not_Affect_Authorized_Clients() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -116,11 +117,12 @@ public Task Handle_Wrong_UserName_And_Password() [TestMethod] public async Task Use_Username_Null_Password_Empty() { - string username = null; - var password = string.Empty; - - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { + string username = null; + var password = string.Empty; + testEnvironment.IgnoreClientLogErrors = true; await testEnvironment.StartServer(); @@ -137,7 +139,8 @@ public async Task Use_Username_Null_Password_Empty() async Task TestCredentials(string userName, string password) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; diff --git a/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs b/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs index 2b4f96c1c..5b53f2c95 100644 --- a/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs @@ -16,7 +16,8 @@ public sealed class Server_Reference_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_With_Reference_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs index 91345da56..d626a2f23 100644 --- a/Source/MQTTnet.Tests/Server/Session_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs @@ -21,7 +21,8 @@ public sealed class Session_Tests : BaseTestClass [TestMethod] public async Task Clean_Session_Persistence() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -72,7 +73,8 @@ public async Task Clean_Session_Persistence() [TestMethod] public async Task Do_Not_Use_Expired_Session() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(o => o.WithPersistentSessions()); @@ -94,15 +96,16 @@ public async Task Do_Not_Use_Expired_Session() [TestMethod] public async Task Fire_Deleted_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Arrange client and server. var server = await testEnvironment.StartServer(o => o.WithPersistentSessions(false)); - var deletedEventFired = false; + var deletedEventFiredTaskSource = new TaskCompletionSource(); server.SessionDeletedAsync += e => { - deletedEventFired = true; + deletedEventFiredTaskSource.TrySetResult(true); return CompletedTask.Instance; }; @@ -111,7 +114,7 @@ public async Task Fire_Deleted_Event() // Act: Disconnect the client -> Event must be fired. await client.DisconnectAsync(); - await LongTestDelay(); + var deletedEventFired = await deletedEventFiredTaskSource.Task; // Assert that the event was fired properly. Assert.IsTrue(deletedEventFired); @@ -121,7 +124,8 @@ public async Task Fire_Deleted_Event() [TestMethod] public async Task Get_Session_Items_In_Status() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -149,7 +153,8 @@ public async Task Get_Session_Items_In_Status() [DataRow(MqttProtocolVersion.V500)] public async Task Handle_Parallel_Connection_Attempts(MqttProtocolVersion protocolVersion) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -157,10 +162,11 @@ public async Task Handle_Parallel_Connection_Attempts(MqttProtocolVersion protoc var options = new MqttClientOptionsBuilder().WithClientId("1").WithTimeout(TimeSpan.FromSeconds(10)).WithProtocolVersion(protocolVersion); - var hasReceive = false; + + var hasReceiveTaskSource = new TaskCompletionSource(); void OnReceive() { - hasReceive = true; + hasReceiveTaskSource.TrySetResult(true); } // Try to connect 50 clients at the same time. @@ -176,7 +182,7 @@ void OnReceive() var sendClient = await testEnvironment.ConnectClient(option2); await sendClient.PublishStringAsync("aaa", "1"); - await LongTestDelay(); + var hasReceive = await hasReceiveTaskSource.Task; Assert.AreEqual(true, hasReceive); } @@ -187,9 +193,10 @@ void OnReceive() [DataRow(MqttQualityOfServiceLevel.AtLeastOnce)] public async Task Retry_If_Not_PubAck(MqttQualityOfServiceLevel qos) { - long count = 0; - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { + long count = 0; await testEnvironment.StartServer(o => o.WithPersistentSessions()); var publisher = await testEnvironment.ConnectClient(); @@ -223,7 +230,8 @@ public async Task Retry_If_Not_PubAck(MqttQualityOfServiceLevel qos) [TestMethod] public async Task Session_Takeover() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -234,11 +242,11 @@ public async Task Session_Takeover() var client1 = await testEnvironment.ConnectClient(options); await Task.Delay(500); - var disconnectReason = MqttClientDisconnectReason.NormalDisconnection; + var disconnectReasonTaskSource = new TaskCompletionSource(); client1.DisconnectedAsync += c => { - disconnectReason = c.Reason; - return CompletedTask.Instance; + disconnectReasonTaskSource.TrySetResult(c.Reason); + return Task.CompletedTask; ; }; var client2 = await testEnvironment.ConnectClient(options); @@ -247,6 +255,7 @@ public async Task Session_Takeover() Assert.IsFalse(client1.IsConnected); Assert.IsTrue(client2.IsConnected); + var disconnectReason = await disconnectReasonTaskSource.Task; Assert.AreEqual(MqttClientDisconnectReason.SessionTakenOver, disconnectReason); } } @@ -254,16 +263,15 @@ public async Task Session_Takeover() [TestMethod] public async Task Set_Session_Item() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - server.ValidatingConnectionAsync += e => { // Don't validate anything. Just set some session items. e.SessionItems["can_subscribe_x"] = true; e.SessionItems["default_payload"] = "Hello World"; - return CompletedTask.Instance; }; @@ -311,7 +319,8 @@ public async Task Set_Session_Item() [TestMethod] public async Task Use_Clean_Session() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -331,7 +340,8 @@ public async Task Use_Clean_Session() [TestMethod] public async Task Will_Message_Do_Not_Send_On_Takeover() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -339,9 +349,11 @@ public async Task Will_Message_Do_Not_Send_On_Takeover() // C1 will receive the last will! var c1 = await testEnvironment.ConnectClient(); + var taskCompletionSource = new TaskCompletionSource(); c1.ApplicationMessageReceivedAsync += e => { Interlocked.Increment(ref receivedMessagesCount); + taskCompletionSource.TrySetResult(); return CompletedTask.Instance; }; @@ -355,7 +367,7 @@ public async Task Will_Message_Do_Not_Send_On_Takeover() // C3 will do the connection takeover. await testEnvironment.ConnectClient(clientOptions); - await Task.Delay(1000); + await taskCompletionSource.Task; Assert.AreEqual(0, receivedMessagesCount); } diff --git a/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs b/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs index c1a39289d..0dac081fc 100644 --- a/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs @@ -15,7 +15,8 @@ public sealed class Shared_Subscriptions_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Shared_Subscriptions_Not_Supported() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -31,7 +32,8 @@ public async Task Server_Reports_Shared_Subscriptions_Not_Supported() [TestMethod] public async Task Subscription_Of_Shared_Subscription_Is_Denied() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Status_Tests.cs b/Source/MQTTnet.Tests/Server/Status_Tests.cs index ef9419eb3..ec4fe6c3e 100644 --- a/Source/MQTTnet.Tests/Server/Status_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Status_Tests.cs @@ -18,7 +18,8 @@ public sealed class Status_Tests : BaseTestClass [TestMethod] public async Task Disconnect_Client() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -46,7 +47,8 @@ public async Task Disconnect_Client() [TestMethod] public async Task Keep_Persistent_Session_Version311() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(o => o.WithPersistentSessions()); @@ -80,7 +82,8 @@ public async Task Keep_Persistent_Session_Version311() [TestMethod] public async Task Keep_Persistent_Session_Version500() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(o => o.WithPersistentSessions()); @@ -116,7 +119,8 @@ public async Task Keep_Persistent_Session_Version500() [TestMethod] public async Task Show_Client_And_Session_Statistics() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -150,7 +154,8 @@ public async Task Show_Client_And_Session_Statistics() [TestMethod] public async Task Track_Sent_Application_Messages() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithPersistentSessions()); @@ -171,7 +176,8 @@ public async Task Track_Sent_Application_Messages() [TestMethod] public async Task Track_Sent_Packets() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithPersistentSessions()); diff --git a/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs b/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs index 9dd48c317..ff89c3afb 100644 --- a/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs @@ -39,7 +39,8 @@ public sealed class Subscribe_Tests : BaseTestClass [DataRow("A/B1/B2/C", "A/+/C", false)] public async Task Subscription_Roundtrip(string topic, string filter, bool shouldWork) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer().ConfigureAwait(false); @@ -66,7 +67,8 @@ public async Task Subscription_Roundtrip(string topic, string filter, bool shoul [TestMethod] public async Task Deny_Invalid_Topic() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -93,7 +95,8 @@ public async Task Deny_Invalid_Topic() [TestMethod] public async Task Intercept_Subscribe_With_User_Properties() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -117,7 +120,8 @@ public async Task Intercept_Subscribe_With_User_Properties() [ExpectedException(typeof(MqttClientDisconnectedException))] public async Task Disconnect_While_Subscribing() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -132,7 +136,8 @@ public async Task Disconnect_While_Subscribing() [TestMethod] public async Task Enqueue_Message_After_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -156,7 +161,8 @@ public async Task Enqueue_Message_After_Subscription() [TestMethod] public async Task Intercept_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -199,7 +205,8 @@ public async Task Intercept_Subscription() [TestMethod] public async Task Response_Contains_Equal_Reason_Codes() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -219,7 +226,8 @@ public async Task Response_Contains_Equal_Reason_Codes() [TestMethod] public async Task Subscribe_Lots_In_Multiple_Requests() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -262,7 +270,8 @@ public async Task Subscribe_Lots_In_Multiple_Requests() [TestMethod] public async Task Subscribe_Lots_In_Single_Request() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -302,7 +311,8 @@ public async Task Subscribe_Lots_In_Single_Request() [TestMethod] public async Task Subscribe_Multiple_In_Multiple_Request() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -340,7 +350,8 @@ public async Task Subscribe_Multiple_In_Multiple_Request() [TestMethod] public async Task Subscribe_Multiple_In_Single_Request() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -374,7 +385,8 @@ public async Task Subscribe_Multiple_In_Single_Request() [TestMethod] public async Task Subscribe_Unsubscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; diff --git a/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs b/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs index 24b370d15..fa1ee652e 100644 --- a/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs @@ -14,7 +14,8 @@ public sealed class Subscription_Identifier_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Subscription_Identifiers_Supported() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -30,7 +31,8 @@ public async Task Server_Reports_Subscription_Identifiers_Supported() [TestMethod] public async Task Subscribe_With_Subscription_Identifier() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -57,7 +59,8 @@ public async Task Subscribe_With_Subscription_Identifier() [TestMethod] public async Task Subscribe_With_Multiple_Subscription_Identifiers() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Tls_Tests.cs b/Source/MQTTnet.Tests/Server/Tls_Tests.cs index a87b0dc4c..57b9a46b1 100644 --- a/Source/MQTTnet.Tests/Server/Tls_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Tls_Tests.cs @@ -50,7 +50,7 @@ static X509Certificate2 CreateCertificate(string oid) [TestMethod] public async Task Tls_Swap_Test() { - var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500); + using var testEnvironment = new TestEnvironment(TestContext, MqttProtocolVersion.V500); var serverOptionsBuilder = testEnvironment.ServerFactory.CreateServerOptionsBuilder(); var firstOid = "1.3.6.1.5.5.7.3.1"; diff --git a/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs b/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs index 024fca16c..e104e58b2 100644 --- a/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs @@ -17,7 +17,8 @@ public sealed class Topic_Alias_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Topic_Alias_Supported() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -35,7 +36,8 @@ public async Task Server_Reports_Topic_Alias_Supported() [TestMethod] public async Task Publish_With_Topic_Alias() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs b/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs index 9f97ad88e..5fc7a3035 100644 --- a/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs @@ -10,6 +10,7 @@ using MQTTnet.Internal; using MQTTnet.Protocol; using MQTTnet.Server; +using MQTTnet.Tests.Mockups; namespace MQTTnet.Tests.Server { @@ -20,7 +21,8 @@ public sealed class Unsubscribe_Tests : BaseTestClass [ExpectedException(typeof(MqttClientDisconnectedException))] public async Task Disconnect_While_Unsubscribing() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -36,7 +38,8 @@ public async Task Disconnect_While_Unsubscribing() [TestMethod] public async Task Intercept_Unsubscribe_With_User_Properties() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs b/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs index 2a6f07137..84cd8e122 100644 --- a/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs +++ b/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs @@ -15,14 +15,13 @@ namespace MQTTnet.Tests.Server { [TestClass] - public class Feature_Tests + public class Feature_Tests : BaseTestClass { - public TestContext TestContext { get; set; } - [TestMethod] public async Task Use_User_Properties() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs b/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs index 85b224d63..fcdcea44c 100644 --- a/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs @@ -14,7 +14,8 @@ public sealed class Wildcard_Subscription_Available_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Wildcard_Subscription_Available_Tests_Supported_V3() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -30,7 +31,8 @@ public async Task Server_Reports_Wildcard_Subscription_Available_Tests_Supported [TestMethod] public async Task Server_Reports_Wildcard_Subscription_Available_Tests_Supported_V5() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Will_Tests.cs b/Source/MQTTnet.Tests/Server/Will_Tests.cs index 0e823bcde..b0f860036 100644 --- a/Source/MQTTnet.Tests/Server/Will_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Will_Tests.cs @@ -11,7 +11,8 @@ public sealed class Will_Tests : BaseTestClass [TestMethod] public async Task Intercept_Will_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer().ConfigureAwait(false); @@ -36,7 +37,8 @@ public async Task Intercept_Will_Message() [TestMethod] public async Task Will_Message_Do_Not_Send_On_Clean_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -59,7 +61,8 @@ public async Task Will_Message_Do_Not_Send_On_Clean_Disconnect() [TestMethod] public async Task Will_Message_Send() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); From 146b161e1f62af3fe92a81132bf48d16e8a8f21a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 6 Dec 2024 15:48:32 +0800 Subject: [PATCH 70/85] CreateTestEnvironment -> CreateMQTTnetTestEnvironment --- Source/MQTTnet.Tests/BaseTestClass.cs | 2 +- .../Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs | 2 +- Source/MQTTnet.Tests/Server/Events_Tests.cs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.Tests/BaseTestClass.cs b/Source/MQTTnet.Tests/BaseTestClass.cs index cd7447636..9441b09a0 100644 --- a/Source/MQTTnet.Tests/BaseTestClass.cs +++ b/Source/MQTTnet.Tests/BaseTestClass.cs @@ -14,7 +14,7 @@ public abstract class BaseTestClass { public TestContext TestContext { get; set; } - protected TestEnvironmentCollection CreateTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + protected TestEnvironmentCollection CreateMQTTnetTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) { var mqttnet = new TestEnvironment(TestContext, protocolVersion); return new TestEnvironmentCollection(mqttnet); diff --git a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs index 14d01fcf6..9e0125404 100644 --- a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs @@ -80,7 +80,7 @@ public async Task Connect_To_Wrong_Host() [TestMethod] public async Task Loose_Connection() { - using var testEnvironments = CreateTestEnvironment(); + using var testEnvironments = CreateMQTTnetTestEnvironment(); foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreServerLogErrors = true; diff --git a/Source/MQTTnet.Tests/Server/Events_Tests.cs b/Source/MQTTnet.Tests/Server/Events_Tests.cs index 5e726979a..0ea75cc63 100644 --- a/Source/MQTTnet.Tests/Server/Events_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Events_Tests.cs @@ -162,7 +162,7 @@ public async Task Fire_Application_Message_Received_Event() [TestMethod] public async Task Fire_Started_Event() { - using var testEnvironments = CreateTestEnvironment(); + using var testEnvironments = CreateMQTTnetTestEnvironment(); foreach (var testEnvironment in testEnvironments) { var server = testEnvironment.CreateServer(new MqttServerOptions()); From 83148ca61588485aff13d09afd8fc4aebd7736c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 6 Dec 2024 16:29:52 +0800 Subject: [PATCH 71/85] Remove Google's connection test to avoid Google being blocked by SNI in certain areas. --- Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs index 7476ffea8..5f09069b1 100644 --- a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs +++ b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs @@ -20,9 +20,9 @@ public class CrossPlatformSocket_Tests public async Task Connect_Send_Receive() { var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp); - await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.google.de", 80), CancellationToken.None); + await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.microsoft.com", 80), CancellationToken.None); - var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n"); + var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.microsoft.com\r\n\r\n"); await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), System.Net.Sockets.SocketFlags.None); var buffer = new byte[1024]; @@ -31,7 +31,7 @@ public async Task Connect_Send_Receive() var responseText = Encoding.UTF8.GetString(buffer, 0, length); - Assert.IsTrue(responseText.Contains("HTTP/1.1 200 OK")); + Assert.IsTrue(responseText.Contains("HTTP/1.1 200")|| responseText.Contains("HTTP/1.1 302")); } [TestMethod] From 41af5c57877240a4e65e5a16050bf9f2f940441d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 6 Dec 2024 17:57:24 +0800 Subject: [PATCH 72/85] AspNetCoreTestEnvironment: Adapt logger. --- .../Mockups/AspNetCoreTestEnvironment.cs | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs index 85d434023..1342a7943 100644 --- a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs @@ -36,14 +36,7 @@ public override IMqttClient CreateClient() { var services = new ServiceCollection(); var clientBuilder = services.AddMqttClient(); - if (EnableLogger) - { - clientBuilder.UseAspNetCoreMqttNetLogger(); - } - else - { - clientBuilder.UseMqttNetNullLogger(); - } + UseLogger(clientBuilder); var s = services.BuildServiceProvider(); var client = s.GetRequiredService().CreateMqttClient(); @@ -84,7 +77,8 @@ public override async Task StartServer(Action + + var serverBuilder = appBuilder.Services.AddMqttServer(optionsBuilder => { optionsBuilder.WithDefaultEndpoint(); optionsBuilder.WithDefaultEndpointPort(ServerPort); @@ -93,19 +87,20 @@ public override async Task StartServer(Action k.ListenMqtt()); appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); _app = appBuilder.Build(); - - // The OS has chosen the port to we have to properly expose it to the tests. - ServerPort = _app.Services.GetRequiredService().DefaultEndpointOptions.Port; - await _app.StartAsync(); + Server = _app.Services.GetRequiredService(); return Server; } @@ -126,17 +121,18 @@ public override async Task StartServer(MqttServerOptionsBuilder opti optionsBuilder.WithDefaultEndpointPort(ServerPort); optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); - var options = optionsBuilder.Build(); - var appBuilder = WebApplication.CreateBuilder(); - appBuilder.Services.AddMqttServer(); - appBuilder.Services.AddSingleton(options); + appBuilder.Services.AddSingleton(optionsBuilder.Build()); + var serverBuilder = appBuilder.Services.AddMqttServer(); + + UseLogger(serverBuilder); appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); _app = appBuilder.Build(); await _app.StartAsync(); + Server = _app.Services.GetRequiredService(); return Server; } @@ -164,5 +160,17 @@ private static int GetServerPort() } return port; } + + private void UseLogger(IMqttBuilder builder) + { + if (EnableLogger) + { + builder.UseAspNetCoreMqttNetLogger(); + } + else + { + builder.UseMqttNetNullLogger(); + } + } } } \ No newline at end of file From f94e17ad6054e5856efac3c239c4b7b19ecc5f12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 6 Dec 2024 18:27:35 +0800 Subject: [PATCH 73/85] wait with timeout. --- Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs | 7 ++++--- Source/MQTTnet.Tests/Server/Events_Tests.cs | 14 +++++++------- Source/MQTTnet.Tests/Server/Session_Tests.cs | 16 +++++++--------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs index b6edba296..bdf3cf43c 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs @@ -10,6 +10,7 @@ using MQTTnet.Internal; using MQTTnet.Protocol; using MQTTnet.Server; +using System; namespace MQTTnet.Tests.MQTTv5 { @@ -42,7 +43,7 @@ public async Task Will_Message_Send() var c2 = await testEnvironment.ConnectClient(clientOptions); c2.Dispose(); // Dispose will not send a DISCONNECT packet first so the will message must be sent. - await taskSource.Task; + await taskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.AreEqual(1, receivedMessagesCount); } @@ -210,14 +211,14 @@ public async Task Disconnect_with_Reason() // Test client should be connected now - var testClientId = await testClientIdTaskSource.Task; + var testClientId = await testClientIdTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsTrue(testClientId != null); // Have the server disconnect the client with AdministrativeAction reason await testEnvironment.Server.DisconnectClientAsync(testClientId, MqttDisconnectReasonCode.AdministrativeAction); - var disconnectReason = await disconnectReasonTaskSource.Task; + var disconnectReason = await disconnectReasonTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; // The reason should be returned to the client in the DISCONNECT packet Assert.AreEqual(MqttClientDisconnectReason.AdministrativeAction, disconnectReason); diff --git a/Source/MQTTnet.Tests/Server/Events_Tests.cs b/Source/MQTTnet.Tests/Server/Events_Tests.cs index 0ea75cc63..9e67e1f86 100644 --- a/Source/MQTTnet.Tests/Server/Events_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Events_Tests.cs @@ -33,7 +33,7 @@ public async Task Fire_Client_Connected_Event() await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); - var eventArgs = await eventArgsTaskSource.Task; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsNotNull(eventArgs); @@ -63,7 +63,7 @@ public async Task Fire_Client_Disconnected_Event() var client = await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); await client.DisconnectAsync(); - var eventArgs = await eventArgsTaskSource.Task; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsNotNull(eventArgs); @@ -92,7 +92,7 @@ public async Task Fire_Client_Subscribed_Event() var client = await testEnvironment.ConnectClient(); await client.SubscribeAsync("The/Topic", MqttQualityOfServiceLevel.AtLeastOnce); - var eventArgs = await eventArgsTaskSource.Task; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsNotNull(eventArgs); @@ -121,7 +121,7 @@ public async Task Fire_Client_Unsubscribed_Event() var client = await testEnvironment.ConnectClient(); await client.UnsubscribeAsync("The/Topic"); - var eventArgs = await eventArgsTaskSource.Task; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsNotNull(eventArgs); @@ -149,7 +149,7 @@ public async Task Fire_Application_Message_Received_Event() var client = await testEnvironment.ConnectClient(); await client.PublishStringAsync("The_Topic", "The_Payload"); - var eventArgs = await eventArgsTaskSource.Task; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsNotNull(eventArgs); @@ -176,7 +176,7 @@ public async Task Fire_Started_Event() await server.StartAsync(); - var eventArgs = await eventArgsTaskSource.Task; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsNotNull(eventArgs); } @@ -199,7 +199,7 @@ public async Task Fire_Stopped_Event() await server.StopAsync(); - var eventArgs = await eventArgsTaskSource.Task; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.IsNotNull(eventArgs); } diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs index d626a2f23..262f9abff 100644 --- a/Source/MQTTnet.Tests/Server/Session_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs @@ -114,7 +114,7 @@ public async Task Fire_Deleted_Event() // Act: Disconnect the client -> Event must be fired. await client.DisconnectAsync(); - var deletedEventFired = await deletedEventFiredTaskSource.Task; + var deletedEventFired = await deletedEventFiredTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; // Assert that the event was fired properly. Assert.IsTrue(deletedEventFired); @@ -182,7 +182,7 @@ void OnReceive() var sendClient = await testEnvironment.ConnectClient(option2); await sendClient.PublishStringAsync("aaa", "1"); - var hasReceive = await hasReceiveTaskSource.Task; + var hasReceive = await hasReceiveTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.AreEqual(true, hasReceive); } @@ -255,7 +255,7 @@ public async Task Session_Takeover() Assert.IsFalse(client1.IsConnected); Assert.IsTrue(client2.IsConnected); - var disconnectReason = await disconnectReasonTaskSource.Task; + var disconnectReason = await disconnectReasonTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; Assert.AreEqual(MqttClientDisconnectReason.SessionTakenOver, disconnectReason); } } @@ -318,7 +318,7 @@ public async Task Set_Session_Item() [TestMethod] public async Task Use_Clean_Session() - { + { using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { @@ -348,12 +348,10 @@ public async Task Will_Message_Do_Not_Send_On_Takeover() await testEnvironment.StartServer(); // C1 will receive the last will! - var c1 = await testEnvironment.ConnectClient(); - var taskCompletionSource = new TaskCompletionSource(); + var c1 = await testEnvironment.ConnectClient(); c1.ApplicationMessageReceivedAsync += e => { - Interlocked.Increment(ref receivedMessagesCount); - taskCompletionSource.TrySetResult(); + Interlocked.Increment(ref receivedMessagesCount); return CompletedTask.Instance; }; @@ -367,7 +365,7 @@ public async Task Will_Message_Do_Not_Send_On_Takeover() // C3 will do the connection takeover. await testEnvironment.ConnectClient(clientOptions); - await taskCompletionSource.Task; + await LongTestDelay(); Assert.AreEqual(0, receivedMessagesCount); } From efcd6ab11f9b292eff94ef5083653568a34c733c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 6 Dec 2024 21:19:22 +0800 Subject: [PATCH 74/85] AspNetCoreTestEnvironment: Adjust all configurations to be consistent with TestEnvironment. --- Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs | 20 ++- .../Mockups/AspNetCoreTestEnvironment.cs | 144 ++++++++---------- .../MQTTnet.Tests/Mockups/TestEnvironment.cs | 22 ++- Source/MQTTnet.Tests/Server/Events_Tests.cs | 14 +- .../MQTTnet.Tests/Server/Keep_Alive_Tests.cs | 21 ++- Source/MQTTnet.Tests/Server/Session_Tests.cs | 37 +++-- 6 files changed, 145 insertions(+), 113 deletions(-) diff --git a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs index bdf3cf43c..57792e43c 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs @@ -182,9 +182,21 @@ public async Task Reconnect_with_different_SessionExpiryInterval() } [TestMethod] - public async Task Disconnect_with_Reason() + public async Task Disconnect_with_Reason_MQTTnet() + { + using var testEnvironments = CreateMQTTnetTestEnvironment(); + await Disconnect_with_Reason(testEnvironments); + } + + [TestMethod] + public async Task Disconnect_with_Reason_AspNetCore() + { + using var testEnvironments = CreateAspNetCoreTestEnvironment(); + await Disconnect_with_Reason(testEnvironments); + } + + private async Task Disconnect_with_Reason(TestEnvironmentCollection testEnvironments) { - using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { var disconnectReasonTaskSource = new TaskCompletionSource(); @@ -211,14 +223,14 @@ public async Task Disconnect_with_Reason() // Test client should be connected now - var testClientId = await testClientIdTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var testClientId = await testClientIdTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsTrue(testClientId != null); // Have the server disconnect the client with AdministrativeAction reason await testEnvironment.Server.DisconnectClientAsync(testClientId, MqttDisconnectReasonCode.AdministrativeAction); - var disconnectReason = await disconnectReasonTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var disconnectReason = await disconnectReasonTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); // The reason should be returned to the client in the DISCONNECT packet Assert.AreEqual(MqttClientDisconnectReason.AdministrativeAction, disconnectReason); diff --git a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs index 1342a7943..ecc221701 100644 --- a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs @@ -10,6 +10,8 @@ using MQTTnet.AspNetCore; using MQTTnet.Formatter; using MQTTnet.Internal; +using MQTTnet.LowLevelClient; +using MQTTnet.Protocol; using MQTTnet.Server; using System; using System.Linq; @@ -32,36 +34,22 @@ public AspNetCoreTestEnvironment(TestContext testContext, MqttProtocolVersion pr { } - public override IMqttClient CreateClient() + protected override IMqttClient CreateClientCore() { - var services = new ServiceCollection(); - var clientBuilder = services.AddMqttClient(); - UseLogger(clientBuilder); - - var s = services.BuildServiceProvider(); - var client = s.GetRequiredService().CreateMqttClient(); - - client.ConnectingAsync += e => - { - if (TestContext != null) - { - var clientOptions = e.ClientOptions; - var existingClientId = clientOptions.ClientId; - if (existingClientId != null && !existingClientId.StartsWith(TestContext.TestName)) - { - clientOptions.ClientId = TestContext.TestName + "_" + existingClientId; - } - } - - return CompletedTask.Instance; - }; + return CreateClientFactory().CreateMqttClient(); + } - lock (_clients) - { - _clients.Add(client); - } + protected override ILowLevelMqttClient CreateLowLevelClientCore() + { + return CreateClientFactory().CreateLowLevelMqttClient(); + } - return client; + private IMqttClientFactory CreateClientFactory() + { + var services = new ServiceCollection(); + var clientBuilder = services.AddMqttClient(); + UseMqttLogger(clientBuilder, "[CLIENT]=>"); + return services.BuildServiceProvider().GetRequiredService(); } public override MqttServer CreateServer(MqttServerOptions options) @@ -69,63 +57,39 @@ public override MqttServer CreateServer(MqttServerOptions options) throw new NotSupportedException("Can not create MqttServer in AspNetCoreTestEnvironment."); } - public override async Task StartServer(Action configure) + public override Task StartServer(Action configure) { - if (Server != null) - { - throw new InvalidOperationException("Server already started."); - } - - var appBuilder = WebApplication.CreateBuilder(); - - var serverBuilder = appBuilder.Services.AddMqttServer(optionsBuilder => - { - optionsBuilder.WithDefaultEndpoint(); - optionsBuilder.WithDefaultEndpointPort(ServerPort); - optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); - }).ConfigureMqttServer(configure, o => - { - if (o.DefaultEndpointOptions.Port == 0) - { - var serverPort = GetServerPort(); - o.DefaultEndpointOptions.Port = serverPort; - ServerPort = serverPort; - } - }); - - UseLogger(serverBuilder); - - appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); - appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); - - _app = appBuilder.Build(); - await _app.StartAsync(); + var optionsBuilder = new MqttServerOptionsBuilder(); + configure?.Invoke(optionsBuilder); + return StartServer(optionsBuilder); + } - Server = _app.Services.GetRequiredService(); - return Server; + public override Task StartServer(MqttServerOptionsBuilder optionsBuilder) + { + optionsBuilder.WithDefaultEndpoint(); + optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); + var serverOptions = optionsBuilder.Build(); + return StartServer(serverOptions); } - public override async Task StartServer(MqttServerOptionsBuilder optionsBuilder) + private async Task StartServer(MqttServerOptions serverOptions) { if (Server != null) { throw new InvalidOperationException("Server already started."); } - if (ServerPort == 0) + if (serverOptions.DefaultEndpointOptions.Port == 0) { - ServerPort = GetServerPort(); + var serverPort = ServerPort > 0 ? ServerPort : GetServerPort(); + serverOptions.DefaultEndpointOptions.Port = serverPort; } - optionsBuilder.WithDefaultEndpoint(); - optionsBuilder.WithDefaultEndpointPort(ServerPort); - optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); - var appBuilder = WebApplication.CreateBuilder(); - appBuilder.Services.AddSingleton(optionsBuilder.Build()); - var serverBuilder = appBuilder.Services.AddMqttServer(); + appBuilder.Services.AddSingleton(serverOptions); - UseLogger(serverBuilder); + var serverBuilder = appBuilder.Services.AddMqttServer(); + UseMqttLogger(serverBuilder, "[SERVER]=>"); appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); @@ -134,20 +98,27 @@ public override async Task StartServer(MqttServerOptionsBuilder opti await _app.StartAsync(); Server = _app.Services.GetRequiredService(); - return Server; - } + ServerPort = serverOptions.DefaultEndpointOptions.Port; - public override void Dispose() - { - if (_app != null) + Server.ValidatingConnectionAsync += e => { - _app.StopAsync().ConfigureAwait(false).GetAwaiter().GetResult(); - _app.DisposeAsync().ConfigureAwait(false).GetAwaiter().GetResult(); - _app = null; - } - base.Dispose(); + if (TestContext != null) + { + // Null is used when the client id is assigned from the server! + if (!string.IsNullOrEmpty(e.ClientId) && !e.ClientId.StartsWith(TestContext.TestName)) + { + TrackException(new InvalidOperationException($"Invalid client ID used ({e.ClientId}). It must start with UnitTest name.")); + e.ReasonCode = MqttConnectReasonCode.ClientIdentifierNotValid; + } + } + + return CompletedTask.Instance; + }; + + return Server; } + private static int GetServerPort() { var listeners = IPGlobalProperties.GetIPGlobalProperties().GetActiveTcpListeners(); @@ -161,16 +132,27 @@ private static int GetServerPort() return port; } - private void UseLogger(IMqttBuilder builder) + private void UseMqttLogger(IMqttBuilder builder, string categoryNamePrefix) { if (EnableLogger) { - builder.UseAspNetCoreMqttNetLogger(); + builder.UseAspNetCoreMqttNetLogger(l => l.CategoryNamePrefix = categoryNamePrefix); } else { builder.UseMqttNetNullLogger(); } } + + public override void Dispose() + { + base.Dispose(); + if (_app != null) + { + _app.StopAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + _app.DisposeAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + _app = null; + } + } } } \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs index daedba802..a11ebda0e 100644 --- a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs @@ -23,7 +23,7 @@ public class TestEnvironment : IDisposable { bool _disposed = false; readonly List _clientErrors = new(); - protected readonly List _clients = new(); + readonly List _clients = new(); readonly List _exceptions = new(); readonly List _lowLevelClients = new(); readonly MqttProtocolVersion _protocolVersion; @@ -196,11 +196,9 @@ public TestApplicationMessageReceivedHandler CreateApplicationMessageHandler(IMq return new TestApplicationMessageReceivedHandler(mqttClient); } - public virtual IMqttClient CreateClient() + public IMqttClient CreateClient() { - var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; - - var client = ClientFactory.CreateMqttClient(logger); + var client = CreateClientCore(); client.ConnectingAsync += e => { @@ -225,6 +223,12 @@ public virtual IMqttClient CreateClient() return client; } + protected virtual IMqttClient CreateClientCore() + { + var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; + return ClientFactory.CreateMqttClient(logger); + } + public MqttClientOptions CreateDefaultClientOptions() { return CreateDefaultClientOptionsBuilder().Build(); @@ -240,7 +244,7 @@ public MqttClientOptionsBuilder CreateDefaultClientOptionsBuilder() public ILowLevelMqttClient CreateLowLevelClient() { - var client = ClientFactory.CreateLowLevelMqttClient(ClientLogger); + var client = CreateLowLevelClientCore(); lock (_lowLevelClients) { @@ -250,6 +254,12 @@ public ILowLevelMqttClient CreateLowLevelClient() return client; } + protected virtual ILowLevelMqttClient CreateLowLevelClientCore() + { + return ClientFactory.CreateLowLevelMqttClient(ClientLogger); + } + + public virtual MqttServer CreateServer(MqttServerOptions options) { if (Server != null) diff --git a/Source/MQTTnet.Tests/Server/Events_Tests.cs b/Source/MQTTnet.Tests/Server/Events_Tests.cs index 9e67e1f86..383b6e581 100644 --- a/Source/MQTTnet.Tests/Server/Events_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Events_Tests.cs @@ -33,7 +33,7 @@ public async Task Fire_Client_Connected_Event() await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); - var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -63,7 +63,7 @@ public async Task Fire_Client_Disconnected_Event() var client = await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); await client.DisconnectAsync(); - var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -92,7 +92,7 @@ public async Task Fire_Client_Subscribed_Event() var client = await testEnvironment.ConnectClient(); await client.SubscribeAsync("The/Topic", MqttQualityOfServiceLevel.AtLeastOnce); - var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -121,7 +121,7 @@ public async Task Fire_Client_Unsubscribed_Event() var client = await testEnvironment.ConnectClient(); await client.UnsubscribeAsync("The/Topic"); - var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -149,7 +149,7 @@ public async Task Fire_Application_Message_Received_Event() var client = await testEnvironment.ConnectClient(); await client.PublishStringAsync("The_Topic", "The_Payload"); - var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -176,7 +176,7 @@ public async Task Fire_Started_Event() await server.StartAsync(); - var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); } @@ -199,7 +199,7 @@ public async Task Fire_Stopped_Event() await server.StopAsync(); - var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); } diff --git a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs index a0a1216fc..e0d3d82b8 100644 --- a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs @@ -9,6 +9,7 @@ using MQTTnet.Formatter; using MQTTnet.Packets; using MQTTnet.Protocol; +using MQTTnet.Tests.Mockups; namespace MQTTnet.Tests.Server { @@ -16,9 +17,21 @@ namespace MQTTnet.Tests.Server public sealed class KeepAlive_Tests : BaseTestClass { [TestMethod] - public async Task Disconnect_Client_DueTo_KeepAlive() + public async Task Disconnect_Client_DueTo_KeepAlive_MQTTnet() + { + using var testEnvironments = CreateMQTTnetTestEnvironment(); + await Disconnect_Client_DueTo_KeepAlive(testEnvironments); + } + + [TestMethod] + public async Task Disconnect_Client_DueTo_KeepAlive_AspNetCore() + { + using var testEnvironments = CreateAspNetCoreTestEnvironment(); + await Disconnect_Client_DueTo_KeepAlive(testEnvironments); + } + + private async Task Disconnect_Client_DueTo_KeepAlive(TestEnvironmentCollection testEnvironments) { - using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -41,12 +54,12 @@ await client.SendAsync(new MqttConnectPacket for (var i = 0; i < 6; i++) { await Task.Delay(500); - + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None); responsePacket = await client.ReceiveAsync(CancellationToken.None); Assert.IsTrue(responsePacket is MqttPingRespPacket); } - + // If we reach this point everything works as expected (server did not close the connection // due to proper ping messages. // Now we will wait 1.1 seconds because the server MUST wait 1.5 seconds in total (See spec). diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs index 262f9abff..9733f8a34 100644 --- a/Source/MQTTnet.Tests/Server/Session_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs @@ -114,7 +114,7 @@ public async Task Fire_Deleted_Event() // Act: Disconnect the client -> Event must be fired. await client.DisconnectAsync(); - var deletedEventFired = await deletedEventFiredTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var deletedEventFired = await deletedEventFiredTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); // Assert that the event was fired properly. Assert.IsTrue(deletedEventFired); @@ -182,7 +182,7 @@ void OnReceive() var sendClient = await testEnvironment.ConnectClient(option2); await sendClient.PublishStringAsync("aaa", "1"); - var hasReceive = await hasReceiveTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + var hasReceive = await hasReceiveTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.AreEqual(true, hasReceive); } @@ -228,9 +228,22 @@ public async Task Retry_If_Not_PubAck(MqttQualityOfServiceLevel qos) } [TestMethod] - public async Task Session_Takeover() + public async Task Session_Takeover_MQTTnet() + { + using var testEnvironments = CreateMQTTnetTestEnvironment(); + await Session_Takeover(testEnvironments); + } + + [TestMethod] + public async Task Session_Takeover_AspNetCore() + { + using var testEnvironments = CreateAspNetCoreTestEnvironment(); + await Session_Takeover(testEnvironments); + } + + + private async Task Session_Takeover(TestEnvironmentCollection testEnvironments) { - using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -242,11 +255,13 @@ public async Task Session_Takeover() var client1 = await testEnvironment.ConnectClient(options); await Task.Delay(500); - var disconnectReasonTaskSource = new TaskCompletionSource(); + var disconnectReason = MqttClientDisconnectReason.NormalDisconnection; + var disconnectTaskSource = new TaskCompletionSource(); client1.DisconnectedAsync += c => { - disconnectReasonTaskSource.TrySetResult(c.Reason); - return Task.CompletedTask; ; + disconnectReason = c.Reason; + disconnectTaskSource.TrySetResult(); + return CompletedTask.Instance; }; var client2 = await testEnvironment.ConnectClient(options); @@ -255,7 +270,7 @@ public async Task Session_Takeover() Assert.IsFalse(client1.IsConnected); Assert.IsTrue(client2.IsConnected); - var disconnectReason = await disconnectReasonTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); ; + await disconnectTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.AreEqual(MqttClientDisconnectReason.SessionTakenOver, disconnectReason); } } @@ -318,7 +333,7 @@ public async Task Set_Session_Item() [TestMethod] public async Task Use_Clean_Session() - { + { using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { @@ -348,10 +363,10 @@ public async Task Will_Message_Do_Not_Send_On_Takeover() await testEnvironment.StartServer(); // C1 will receive the last will! - var c1 = await testEnvironment.ConnectClient(); + var c1 = await testEnvironment.ConnectClient(); c1.ApplicationMessageReceivedAsync += e => { - Interlocked.Increment(ref receivedMessagesCount); + Interlocked.Increment(ref receivedMessagesCount); return CompletedTask.Instance; }; From b0356594ce644c577c0fda483f1ba4d77b94f37d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 7 Dec 2024 00:06:08 +0800 Subject: [PATCH 75/85] MqttChannel: Exception handling remains consistent with MqttChannelAdapter. --- .../Internal/MqttChannel.cs | 103 +++++++++++++++++- .../Internal/MqttClientChannelAdapter.cs | 22 +++- .../ASP/MqttConnectionContextTest.cs | 2 +- .../LowLevelMqttClient_Tests.cs | 2 +- Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs | 16 +-- .../Mockups/AspNetCoreTestEnvironment.cs | 33 +++--- .../MQTTnet.Tests/Server/Keep_Alive_Tests.cs | 16 +-- Source/MQTTnet.Tests/Server/Session_Tests.cs | 17 +-- 8 files changed, 137 insertions(+), 74 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index 7b93d5430..56bdc3163 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -12,8 +12,11 @@ using MQTTnet.Packets; using System; using System.Buffers; +using System.IO; using System.IO.Pipelines; using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -106,8 +109,18 @@ private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, IT public async Task DisconnectAsync() { - await _input.CompleteAsync().ConfigureAwait(false); - await _output.CompleteAsync().ConfigureAwait(false); + try + { + await _input.CompleteAsync().ConfigureAwait(false); + await _output.CompleteAsync().ConfigureAwait(false); + } + catch (Exception exception) + { + if (!WrapAndThrowException(exception)) + { + throw; + } + } } public virtual void Dispose() @@ -116,6 +129,29 @@ public virtual void Dispose() } public async Task ReceivePacketAsync(CancellationToken cancellationToken) + { + try + { + return await ReceivePacketCoreAsync(cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + } + catch (ObjectDisposedException) + { + } + catch (Exception exception) + { + if (!WrapAndThrowException(exception)) + { + throw; + } + } + + return null; + } + + private async Task ReceivePacketCoreAsync(CancellationToken cancellationToken) { try { @@ -143,7 +179,7 @@ public virtual void Dispose() { if (!buffer.IsEmpty) { - if (PacketFormatterAdapter.TryDecode(buffer,_packetInspector, out var packet, out consumed, out observed, out var received)) + if (PacketFormatterAdapter.TryDecode(buffer, _packetInspector, out var packet, out consumed, out observed, out var received)) { BytesReceived += received; @@ -168,11 +204,11 @@ public virtual void Dispose() } } } - catch (Exception exception) + catch (Exception) { // completing the channel makes sure that there is no more data read after a protocol error - _input.Complete(exception); - _output.Complete(exception); + await _input.CompleteAsync().ConfigureAwait(false); + await _output.CompleteAsync().ConfigureAwait(false); throw; } @@ -188,6 +224,21 @@ public void ResetStatistics() } public async Task SendPacketAsync(MqttPacket packet, CancellationToken cancellationToken) + { + try + { + await SendPacketCoreAsync(packet, cancellationToken).ConfigureAwait(false); + } + catch (Exception exception) + { + if (!WrapAndThrowException(exception)) + { + throw; + } + } + } + + private async Task SendPacketCoreAsync(MqttPacket packet, CancellationToken cancellationToken) { using (await _writerLock.EnterAsync(cancellationToken).ConfigureAwait(false)) { @@ -241,4 +292,44 @@ static void WritePacketBuffer(PipeWriter output, MqttPacketBuffer buffer) buffer.Payload.CopyTo(destination: span.Slice(offset)); output.Advance(buffer.Length); } + + public static bool WrapAndThrowException(Exception exception) + { + if (exception is OperationCanceledException || + exception is MqttCommunicationTimedOutException || + exception is MqttCommunicationException || + exception is MqttProtocolViolationException) + { + return false; + } + + if (exception is IOException && exception.InnerException is SocketException innerException) + { + exception = innerException; + } + + if (exception is SocketException socketException) + { + if (socketException.SocketErrorCode == SocketError.OperationAborted) + { + throw new OperationCanceledException(); + } + + if (socketException.SocketErrorCode == SocketError.ConnectionAborted) + { + throw new MqttCommunicationException(socketException); + } + } + + if (exception is COMException comException) + { + const uint ErrorOperationAborted = 0x800703E3; + if ((uint)comException.HResult == ErrorOperationAborted) + { + throw new OperationCanceledException(); + } + } + + throw new MqttCommunicationException(exception); + } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index 671043d8f..bc382d276 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -52,13 +52,23 @@ public MqttClientChannelAdapter( public async Task ConnectAsync(CancellationToken cancellationToken) { - _connection = _channelOptions switch + try { - MqttClientTcpOptions tcpOptions => await ClientConnectionContext.CreateAsync(tcpOptions, cancellationToken).ConfigureAwait(false), - MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false), - _ => throw new NotSupportedException(), - }; - _channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector, _allowPacketFragmentation); + _connection = _channelOptions switch + { + MqttClientTcpOptions tcpOptions => await ClientConnectionContext.CreateAsync(tcpOptions, cancellationToken).ConfigureAwait(false), + MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false), + _ => throw new NotSupportedException(), + }; + _channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector, _allowPacketFragmentation); + } + catch (Exception ex) + { + if (!MqttChannel.WrapAndThrowException(ex)) + { + throw; + } + } } public Task DisconnectAsync(CancellationToken cancellationToken) diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs index 77d3c36cd..8d90eaa24 100644 --- a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs @@ -37,7 +37,7 @@ public async Task TestCorruptedConnectPacket() await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(CancellationToken.None)); // the first exception should complete the pipes so if someone tries to use the connection after that it should throw immidiatly - await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(CancellationToken.None)); + await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(CancellationToken.None)); } // TODO: Fix test diff --git a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs index 9e0125404..86d0477bc 100644 --- a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs @@ -80,7 +80,7 @@ public async Task Connect_To_Wrong_Host() [TestMethod] public async Task Loose_Connection() { - using var testEnvironments = CreateMQTTnetTestEnvironment(); + using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreServerLogErrors = true; diff --git a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs index 57792e43c..6a2849cb3 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs @@ -182,21 +182,9 @@ public async Task Reconnect_with_different_SessionExpiryInterval() } [TestMethod] - public async Task Disconnect_with_Reason_MQTTnet() - { - using var testEnvironments = CreateMQTTnetTestEnvironment(); - await Disconnect_with_Reason(testEnvironments); - } - - [TestMethod] - public async Task Disconnect_with_Reason_AspNetCore() - { - using var testEnvironments = CreateAspNetCoreTestEnvironment(); - await Disconnect_with_Reason(testEnvironments); - } - - private async Task Disconnect_with_Reason(TestEnvironmentCollection testEnvironments) + public async Task Disconnect_with_Reason() { + using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { var disconnectReasonTaskSource = new TaskCompletionSource(); diff --git a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs index ecc221701..09801a20d 100644 --- a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs @@ -8,6 +8,7 @@ using Microsoft.Extensions.Hosting; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.AspNetCore; +using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; using MQTTnet.Internal; using MQTTnet.LowLevelClient; @@ -47,8 +48,11 @@ protected override ILowLevelMqttClient CreateLowLevelClientCore() private IMqttClientFactory CreateClientFactory() { var services = new ServiceCollection(); - var clientBuilder = services.AddMqttClient(); - UseMqttLogger(clientBuilder, "[CLIENT]=>"); + + var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; + services.AddSingleton(logger); + services.AddMqttClient(); + return services.BuildServiceProvider().GetRequiredService(); } @@ -67,6 +71,7 @@ public override Task StartServer(Action co public override Task StartServer(MqttServerOptionsBuilder optionsBuilder) { optionsBuilder.WithDefaultEndpoint(); + optionsBuilder.WithDefaultEndpointPort(ServerPort); optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); var serverOptions = optionsBuilder.Build(); return StartServer(serverOptions); @@ -88,14 +93,14 @@ private async Task StartServer(MqttServerOptions serverOptions) var appBuilder = WebApplication.CreateBuilder(); appBuilder.Services.AddSingleton(serverOptions); - var serverBuilder = appBuilder.Services.AddMqttServer(); - UseMqttLogger(serverBuilder, "[SERVER]=>"); + var logger = EnableLogger ? (IMqttNetLogger)ServerLogger : new MqttNetNullLogger(); + appBuilder.Services.AddSingleton(logger); + appBuilder.Services.AddMqttServer(); appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); _app = appBuilder.Build(); - await _app.StartAsync(); Server = _app.Services.GetRequiredService(); ServerPort = serverOptions.DefaultEndpointOptions.Port; @@ -115,6 +120,12 @@ private async Task StartServer(MqttServerOptions serverOptions) return CompletedTask.Instance; }; + var appStartedSource = new TaskCompletionSource(); + _app.Lifetime.ApplicationStarted.Register(() => appStartedSource.TrySetResult()); + + await _app.StartAsync(); + await appStartedSource.Task; + return Server; } @@ -132,18 +143,6 @@ private static int GetServerPort() return port; } - private void UseMqttLogger(IMqttBuilder builder, string categoryNamePrefix) - { - if (EnableLogger) - { - builder.UseAspNetCoreMqttNetLogger(l => l.CategoryNamePrefix = categoryNamePrefix); - } - else - { - builder.UseMqttNetNullLogger(); - } - } - public override void Dispose() { base.Dispose(); diff --git a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs index e0d3d82b8..43b781bbf 100644 --- a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs @@ -17,21 +17,9 @@ namespace MQTTnet.Tests.Server public sealed class KeepAlive_Tests : BaseTestClass { [TestMethod] - public async Task Disconnect_Client_DueTo_KeepAlive_MQTTnet() - { - using var testEnvironments = CreateMQTTnetTestEnvironment(); - await Disconnect_Client_DueTo_KeepAlive(testEnvironments); - } - - [TestMethod] - public async Task Disconnect_Client_DueTo_KeepAlive_AspNetCore() - { - using var testEnvironments = CreateAspNetCoreTestEnvironment(); - await Disconnect_Client_DueTo_KeepAlive(testEnvironments); - } - - private async Task Disconnect_Client_DueTo_KeepAlive(TestEnvironmentCollection testEnvironments) + public async Task Disconnect_Client_DueTo_KeepAlive() { + using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs index 9733f8a34..ab28c896f 100644 --- a/Source/MQTTnet.Tests/Server/Session_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs @@ -228,22 +228,9 @@ public async Task Retry_If_Not_PubAck(MqttQualityOfServiceLevel qos) } [TestMethod] - public async Task Session_Takeover_MQTTnet() - { - using var testEnvironments = CreateMQTTnetTestEnvironment(); - await Session_Takeover(testEnvironments); - } - - [TestMethod] - public async Task Session_Takeover_AspNetCore() - { - using var testEnvironments = CreateAspNetCoreTestEnvironment(); - await Session_Takeover(testEnvironments); - } - - - private async Task Session_Takeover(TestEnvironmentCollection testEnvironments) + public async Task Session_Takeover() { + using var testEnvironments = CreateMixedTestEnvironment(); foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); From 6171c818b1d0d61dfddac217d09fdd7ef5f4d2f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sat, 7 Dec 2024 00:40:22 +0800 Subject: [PATCH 76/85] CrossPlatformSocket_Tests: create a localhost web server for remote http server. --- .../Internal/CrossPlatformSocket_Tests.cs | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs index 5f09069b1..765429e6b 100644 --- a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs +++ b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs @@ -2,14 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Implementations; using System; +using System.Linq; using System.Net; +using System.Net.NetworkInformation; using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Implementations; namespace MQTTnet.Tests.Internal { @@ -19,19 +24,47 @@ public class CrossPlatformSocket_Tests [TestMethod] public async Task Connect_Send_Receive() { + var serverPort = GetServerPort(); + var responseContent = "Connect_Send_Receive"; + + // create a localhost web server. + var builder = WebApplication.CreateSlimBuilder(); + builder.WebHost.UseKestrel(k => k.ListenLocalhost(serverPort)); + + await using var webApp = builder.Build(); + var webAppStartedSource = new TaskCompletionSource(); + webApp.Lifetime.ApplicationStarted.Register(() => webAppStartedSource.TrySetResult()); + webApp.Use(next => context => context.Response.WriteAsync(responseContent)); + await webApp.StartAsync(); + await webAppStartedSource.Task; + + var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp); - await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.microsoft.com", 80), CancellationToken.None); + await crossPlatformSocket.ConnectAsync(new DnsEndPoint("localhost", serverPort), CancellationToken.None); - var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.microsoft.com\r\n\r\n"); - await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), System.Net.Sockets.SocketFlags.None); + var requestBuffer = Encoding.UTF8.GetBytes($"GET /test/path HTTP/1.1\r\nHost: localhost:{serverPort}\r\n\r\n"); + await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), SocketFlags.None); var buffer = new byte[1024]; - var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment(buffer), System.Net.Sockets.SocketFlags.None); + var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment(buffer), SocketFlags.None); crossPlatformSocket.Dispose(); var responseText = Encoding.UTF8.GetString(buffer, 0, length); - Assert.IsTrue(responseText.Contains("HTTP/1.1 200")|| responseText.Contains("HTTP/1.1 302")); + Assert.IsTrue(responseText.Contains(responseContent)); + + + static int GetServerPort(int defaultPort = 9999) + { + var listeners = IPGlobalProperties.GetIPGlobalProperties().GetActiveTcpListeners(); + var portSet = listeners.Select(i => i.Port).ToHashSet(); + + while (!portSet.Add(defaultPort)) + { + defaultPort += 1; + } + return defaultPort; + } } [TestMethod] From 9a7a8bd5fe910cb723a80098b22fd3e2d88fbeb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sun, 8 Dec 2024 04:46:35 +0800 Subject: [PATCH 77/85] Server-side adaptation of AllowPacketFragmentation options. --- .../ConnectionBuilderExtensions.cs | 7 ++- .../Features/PacketFragmentationFeature.cs | 14 +++++ .../Features/TlsConnectionFeature.cs | 28 +++++++++ .../Features/WebSocketConnectionFeature.cs | 14 +++++ .../AspNetCoreMqttClientAdapterFactory.cs | 2 +- .../Internal/ClientConnectionContext.Tcp.cs | 8 ++- .../ClientConnectionContext.WebSocket.cs | 3 +- .../Internal/ClientConnectionContext.cs | 25 -------- .../Internal/MqttChannel.cs | 63 +++++++++---------- .../Internal/MqttClientChannelAdapter.cs | 13 ++-- .../Internal/MqttConnectionHandler.cs | 12 +++- .../Internal/MqttConnectionMiddleware.cs | 12 +++- .../Internal/MqttServerChannelAdapter.cs | 27 +++++++- .../KestrelServerOptionsExtensions.cs | 15 ++++- .../ASP/Mockups/ConnectionHandlerMockup.cs | 3 +- .../ASP/MqttConnectionContextTest.cs | 7 ++- 16 files changed, 172 insertions(+), 81 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs create mode 100644 Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs create mode 100644 Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index 0dd7dc28b..f45380831 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -4,7 +4,9 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Adapter; using MQTTnet.Server; +using System; namespace MQTTnet.AspNetCore { @@ -15,15 +17,16 @@ public static class ConnectionBuilderExtensions /// /// /// + /// /// - public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket) + public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket, Func? allowPacketFragmentationSelector = null) { // check services.AddMqttServer() builder.ApplicationServices.GetRequiredService(); builder.ApplicationServices.GetRequiredService().UseFlag = true; var middleware = builder.ApplicationServices.GetRequiredService(); - return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols)); + return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols, allowPacketFragmentationSelector)); } } } diff --git a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs new file mode 100644 index 000000000..caf3451f1 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Adapter; +using System; + +namespace MQTTnet.AspNetCore +{ + sealed class PacketFragmentationFeature(Func allowPacketFragmentationSelector) + { + public Func AllowPacketFragmentationSelector { get; } = allowPacketFragmentationSelector; + } +} diff --git a/Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs b/Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs new file mode 100644 index 000000000..5fd7fd6e3 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http.Features; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + sealed class TlsConnectionFeature : ITlsConnectionFeature + { + public static readonly TlsConnectionFeature WithoutClientCertificate = new(null); + + public X509Certificate2? ClientCertificate { get; set; } + + public Task GetClientCertificateAsync(CancellationToken cancellationToken) + { + return Task.FromResult(ClientCertificate); + } + + public TlsConnectionFeature(X509Certificate? clientCertificate) + { + ClientCertificate = clientCertificate as X509Certificate2; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs b/Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs new file mode 100644 index 000000000..872a76515 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace MQTTnet.AspNetCore +{ + sealed class WebSocketConnectionFeature(string path) + { + /// + /// The path of WebSocket request. + /// + public string Path { get; } = path; + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs index 1cd50bec3..f739de225 100644 --- a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -16,7 +16,7 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa ArgumentNullException.ThrowIfNull(nameof(options)); var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax); var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter); - return new MqttClientChannelAdapter(formatter, options.ChannelOptions, packetInspector, options.AllowPacketFragmentation); + return new MqttClientChannelAdapter(formatter, options.ChannelOptions, options.AllowPacketFragmentation, packetInspector); } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs index 4c03d0c69..fe3caf6fa 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using System; using System.Net; @@ -72,7 +71,11 @@ public static async Task CreateAsync(MqttClientTcpOptio var networkStream = new NetworkStream(socket, ownsSocket: true); if (options.TlsOptions?.UseTls != true) { - return new ClientConnectionContext(networkStream); + return new ClientConnectionContext(networkStream) + { + LocalEndPoint = socket.LocalEndPoint, + RemoteEndPoint = socket.RemoteEndPoint, + }; } var targetHost = options.TlsOptions.TargetHost; @@ -143,7 +146,6 @@ public static async Task CreateAsync(MqttClientTcpOptio RemoteEndPoint = socket.RemoteEndPoint, }; - connection.Features.Set(new ConnectionSocketFeature(socket)); connection.Features.Set(new TlsConnectionFeature(sslStream.LocalCertificate)); return connection; diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs index 4ac414ebd..4965278fc 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -48,9 +48,10 @@ public static async Task CreateAsync(MqttClientWebSocke RemoteEndPoint = new DnsEndPoint(uri.Host, uri.Port), }; + connection.Features.Set(new WebSocketConnectionFeature(uri.AbsolutePath)); if (uri.Scheme == Uri.UriSchemeWss) { - connection.Features.Set(TlsConnectionFeature.Default); + connection.Features.Set(TlsConnectionFeature.WithoutClientCertificate); } return connection; } diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs index 1e6b63d8b..e9ef720d6 100644 --- a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -3,14 +3,11 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using System; using System.Collections.Generic; using System.IO; using System.IO.Pipelines; -using System.Net.Sockets; -using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; @@ -67,27 +64,5 @@ private class StreamTransport(Stream stream) : IDuplexPipe public PipeWriter Output { get; } = PipeWriter.Create(stream, new StreamPipeWriterOptions(leaveOpen: true)); } - - private class TlsConnectionFeature : ITlsConnectionFeature - { - public static readonly TlsConnectionFeature Default = new(null); - - public X509Certificate2? ClientCertificate { get; set; } - - public Task GetClientCertificateAsync(CancellationToken cancellationToken) - { - return Task.FromResult(ClientCertificate); - } - - public TlsConnectionFeature(X509Certificate? clientCertificate) - { - ClientCertificate = clientCertificate as X509Certificate2; - } - } - - private class ConnectionSocketFeature(Socket socket) : IConnectionSocketFeature - { - public Socket Socket { get; } = socket; - } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index 56bdc3163..11a04222b 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -3,7 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Http.Connections.Features; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using MQTTnet.Adapter; using MQTTnet.Exceptions; @@ -29,7 +29,7 @@ class MqttChannel : IDisposable readonly PipeReader _input; readonly PipeWriter _output; readonly MqttPacketInspector? _packetInspector; - readonly bool _allowPacketFragmentation; + bool _allowPacketFragmentation = false; public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } @@ -43,45 +43,38 @@ class MqttChannel : IDisposable public bool IsSecureConnection { get; } + public bool IsWebSocketConnection { get; } + public MqttChannel( MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, - MqttPacketInspector? packetInspector = null, - bool? allowPacketFragmentation = null) + HttpContext? httpContext, + MqttPacketInspector? packetInspector) { PacketFormatterAdapter = packetFormatterAdapter; - _packetInspector = packetInspector; - var httpContextFeature = connection.Features.Get(); var tlsConnectionFeature = connection.Features.Get(); - RemoteEndPoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint); - IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature); - ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature); + RemoteEndPoint = GetRemoteEndPoint(connection.RemoteEndPoint, httpContext); + ClientCertificate = GetClientCertificate(tlsConnectionFeature, httpContext); + IsSecureConnection = IsTlsConnection(tlsConnectionFeature, httpContext); + IsWebSocketConnection = connection.Features.Get() != null; + _packetInspector = packetInspector; _input = connection.Transport.Input; _output = connection.Transport.Output; - - _allowPacketFragmentation = allowPacketFragmentation == null - ? AllowPacketFragmentation(httpContextFeature) - : allowPacketFragmentation.Value; } - private static bool AllowPacketFragmentation(IHttpContextFeature? _httpContextFeature) + private static EndPoint? GetRemoteEndPoint(EndPoint? remoteEndPoint, HttpContext? httpContext) { - var serverModeWebSocket = _httpContextFeature != null && - _httpContextFeature.HttpContext != null && - _httpContextFeature.HttpContext.WebSockets.IsWebSocketRequest; - - return !serverModeWebSocket; - } - + if (remoteEndPoint != null) + { + return remoteEndPoint; + } - private static EndPoint? GetRemoteEndPoint(IHttpContextFeature? _httpContextFeature, EndPoint? remoteEndPoint) - { - if (_httpContextFeature != null && _httpContextFeature.HttpContext != null) + if (httpContext != null) { - var httpConnection = _httpContextFeature.HttpContext.Connection; + var httpConnection = httpContext.Connection; var remoteAddress = httpConnection.RemoteIpAddress; if (remoteAddress != null) { @@ -89,23 +82,25 @@ private static bool AllowPacketFragmentation(IHttpContextFeature? _httpContextFe } } - return remoteEndPoint; + return null; } - private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature) + private static bool IsTlsConnection(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext) { - return _httpContextFeature != null && _httpContextFeature.HttpContext != null - ? _httpContextFeature.HttpContext.Request.IsHttps - : tlsConnectionFeature != null; + return tlsConnectionFeature != null || (httpContext != null && httpContext.Request.IsHttps); } - private static X509Certificate2? GetClientCertificate(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature) + private static X509Certificate2? GetClientCertificate(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext) { - return _httpContextFeature != null && _httpContextFeature.HttpContext != null - ? _httpContextFeature.HttpContext.Connection.ClientCertificate - : tlsConnectionFeature?.ClientCertificate; + return tlsConnectionFeature != null + ? tlsConnectionFeature.ClientCertificate + : httpContext?.Connection.ClientCertificate; } + public void SetAllowPacketFragmentation(bool value) + { + _allowPacketFragmentation = value; + } public async Task DisconnectAsync() { diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index bc382d276..832df3f0b 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -22,19 +22,19 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable private MqttChannel? _channel; private readonly MqttPacketFormatterAdapter _packetFormatterAdapter; private readonly IMqttClientChannelOptions _channelOptions; + private readonly bool _allowPacketFragmentation; private readonly MqttPacketInspector? _packetInspector; - private readonly bool? _allowPacketFragmentation; public MqttClientChannelAdapter( MqttPacketFormatterAdapter packetFormatterAdapter, IMqttClientChannelOptions channelOptions, - MqttPacketInspector? packetInspector, - bool? allowPacketFragmentation) + bool allowPacketFragmentation, + MqttPacketInspector? packetInspector) { _packetFormatterAdapter = packetFormatterAdapter; _channelOptions = channelOptions; - _packetInspector = packetInspector; _allowPacketFragmentation = allowPacketFragmentation; + _packetInspector = packetInspector; } public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter; @@ -49,6 +49,8 @@ public MqttClientChannelAdapter( public bool IsSecureConnection => GetChannel().IsSecureConnection; + public bool IsWebSocketConnection => GetChannel().IsSecureConnection; + public async Task ConnectAsync(CancellationToken cancellationToken) { @@ -60,7 +62,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken) MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false), _ => throw new NotSupportedException(), }; - _channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector, _allowPacketFragmentation); + _channel = new MqttChannel(_packetFormatterAdapter, _connection, httpContext: null, _packetInspector); + _channel.SetAllowPacketFragmentation(_allowPacketFragmentation); } catch (Exception ex) { diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs index 3f0033cb8..c5675b18a 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Connections; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Formatter; @@ -51,12 +52,19 @@ public override async Task OnConnectedAsync(ConnectionContext connection) transferFormatFeature.ActiveFormat = TransferFormat.Binary; } - var bufferWriter = _bufferWriterPool.Rent(); + // WebSocketConnectionFeature will be accessed in MqttChannel + var httpContext = connection.GetHttpContext(); + if (httpContext != null && httpContext.WebSockets.IsWebSocketRequest) + { + var path = httpContext.Request.Path; + connection.Features.Set(new WebSocketConnectionFeature(path)); + } + var bufferWriter = _bufferWriterPool.Rent(); try { var formatter = new MqttPacketFormatterAdapter(bufferWriter); - using var adapter = new MqttServerChannelAdapter(formatter, connection); + using var adapter = new MqttServerChannelAdapter(formatter, connection, httpContext); await clientHandler(adapter).ConfigureAwait(false); } finally diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs index 69254cb7e..58b7aef8d 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; +using MQTTnet.Adapter; using System; using System.Buffers; using System.Threading.Tasks; @@ -23,8 +24,17 @@ public MqttConnectionMiddleware(MqttConnectionHandler connectionHandler) _connectionHandler = connectionHandler; } - public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext connection, MqttProtocols protocols) + public async Task InvokeAsync( + ConnectionDelegate next, + ConnectionContext connection, + MqttProtocols protocols, + Func? allowPacketFragmentationSelector) { + if (allowPacketFragmentationSelector != null) + { + connection.Features.Set(new PacketFragmentationFeature(allowPacketFragmentationSelector)); + } + if (protocols == MqttProtocols.MqttAndWebSocket) { var input = connection.Transport.Input; diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs index 7e6e482f2..0561ee81c 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; using MQTTnet.Adapter; using MQTTnet.Formatter; using System.Threading; @@ -12,9 +13,31 @@ namespace MQTTnet.AspNetCore; sealed class MqttServerChannelAdapter : MqttChannel, IMqttChannelAdapter { - public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) - : base(packetFormatterAdapter, connection) + public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, HttpContext? httpContext) + : base(packetFormatterAdapter, connection, httpContext, packetInspector: null) { + SetAllowPacketFragmentation(connection, httpContext); + } + + private void SetAllowPacketFragmentation(ConnectionContext connection, HttpContext? httpContext) + { + // When connection is from MapMqtt(), + // the PacketFragmentationFeature instance is copied from kestrel's ConnectionContext.Features to HttpContext.Features, + // but no longer from HttpContext.Features to connection.Features. + var feature = httpContext == null + ? connection.Features.Get() + : httpContext.Features.Get(); + + if (feature == null) + { + var value = !IsWebSocketConnection; + SetAllowPacketFragmentation(value); + } + else + { + var value = feature.AllowPacketFragmentationSelector(this); + SetAllowPacketFragmentation(value); + } } /// diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 1ca1aa8dd..d692a3088 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Server; using System; @@ -90,7 +91,19 @@ void UseMiddleware(ListenOptions listenOptions) tlsConfigure?.Invoke(httpsOptions); }); } - listenOptions.UseMqtt(protocols); + listenOptions.UseMqtt(protocols, AllowPacketFragmentationSelector); + } + + bool AllowPacketFragmentationSelector(IMqttChannelAdapter channelAdapter) + { + if (channelAdapter is MqttServerChannelAdapter serverChannelAdapter) + { + if (serverChannelAdapter.IsWebSocketConnection) + { + return false; + } + } + return endpoint.AllowPacketFragmentation; } } } diff --git a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs index e95ce11ce..8e2e73020 100644 --- a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs +++ b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs @@ -5,6 +5,7 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Connections; using MQTTnet.Adapter; using MQTTnet.AspNetCore; using MQTTnet.Diagnostics.Logger; @@ -27,7 +28,7 @@ public async Task OnConnectedAsync(ConnectionContext connection) try { var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(4096, 65535)); - var context = new MqttServerChannelAdapter(formatter, connection); + var context = new MqttServerChannelAdapter(formatter, connection, connection.GetHttpContext()); Context.TrySetResult(context); await ClientHandler(context); diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs index 8d90eaa24..83493671c 100644 --- a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Connections; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.AspNetCore; using MQTTnet.Exceptions; @@ -30,7 +31,7 @@ public async Task TestCorruptedConnectPacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttServerChannelAdapter(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection, connection.GetHttpContext()); await pipe.Receive.Writer.WriteAsync(writer.AddMqttHeader(MqttControlPacketType.Connect, Array.Empty())); @@ -98,7 +99,7 @@ public async Task TestLargePacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttServerChannelAdapter(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection, connection.GetHttpContext()); await ctx.SendPacketAsync(new MqttPublishPacket { PayloadSegment = new byte[20_000] }, CancellationToken.None).ConfigureAwait(false); @@ -113,7 +114,7 @@ public async Task TestReceivePacketAsyncThrowsWhenReaderCompleted() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttServerChannelAdapter(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection, connection.GetHttpContext()); pipe.Receive.Writer.Complete(); From e01e5f37cb20865c2422af5b6f9e1d66e47991d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sun, 8 Dec 2024 15:11:54 +0800 Subject: [PATCH 78/85] Merge implementation of IsAllowPacketFragmentation. --- .../Features/PacketFragmentationFeature.cs | 22 ++++++++++++++++++- .../Internal/MqttServerChannelAdapter.cs | 8 +++---- .../KestrelServerOptionsExtensions.cs | 15 +------------ 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs index caf3451f1..b5773c08e 100644 --- a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs +++ b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs @@ -3,12 +3,32 @@ // See the LICENSE file in the project root for more information. using MQTTnet.Adapter; +using MQTTnet.Server; using System; namespace MQTTnet.AspNetCore { sealed class PacketFragmentationFeature(Func allowPacketFragmentationSelector) { - public Func AllowPacketFragmentationSelector { get; } = allowPacketFragmentationSelector; + public Func AllowPacketFragmentationSelector { get; } = allowPacketFragmentationSelector; + + public static bool IsAllowPacketFragmentation(IMqttChannelAdapter channelAdapter, MqttServerTcpEndpointBaseOptions? endpointOptions) + { + //if (endpointOptions != null && endpointOptions.AllowPacketFragmentationSelector != null) + //{ + // return endpointOptions.AllowPacketFragmentationSelector(channelAdapter); + //} + + // In the AspNetCore environment, we need to exclude WebSocket before AllowPacketFragmentation. + if (channelAdapter is MqttServerChannelAdapter serverChannelAdapter) + { + if (serverChannelAdapter.IsWebSocketConnection) + { + return false; + } + } + + return endpointOptions == null || endpointOptions.AllowPacketFragmentation; + } } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs index 0561ee81c..bbf53db70 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -24,18 +24,18 @@ private void SetAllowPacketFragmentation(ConnectionContext connection, HttpConte // When connection is from MapMqtt(), // the PacketFragmentationFeature instance is copied from kestrel's ConnectionContext.Features to HttpContext.Features, // but no longer from HttpContext.Features to connection.Features. - var feature = httpContext == null + var packetFragmentationFeature = httpContext == null ? connection.Features.Get() : httpContext.Features.Get(); - if (feature == null) + if (packetFragmentationFeature == null) { - var value = !IsWebSocketConnection; + var value = PacketFragmentationFeature.IsAllowPacketFragmentation(this, null); SetAllowPacketFragmentation(value); } else { - var value = feature.AllowPacketFragmentationSelector(this); + var value = packetFragmentationFeature.AllowPacketFragmentationSelector(this); SetAllowPacketFragmentation(value); } } diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index d692a3088..93869e4f1 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -6,7 +6,6 @@ using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Https; using Microsoft.Extensions.DependencyInjection; -using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Server; using System; @@ -91,19 +90,7 @@ void UseMiddleware(ListenOptions listenOptions) tlsConfigure?.Invoke(httpsOptions); }); } - listenOptions.UseMqtt(protocols, AllowPacketFragmentationSelector); - } - - bool AllowPacketFragmentationSelector(IMqttChannelAdapter channelAdapter) - { - if (channelAdapter is MqttServerChannelAdapter serverChannelAdapter) - { - if (serverChannelAdapter.IsWebSocketConnection) - { - return false; - } - } - return endpoint.AllowPacketFragmentation; + listenOptions.UseMqtt(protocols, channelAdapter => PacketFragmentationFeature.IsAllowPacketFragmentation(channelAdapter, endpoint)); } } } From bffe06580a4f623d2d10195b50e44168ad519178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Sun, 8 Dec 2024 23:33:15 +0800 Subject: [PATCH 79/85] Added UseLogger overloaded method and renamed an internal method. --- .../Features/PacketFragmentationFeature.cs | 2 +- .../Internal/MqttServerChannelAdapter.cs | 2 +- .../KestrelServerOptionsExtensions.cs | 2 +- .../MQTTnet.AspnetCore/MqttBuilderExtensions.cs | 17 +++++++++++++++-- .../Mockups/AspNetCoreTestEnvironment.cs | 6 ++---- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs index b5773c08e..d8e9f7458 100644 --- a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs +++ b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs @@ -12,7 +12,7 @@ sealed class PacketFragmentationFeature(Func allowPac { public Func AllowPacketFragmentationSelector { get; } = allowPacketFragmentationSelector; - public static bool IsAllowPacketFragmentation(IMqttChannelAdapter channelAdapter, MqttServerTcpEndpointBaseOptions? endpointOptions) + public static bool CanAllowPacketFragmentation(IMqttChannelAdapter channelAdapter, MqttServerTcpEndpointBaseOptions? endpointOptions) { //if (endpointOptions != null && endpointOptions.AllowPacketFragmentationSelector != null) //{ diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs index bbf53db70..657a8e237 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -30,7 +30,7 @@ private void SetAllowPacketFragmentation(ConnectionContext connection, HttpConte if (packetFragmentationFeature == null) { - var value = PacketFragmentationFeature.IsAllowPacketFragmentation(this, null); + var value = PacketFragmentationFeature.CanAllowPacketFragmentation(this, null); SetAllowPacketFragmentation(value); } else diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 93869e4f1..e91f7f754 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -90,7 +90,7 @@ void UseMiddleware(ListenOptions listenOptions) tlsConfigure?.Invoke(httpsOptions); }); } - listenOptions.UseMqtt(protocols, channelAdapter => PacketFragmentationFeature.IsAllowPacketFragmentation(channelAdapter, endpoint)); + listenOptions.UseMqtt(protocols, channelAdapter => PacketFragmentationFeature.CanAllowPacketFragmentation(channelAdapter, endpoint)); } } } diff --git a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs index d104e3562..c0623a048 100644 --- a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs @@ -41,7 +41,7 @@ public static IMqttBuilder UseAspNetCoreMqttNetLogger(this IMqttBuilder builder) /// public static IMqttBuilder UseMqttNetNullLogger(this IMqttBuilder builder) { - return builder.UseLogger(); + return builder.UseLogger(MqttNetNullLogger.Instance); } /// @@ -55,6 +55,19 @@ public static IMqttBuilder UseMqttNetNullLogger(this IMqttBuilder builder) { builder.Services.Replace(ServiceDescriptor.Singleton()); return builder; - } + } + + /// + /// Use a logger + /// + /// + /// + /// + public static IMqttBuilder UseLogger(this IMqttBuilder builder, IMqttNetLogger logger) + { + ArgumentNullException.ThrowIfNull(logger); + builder.Services.Replace(ServiceDescriptor.Singleton(logger)); + return builder; + } } } diff --git a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs index 09801a20d..739f318b8 100644 --- a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs @@ -50,8 +50,7 @@ private IMqttClientFactory CreateClientFactory() var services = new ServiceCollection(); var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; - services.AddSingleton(logger); - services.AddMqttClient(); + services.AddMqttClient().UseLogger(logger); return services.BuildServiceProvider().GetRequiredService(); } @@ -94,8 +93,7 @@ private async Task StartServer(MqttServerOptions serverOptions) appBuilder.Services.AddSingleton(serverOptions); var logger = EnableLogger ? (IMqttNetLogger)ServerLogger : new MqttNetNullLogger(); - appBuilder.Services.AddSingleton(logger); - appBuilder.Services.AddMqttServer(); + appBuilder.Services.AddMqttServer().UseLogger(logger); appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); From ca5d13d9c9a1d2146e6d13102c47aa342e9c9017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 9 Dec 2024 01:18:08 +0800 Subject: [PATCH 80/85] Add some extension methods to IMqttChannelAdapter. --- .../Features/PacketFragmentationFeature.cs | 7 +-- .../Internal/IAspNetCoreMqttChannelAdapter.cs | 16 +++++++ .../Internal/MqttClientChannelAdapter.cs | 7 ++- .../Internal/MqttServerChannelAdapter.cs | 10 +++- .../MqttChannelAdapterExtensions.cs | 48 +++++++++++++++++++ 5 files changed, 80 insertions(+), 8 deletions(-) create mode 100644 Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs create mode 100644 Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs diff --git a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs index d8e9f7458..eb05414a4 100644 --- a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs +++ b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs @@ -20,12 +20,9 @@ public static bool CanAllowPacketFragmentation(IMqttChannelAdapter channelAdapte //} // In the AspNetCore environment, we need to exclude WebSocket before AllowPacketFragmentation. - if (channelAdapter is MqttServerChannelAdapter serverChannelAdapter) + if (channelAdapter.IsWebSocketConnection() == true) { - if (serverChannelAdapter.IsWebSocketConnection) - { - return false; - } + return false; } return endpointOptions == null || endpointOptions.AllowPacketFragmentation; diff --git a/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs new file mode 100644 index 000000000..323f29999 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using MQTTnet.Adapter; + +namespace MQTTnet.AspNetCore +{ + interface IAspNetCoreMqttChannelAdapter : IMqttChannelAdapter + { + HttpContext? HttpContext { get; } + IFeatureCollection? Features { get; } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index 832df3f0b..1fec39feb 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using MQTTnet.Adapter; using MQTTnet.Formatter; using MQTTnet.Packets; @@ -15,7 +17,7 @@ namespace MQTTnet.AspNetCore; -sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable +sealed class MqttClientChannelAdapter : IAspNetCoreMqttChannelAdapter, IAsyncDisposable { private bool _disposed = false; private ConnectionContext? _connection; @@ -25,6 +27,9 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable private readonly bool _allowPacketFragmentation; private readonly MqttPacketInspector? _packetInspector; + public HttpContext? HttpContext => null; + public IFeatureCollection? Features => _connection?.Features; + public MqttClientChannelAdapter( MqttPacketFormatterAdapter packetFormatterAdapter, IMqttClientChannelOptions channelOptions, diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs index 657a8e237..3703931b1 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -4,18 +4,24 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; -using MQTTnet.Adapter; +using Microsoft.AspNetCore.Http.Features; using MQTTnet.Formatter; using System.Threading; using System.Threading.Tasks; namespace MQTTnet.AspNetCore; -sealed class MqttServerChannelAdapter : MqttChannel, IMqttChannelAdapter +sealed class MqttServerChannelAdapter : MqttChannel, IAspNetCoreMqttChannelAdapter { + public HttpContext? HttpContext { get; } + public IFeatureCollection? Features { get; } + public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, HttpContext? httpContext) : base(packetFormatterAdapter, connection, httpContext, packetInspector: null) { + HttpContext = httpContext; + Features = connection.Features; + SetAllowPacketFragmentation(connection, httpContext); } diff --git a/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs new file mode 100644 index 000000000..17e834c77 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http; +using MQTTnet.Adapter; +using System; + +namespace MQTTnet.AspNetCore +{ + public static class MqttChannelAdapterExtensions + { + public static bool? IsWebSocketConnection(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannelAdapter adapter + ? adapter.Features != null && adapter.Features.Get() != null + : null; + } + + /// + /// Retrieves the requested feature from the feature collection of channelAdapter. + /// + /// + /// + /// + public static TFeature? GetFeature(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannelAdapter adapter && adapter.Features != null + ? adapter.Features.Get() + : default; + } + + /// + /// When the channelAdapter is a WebSocket connection, it can get an associated . + /// + /// + /// + public static HttpContext? GetHttpContext(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannelAdapter adapter + ? adapter.HttpContext + : null; + } + } +} From 8ceba5ffe36704c7fa177f10447b66b3d004dcba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 9 Dec 2024 09:53:18 +0800 Subject: [PATCH 81/85] Add IAspNetCoreMqttChannel and remove IAspNetCoreMqttChannelAdapter; --- ...elAdapter.cs => IAspNetCoreMqttChannel.cs} | 9 ++-- .../Internal/MqttChannel.cs | 53 ++++++++++++------- .../Internal/MqttClientChannelAdapter.cs | 13 +++-- .../Internal/MqttServerChannelAdapter.cs | 23 ++------ .../MqttChannelAdapterExtensions.cs | 12 ++--- 5 files changed, 55 insertions(+), 55 deletions(-) rename Source/MQTTnet.AspnetCore/Internal/{IAspNetCoreMqttChannelAdapter.cs => IAspNetCoreMqttChannel.cs} (64%) diff --git a/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannel.cs similarity index 64% rename from Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs rename to Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannel.cs index 323f29999..5eb5358ae 100644 --- a/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannel.cs @@ -3,14 +3,15 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features; -using MQTTnet.Adapter; namespace MQTTnet.AspNetCore { - interface IAspNetCoreMqttChannelAdapter : IMqttChannelAdapter + interface IAspNetCoreMqttChannel { HttpContext? HttpContext { get; } - IFeatureCollection? Features { get; } + + bool IsWebSocketConnection { get; } + + TFeature? GetFeature(); } } diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs index 11a04222b..dc6cdacdd 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -23,8 +23,11 @@ namespace MQTTnet.AspNetCore; -class MqttChannel : IDisposable +class MqttChannel : IAspNetCoreMqttChannel, IDisposable { + readonly ConnectionContext _connection; + readonly HttpContext? _httpContext; + readonly AsyncLock _writerLock = new(); readonly PipeReader _input; readonly PipeWriter _output; @@ -45,6 +48,7 @@ class MqttChannel : IDisposable public bool IsWebSocketConnection { get; } + public HttpContext? HttpContext => _httpContext; public MqttChannel( MqttPacketFormatterAdapter packetFormatterAdapter, @@ -53,16 +57,37 @@ public MqttChannel( MqttPacketInspector? packetInspector) { PacketFormatterAdapter = packetFormatterAdapter; - - var tlsConnectionFeature = connection.Features.Get(); - RemoteEndPoint = GetRemoteEndPoint(connection.RemoteEndPoint, httpContext); - ClientCertificate = GetClientCertificate(tlsConnectionFeature, httpContext); - IsSecureConnection = IsTlsConnection(tlsConnectionFeature, httpContext); - IsWebSocketConnection = connection.Features.Get() != null; - + _connection = connection; + _httpContext = httpContext; _packetInspector = packetInspector; + _input = connection.Transport.Input; _output = connection.Transport.Output; + + var tlsConnectionFeature = GetFeature(); + var webSocketConnectionFeature = GetFeature(); + + IsWebSocketConnection = webSocketConnectionFeature != null; + IsSecureConnection = tlsConnectionFeature != null; + ClientCertificate = tlsConnectionFeature?.ClientCertificate; + RemoteEndPoint = GetRemoteEndPoint(connection.RemoteEndPoint, httpContext); + } + + + public TFeature? GetFeature() + { + var feature = _connection.Features.Get(); + if (feature != null) + { + return feature; + } + + if (_httpContext != null) + { + return _httpContext.Features.Get(); + } + + return default; } private static EndPoint? GetRemoteEndPoint(EndPoint? remoteEndPoint, HttpContext? httpContext) @@ -85,18 +110,6 @@ public MqttChannel( return null; } - private static bool IsTlsConnection(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext) - { - return tlsConnectionFeature != null || (httpContext != null && httpContext.Request.IsHttps); - } - - private static X509Certificate2? GetClientCertificate(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext) - { - return tlsConnectionFeature != null - ? tlsConnectionFeature.ClientCertificate - : httpContext?.Connection.ClientCertificate; - } - public void SetAllowPacketFragmentation(bool value) { _allowPacketFragmentation = value; diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs index 1fec39feb..311a469ba 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -4,7 +4,6 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features; using MQTTnet.Adapter; using MQTTnet.Formatter; using MQTTnet.Packets; @@ -17,7 +16,7 @@ namespace MQTTnet.AspNetCore; -sealed class MqttClientChannelAdapter : IAspNetCoreMqttChannelAdapter, IAsyncDisposable +sealed class MqttClientChannelAdapter : IAsyncDisposable, IMqttChannelAdapter, IAspNetCoreMqttChannel { private bool _disposed = false; private ConnectionContext? _connection; @@ -27,9 +26,6 @@ sealed class MqttClientChannelAdapter : IAspNetCoreMqttChannelAdapter, IAsyncDis private readonly bool _allowPacketFragmentation; private readonly MqttPacketInspector? _packetInspector; - public HttpContext? HttpContext => null; - public IFeatureCollection? Features => _connection?.Features; - public MqttClientChannelAdapter( MqttPacketFormatterAdapter packetFormatterAdapter, IMqttClientChannelOptions channelOptions, @@ -56,6 +52,13 @@ public MqttClientChannelAdapter( public bool IsWebSocketConnection => GetChannel().IsSecureConnection; + public HttpContext? HttpContext => GetChannel().HttpContext; + + public TFeature? GetFeature() + { + return GetChannel().GetFeature(); + } + public async Task ConnectAsync(CancellationToken cancellationToken) { diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs index 3703931b1..c6b565ac6 100644 --- a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -4,36 +4,19 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features; +using MQTTnet.Adapter; using MQTTnet.Formatter; using System.Threading; using System.Threading.Tasks; namespace MQTTnet.AspNetCore; -sealed class MqttServerChannelAdapter : MqttChannel, IAspNetCoreMqttChannelAdapter +sealed class MqttServerChannelAdapter : MqttChannel, IMqttChannelAdapter, IAspNetCoreMqttChannel { - public HttpContext? HttpContext { get; } - public IFeatureCollection? Features { get; } - public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, HttpContext? httpContext) : base(packetFormatterAdapter, connection, httpContext, packetInspector: null) { - HttpContext = httpContext; - Features = connection.Features; - - SetAllowPacketFragmentation(connection, httpContext); - } - - private void SetAllowPacketFragmentation(ConnectionContext connection, HttpContext? httpContext) - { - // When connection is from MapMqtt(), - // the PacketFragmentationFeature instance is copied from kestrel's ConnectionContext.Features to HttpContext.Features, - // but no longer from HttpContext.Features to connection.Features. - var packetFragmentationFeature = httpContext == null - ? connection.Features.Get() - : httpContext.Features.Get(); - + var packetFragmentationFeature = GetFeature(); if (packetFragmentationFeature == null) { var value = PacketFragmentationFeature.CanAllowPacketFragmentation(this, null); diff --git a/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs index 17e834c77..ff96dd0fc 100644 --- a/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs +++ b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs @@ -13,8 +13,8 @@ public static class MqttChannelAdapterExtensions public static bool? IsWebSocketConnection(this IMqttChannelAdapter channelAdapter) { ArgumentNullException.ThrowIfNull(channelAdapter); - return channelAdapter is IAspNetCoreMqttChannelAdapter adapter - ? adapter.Features != null && adapter.Features.Get() != null + return channelAdapter is IAspNetCoreMqttChannel channel + ? channel.IsWebSocketConnection : null; } @@ -27,8 +27,8 @@ public static class MqttChannelAdapterExtensions public static TFeature? GetFeature(this IMqttChannelAdapter channelAdapter) { ArgumentNullException.ThrowIfNull(channelAdapter); - return channelAdapter is IAspNetCoreMqttChannelAdapter adapter && adapter.Features != null - ? adapter.Features.Get() + return channelAdapter is IAspNetCoreMqttChannel channel + ? channel.GetFeature() : default; } @@ -40,8 +40,8 @@ public static class MqttChannelAdapterExtensions public static HttpContext? GetHttpContext(this IMqttChannelAdapter channelAdapter) { ArgumentNullException.ThrowIfNull(channelAdapter); - return channelAdapter is IAspNetCoreMqttChannelAdapter adapter - ? adapter.HttpContext + return channelAdapter is IAspNetCoreMqttChannel channel + ? channel.HttpContext : null; } } From b65c176fc585390897b057961b9126e63fa36427 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 9 Dec 2024 14:19:04 +0800 Subject: [PATCH 82/85] MapMqtt: Restricted to WebSocket transport protocol. --- .../EndpointRouteBuilderExtensions.cs | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs index dc1c73098..0448c3154 100644 --- a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -24,13 +24,7 @@ public static class EndpointRouteBuilderExtensions /// public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern) { - return endpoints.MapMqtt(pattern, options => options.WebSockets.SubProtocolSelector = SelectSubProtocol); - - static string SelectSubProtocol(IList requestedSubProtocolValues) - { - // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. - return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt"))!; - } + return endpoints.MapMqtt(pattern, null); } /// @@ -38,15 +32,29 @@ static string SelectSubProtocol(IList requestedSubProtocolValues) /// /// /// - /// + /// /// - public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern, Action options) + public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern, Action? configureOptions) { // check services.AddMqttServer() endpoints.ServiceProvider.GetRequiredService(); endpoints.ServiceProvider.GetRequiredService().MapFlag = true; - return endpoints.MapConnectionHandler(pattern, options); + return endpoints.MapConnectionHandler(pattern, ConfigureOptions); + + + void ConfigureOptions(HttpConnectionDispatcherOptions options) + { + options.Transports = HttpTransportType.WebSockets; + options.WebSockets.SubProtocolSelector = SelectSubProtocol; + configureOptions?.Invoke(options); + } + + static string SelectSubProtocol(IList requestedSubProtocolValues) + { + // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. + return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt"))!; + } } } } From 51fb18564776c56da122860cf4ab5318b831bcfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 9 Dec 2024 14:56:27 +0800 Subject: [PATCH 83/85] Adapt MqttServerTcpEndpointBaseOptions to the Socket accepted by kestrel. --- .../KestrelServerOptionsExtensions.cs | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index e91f7f754..6f33a3053 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Https; @@ -10,6 +11,7 @@ using MQTTnet.Server; using System; using System.Net; +using System.Net.Sockets; using System.Security.Cryptography.X509Certificates; namespace MQTTnet.AspNetCore @@ -82,6 +84,16 @@ void Listen(MqttServerTcpEndpointBaseOptions endpoint) void UseMiddleware(ListenOptions listenOptions) { + listenOptions.Use(next => context => + { + var socketFeature = context.Features.Get(); + if (socketFeature != null) + { + endpoint.AdaptTo(socketFeature.Socket); + } + return next(context); + }); + if (endpoint is MqttServerTlsTcpEndpointOptions tlsEndPoint) { listenOptions.UseHttps(httpsOptions => @@ -90,11 +102,54 @@ void UseMiddleware(ListenOptions listenOptions) tlsConfigure?.Invoke(httpsOptions); }); } + listenOptions.UseMqtt(protocols, channelAdapter => PacketFragmentationFeature.CanAllowPacketFragmentation(channelAdapter, endpoint)); } } } + private static void AdaptTo(this MqttServerTcpEndpointBaseOptions endpoint, Socket socket) + { + if (endpoint.NoDelay) + { + socket.NoDelay = true; + } + + if (endpoint.LingerState != null) + { + socket.LingerState = endpoint.LingerState; + } + + if (endpoint.ReuseAddress) + { + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); + } + + if (endpoint.KeepAlive.HasValue) + { + var value = endpoint.KeepAlive.Value; + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, value); + } + + if (endpoint.TcpKeepAliveInterval.HasValue) + { + var value = endpoint.TcpKeepAliveInterval.Value; + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveInterval, value); + } + + if (endpoint.TcpKeepAliveRetryCount.HasValue) + { + var value = endpoint.TcpKeepAliveRetryCount.Value; + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveRetryCount, value); + } + + if (endpoint.TcpKeepAliveTime.HasValue) + { + var value = endpoint.TcpKeepAliveTime.Value; + socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveTime, value); + } + } + private static void AdaptTo(this MqttServerTlsTcpEndpointOptions tlsEndPoint, HttpsConnectionAdapterOptions httpsOptions) { httpsOptions.SslProtocols = tlsEndPoint.SslProtocol; From 79f4f683cbd3c2963cb75d96b389c7b70e61c0e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 9 Dec 2024 15:18:50 +0800 Subject: [PATCH 84/85] SocketOptionName.ReuseAddress can only be used for listening Socket settings. --- Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index 6f33a3053..a051e01cf 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -120,11 +120,6 @@ private static void AdaptTo(this MqttServerTcpEndpointBaseOptions endpoint, Sock socket.LingerState = endpoint.LingerState; } - if (endpoint.ReuseAddress) - { - socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, true); - } - if (endpoint.KeepAlive.HasValue) { var value = endpoint.KeepAlive.Value; From 7139431b27a167c79016225d0b1f4318f660d0e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Mon, 9 Dec 2024 19:42:18 +0800 Subject: [PATCH 85/85] Accurately detect the DualMode value of listenSocket. --- .../KestrelServerOptionsExtensions.cs | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs index a051e01cf..afbbed92b 100644 --- a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -6,7 +6,9 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; using MQTTnet.Exceptions; using MQTTnet.Server; using System; @@ -56,8 +58,9 @@ public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, // check services.AddMqttServer() kestrel.ApplicationServices.GetRequiredService(); - var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); var serverOptions = kestrel.ApplicationServices.GetRequiredService(); + var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); + var listenSocketFactory = kestrel.ApplicationServices.GetRequiredService>().Value.CreateBoundListenSocket ?? SocketTransportOptions.CreateDefaultBoundListenSocket; Listen(serverOptions.DefaultEndpointOptions); Listen(serverOptions.TlsEndpointOptions); @@ -73,16 +76,19 @@ void Listen(MqttServerTcpEndpointBaseOptions endpoint) return; } - // No need to listen any IPv4 when has IPv6Any - if (!IPAddress.IPv6Any.Equals(endpoint.BoundInterNetworkV6Address)) + // No need to listen IPv4EndPoint when IPv6EndPoint's DualMode is true. + var ipV6EndPoint = new IPEndPoint(endpoint.BoundInterNetworkV6Address, endpoint.Port); + using var listenSocket = listenSocketFactory.Invoke(ipV6EndPoint); + if (!listenSocket.DualMode) { - kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, UseMiddleware); + kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, UseMiddlewares); } - kestrel.Listen(endpoint.BoundInterNetworkV6Address, endpoint.Port, UseMiddleware); + + kestrel.Listen(ipV6EndPoint, UseMiddlewares); connectionHandler.ListenFlag = true; - void UseMiddleware(ListenOptions listenOptions) + void UseMiddlewares(ListenOptions listenOptions) { listenOptions.Use(next => context => { @@ -108,40 +114,36 @@ void UseMiddleware(ListenOptions listenOptions) } } - private static void AdaptTo(this MqttServerTcpEndpointBaseOptions endpoint, Socket socket) + private static void AdaptTo(this MqttServerTcpEndpointBaseOptions endpoint, Socket acceptSocket) { - if (endpoint.NoDelay) - { - socket.NoDelay = true; - } - + acceptSocket.NoDelay = endpoint.NoDelay; if (endpoint.LingerState != null) { - socket.LingerState = endpoint.LingerState; + acceptSocket.LingerState = endpoint.LingerState; } if (endpoint.KeepAlive.HasValue) { var value = endpoint.KeepAlive.Value; - socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, value); + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, value); } if (endpoint.TcpKeepAliveInterval.HasValue) { var value = endpoint.TcpKeepAliveInterval.Value; - socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveInterval, value); + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveInterval, value); } if (endpoint.TcpKeepAliveRetryCount.HasValue) { var value = endpoint.TcpKeepAliveRetryCount.Value; - socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveRetryCount, value); + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveRetryCount, value); } if (endpoint.TcpKeepAliveTime.HasValue) { var value = endpoint.TcpKeepAliveTime.Value; - socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveTime, value); + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveTime, value); } }