Skip to content

Commit 9a7a8bd

Browse files
committed
Server-side adaptation of AllowPacketFragmentation options.
1 parent 6171c81 commit 9a7a8bd

16 files changed

+172
-81
lines changed

Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
using Microsoft.AspNetCore.Connections;
66
using Microsoft.Extensions.DependencyInjection;
7+
using MQTTnet.Adapter;
78
using MQTTnet.Server;
9+
using System;
810

911
namespace MQTTnet.AspNetCore
1012
{
@@ -15,15 +17,16 @@ public static class ConnectionBuilderExtensions
1517
/// </summary>
1618
/// <param name="builder"></param>
1719
/// <param name="protocols"></param>
20+
/// <param name="allowPacketFragmentationSelector"></param>
1821
/// <returns></returns>
19-
public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket)
22+
public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket, Func<IMqttChannelAdapter, bool>? allowPacketFragmentationSelector = null)
2023
{
2124
// check services.AddMqttServer()
2225
builder.ApplicationServices.GetRequiredService<MqttServer>();
2326
builder.ApplicationServices.GetRequiredService<MqttConnectionHandler>().UseFlag = true;
2427

2528
var middleware = builder.ApplicationServices.GetRequiredService<MqttConnectionMiddleware>();
26-
return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols));
29+
return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols, allowPacketFragmentationSelector));
2730
}
2831
}
2932
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using MQTTnet.Adapter;
6+
using System;
7+
8+
namespace MQTTnet.AspNetCore
9+
{
10+
sealed class PacketFragmentationFeature(Func<IMqttChannelAdapter, bool> allowPacketFragmentationSelector)
11+
{
12+
public Func<IMqttChannelAdapter, bool> AllowPacketFragmentationSelector { get; } = allowPacketFragmentationSelector;
13+
}
14+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.AspNetCore.Http.Features;
6+
using System.Security.Cryptography.X509Certificates;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
10+
namespace MQTTnet.AspNetCore
11+
{
12+
sealed class TlsConnectionFeature : ITlsConnectionFeature
13+
{
14+
public static readonly TlsConnectionFeature WithoutClientCertificate = new(null);
15+
16+
public X509Certificate2? ClientCertificate { get; set; }
17+
18+
public Task<X509Certificate2?> GetClientCertificateAsync(CancellationToken cancellationToken)
19+
{
20+
return Task.FromResult(ClientCertificate);
21+
}
22+
23+
public TlsConnectionFeature(X509Certificate? clientCertificate)
24+
{
25+
ClientCertificate = clientCertificate as X509Certificate2;
26+
}
27+
}
28+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
namespace MQTTnet.AspNetCore
6+
{
7+
sealed class WebSocketConnectionFeature(string path)
8+
{
9+
/// <summary>
10+
/// The path of WebSocket request.
11+
/// </summary>
12+
public string Path { get; } = path;
13+
}
14+
}

Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPa
1616
ArgumentNullException.ThrowIfNull(nameof(options));
1717
var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax);
1818
var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter);
19-
return new MqttClientChannelAdapter(formatter, options.ChannelOptions, packetInspector, options.AllowPacketFragmentation);
19+
return new MqttClientChannelAdapter(formatter, options.ChannelOptions, options.AllowPacketFragmentation, packetInspector);
2020
}
2121
}
2222
}

Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Microsoft.AspNetCore.Connections.Features;
65
using Microsoft.AspNetCore.Http.Features;
76
using System;
87
using System.Net;
@@ -72,7 +71,11 @@ public static async Task<ClientConnectionContext> CreateAsync(MqttClientTcpOptio
7271
var networkStream = new NetworkStream(socket, ownsSocket: true);
7372
if (options.TlsOptions?.UseTls != true)
7473
{
75-
return new ClientConnectionContext(networkStream);
74+
return new ClientConnectionContext(networkStream)
75+
{
76+
LocalEndPoint = socket.LocalEndPoint,
77+
RemoteEndPoint = socket.RemoteEndPoint,
78+
};
7679
}
7780

7881
var targetHost = options.TlsOptions.TargetHost;
@@ -143,7 +146,6 @@ public static async Task<ClientConnectionContext> CreateAsync(MqttClientTcpOptio
143146
RemoteEndPoint = socket.RemoteEndPoint,
144147
};
145148

146-
connection.Features.Set<IConnectionSocketFeature>(new ConnectionSocketFeature(socket));
147149
connection.Features.Set<ITlsConnectionFeature>(new TlsConnectionFeature(sslStream.LocalCertificate));
148150
return connection;
149151

Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ public static async Task<ClientConnectionContext> CreateAsync(MqttClientWebSocke
4848
RemoteEndPoint = new DnsEndPoint(uri.Host, uri.Port),
4949
};
5050

51+
connection.Features.Set(new WebSocketConnectionFeature(uri.AbsolutePath));
5152
if (uri.Scheme == Uri.UriSchemeWss)
5253
{
53-
connection.Features.Set<ITlsConnectionFeature>(TlsConnectionFeature.Default);
54+
connection.Features.Set<ITlsConnectionFeature>(TlsConnectionFeature.WithoutClientCertificate);
5455
}
5556
return connection;
5657
}

Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.AspNetCore.Connections;
6-
using Microsoft.AspNetCore.Connections.Features;
76
using Microsoft.AspNetCore.Http.Features;
87
using System;
98
using System.Collections.Generic;
109
using System.IO;
1110
using System.IO.Pipelines;
12-
using System.Net.Sockets;
13-
using System.Security.Cryptography.X509Certificates;
1411
using System.Threading;
1512
using System.Threading.Tasks;
1613

@@ -67,27 +64,5 @@ private class StreamTransport(Stream stream) : IDuplexPipe
6764

6865
public PipeWriter Output { get; } = PipeWriter.Create(stream, new StreamPipeWriterOptions(leaveOpen: true));
6966
}
70-
71-
private class TlsConnectionFeature : ITlsConnectionFeature
72-
{
73-
public static readonly TlsConnectionFeature Default = new(null);
74-
75-
public X509Certificate2? ClientCertificate { get; set; }
76-
77-
public Task<X509Certificate2?> GetClientCertificateAsync(CancellationToken cancellationToken)
78-
{
79-
return Task.FromResult(ClientCertificate);
80-
}
81-
82-
public TlsConnectionFeature(X509Certificate? clientCertificate)
83-
{
84-
ClientCertificate = clientCertificate as X509Certificate2;
85-
}
86-
}
87-
88-
private class ConnectionSocketFeature(Socket socket) : IConnectionSocketFeature
89-
{
90-
public Socket Socket { get; } = socket;
91-
}
9267
}
9368
}

Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.AspNetCore.Connections;
6-
using Microsoft.AspNetCore.Http.Connections.Features;
6+
using Microsoft.AspNetCore.Http;
77
using Microsoft.AspNetCore.Http.Features;
88
using MQTTnet.Adapter;
99
using MQTTnet.Exceptions;
@@ -29,7 +29,7 @@ class MqttChannel : IDisposable
2929
readonly PipeReader _input;
3030
readonly PipeWriter _output;
3131
readonly MqttPacketInspector? _packetInspector;
32-
readonly bool _allowPacketFragmentation;
32+
bool _allowPacketFragmentation = false;
3333

3434
public MqttPacketFormatterAdapter PacketFormatterAdapter { get; }
3535

@@ -43,69 +43,64 @@ class MqttChannel : IDisposable
4343

4444
public bool IsSecureConnection { get; }
4545

46+
public bool IsWebSocketConnection { get; }
47+
4648

4749
public MqttChannel(
4850
MqttPacketFormatterAdapter packetFormatterAdapter,
4951
ConnectionContext connection,
50-
MqttPacketInspector? packetInspector = null,
51-
bool? allowPacketFragmentation = null)
52+
HttpContext? httpContext,
53+
MqttPacketInspector? packetInspector)
5254
{
5355
PacketFormatterAdapter = packetFormatterAdapter;
54-
_packetInspector = packetInspector;
5556

56-
var httpContextFeature = connection.Features.Get<IHttpContextFeature>();
5757
var tlsConnectionFeature = connection.Features.Get<ITlsConnectionFeature>();
58-
RemoteEndPoint = GetRemoteEndPoint(httpContextFeature, connection.RemoteEndPoint);
59-
IsSecureConnection = IsTlsConnection(httpContextFeature, tlsConnectionFeature);
60-
ClientCertificate = GetClientCertificate(httpContextFeature, tlsConnectionFeature);
58+
RemoteEndPoint = GetRemoteEndPoint(connection.RemoteEndPoint, httpContext);
59+
ClientCertificate = GetClientCertificate(tlsConnectionFeature, httpContext);
60+
IsSecureConnection = IsTlsConnection(tlsConnectionFeature, httpContext);
61+
IsWebSocketConnection = connection.Features.Get<WebSocketConnectionFeature>() != null;
6162

63+
_packetInspector = packetInspector;
6264
_input = connection.Transport.Input;
6365
_output = connection.Transport.Output;
64-
65-
_allowPacketFragmentation = allowPacketFragmentation == null
66-
? AllowPacketFragmentation(httpContextFeature)
67-
: allowPacketFragmentation.Value;
6866
}
6967

70-
private static bool AllowPacketFragmentation(IHttpContextFeature? _httpContextFeature)
68+
private static EndPoint? GetRemoteEndPoint(EndPoint? remoteEndPoint, HttpContext? httpContext)
7169
{
72-
var serverModeWebSocket = _httpContextFeature != null &&
73-
_httpContextFeature.HttpContext != null &&
74-
_httpContextFeature.HttpContext.WebSockets.IsWebSocketRequest;
75-
76-
return !serverModeWebSocket;
77-
}
78-
70+
if (remoteEndPoint != null)
71+
{
72+
return remoteEndPoint;
73+
}
7974

80-
private static EndPoint? GetRemoteEndPoint(IHttpContextFeature? _httpContextFeature, EndPoint? remoteEndPoint)
81-
{
82-
if (_httpContextFeature != null && _httpContextFeature.HttpContext != null)
75+
if (httpContext != null)
8376
{
84-
var httpConnection = _httpContextFeature.HttpContext.Connection;
77+
var httpConnection = httpContext.Connection;
8578
var remoteAddress = httpConnection.RemoteIpAddress;
8679
if (remoteAddress != null)
8780
{
8881
return new IPEndPoint(remoteAddress, httpConnection.RemotePort);
8982
}
9083
}
9184

92-
return remoteEndPoint;
85+
return null;
9386
}
9487

95-
private static bool IsTlsConnection(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature)
88+
private static bool IsTlsConnection(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext)
9689
{
97-
return _httpContextFeature != null && _httpContextFeature.HttpContext != null
98-
? _httpContextFeature.HttpContext.Request.IsHttps
99-
: tlsConnectionFeature != null;
90+
return tlsConnectionFeature != null || (httpContext != null && httpContext.Request.IsHttps);
10091
}
10192

102-
private static X509Certificate2? GetClientCertificate(IHttpContextFeature? _httpContextFeature, ITlsConnectionFeature? tlsConnectionFeature)
93+
private static X509Certificate2? GetClientCertificate(ITlsConnectionFeature? tlsConnectionFeature, HttpContext? httpContext)
10394
{
104-
return _httpContextFeature != null && _httpContextFeature.HttpContext != null
105-
? _httpContextFeature.HttpContext.Connection.ClientCertificate
106-
: tlsConnectionFeature?.ClientCertificate;
95+
return tlsConnectionFeature != null
96+
? tlsConnectionFeature.ClientCertificate
97+
: httpContext?.Connection.ClientCertificate;
10798
}
10899

100+
public void SetAllowPacketFragmentation(bool value)
101+
{
102+
_allowPacketFragmentation = value;
103+
}
109104

110105
public async Task DisconnectAsync()
111106
{

Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@ sealed class MqttClientChannelAdapter : IMqttChannelAdapter, IAsyncDisposable
2222
private MqttChannel? _channel;
2323
private readonly MqttPacketFormatterAdapter _packetFormatterAdapter;
2424
private readonly IMqttClientChannelOptions _channelOptions;
25+
private readonly bool _allowPacketFragmentation;
2526
private readonly MqttPacketInspector? _packetInspector;
26-
private readonly bool? _allowPacketFragmentation;
2727

2828
public MqttClientChannelAdapter(
2929
MqttPacketFormatterAdapter packetFormatterAdapter,
3030
IMqttClientChannelOptions channelOptions,
31-
MqttPacketInspector? packetInspector,
32-
bool? allowPacketFragmentation)
31+
bool allowPacketFragmentation,
32+
MqttPacketInspector? packetInspector)
3333
{
3434
_packetFormatterAdapter = packetFormatterAdapter;
3535
_channelOptions = channelOptions;
36-
_packetInspector = packetInspector;
3736
_allowPacketFragmentation = allowPacketFragmentation;
37+
_packetInspector = packetInspector;
3838
}
3939

4040
public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter;
@@ -49,6 +49,8 @@ public MqttClientChannelAdapter(
4949

5050
public bool IsSecureConnection => GetChannel().IsSecureConnection;
5151

52+
public bool IsWebSocketConnection => GetChannel().IsSecureConnection;
53+
5254

5355
public async Task ConnectAsync(CancellationToken cancellationToken)
5456
{
@@ -60,7 +62,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
6062
MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false),
6163
_ => throw new NotSupportedException(),
6264
};
63-
_channel = new MqttChannel(_packetFormatterAdapter, _connection, _packetInspector, _allowPacketFragmentation);
65+
_channel = new MqttChannel(_packetFormatterAdapter, _connection, httpContext: null, _packetInspector);
66+
_channel.SetAllowPacketFragmentation(_allowPacketFragmentation);
6467
}
6568
catch (Exception ex)
6669
{

0 commit comments

Comments
 (0)