Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,37 @@ static void SetServerTags(Activity? activity, Uri? uri)

private async Task<object?> InvokeCoreAsyncCore(string methodName, Type returnType, object?[] args, CancellationToken cancellationToken)
{
async Task OnInvocationCanceled(InvocationRequest irq)
{
// We need to take the connection lock in order to ensure we a) have a connection and b) are the only one accessing the write end of the pipe.
await _state.WaitConnectionLockAsync(token: default).ConfigureAwait(false);
try
{
if (_state.CurrentConnectionStateUnsynchronized != null)
{
Log.SendingCancellation(_logger, irq.InvocationId);

// Don't pass irq.CancellationToken, that would result in canceling the Flush and a delayed CancelInvocationMessage being sent.
await SendHubMessage(_state.CurrentConnectionStateUnsynchronized, new CancelInvocationMessage(irq.InvocationId), cancellationToken: default).ConfigureAwait(false);
}
else
{
Log.UnableToSendCancellation(_logger, irq.InvocationId);
}
}
catch
{
// Connection closed while trying to cancel an invocation. This is fine to ignore.
}
finally
{
_state.ReleaseConnectionLock();
}

// Cancel the invocation
irq.Dispose();
}

var readers = default(Dictionary<string, object>);

CheckDisposed();
Expand All @@ -1094,6 +1125,11 @@ static void SetServerTags(Activity? activity, Uri? uri)
var irq = InvocationRequest.Invoke(cancellationToken, returnType, connectionState.GetNextId(), _loggerFactory, this, activity, out invocationTask);
await InvokeCore(connectionState, methodName, irq, args, streamIds?.ToArray(), cancellationToken).ConfigureAwait(false);

if (cancellationToken.CanBeCanceled)
{
cancellationToken.Register(state => _ = OnInvocationCanceled((InvocationRequest)state!), irq);
}

LaunchStreams(connectionState, readers, cancellationToken);
}
finally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,35 @@ await connection.ReceiveJsonMessage(
}
}

[Fact]
public async Task CanCancelTokenDuringInvoke_SendsCancelInvocation()
{
using (StartVerifiableLog())
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);

await hubConnection.StartAsync().DefaultTimeout();

using var cts = new CancellationTokenSource();
var invokeTask = hubConnection.InvokeAsync<int>("TestMethod", cts.Token);

var item = await connection.ReadSentJsonAsync().DefaultTimeout();
var invocationId = item["invocationId"];

// Cancel the invocation
cts.Cancel();

// Should receive CancelInvocationMessage
item = await connection.ReadSentJsonAsync().DefaultTimeout();
Assert.Equal(HubProtocolConstants.CancelInvocationMessageType, item["type"]);
Assert.Equal(invocationId, item["invocationId"]);

// Invocation on client-side completes with cancellation
await Assert.ThrowsAsync<TaskCanceledException>(async () => await invokeTask).DefaultTimeout();
}
}

[Fact]
public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessages()
{
Expand Down
33 changes: 26 additions & 7 deletions src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,17 @@ public override Task DispatchMessageAsync(HubConnectionContext connection, HubMe
return ProcessInvocation(connection, streamInvocationMessage, isStreamResponse: true);

case CancelInvocationMessage cancelInvocationMessage:
// Check if there is an associated active stream and cancel it if it exists.
// The cts will be removed when the streaming method completes executing
// Check if there is an associated active invocation or stream and cancel it if it exists.
// The cts will be removed when the hub method completes executing
if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId!, out var cts))
{
Log.CancelStream(_logger, cancelInvocationMessage.InvocationId!);
Log.CancelInvocation(_logger, cancelInvocationMessage.InvocationId!);
cts.Cancel();
}
else
{
// Stream can be canceled on the server while client is canceling stream.
Log.UnexpectedCancel(_logger);
// Invocation can be canceled on the server while client is canceling invocation.
Log.UnexpectedCancel(_logger, cancelInvocationMessage.InvocationId!);
}
break;

Expand Down Expand Up @@ -390,7 +390,8 @@ static async Task ExecuteInvocation(DefaultHubDispatcher<THub> dispatcher,
IHubActivator<THub> hubActivator,
HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage,
bool isStreamCall)
bool isStreamCall,
CancellationTokenSource? cts)
{
var logger = dispatcher._logger;
var enableDetailedErrors = dispatcher._enableDetailedErrors;
Expand All @@ -406,6 +407,18 @@ static async Task ExecuteInvocation(DefaultHubDispatcher<THub> dispatcher,
// We want to take HubMethodNameAttribute into account which will be the same as what the invocation target is
var activity = StartActivity(SignalRServerActivitySource.InvocationIn, ActivityKind.Server, connection.OriginalActivity, scope.ServiceProvider, hubMethodInvocationMessage.Target, hubMethodInvocationMessage.Headers, logger);

// Register the CancellationTokenSource if present so CancelInvocationMessage can cancel it
if (cts != null && !string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
if (!connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId!, cts))
{
Log.InvocationIdInUse(logger, hubMethodInvocationMessage.InvocationId);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"Invocation ID '{hubMethodInvocationMessage.InvocationId}' is already in use.");
return;
}
}

object? result;
try
{
Expand All @@ -430,6 +443,12 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
Activity.Current = previousActivity;
}

// Remove the CancellationTokenSource from active requests
if (cts != null && !string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
connection.ActiveRequestCancellationSources.TryRemove(hubMethodInvocationMessage.InvocationId!, out _);
}

// Stream response handles cleanup in StreamResultsAsync
// And normal invocations handle cleanup below in the finally
if (isStreamCall)
Expand All @@ -446,7 +465,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
}
}

invocation = ExecuteInvocation(this, methodExecutor, hub, arguments, scope, hubActivator, connection, hubMethodInvocationMessage, isStreamCall);
invocation = ExecuteInvocation(this, methodExecutor, hub, arguments, scope, hubActivator, connection, hubMethodInvocationMessage, isStreamCall, cts);
}

if (isStreamCall || isStreamResponse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ public static void SendingResult(ILogger logger, string? invocationId, ObjectMet
[LoggerMessage(9, LogLevel.Trace, "'{HubName}' hub method '{HubMethod}' is bound.", EventName = "HubMethodBound")]
public static partial void HubMethodBound(ILogger logger, string hubName, string hubMethod);

[LoggerMessage(10, LogLevel.Debug, "Canceling stream for invocation {InvocationId}.", EventName = "CancelStream")]
public static partial void CancelStream(ILogger logger, string invocationId);
[LoggerMessage(10, LogLevel.Debug, "Canceling invocation {InvocationId}.", EventName = "CancelInvocation")]
public static partial void CancelInvocation(ILogger logger, string invocationId);

[LoggerMessage(11, LogLevel.Debug, "CancelInvocationMessage received unexpectedly.", EventName = "UnexpectedCancel")]
public static partial void UnexpectedCancel(ILogger logger);
[LoggerMessage(11, LogLevel.Debug, "CancelInvocationMessage received for {InvocationId} but invocation was not found.", EventName = "UnexpectedCancel")]
public static partial void UnexpectedCancel(ILogger logger, string invocationId);

[LoggerMessage(12, LogLevel.Debug, "Received stream hub invocation: {InvocationMessage}.", EventName = "ReceivedStreamHubInvocation")]
public static partial void ReceivedStreamHubInvocation(ILogger logger, StreamInvocationMessage invocationMessage);
Expand Down
4 changes: 2 additions & 2 deletions src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
ParameterTypes = methodExecutor.MethodParameters.Where((p, index) =>
{
// Only streams can take CancellationTokens currently
if (IsStreamResponse && p.ParameterType == typeof(CancellationToken))
// CancellationTokens are synthetic arguments provided by the server
if (p.ParameterType == typeof(CancellationToken))
{
HasSyntheticArguments = true;
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,15 @@ public async Task<int> LongRunningMethod()
return 12;
}

public async Task<int> CancelableInvocation(CancellationToken token)
{
_tcsService.StartedMethod.SetResult(null);
// Wait for cancellation. Test timeout is enforced by .DefaultTimeout() in the test itself.
await token.WaitForCancellationAsync();
_tcsService.EndMethod.SetResult(null);
return 42;
}

public async Task<ChannelReader<string>> LongRunningStream()
{
_tcsService.StartedMethod.TrySetResult(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4376,22 +4376,35 @@ public async Task StreamHubMethodCanAcceptNullableParameterWithCancellationToken
}

[Fact]
public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument()
public async Task InvokeHubMethodCanAcceptCancellationTokenAsArgument()
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();

using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout();

var invocationId = await client.SendInvocationAsync(nameof(MethodHub.InvalidArgument)).DefaultTimeout();
var invocationId = await client.SendInvocationAsync(nameof(LongRunningHub.CancelableInvocation)).DefaultTimeout();
// Wait for the hub method to start
await tcsService.StartedMethod.Task.DefaultTimeout();

var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().DefaultTimeout());
// Cancel the invocation which should trigger the CancellationToken in the hub method
await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).DefaultTimeout();

var result = await client.ReadAsync().DefaultTimeout();

Assert.Equal("Failed to invoke 'InvalidArgument' due to an error on the server.", completion.Error);
var completion = Assert.IsType<CompletionMessage>(result);
Assert.Null(completion.Error);

// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
await tcsService.EndMethod.Task.DefaultTimeout();

client.Dispose();

Expand Down
Loading