Skip to content

Commit 8b64d37

Browse files
committed
Implement cancellation notifications
1 parent 989e3c7 commit 8b64d37

File tree

4 files changed

+201
-47
lines changed

4 files changed

+201
-47
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System.Text.Json.Serialization;
2+
3+
namespace ModelContextProtocol.Protocol.Messages;
4+
5+
/// <summary>
6+
/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
7+
/// </summary>
8+
public sealed class CancelledNotification
9+
{
10+
/// <summary>
11+
/// The ID of the request to cancel.
12+
/// </summary>
13+
[JsonPropertyName("requestId")]
14+
public RequestId RequestId { get; set; }
15+
16+
/// <summary>
17+
/// An optional string describing the reason for the cancellation.
18+
/// </summary>
19+
[JsonPropertyName("reason")]
20+
public string? Reason { get; set; }
21+
}

src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs

Lines changed: 130 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ namespace ModelContextProtocol.Shared;
2222
/// </summary>
2323
internal abstract class McpJsonRpcEndpoint : IAsyncDisposable
2424
{
25+
/// <summary>
26+
/// In-flight request handling, indexed by request ID. The value provides a <see cref="CancellationTokenSource"/>
27+
/// that can be used to request cancellation of the in-flight handler.
28+
/// </summary>
29+
private static readonly ConcurrentDictionary<RequestId, CancellationTokenSource> s_handlingRequests = new();
30+
2531
private readonly string _id = Guid.NewGuid().ToString("N");
2632
private readonly ITransport _transport;
2733
private readonly ConcurrentDictionary<RequestId, TaskCompletionSource<IJsonRpcMessage>> _pendingRequests;
@@ -78,25 +84,69 @@ internal async Task ProcessMessagesAsync(CancellationToken cancellationToken)
7884
{
7985
_logger.TransportMessageRead(EndpointName, message.GetType().Name);
8086

81-
// Fire and forget the message handling task to avoid blocking the transport
82-
// If awaiting the task, the transport will not be able to read more messages,
83-
// which could lead to a deadlock if the handler sends a message back
8487
_ = ProcessMessageAsync();
8588
async Task ProcessMessageAsync()
8689
{
90+
IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
91+
CancellationTokenSource? combinedCts = null;
92+
try
93+
{
94+
// Register before we yield, so that the tracking is guaranteed to be there
95+
// when subsequent messages arrive, even if the asynchronous processing happens
96+
// out of order.
97+
if (messageWithId is not null)
98+
{
99+
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
100+
s_handlingRequests[messageWithId.Id] = combinedCts;
101+
}
102+
103+
// Fire and forget the message handling to avoid blocking the transport
104+
// If awaiting the task, the transport will not be able to read more messages,
105+
// which could lead to a deadlock if the handler sends a message back
87106
#if NET
88-
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
107+
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
89108
#else
90-
await default(ForceYielding);
109+
await default(ForceYielding);
91110
#endif
92-
try
93-
{
94-
await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false);
111+
112+
// Handle the message.
113+
await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false);
95114
}
96115
catch (Exception ex)
97116
{
98-
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
99-
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
117+
// Only send responses for request errors that aren't user-initiated cancellation.
118+
bool isUserCancellation =
119+
ex is OperationCanceledException &&
120+
!cancellationToken.IsCancellationRequested &&
121+
combinedCts?.IsCancellationRequested is true;
122+
123+
if (!isUserCancellation && message is JsonRpcRequest request)
124+
{
125+
_logger.RequestHandlerError(EndpointName, request.Method, ex);
126+
await _transport.SendMessageAsync(new JsonRpcError
127+
{
128+
Id = request.Id,
129+
JsonRpc = "2.0",
130+
Error = new JsonRpcErrorDetail
131+
{
132+
Code = ErrorCodes.InternalError,
133+
Message = ex.Message
134+
}
135+
}, cancellationToken).ConfigureAwait(false);
136+
}
137+
else if (ex is not OperationCanceledException)
138+
{
139+
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
140+
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
141+
}
142+
}
143+
finally
144+
{
145+
if (messageWithId is not null)
146+
{
147+
s_handlingRequests.TryRemove(messageWithId.Id, out _);
148+
combinedCts!.Dispose();
149+
}
100150
}
101151
}
102152
}
@@ -136,6 +186,24 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
136186

137187
private async Task HandleNotification(JsonRpcNotification notification)
138188
{
189+
// Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
190+
if (notification.Method == NotificationMethods.CancelledNotification)
191+
{
192+
try
193+
{
194+
if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
195+
s_handlingRequests.TryGetValue(cn.RequestId, out var cts))
196+
{
197+
await cts.CancelAsync().ConfigureAwait(false);
198+
}
199+
}
200+
catch
201+
{
202+
// "Invalid cancellation notifications SHOULD be ignored"
203+
}
204+
}
205+
206+
// Handle user-defined notifications.
139207
if (_notificationHandlers.TryGetValue(notification.Method, out var handlers))
140208
{
141209
foreach (var notificationHandler in handlers)
@@ -170,33 +238,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
170238
{
171239
if (_requestHandlers.TryGetValue(request.Method, out var handler))
172240
{
173-
try
174-
{
175-
_logger.RequestHandlerCalled(EndpointName, request.Method);
176-
var result = await handler(request, cancellationToken).ConfigureAwait(false);
177-
_logger.RequestHandlerCompleted(EndpointName, request.Method);
178-
await _transport.SendMessageAsync(new JsonRpcResponse
179-
{
180-
Id = request.Id,
181-
JsonRpc = "2.0",
182-
Result = result
183-
}, cancellationToken).ConfigureAwait(false);
184-
}
185-
catch (Exception ex)
241+
_logger.RequestHandlerCalled(EndpointName, request.Method);
242+
var result = await handler(request, cancellationToken).ConfigureAwait(false);
243+
_logger.RequestHandlerCompleted(EndpointName, request.Method);
244+
await _transport.SendMessageAsync(new JsonRpcResponse
186245
{
187-
_logger.RequestHandlerError(EndpointName, request.Method, ex);
188-
// Send error response
189-
await _transport.SendMessageAsync(new JsonRpcError
190-
{
191-
Id = request.Id,
192-
JsonRpc = "2.0",
193-
Error = new JsonRpcErrorDetail
194-
{
195-
Code = -32000, // Implementation defined error
196-
Message = ex.Message
197-
}
198-
}, cancellationToken).ConfigureAwait(false);
199-
}
246+
Id = request.Id,
247+
JsonRpc = "2.0",
248+
Result = result
249+
}, cancellationToken).ConfigureAwait(false);
200250
}
201251
else
202252
{
@@ -221,8 +271,11 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
221271
throw new McpClientException("Transport is not connected");
222272
}
223273

224-
// Set request ID
225-
request.Id = new($"{_id}-{Interlocked.Increment(ref _nextRequestId)}");
274+
// Set request ID if it's not already set to a valid identifier.
275+
if (request.Id.IsDefault)
276+
{
277+
request.Id = new($"{_id}-{Interlocked.Increment(ref _nextRequestId)}");
278+
}
226279

227280
var tcs = new TaskCompletionSource<IJsonRpcMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
228281
_pendingRequests[request.Id] = tcs;
@@ -279,7 +332,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
279332
}
280333
}
281334

282-
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
335+
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
283336
{
284337
Throw.IfNull(message);
285338

@@ -294,7 +347,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
294347
_logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>()));
295348
}
296349

297-
return _transport.SendMessageAsync(message, cancellationToken);
350+
await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
351+
352+
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
353+
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
354+
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
355+
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
356+
GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
357+
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
358+
{
359+
tcs.TrySetCanceled(default);
360+
}
361+
}
362+
363+
private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams)
364+
{
365+
try
366+
{
367+
switch (notificationParams)
368+
{
369+
case null:
370+
return null;
371+
372+
case CancelledNotification cn:
373+
return cn;
374+
375+
case JsonElement je:
376+
return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
377+
378+
default:
379+
return JsonSerializer.Deserialize(
380+
JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo<object?>()),
381+
McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
382+
}
383+
}
384+
catch
385+
{
386+
return null;
387+
}
298388
}
299389

300390
/// <summary>

src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
121121
// MCP Request Params / Results
122122
[JsonSerializable(typeof(CallToolRequestParams))]
123123
[JsonSerializable(typeof(CallToolResponse))]
124+
[JsonSerializable(typeof(CancelledNotification))]
124125
[JsonSerializable(typeof(CompleteRequestParams))]
125126
[JsonSerializable(typeof(CompleteResult))]
126127
[JsonSerializable(typeof(CreateMessageRequestParams))]

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@
1010
using ModelContextProtocol.Server;
1111
using ModelContextProtocol.Tests.Transport;
1212
using ModelContextProtocol.Tests.Utils;
13-
using System;
1413
using System.Collections.Concurrent;
1514
using System.ComponentModel;
1615
using System.IO.Pipelines;
1716
using System.Text.Json;
1817
using System.Text.RegularExpressions;
19-
using System.Threading;
2018
using System.Threading.Channels;
2119

2220
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
@@ -85,7 +83,7 @@ public async Task Can_List_Registered_Tools()
8583
IMcpClient client = await CreateMcpClientForServer();
8684

8785
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
88-
Assert.Equal(12, tools.Count);
86+
Assert.Equal(13, tools.Count);
8987

9088
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
9189
Assert.Equal("Echo", echoTool.Name);
@@ -135,7 +133,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
135133
cancellationToken: TestContext.Current.CancellationToken);
136134

137135
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
138-
Assert.Equal(12, tools.Count);
136+
Assert.Equal(13, tools.Count);
139137

140138
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
141139
Assert.Equal("Echo", echoTool.Name);
@@ -163,7 +161,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
163161
IMcpClient client = await CreateMcpClientForServer();
164162

165163
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
166-
Assert.Equal(12, tools.Count);
164+
Assert.Equal(13, tools.Count);
167165

168166
Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
169167
client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification =>
@@ -183,7 +181,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
183181
await notificationRead;
184182

185183
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
186-
Assert.Equal(13, tools.Count);
184+
Assert.Equal(14, tools.Count);
187185
Assert.Contains(tools, t => t.Name == "NewTool");
188186

189187
notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
@@ -192,7 +190,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
192190
await notificationRead;
193191

194192
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
195-
Assert.Equal(12, tools.Count);
193+
Assert.Equal(13, tools.Count);
196194
Assert.DoesNotContain(tools, t => t.Name == "NewTool");
197195
}
198196

@@ -508,6 +506,35 @@ public async Task HandlesIProgressParameter()
508506
}
509507
}
510508

509+
[Fact]
510+
public async Task CancellationNotificationsPropagateToToolTokens()
511+
{
512+
IMcpClient client = await CreateMcpClientForServer();
513+
514+
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
515+
Assert.NotNull(tools);
516+
Assert.NotEmpty(tools);
517+
McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation));
518+
519+
var requestId = new RequestId(Guid.NewGuid().ToString());
520+
var invokeTask = client.SendRequestAsync<CallToolResponse>(new JsonRpcRequest()
521+
{
522+
Method = RequestMethods.ToolsCall,
523+
Id = requestId,
524+
Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name },
525+
}, TestContext.Current.CancellationToken);
526+
527+
await client.SendNotificationAsync(
528+
NotificationMethods.CancelledNotification,
529+
parameters: new CancelledNotification()
530+
{
531+
RequestId = requestId,
532+
},
533+
cancellationToken: TestContext.Current.CancellationToken);
534+
535+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => invokeTask);
536+
}
537+
511538
[McpServerToolType]
512539
public sealed class EchoTool(ObjectWithId objectFromDI)
513540
{
@@ -573,6 +600,21 @@ public static string EchoComplex(ComplexObject complex)
573600
return complex.Name!;
574601
}
575602

603+
[McpServerTool]
604+
public static async Task<string> InfiniteCancelableOperation(CancellationToken cancellationToken)
605+
{
606+
try
607+
{
608+
await Task.Delay(Timeout.Infinite, cancellationToken);
609+
}
610+
catch (Exception)
611+
{
612+
return "canceled";
613+
}
614+
615+
return "unreachable";
616+
}
617+
576618
[McpServerTool]
577619
public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}";
578620

0 commit comments

Comments
 (0)