Skip to content

Commit f8732f5

Browse files
Add server option for seamless reconnect (#48789)
1 parent 14a4f45 commit f8732f5

File tree

17 files changed

+359
-6
lines changed

17 files changed

+359
-6
lines changed

src/SignalR/clients/csharp/Client/test/FunctionalTests/Startup.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public void Configure(IApplicationBuilder app)
8686

8787
app.UseEndpoints(endpoints =>
8888
{
89-
endpoints.MapHub<TestHub>("/default");
89+
endpoints.MapHub<TestHub>("/default", o => o.AllowAcks = true);
9090
endpoints.MapHub<DynamicTestHub>("/dynamic");
9191
endpoints.MapHub<TestHubT>("/hubT");
9292
endpoints.MapHub<HubWithAuthorization>("/authorizedhub");

src/SignalR/common/Http.Connections/src/HttpConnectionDispatcherOptions.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ public TimeSpan TransportSendTimeout
124124
/// </remarks>
125125
public bool CloseOnAuthenticationExpiration { get; set; }
126126

127+
/// <summary>
128+
/// Set to allow connections to ack messages, helps enable reconnects that keep connection state.
129+
/// </summary>
130+
/// <remarks>
131+
/// Keeps messages in memory until acked (up to a limit), and keeps connections around for a short time to allow stateful reconnects.
132+
/// </remarks>
133+
public bool AllowAcks { get; set; }
134+
127135
internal long TransportSendTimeoutTicks { get; private set; }
128136
internal bool TransportSendTimeoutEnabled => _transportSendTimeout != Timeout.InfiniteTimeSpan;
129137

src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatche
336336
}
337337

338338
var useAck = false;
339-
if (context.Request.Query.TryGetValue("UseAck", out var useAckValue))
339+
if (options.AllowAcks == true && context.Request.Query.TryGetValue("UseAck", out var useAckValue))
340340
{
341341
var useAckStringValue = useAckValue.ToString();
342342
bool.TryParse(useAckStringValue, out useAck);
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
#nullable enable
2+
Microsoft.AspNetCore.Http.Connections.HttpConnectionDispatcherOptions.AllowAcks.get -> bool
3+
Microsoft.AspNetCore.Http.Connections.HttpConnectionDispatcherOptions.AllowAcks.set -> void

src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
using Microsoft.AspNetCore.Authorization;
1919
using Microsoft.AspNetCore.Builder;
2020
using Microsoft.AspNetCore.Connections;
21+
using Microsoft.AspNetCore.Connections.Abstractions;
2122
using Microsoft.AspNetCore.Connections.Features;
2223
using Microsoft.AspNetCore.DataProtection;
2324
using Microsoft.AspNetCore.Hosting;
@@ -2258,6 +2259,87 @@ public async Task NegotiateDoesNotReturnWebSocketsWhenNotAvailable()
22582259
}
22592260
}
22602261

2262+
[Fact]
2263+
public async Task NegotiateDoesNotReturnUseAckWhenNotEnabledOnServer()
2264+
{
2265+
using (StartVerifiableLog())
2266+
{
2267+
var manager = CreateConnectionManager(LoggerFactory);
2268+
var dispatcher = CreateDispatcher(manager, LoggerFactory);
2269+
var context = new DefaultHttpContext();
2270+
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
2271+
var services = new ServiceCollection();
2272+
services.AddSingleton<TestConnectionHandler>();
2273+
services.AddOptions();
2274+
var ms = new MemoryStream();
2275+
context.Request.Path = "/foo";
2276+
context.Request.Method = "POST";
2277+
context.Response.Body = ms;
2278+
context.Request.QueryString = new QueryString("?negotiateVersion=1&UseAck=true");
2279+
await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { AllowAcks = false });
2280+
2281+
var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
2282+
Assert.False(negotiateResponse.TryGetValue("useAck", out _));
2283+
2284+
Assert.True(manager.TryGetConnection(negotiateResponse["connectionToken"].ToString(), out var connection));
2285+
Assert.Null(connection.Features.Get<IReconnectFeature>());
2286+
}
2287+
}
2288+
2289+
[Fact]
2290+
public async Task NegotiateDoesNotReturnUseAckWhenEnabledOnServerButNotRequestedByClient()
2291+
{
2292+
using (StartVerifiableLog())
2293+
{
2294+
var manager = CreateConnectionManager(LoggerFactory);
2295+
var dispatcher = CreateDispatcher(manager, LoggerFactory);
2296+
var context = new DefaultHttpContext();
2297+
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
2298+
var services = new ServiceCollection();
2299+
services.AddSingleton<TestConnectionHandler>();
2300+
services.AddOptions();
2301+
var ms = new MemoryStream();
2302+
context.Request.Path = "/foo";
2303+
context.Request.Method = "POST";
2304+
context.Response.Body = ms;
2305+
context.Request.QueryString = new QueryString("?negotiateVersion=1");
2306+
await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { AllowAcks = true });
2307+
2308+
var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
2309+
Assert.False(negotiateResponse.TryGetValue("useAck", out _));
2310+
2311+
Assert.True(manager.TryGetConnection(negotiateResponse["connectionToken"].ToString(), out var connection));
2312+
Assert.Null(connection.Features.Get<IReconnectFeature>());
2313+
}
2314+
}
2315+
2316+
[Fact]
2317+
public async Task NegotiateReturnsUseAckWhenEnabledOnServerAndRequestedByClient()
2318+
{
2319+
using (StartVerifiableLog())
2320+
{
2321+
var manager = CreateConnectionManager(LoggerFactory);
2322+
var dispatcher = CreateDispatcher(manager, LoggerFactory);
2323+
var context = new DefaultHttpContext();
2324+
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
2325+
var services = new ServiceCollection();
2326+
services.AddSingleton<TestConnectionHandler>();
2327+
services.AddOptions();
2328+
var ms = new MemoryStream();
2329+
context.Request.Path = "/foo";
2330+
context.Request.Method = "POST";
2331+
context.Response.Body = ms;
2332+
context.Request.QueryString = new QueryString("?negotiateVersion=1&UseAck=true");
2333+
await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { AllowAcks = true });
2334+
2335+
var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
2336+
Assert.True((bool)negotiateResponse["useAck"]);
2337+
2338+
Assert.True(manager.TryGetConnection(negotiateResponse["connectionToken"].ToString(), out var connection));
2339+
Assert.NotNull(connection.Features.Get<IReconnectFeature>());
2340+
}
2341+
}
2342+
22612343
private class ControllableMemoryStream : MemoryStream
22622344
{
22632345
private readonly SyncPoint _syncPoint;

src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ private static AckMessage BindAckMessage(long? sequenceId)
888888
{
889889
if (sequenceId is null)
890890
{
891-
throw new InvalidDataException("Missing 'sequenceId' in Ack message.");
891+
throw new InvalidDataException("Missing required property 'sequenceId'.");
892892
}
893893

894894
return new AckMessage(sequenceId.Value);
@@ -898,7 +898,7 @@ private static SequenceMessage BindSequenceMessage(long? sequenceId)
898898
{
899899
if (sequenceId is null)
900900
{
901-
throw new InvalidDataException("Missing 'sequenceId' in Sequence message.");
901+
throw new InvalidDataException("Missing required property 'sequenceId'.");
902902
}
903903

904904
return new SequenceMessage(sequenceId.Value);

src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder
6060
return PingMessage.Instance;
6161
case HubProtocolConstants.CloseMessageType:
6262
return CreateCloseMessage(ref reader, itemCount);
63+
case HubProtocolConstants.AckMessageType:
64+
return CreateAckMessage(ref reader);
65+
case HubProtocolConstants.SequenceMessageType:
66+
return CreateSequenceMessage(ref reader);
6367
default:
6468
// Future protocol changes can add message types, old clients can ignore them
6569
return null;
@@ -279,6 +283,16 @@ private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int
279283
return streams?.ToArray();
280284
}
281285

286+
private static AckMessage CreateAckMessage(ref MessagePackReader reader)
287+
{
288+
return new AckMessage(ReadInt64(ref reader, "sequenceId"));
289+
}
290+
291+
private static SequenceMessage CreateSequenceMessage(ref MessagePackReader reader)
292+
{
293+
return new SequenceMessage(ReadInt64(ref reader, "sequenceId"));
294+
}
295+
282296
private object?[] BindArguments(ref MessagePackReader reader, IReadOnlyList<Type> parameterTypes)
283297
{
284298
var argumentCount = ReadArrayLength(ref reader, "arguments");
@@ -395,6 +409,12 @@ private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer)
395409
case CloseMessage closeMessage:
396410
WriteCloseMessage(closeMessage, ref writer);
397411
break;
412+
case AckMessage ackMessage:
413+
WriteAckMessage(ackMessage, ref writer);
414+
break;
415+
case SequenceMessage sequenceMessage:
416+
WriteSequenceMessage(sequenceMessage, ref writer);
417+
break;
398418
default:
399419
throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
400420
}
@@ -555,6 +575,20 @@ private static void WritePingMessage(ref MessagePackWriter writer)
555575
writer.Write(HubProtocolConstants.PingMessageType);
556576
}
557577

578+
private static void WriteAckMessage(AckMessage message, ref MessagePackWriter writer)
579+
{
580+
writer.WriteArrayHeader(2);
581+
writer.Write(HubProtocolConstants.AckMessageType);
582+
writer.Write(message.SequenceId);
583+
}
584+
585+
private static void WriteSequenceMessage(SequenceMessage message, ref MessagePackWriter writer)
586+
{
587+
writer.WriteArrayHeader(2);
588+
writer.Write(HubProtocolConstants.SequenceMessageType);
589+
writer.Write(message.SequenceId);
590+
}
591+
558592
private static void PackHeaders(IDictionary<string, string>? headers, ref MessagePackWriter writer)
559593
{
560594
if (headers != null)
@@ -602,6 +636,18 @@ private static int ReadInt32(ref MessagePackReader reader, string field)
602636
}
603637
}
604638

639+
private static long ReadInt64(ref MessagePackReader reader, string field)
640+
{
641+
try
642+
{
643+
return reader.ReadInt64();
644+
}
645+
catch (Exception ex)
646+
{
647+
throw new InvalidDataException($"Reading '{field}' as Int64 failed.", ex);
648+
}
649+
}
650+
605651
protected static string? ReadString(ref MessagePackReader reader, IInvocationBinder binder, string field)
606652
{
607653
try

src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public class NewtonsoftJsonHubProtocol : IHubProtocol
3434
private const string ArgumentsPropertyName = "arguments";
3535
private const string HeadersPropertyName = "headers";
3636
private const string AllowReconnectPropertyName = "allowReconnect";
37+
private const string SequenceIdPropertyName = "sequenceId";
3738

3839
private const string ProtocolName = "json";
3940
private const int ProtocolVersion = 1;
@@ -136,6 +137,7 @@ public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
136137
Dictionary<string, string>? headers = null;
137138
var completed = false;
138139
var allowReconnect = false;
140+
long? sequenceId = null;
139141

140142
using (var reader = JsonUtils.CreateJsonTextReader(textReader))
141143
{
@@ -310,6 +312,13 @@ public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
310312
JsonUtils.CheckRead(reader);
311313
headers = ReadHeaders(reader);
312314
break;
315+
case SequenceIdPropertyName:
316+
sequenceId = JsonUtils.ReadAsInt64(reader, SequenceIdPropertyName);
317+
if (sequenceId is null)
318+
{
319+
throw new InvalidDataException($"Missing required property '{SequenceIdPropertyName}'.");
320+
}
321+
break;
313322
default:
314323
// Skip read the property name
315324
JsonUtils.CheckRead(reader);
@@ -447,6 +456,10 @@ public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
447456
return PingMessage.Instance;
448457
case HubProtocolConstants.CloseMessageType:
449458
return BindCloseMessage(error, allowReconnect);
459+
case HubProtocolConstants.AckMessageType:
460+
return BindAckMessage(sequenceId);
461+
case HubProtocolConstants.SequenceMessageType:
462+
return BindSequenceMessage(sequenceId);
450463
case null:
451464
throw new InvalidDataException($"Missing required property '{TypePropertyName}'.");
452465
default:
@@ -539,6 +552,14 @@ private void WriteMessageCore(HubMessage message, IBufferWriter<byte> stream)
539552
WriteMessageType(writer, HubProtocolConstants.CloseMessageType);
540553
WriteCloseMessage(m, writer);
541554
break;
555+
case AckMessage m:
556+
WriteMessageType(writer, HubProtocolConstants.AckMessageType);
557+
WriteAckMessage(m, writer);
558+
break;
559+
case SequenceMessage m:
560+
WriteMessageType(writer, HubProtocolConstants.SequenceMessageType);
561+
WriteSequenceMessage(m, writer);
562+
break;
542563
default:
543564
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
544565
}
@@ -685,6 +706,18 @@ private static void WriteMessageType(JsonTextWriter writer, int type)
685706
writer.WriteValue(type);
686707
}
687708

709+
private static void WriteAckMessage(AckMessage message, JsonTextWriter writer)
710+
{
711+
writer.WritePropertyName(SequenceIdPropertyName);
712+
writer.WriteValue(message.SequenceId);
713+
}
714+
715+
private static void WriteSequenceMessage(SequenceMessage message, JsonTextWriter writer)
716+
{
717+
writer.WritePropertyName(SequenceIdPropertyName);
718+
writer.WriteValue(message.SequenceId);
719+
}
720+
688721
private static HubMessage BindCancelInvocationMessage(string? invocationId)
689722
{
690723
if (string.IsNullOrEmpty(invocationId))
@@ -839,6 +872,26 @@ private static CloseMessage BindCloseMessage(string? error, bool allowReconnect)
839872
return new CloseMessage(error, allowReconnect);
840873
}
841874

875+
private static AckMessage BindAckMessage(long? sequenceId)
876+
{
877+
if (sequenceId is null)
878+
{
879+
throw new InvalidDataException($"Missing required property '{SequenceIdPropertyName}'.");
880+
}
881+
882+
return new AckMessage(sequenceId.Value);
883+
}
884+
885+
private static SequenceMessage BindSequenceMessage(long? sequenceId)
886+
{
887+
if (sequenceId is null)
888+
{
889+
throw new InvalidDataException($"Missing required property '{SequenceIdPropertyName}'.");
890+
}
891+
892+
return new SequenceMessage(sequenceId.Value);
893+
}
894+
842895
private object?[] BindArguments(JArray args, IReadOnlyList<Type> paramTypes)
843896
{
844897
var paramCount = paramTypes.Count;

src/SignalR/common/Shared/JsonUtils.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,23 @@ public static bool ReadAsBoolean(JsonTextReader reader, string propertyName)
131131
return Convert.ToInt32(reader.Value, CultureInfo.InvariantCulture);
132132
}
133133

134+
public static long? ReadAsInt64(JsonTextReader reader, string propertyName)
135+
{
136+
reader.Read();
137+
138+
if (reader.TokenType != JsonToken.Integer)
139+
{
140+
throw new InvalidDataException($"Expected '{propertyName}' to be of type {JTokenType.Integer}.");
141+
}
142+
143+
if (reader.Value == null)
144+
{
145+
return null;
146+
}
147+
148+
return Convert.ToInt64(reader.Value, CultureInfo.InvariantCulture);
149+
}
150+
134151
public static string? ReadAsString(JsonTextReader reader, string propertyName)
135152
{
136153
reader.Read();

src/SignalR/common/SignalR.Common/src/Protocol/HubProtocolConstants.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ public static class HubProtocolConstants
4444
public const int CloseMessageType = 7;
4545

4646
/// <summary>
47-
///
47+
/// Represents the ack message type.
4848
/// </summary>
4949
public const int AckMessageType = 8;
5050

5151
/// <summary>
52-
///
52+
/// Represents the sequence message type.
5353
/// </summary>
5454
public const int SequenceMessageType = 9;
5555
}

0 commit comments

Comments
 (0)