Skip to content

Commit 0a73151

Browse files
committed
Propagate CancellationToken request cancellation to remote endpoint
1 parent 4dd2f42 commit 0a73151

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

src/ModelContextProtocol/McpEndpointExtensions.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ public static Task NotifyProgressAsync(
155155
{
156156
Throw.IfNull(endpoint);
157157

158-
return endpoint.SendMessageAsync(new JsonRpcNotification()
159-
{
160-
Method = NotificationMethods.ProgressNotification,
161-
Params = JsonSerializer.SerializeToNode(new ProgressNotification
158+
return endpoint.SendNotificationAsync(
159+
NotificationMethods.ProgressNotification,
160+
new ProgressNotification
162161
{
163162
ProgressToken = progressToken,
164163
Progress = progress,
165-
}, McpJsonUtilities.JsonContext.Default.ProgressNotification),
166-
}, cancellationToken);
164+
},
165+
McpJsonUtilities.JsonContext.Default.ProgressNotification,
166+
cancellationToken);
167167
}
168168
}

src/ModelContextProtocol/Shared/McpSession.cs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,24 @@ await _transport.SendMessageAsync(new JsonRpcResponse
296296
}, cancellationToken).ConfigureAwait(false);
297297
}
298298

299+
private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, RequestId requestId)
300+
{
301+
if (!cancellationToken.CanBeCanceled)
302+
{
303+
return default;
304+
}
305+
306+
return cancellationToken.Register(static objState =>
307+
{
308+
var state = (Tuple<McpSession, RequestId>)objState!;
309+
_ = state.Item1.SendMessageAsync(new JsonRpcNotification
310+
{
311+
Method = NotificationMethods.CancelledNotification,
312+
Params = JsonSerializer.SerializeToNode(new CancelledNotification { RequestId = state.Item2 }, McpJsonUtilities.JsonContext.Default.CancelledNotification)
313+
});
314+
}, Tuple.Create(this, requestId));
315+
}
316+
299317
public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler)
300318
{
301319
Throw.IfNullOrWhiteSpace(method);
@@ -357,9 +375,16 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
357375
_logger.SendingRequest(EndpointName, request.Method);
358376

359377
await _transport.SendMessageAsync(request, cancellationToken).ConfigureAwait(false);
360-
361378
_logger.RequestSentAwaitingResponse(EndpointName, request.Method, request.Id.ToString());
362-
var response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
379+
380+
// Now that the request has been sent, register for cancellation. If we registered before,
381+
// a cancellation request could arrive before the server knew about that request ID, in which
382+
// case the server could ignore it.
383+
IJsonRpcMessage? response;
384+
using (var registration = RegisterCancellation(cancellationToken, request.Id))
385+
{
386+
response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
387+
}
363388

364389
if (response is JsonRpcError error)
365390
{

0 commit comments

Comments
 (0)