Skip to content

Commit b1ed9b9

Browse files
Disable RequestTimeout middleware in SignalR and WebSockets (#47342)
1 parent 46d1c24 commit b1ed9b9

File tree

7 files changed

+185
-9
lines changed

7 files changed

+185
-9
lines changed

src/Middleware/WebSockets/src/Microsoft.AspNetCore.WebSockets.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
</PropertyGroup>
1313

1414
<ItemGroup>
15-
<Reference Include="Microsoft.AspNetCore.Http.Extensions" />
15+
<Reference Include="Microsoft.AspNetCore.Http" />
1616
<Reference Include="Microsoft.Extensions.Logging.Abstractions" />
1717
<Reference Include="Microsoft.Extensions.Options" />
1818
</ItemGroup>

src/Middleware/WebSockets/src/WebSocketMiddleware.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.AspNetCore.Builder;
77
using Microsoft.AspNetCore.Http;
88
using Microsoft.AspNetCore.Http.Features;
9+
using Microsoft.AspNetCore.Http.Timeouts;
910
using Microsoft.Extensions.Logging;
1011
using Microsoft.Extensions.Options;
1112
using Microsoft.Extensions.Primitives;
@@ -201,6 +202,9 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
201202
opaqueTransport = await _upgradeFeature!.UpgradeAsync(); // Sets status code to 101
202203
}
203204

205+
// Disable request timeout, if there is one, after the websocket has been accepted
206+
_context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();
207+
204208
return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
205209
{
206210
IsServer = true,

src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Net.Http;
66
using System.Net.WebSockets;
77
using System.Text;
8+
using Microsoft.AspNetCore.Http.Timeouts;
89
using Microsoft.AspNetCore.Testing;
910
using Microsoft.Net.Http.Headers;
1011

@@ -627,4 +628,38 @@ public async Task MultipleValueHeadersNotOverridden()
627628
}
628629
}
629630
}
631+
632+
[Fact]
633+
public async Task AcceptingWebSocketRequestDisablesTimeout()
634+
{
635+
await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
636+
{
637+
context.Features.Set<IHttpRequestTimeoutFeature>(new HttpRequestTimeoutFeature());
638+
Assert.True(context.WebSockets.IsWebSocketRequest);
639+
var feature = Assert.IsType<HttpRequestTimeoutFeature>(context.Features.Get<IHttpRequestTimeoutFeature>());
640+
Assert.True(feature.Enabled);
641+
642+
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
643+
644+
Assert.False(feature.Enabled);
645+
}))
646+
{
647+
using (var client = new ClientWebSocket())
648+
{
649+
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
650+
}
651+
}
652+
}
653+
654+
internal sealed class HttpRequestTimeoutFeature : IHttpRequestTimeoutFeature
655+
{
656+
public bool Enabled { get; private set; } = true;
657+
658+
public CancellationToken RequestTimeoutToken => new CancellationToken();
659+
660+
public void DisableTimeout()
661+
{
662+
Enabled = false;
663+
}
664+
}
630665
}

src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Microsoft.AspNetCore.Connections;
66
using Microsoft.AspNetCore.Http.Connections;
77
using Microsoft.AspNetCore.Http.Connections.Internal;
8+
using Microsoft.AspNetCore.Http.Timeouts;
89
using Microsoft.AspNetCore.Routing;
910
using Microsoft.Extensions.DependencyInjection;
1011

@@ -111,6 +112,7 @@ public static ConnectionEndpointRouteBuilder MapConnections(this IEndpointRouteB
111112
var executehandler = app.Build();
112113

113114
var executeBuilder = endpoints.Map(pattern, executehandler);
115+
executeBuilder.WithMetadata(new DisableRequestTimeoutAttribute());
114116
conventionBuilders.Add(executeBuilder);
115117

116118
var compositeConventionBuilder = new CompositeEndpointConventionBuilder(conventionBuilders);

src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using Microsoft.AspNetCore.Connections;
99
using Microsoft.AspNetCore.Http.Connections.Internal.Transports;
1010
using Microsoft.AspNetCore.Http.Features;
11+
using Microsoft.AspNetCore.Http.Timeouts;
1112
using Microsoft.AspNetCore.Internal;
1213
using Microsoft.Extensions.DependencyInjection;
1314
using Microsoft.Extensions.Logging;
@@ -213,6 +214,7 @@ private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connecti
213214
return;
214215
}
215216

217+
context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();
216218
var resultTask = await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!);
217219

218220
try
@@ -274,6 +276,7 @@ private async Task DoPersistentConnection(ConnectionDelegate connectionDelegate,
274276
{
275277
if (connection.TryActivatePersistentConnection(connectionDelegate, transport, context, _logger))
276278
{
279+
context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();
277280
// Wait for any of them to end
278281
await Task.WhenAny(connection.ApplicationTask!, connection.TransportTask!);
279282

@@ -439,7 +442,7 @@ private async Task ProcessSend(HttpContext context)
439442
}
440443
catch (OperationCanceledException)
441444
{
442-
// CancelPendingFlush has canceled pending writes caused by backpresure
445+
// CancelPendingFlush has canceled pending writes caused by backpressure
443446
Log.ConnectionDisposed(_logger, connection.ConnectionId);
444447

445448
context.Response.StatusCode = StatusCodes.Status404NotFound;

src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs

Lines changed: 117 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4-
using System;
54
using System.Buffers;
6-
using System.Collections.Generic;
75
using System.Diagnostics;
86
using System.IdentityModel.Tokens.Jwt;
9-
using System.IO;
107
using System.IO.Pipelines;
11-
using System.Linq;
128
using System.Net;
139
using System.Net.Http;
1410
using System.Net.WebSockets;
1511
using System.Security.Claims;
1612
using System.Security.Principal;
1713
using System.Text;
18-
using System.Threading;
19-
using System.Threading.Tasks;
2014
using Microsoft.AspNetCore.Authentication;
2115
using Microsoft.AspNetCore.Authentication.Cookies;
2216
using Microsoft.AspNetCore.Authentication.JwtBearer;
@@ -29,8 +23,10 @@
2923
using Microsoft.AspNetCore.Hosting.Server;
3024
using Microsoft.AspNetCore.Hosting.Server.Features;
3125
using Microsoft.AspNetCore.Http.Connections.Client;
26+
using Microsoft.AspNetCore.Http.Connections.Features;
3227
using Microsoft.AspNetCore.Http.Connections.Internal;
3328
using Microsoft.AspNetCore.Http.Features;
29+
using Microsoft.AspNetCore.Http.Timeouts;
3430
using Microsoft.AspNetCore.Internal;
3531
using Microsoft.AspNetCore.Routing;
3632
using Microsoft.AspNetCore.SignalR.Tests;
@@ -45,7 +41,6 @@
4541
using Moq;
4642
using Newtonsoft.Json;
4743
using Newtonsoft.Json.Linq;
48-
using Xunit;
4944

5045
namespace Microsoft.AspNetCore.Http.Connections.Tests;
5146

@@ -3112,6 +3107,109 @@ public async Task AuthenticationExpirationSetToMaxValueByDefault()
31123107
await connection.DisposeAsync();
31133108
}
31143109

3110+
[Theory]
3111+
[InlineData(HttpTransportType.ServerSentEvents)]
3112+
[InlineData(HttpTransportType.WebSockets)]
3113+
public async Task RequestTimeoutDisabledWhenConnected(HttpTransportType transportType)
3114+
{
3115+
using (StartVerifiableLog())
3116+
{
3117+
using var host = new HostBuilder()
3118+
.ConfigureWebHost(webHostBuilder =>
3119+
{
3120+
webHostBuilder
3121+
.UseKestrel()
3122+
.ConfigureLogging(o =>
3123+
{
3124+
o.AddProvider(new ForwardingLoggerProvider(LoggerFactory));
3125+
})
3126+
.ConfigureServices(services =>
3127+
{
3128+
services.AddConnections();
3129+
3130+
// Since tests run in parallel, it's possible multiple servers will startup,
3131+
// we use an ephemeral key provider to avoid filesystem contention issues
3132+
services.AddSingleton<IDataProtectionProvider, EphemeralDataProtectionProvider>();
3133+
})
3134+
.Configure(app =>
3135+
{
3136+
app.Use((c, n) =>
3137+
{
3138+
c.Features.Set<IHttpRequestTimeoutFeature>(new HttpRequestTimeoutFeature());
3139+
Assert.True(((HttpRequestTimeoutFeature)c.Features.Get<IHttpRequestTimeoutFeature>()).Enabled);
3140+
return n(c);
3141+
});
3142+
app.UseRouting();
3143+
app.UseEndpoints(endpoints =>
3144+
{
3145+
endpoints.MapConnectionHandler<TestConnectionHandler>("/foo");
3146+
});
3147+
})
3148+
.UseUrls("http://127.0.0.1:0");
3149+
})
3150+
.Build();
3151+
3152+
host.Start();
3153+
3154+
var manager = host.Services.GetRequiredService<HttpConnectionManager>();
3155+
var url = host.Services.GetService<IServer>().Features.Get<IServerAddressesFeature>().Addresses.Single();
3156+
3157+
var stream = new MemoryStream();
3158+
var connection = new HttpConnection(
3159+
new HttpConnectionOptions()
3160+
{
3161+
Url = new Uri(url + "/foo"),
3162+
Transports = transportType,
3163+
DefaultTransferFormat = TransferFormat.Text,
3164+
HttpMessageHandlerFactory = handler => new GetNegotiateHttpHandler(handler, stream)
3165+
},
3166+
LoggerFactory);
3167+
3168+
await connection.StartAsync();
3169+
3170+
var negotiateResponse = NegotiateProtocol.ParseResponse(stream.ToArray());
3171+
3172+
Assert.True(manager.TryGetConnection(negotiateResponse.ConnectionToken, out var context));
3173+
var feature = Assert.IsType<HttpRequestTimeoutFeature>(context.Features.Get<IHttpContextFeature>()?.HttpContext.Features.Get<IHttpRequestTimeoutFeature>());
3174+
Assert.False(feature.Enabled);
3175+
3176+
await connection.DisposeAsync();
3177+
}
3178+
}
3179+
3180+
[Fact]
3181+
public async Task DisableRequestTimeoutInLongPolling()
3182+
{
3183+
using (StartVerifiableLog())
3184+
{
3185+
var manager = CreateConnectionManager(LoggerFactory, TimeSpan.FromSeconds(5));
3186+
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
3187+
var options = new HttpConnectionDispatcherOptions();
3188+
var connection = manager.CreateConnection(options);
3189+
connection.TransportType = HttpTransportType.LongPolling;
3190+
3191+
var services = new ServiceCollection();
3192+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
3193+
builder.UseConnectionHandler<HttpContextConnectionHandler>();
3194+
var app = builder.Build();
3195+
var context = MakeRequest("/foo", connection, services);
3196+
context.Features.Set<IHttpRequestTimeoutFeature>(new HttpRequestTimeoutFeature());
3197+
Assert.True(((HttpRequestTimeoutFeature)context.Features.Get<IHttpRequestTimeoutFeature>()).Enabled);
3198+
3199+
// Initial poll will complete immediately
3200+
await dispatcher.ExecuteAsync(context, options, app).DefaultTimeout();
3201+
Assert.False(((HttpRequestTimeoutFeature)context.Features.Get<IHttpRequestTimeoutFeature>()).Enabled);
3202+
3203+
context.Features.Set<IHttpRequestTimeoutFeature>(new HttpRequestTimeoutFeature());
3204+
Assert.True(((HttpRequestTimeoutFeature)context.Features.Get<IHttpRequestTimeoutFeature>()).Enabled);
3205+
var pollTask = dispatcher.ExecuteAsync(context, options, app);
3206+
// disables on every poll
3207+
Assert.False(((HttpRequestTimeoutFeature)context.Features.Get<IHttpRequestTimeoutFeature>()).Enabled);
3208+
3209+
await connection.DisposeAsync().DefaultTimeout();
3210+
}
3211+
}
3212+
31153213
private class GetNegotiateHttpHandler : DelegatingHandler
31163214
{
31173215
private readonly MemoryStream _stream;
@@ -3502,3 +3600,15 @@ public class MessageWrapper
35023600
{
35033601
public ReadOnlySequence<byte> Buffer { get; set; }
35043602
}
3603+
3604+
internal sealed class HttpRequestTimeoutFeature : IHttpRequestTimeoutFeature
3605+
{
3606+
public bool Enabled { get; private set; } = true;
3607+
3608+
public CancellationToken RequestTimeoutToken => new CancellationToken();
3609+
3610+
public void DisableTimeout()
3611+
{
3612+
Enabled = false;
3613+
}
3614+
}

src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.AspNetCore.Hosting;
1111
using Microsoft.AspNetCore.Hosting.Server;
1212
using Microsoft.AspNetCore.Hosting.Server.Features;
13+
using Microsoft.AspNetCore.Http.Timeouts;
1314
using Microsoft.AspNetCore.Routing;
1415
using Microsoft.AspNetCore.SignalR.Tests;
1516
using Microsoft.AspNetCore.Testing;
@@ -405,6 +406,27 @@ public async Task MapConnectionHandlerWithWebSocketSubProtocolSetsProtocol()
405406
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
406407
}
407408

409+
[Fact]
410+
public void MapConnectionHandlerAddsDisableRequestTimeoutMetadata()
411+
{
412+
using var host = BuildWebHost<MyConnectionHandler>("/test", o => { });
413+
host.Start();
414+
415+
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
416+
// We register 2 endpoints (/negotiate and /)
417+
Assert.Collection(dataSource.Endpoints,
418+
endpoint =>
419+
{
420+
Assert.Equal("/test/negotiate", endpoint.DisplayName);
421+
Assert.Empty(endpoint.Metadata.GetOrderedMetadata<DisableRequestTimeoutAttribute>());
422+
},
423+
endpoint =>
424+
{
425+
Assert.Equal("/test", endpoint.DisplayName);
426+
Assert.Single(endpoint.Metadata.GetOrderedMetadata<DisableRequestTimeoutAttribute>());
427+
});
428+
}
429+
408430
private class MyConnectionHandler : ConnectionHandler
409431
{
410432
public override async Task OnConnectedAsync(ConnectionContext connection)

0 commit comments

Comments
 (0)