Skip to content

Commit 6e1b89d

Browse files
Fix incoming connections with same ID sometimes failing incorrectly (#50417)
* Fix incoming connections with same ID sometimes failing incorrectly * using * fb
1 parent 3561949 commit 6e1b89d

File tree

3 files changed

+112
-82
lines changed

3 files changed

+112
-82
lines changed

src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,36 @@ internal async Task<bool> CancelPreviousPoll(HttpContext context)
581581
}
582582
}
583583

584+
internal SetTransportState TrySetTransport(HttpTransportType transportType, HttpConnectionsMetrics metrics)
585+
{
586+
lock (_stateLock)
587+
{
588+
if (TransportType == HttpTransportType.None)
589+
{
590+
TransportType = transportType;
591+
592+
if (HttpConnectionsEventSource.Log.IsEnabled() || MetricsContext.ConnectionDurationEnabled)
593+
{
594+
StartTimestamp = Stopwatch.GetTimestamp();
595+
}
596+
597+
HttpConnectionsEventSource.Log.ConnectionStart(ConnectionId);
598+
599+
metrics.ConnectionTransportStart(MetricsContext, transportType);
600+
}
601+
else if (TransportType != transportType)
602+
{
603+
return SetTransportState.CannotChange;
604+
}
605+
else if (!ClientReconnectExpected())
606+
{
607+
return SetTransportState.AlreadyActive;
608+
}
609+
610+
return SetTransportState.Success;
611+
}
612+
}
613+
584614
public void MarkInactive()
585615
{
586616
lock (_stateLock)
@@ -718,6 +748,19 @@ public void OnReconnected(Func<PipeWriter, Task> notifyOnReconnect)
718748
}
719749
}
720750

751+
// If the connection is using the Stateful Reconnect feature or using LongPolling
752+
internal bool ClientReconnectExpected()
753+
{
754+
return UseStatefulReconnect == true || TransportType == HttpTransportType.LongPolling;
755+
}
756+
757+
internal enum SetTransportState
758+
{
759+
Success,
760+
AlreadyActive,
761+
CannotChange,
762+
}
763+
721764
private static partial class Log
722765
{
723766
[LoggerMessage(1, LogLevel.Trace, "Disposing connection {TransportConnectionId}.", EventName = "DisposingConnection")]

src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Buffers;
5-
using System.Diagnostics;
65
using System.Security.Claims;
76
using System.Security.Principal;
87
using Microsoft.AspNetCore.Authentication;
@@ -570,30 +569,30 @@ private async Task<bool> EnsureConnectionStateAsync(HttpConnectionContext connec
570569
return false;
571570
}
572571

573-
// Set the IHttpConnectionFeature now that we can access it.
574-
connection.Features.Set(context.Features.Get<IHttpConnectionFeature>());
575-
576-
if (connection.TransportType == HttpTransportType.None)
572+
switch (connection.TrySetTransport(transportType, _metrics))
577573
{
578-
if (HttpConnectionsEventSource.Log.IsEnabled() || connection.MetricsContext.ConnectionDurationEnabled)
579-
{
580-
connection.StartTimestamp = Stopwatch.GetTimestamp();
581-
}
574+
case HttpConnectionContext.SetTransportState.Success:
575+
break;
582576

583-
connection.TransportType = transportType;
577+
case HttpConnectionContext.SetTransportState.AlreadyActive:
578+
Log.ConnectionAlreadyActive(_logger, connection.ConnectionId, context.TraceIdentifier);
584579

585-
HttpConnectionsEventSource.Log.ConnectionStart(connection.ConnectionId);
586-
_metrics.ConnectionTransportStart(connection.MetricsContext, transportType);
587-
}
588-
else if (connection.TransportType != transportType)
589-
{
590-
context.Response.ContentType = "text/plain";
591-
context.Response.StatusCode = StatusCodes.Status400BadRequest;
592-
Log.CannotChangeTransport(_logger, connection.TransportType, transportType);
593-
await context.Response.WriteAsync("Cannot change transports mid-connection");
594-
return false;
580+
// Reject the request with a 409 conflict
581+
context.Response.StatusCode = StatusCodes.Status409Conflict;
582+
context.Response.ContentType = "text/plain";
583+
return false;
584+
585+
case HttpConnectionContext.SetTransportState.CannotChange:
586+
context.Response.ContentType = "text/plain";
587+
context.Response.StatusCode = StatusCodes.Status400BadRequest;
588+
Log.CannotChangeTransport(_logger, connection.TransportType, transportType);
589+
await context.Response.WriteAsync("Cannot change transports mid-connection");
590+
return false;
595591
}
596592

593+
// Set the IHttpConnectionFeature now that we can access it.
594+
connection.Features.Set(context.Features.Get<IHttpConnectionFeature>());
595+
597596
// Configure transport-specific features.
598597
if (transportType == HttpTransportType.LongPolling)
599598
{

src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs

Lines changed: 50 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,6 @@ public async Task TransportEndingGracefullyWaitsOnApplication(HttpTransportType
450450
var manager = CreateConnectionManager(LoggerFactory);
451451
var dispatcher = CreateDispatcher(manager, LoggerFactory);
452452
var connection = manager.CreateConnection();
453-
connection.TransportType = transportType;
454453

455454
using (var strm = new MemoryStream())
456455
{
@@ -1001,7 +1000,6 @@ public async Task CompletedEndPointEndsConnection()
10011000
{
10021001
var manager = CreateConnectionManager(LoggerFactory);
10031002
var connection = manager.CreateConnection();
1004-
connection.TransportType = HttpTransportType.ServerSentEvents;
10051003

10061004
var dispatcher = CreateDispatcher(manager, LoggerFactory);
10071005

@@ -1036,7 +1034,6 @@ bool ExpectedErrors(WriteContext writeContext)
10361034
{
10371035
var manager = CreateConnectionManager(LoggerFactory);
10381036
var connection = manager.CreateConnection();
1039-
connection.TransportType = HttpTransportType.ServerSentEvents;
10401037

10411038
var dispatcher = CreateDispatcher(manager, LoggerFactory);
10421039
var services = new ServiceCollection();
@@ -1315,12 +1312,11 @@ public async Task SSEConnectionClosesWhenSendTimeoutReached()
13151312
{
13161313
var manager = CreateConnectionManager(LoggerFactory);
13171314
var connection = manager.CreateConnection();
1318-
connection.TransportType = HttpTransportType.ServerSentEvents;
13191315
var dispatcher = CreateDispatcher(manager, LoggerFactory);
13201316
var services = new ServiceCollection();
13211317
services.AddSingleton<TestConnectionHandler>();
13221318
var context = MakeRequest("/foo", connection, services);
1323-
SetTransport(context, connection.TransportType);
1319+
SetTransport(context, HttpTransportType.ServerSentEvents);
13241320
var builder = new ConnectionBuilder(services.BuildServiceProvider());
13251321
builder.UseConnectionHandler<TestConnectionHandler>();
13261322
var app = builder.Build();
@@ -1352,13 +1348,12 @@ bool ExpectedErrors(WriteContext writeContext)
13521348
{
13531349
var manager = CreateConnectionManager(LoggerFactory);
13541350
var connection = manager.CreateConnection();
1355-
connection.TransportType = HttpTransportType.WebSockets;
13561351
var dispatcher = CreateDispatcher(manager, LoggerFactory);
13571352
var sync = new SyncPoint();
13581353
var services = new ServiceCollection();
13591354
services.AddSingleton<TestConnectionHandler>();
13601355
var context = MakeRequest("/foo", connection, services);
1361-
SetTransport(context, connection.TransportType, sync);
1356+
SetTransport(context, HttpTransportType.WebSockets, sync);
13621357
var builder = new ConnectionBuilder(services.BuildServiceProvider());
13631358
builder.UseConnectionHandler<TestConnectionHandler>();
13641359
var app = builder.Build();
@@ -1414,7 +1409,6 @@ public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTrans
14141409
{
14151410
var manager = CreateConnectionManager(LoggerFactory);
14161411
var connection = manager.CreateConnection();
1417-
connection.TransportType = transportType;
14181412

14191413
var dispatcher = CreateDispatcher(manager, LoggerFactory);
14201414

@@ -1434,7 +1428,10 @@ public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTrans
14341428

14351429
await dispatcher.ExecuteAsync(context2, options, app).DefaultTimeout();
14361430

1431+
Assert.False(request1.IsCompleted);
1432+
14371433
Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode);
1434+
Assert.NotSame(connection.HttpContext, context2);
14381435

14391436
var webSocketTask = Task.CompletedTask;
14401437

@@ -1586,7 +1583,6 @@ public async Task RequestToDisposedConnectionIdReturns404(HttpTransportType tran
15861583
{
15871584
var manager = CreateConnectionManager(LoggerFactory);
15881585
var connection = manager.CreateConnection();
1589-
connection.TransportType = transportType;
15901586
connection.Status = HttpConnectionStatus.Disposed;
15911587

15921588
var dispatcher = CreateDispatcher(manager, LoggerFactory);
@@ -1653,7 +1649,6 @@ public async Task BlockingConnectionWorksWithStreamingConnections()
16531649
{
16541650
var manager = CreateConnectionManager(LoggerFactory);
16551651
var connection = manager.CreateConnection();
1656-
connection.TransportType = HttpTransportType.ServerSentEvents;
16571652

16581653
var dispatcher = CreateDispatcher(manager, LoggerFactory);
16591654

@@ -1777,7 +1772,6 @@ public async Task TransferModeSet(HttpTransportType transportType, TransferForma
17771772
{
17781773
var manager = CreateConnectionManager(LoggerFactory);
17791774
var connection = manager.CreateConnection();
1780-
connection.TransportType = transportType;
17811775

17821776
var dispatcher = CreateDispatcher(manager, LoggerFactory);
17831777

@@ -2434,7 +2428,6 @@ public async Task DisableReconnectDisallowsReplacementConnection()
24342428
options.WebSockets.CloseTimeout = TimeSpan.FromMilliseconds(1);
24352429
// pretend negotiate occurred
24362430
var connection = manager.CreateConnection(options, negotiateVersion: 1, useStatefulReconnect: true);
2437-
connection.TransportType = HttpTransportType.WebSockets;
24382431

24392432
var dispatcher = CreateDispatcher(manager, LoggerFactory);
24402433
var services = new ServiceCollection();
@@ -2800,12 +2793,11 @@ public async Task SSEConnectionClosingTriggersConnectionClosedToken()
28002793
{
28012794
var manager = CreateConnectionManager(LoggerFactory);
28022795
var connection = manager.CreateConnection();
2803-
connection.TransportType = HttpTransportType.ServerSentEvents;
28042796
var dispatcher = CreateDispatcher(manager, LoggerFactory);
28052797
var services = new ServiceCollection();
28062798
services.AddSingleton<NeverEndingConnectionHandler>();
28072799
var context = MakeRequest("/foo", connection, services);
2808-
SetTransport(context, connection.TransportType);
2800+
SetTransport(context, HttpTransportType.ServerSentEvents);
28092801

28102802
var builder = new ConnectionBuilder(services.BuildServiceProvider());
28112803
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
@@ -2827,7 +2819,6 @@ public async Task WebSocketConnectionClosingTriggersConnectionClosedToken()
28272819
{
28282820
var manager = CreateConnectionManager(LoggerFactory);
28292821
var connection = manager.CreateConnection();
2830-
connection.TransportType = HttpTransportType.WebSockets;
28312822

28322823
var dispatcher = CreateDispatcher(manager, LoggerFactory);
28332824
var services = new ServiceCollection();
@@ -2875,14 +2866,13 @@ public async Task AbortingConnectionAbortsHttpContextAndTriggersConnectionClosed
28752866
{
28762867
var manager = CreateConnectionManager(LoggerFactory);
28772868
var connection = manager.CreateConnection();
2878-
connection.TransportType = HttpTransportType.ServerSentEvents;
28792869
var dispatcher = CreateDispatcher(manager, LoggerFactory);
28802870
var services = new ServiceCollection();
28812871
services.AddSingleton<NeverEndingConnectionHandler>();
28822872
var context = MakeRequest("/foo", connection, services);
28832873
var lifetimeFeature = new CustomHttpRequestLifetimeFeature();
28842874
context.Features.Set<IHttpRequestLifetimeFeature>(lifetimeFeature);
2885-
SetTransport(context, connection.TransportType);
2875+
SetTransport(context, HttpTransportType.ServerSentEvents);
28862876

28872877
var builder = new ConnectionBuilder(services.BuildServiceProvider());
28882878
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
@@ -3062,7 +3052,6 @@ public async Task LongRunningActivityTagSetOnExecuteAsync()
30623052
{
30633053
var manager = CreateConnectionManager(LoggerFactory);
30643054
var connection = manager.CreateConnection();
3065-
connection.TransportType = HttpTransportType.ServerSentEvents;
30663055

30673056
var dispatcher = CreateDispatcher(manager, LoggerFactory);
30683057
var services = new ServiceCollection();
@@ -3313,56 +3302,56 @@ public async Task AuthenticationExpirationUsesCorrectScheme(HttpTransportType tr
33133302
var JwtTokenHandler = new JwtSecurityTokenHandler();
33143303

33153304
using var host = CreateHost(services =>
3305+
{
3306+
// Set default to Cookie auth but use JWT auth for the endpoint
3307+
// This makes sure we take the scheme into account when grabbing the token expiration
3308+
services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme)
3309+
.AddCookie()
3310+
.AddJwtBearer(options =>
33163311
{
3317-
// Set default to Cookie auth but use JWT auth for the endpoint
3318-
// This makes sure we take the scheme into account when grabbing the token expiration
3319-
services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme)
3320-
.AddCookie()
3321-
.AddJwtBearer(options =>
3312+
options.TokenValidationParameters =
3313+
new TokenValidationParameters
3314+
{
3315+
LifetimeValidator = (before, expires, token, parameters) => expires > DateTime.UtcNow,
3316+
ValidateAudience = false,
3317+
ValidateIssuer = false,
3318+
ValidateActor = false,
3319+
ValidateLifetime = true,
3320+
IssuerSigningKey = SecurityKey
3321+
};
3322+
3323+
options.Events = new JwtBearerEvents
33223324
{
3323-
options.TokenValidationParameters =
3324-
new TokenValidationParameters
3325-
{
3326-
LifetimeValidator = (before, expires, token, parameters) => expires > DateTime.UtcNow,
3327-
ValidateAudience = false,
3328-
ValidateIssuer = false,
3329-
ValidateActor = false,
3330-
ValidateLifetime = true,
3331-
IssuerSigningKey = SecurityKey
3332-
};
3333-
3334-
options.Events = new JwtBearerEvents
3325+
OnMessageReceived = context =>
33353326
{
3336-
OnMessageReceived = context =>
3327+
var accessToken = context.Request.Query["access_token"];
3328+
3329+
if (!string.IsNullOrEmpty(accessToken) &&
3330+
(context.HttpContext.WebSockets.IsWebSocketRequest || context.Request.Headers["Accept"] == "text/event-stream"))
33373331
{
3338-
var accessToken = context.Request.Query["access_token"];
3339-
3340-
if (!string.IsNullOrEmpty(accessToken) &&
3341-
(context.HttpContext.WebSockets.IsWebSocketRequest || context.Request.Headers["Accept"] == "text/event-stream"))
3342-
{
3343-
context.Token = context.Request.Query["access_token"];
3344-
}
3345-
return Task.CompletedTask;
3332+
context.Token = context.Request.Query["access_token"];
33463333
}
3347-
};
3348-
});
3349-
}, endpoints =>
3350-
{
3351-
endpoints.MapConnectionHandler<JwtConnectionHandler>("/foo", o => o.CloseOnAuthenticationExpiration = true);
3334+
return Task.CompletedTask;
3335+
}
3336+
};
3337+
});
3338+
}, endpoints =>
3339+
{
3340+
endpoints.MapConnectionHandler<JwtConnectionHandler>("/foo", o => o.CloseOnAuthenticationExpiration = true);
33523341

3353-
endpoints.MapGet("/generatetoken", context =>
3354-
{
3355-
return context.Response.WriteAsync(GenerateToken(context));
3356-
});
3342+
endpoints.MapGet("/generatetoken", context =>
3343+
{
3344+
return context.Response.WriteAsync(GenerateToken(context));
3345+
});
33573346

3358-
string GenerateToken(HttpContext httpContext)
3359-
{
3360-
var claims = new[] { new Claim(ClaimTypes.NameIdentifier, httpContext.Request.Query["user"]) };
3361-
var credentials = new SigningCredentials(SecurityKey, SecurityAlgorithms.HmacSha256);
3362-
var token = new JwtSecurityToken("SignalRTestServer", "SignalRTests", claims, expires: DateTime.UtcNow.AddMinutes(1), signingCredentials: credentials);
3363-
return JwtTokenHandler.WriteToken(token);
3364-
}
3365-
}, LoggerFactory);
3347+
string GenerateToken(HttpContext httpContext)
3348+
{
3349+
var claims = new[] { new Claim(ClaimTypes.NameIdentifier, httpContext.Request.Query["user"]) };
3350+
var credentials = new SigningCredentials(SecurityKey, SecurityAlgorithms.HmacSha256);
3351+
var token = new JwtSecurityToken("SignalRTestServer", "SignalRTests", claims, expires: DateTime.UtcNow.AddMinutes(1), signingCredentials: credentials);
3352+
return JwtTokenHandler.WriteToken(token);
3353+
}
3354+
}, LoggerFactory);
33663355

33673356
host.Start();
33683357

@@ -3639,7 +3628,6 @@ private static async Task CheckTransportSupported(HttpTransportType supportedTra
36393628
{
36403629
var manager = CreateConnectionManager(loggerFactory);
36413630
var connection = manager.CreateConnection();
3642-
connection.TransportType = transportType;
36433631

36443632
var dispatcher = CreateDispatcher(manager, loggerFactory);
36453633
using (var strm = new MemoryStream())

0 commit comments

Comments
 (0)