Skip to content

Commit 756ceb7

Browse files
made sure to dispose of registration for cancellation tokens.
1 parent 4e5f954 commit 756ceb7

File tree

3 files changed

+11
-14
lines changed

3 files changed

+11
-14
lines changed

src/ModelContextProtocol/Client/IMcpClient.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using ModelContextProtocol.Protocol.Types;
2-
using ModelContextProtocol.Shared;
32

43
namespace ModelContextProtocol.Client;
54

src/ModelContextProtocol/Shared/McpSession.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,15 @@ await _transport.SendMessageAsync(new JsonRpcResponse
296296
}, cancellationToken).ConfigureAwait(false);
297297
}
298298

299-
private void RegisterCancellation(CancellationToken cancellationToken, RequestId requestId)
300-
{
301-
cancellationToken.Register(async () => await SendMessageAsync(new JsonRpcNotification
299+
private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, RequestId requestId)
300+
=> cancellationToken.Register(() => _ = SendMessageAsync(new JsonRpcNotification
302301
{
303302
Method = NotificationMethods.CancelledNotification,
304303
Params = JsonSerializer.SerializeToNode(new CancelledNotification
305304
{
306305
RequestId = requestId,
307306
}, McpJsonUtilities.JsonContext.Default.CancelledNotification)
308307
}));
309-
}
310308

311309
public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler)
312310
{
@@ -326,7 +324,7 @@ public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcN
326324
/// <returns>A task containing the server's response.</returns>
327325
public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken)
328326
{
329-
RegisterCancellation(cancellationToken, request.Id);
327+
using var registration = RegisterCancellation(cancellationToken, request.Id);
330328
if (!_transport.IsConnected)
331329
{
332330
_logger.EndpointNotConnected(EndpointName);
@@ -373,7 +371,7 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
373371

374372
_logger.RequestSentAwaitingResponse(EndpointName, request.Method, request.Id.ToString());
375373
var response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
376-
374+
377375
if (response is JsonRpcError error)
378376
{
379377
_logger.RequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code);
@@ -399,6 +397,7 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
399397
finally
400398
{
401399
_pendingRequests.TryRemove(request.Id, out _);
400+
await registration.DisposeAsync().ConfigureAwait(false);
402401
FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags);
403402
}
404403
}

tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ public async Task Can_Handle_Notify_Cancel()
386386
var token = TestContext.Current.CancellationToken;
387387
TaskCompletionSource<JsonRpcNotification> clientReceived = new();
388388
await using var client = await CreateMcpClientForServer(
389-
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, cancellationToken) =>
389+
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, _) =>
390390
{
391391
clientReceived.TrySetResult(notification);
392392
return clientReceived.Task;
@@ -408,7 +408,6 @@ await NotifyClientAsync(
408408

409409
// Assert
410410
Assert.NotNull(notification.Params);
411-
// Parse the Params string back to a CancelledNotification
412411
var cancelled = JsonSerializer.Deserialize<CancelledNotification>(notification.Params.ToString());
413412
Assert.NotNull(cancelled);
414413
Assert.Equal(rpcNotification.RequestId.ToString(), cancelled.RequestId.ToString());
@@ -422,7 +421,7 @@ public async Task Should_Not_Intercept_Sent_Notifications()
422421
var token = TestContext.Current.CancellationToken;
423422
TaskCompletionSource<JsonRpcNotification> clientReceived = new();
424423
await using var client = await CreateMcpClientForServer(
425-
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, cancellationToken) =>
424+
options: CreateClientOptions([new(NotificationMethods.CancelledNotification, (notification, _) =>
426425
{
427426
var exception = new InvalidOperationException("Should not intercept sent notifications");
428427
clientReceived.TrySetException(exception);
@@ -454,7 +453,7 @@ public async Task Can_Notify_Cancel()
454453
await using var client = await CreateMcpClientForServer(
455454
options: CreateClientOptions(new Dictionary<string, Func<JsonRpcNotification, CancellationToken, Task>>()
456455
{
457-
[NotificationMethods.CancelledNotification] = (notification, cancellationToken) =>
456+
[NotificationMethods.CancelledNotification] = (notification, _) =>
458457
{
459458
InvalidOperationException exception = new("Should not intercept sent notifications");
460459
clientReceived.TrySetException(exception);
@@ -491,7 +490,7 @@ private static McpClientOptions CreateClientOptions(
491490

492491
private async Task NotifyClientAsync(
493492
string message, object? parameters = null, CancellationToken token = default)
494-
=> await NotifyPipeAsync(_serverToClientPipe, message, parameters, token);
493+
=> await NotifyPipeAsync(_serverToClientPipe, message, parameters, token).ConfigureAwait(false);
495494
private async static Task NotifyPipeAsync(
496495
Pipe pipe, string message, object? parameters = null, CancellationToken token = default)
497496
{
@@ -500,7 +499,7 @@ private async static Task NotifyPipeAsync(
500499
Method = message,
501500
Params = parameters is not null ? JsonSerializer.Serialize(parameters) : null,
502501
});
503-
await pipe.Writer.WriteAsync(bytes, token);
504-
await pipe.Writer.CompleteAsync(); // Signal the end of the message
502+
await pipe.Writer.WriteAsync(bytes, token).ConfigureAwait(false);
503+
await pipe.Writer.CompleteAsync();
505504
}
506505
}

0 commit comments

Comments
 (0)