Skip to content

Commit 46cff23

Browse files
Refactor notification methods to use SendNotificationAsync and remove IMcpSession interface
1 parent 4f3a74e commit 46cff23

File tree

5 files changed

+69
-187
lines changed

5 files changed

+69
-187
lines changed

src/ModelContextProtocol/McpEndpointExtensions.cs

Lines changed: 6 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -155,85 +155,14 @@ public static Task NotifyProgressAsync(
155155
{
156156
Throw.IfNull(endpoint);
157157

158-
return endpoint.SendMessageAsync(new JsonRpcNotification()
159-
{
160-
Method = NotificationMethods.ProgressNotification,
161-
Params = JsonSerializer.SerializeToNode(new ProgressNotification
158+
return endpoint.SendNotificationAsync(
159+
method: NotificationMethods.ProgressNotification,
160+
parameters: new ProgressNotification
162161
{
163162
ProgressToken = progressToken,
164163
Progress = progress,
165-
}, McpJsonUtilities.JsonContext.Default.ProgressNotification),
166-
}, cancellationToken);
167-
}
168-
169-
170-
/// <summary>
171-
/// Notifies the connected endpoint of an event.
172-
/// </summary>
173-
/// <param name="endpoint">The endpoint issuing the notification.</param>
174-
/// <param name="notification">The notification to send.</param>
175-
/// <param name="cancellationToken">A token to cancel the operation.</param>
176-
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
177-
/// <returns>A task representing the completion of the operation.</returns>
178-
public static Task NotifyAsync(
179-
this IMcpEndpoint endpoint,
180-
JsonRpcNotification notification,
181-
CancellationToken cancellationToken = default)
182-
{
183-
Throw.IfNull(endpoint);
184-
185-
return endpoint.SendMessageAsync(notification, cancellationToken);
186-
}
187-
188-
/// <summary>
189-
/// Notifies the connected endpoint of an event.
190-
/// </summary>
191-
/// <param name="endpoint">The endpoint issuing the notification.</param>
192-
/// <param name="method">The method to call.</param>
193-
/// <param name="parameters">The parameters to send.</param>
194-
/// <param name="cancellationToken">A token to cancel the operation.</param>
195-
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
196-
/// <returns>A task representing the completion of the operation.</returns>
197-
public static Task NotifyAsync(
198-
this IMcpEndpoint endpoint,
199-
string method,
200-
JsonNode? parameters = null,
201-
CancellationToken cancellationToken = default)
202-
{
203-
Throw.IfNull(endpoint);
204-
205-
return endpoint.NotifyAsync(new()
206-
{
207-
Method = method,
208-
Params = parameters,
209-
}, cancellationToken);
210-
}
211-
212-
/// <summary>
213-
/// Notifies the connected endpoint that a request has been cancelled.
214-
/// </summary>
215-
/// <param name="endpoint">The endpoint issuing the notification.</param>
216-
/// <param name="requestId">The ID of the request to cancel.</param>
217-
/// <param name="reason">An optional reason for the cancellation.</param>
218-
/// <param name="cancellationToken">A token to cancel the operation.</param>
219-
/// <returns>A task representing the completion of the operation.</returns>
220-
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
221-
public static Task NotifyCancelAsync(
222-
this IMcpEndpoint endpoint,
223-
RequestId requestId,
224-
string? reason = null,
225-
CancellationToken cancellationToken = default)
226-
{
227-
Throw.IfNull(endpoint);
228-
229-
return endpoint.SendMessageAsync(new JsonRpcNotification()
230-
{
231-
Method = NotificationMethods.CancelledNotification,
232-
Params = JsonSerializer.SerializeToNode(new CancelledNotification
233-
{
234-
RequestId = requestId,
235-
Reason = reason,
236-
}, McpJsonUtilities.JsonContext.Default.CancelledNotification),
237-
}, cancellationToken);
164+
},
165+
McpJsonUtilities.JsonContext.Default.ProgressNotification,
166+
cancellationToken: cancellationToken);
238167
}
239168
}

src/ModelContextProtocol/Shared/IMcpSession.cs

Lines changed: 0 additions & 26 deletions
This file was deleted.

tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,13 @@ public async Task Should_Not_Intercept_Sent_Notifications()
433433
cancellationToken: token);
434434

435435
// Act
436-
await client.NotifyCancelAsync(
437-
requestId: new("abc"),
438-
reason: "Cancelled",
439-
cancellationToken: token);
436+
await client.SendNotificationAsync(
437+
method: NotificationMethods.CancelledNotification,
438+
parameters: new CancelledNotification
439+
{
440+
RequestId = new("abc"),
441+
Reason = "Cancelled",
442+
}, cancellationToken: token);
440443
await Assert.ThrowsAsync<TimeoutException>(
441444
async () => await clientReceived.Task
442445
.WaitAsync(TimeSpan.FromSeconds(5), token));
@@ -449,35 +452,36 @@ public async Task Can_Notify_Cancel()
449452
{
450453
// Arrange
451454
var token = TestContext.Current.CancellationToken;
455+
TaskCompletionSource clientReceived = new();
452456
await using var client = await CreateMcpClientForServer(
453-
cancellationToken: token);
457+
options: CreateClientOptions(new Dictionary<string, Func<JsonRpcNotification, Task>>()
458+
{
459+
[NotificationMethods.CancelledNotification] = notification =>
460+
{
461+
InvalidOperationException exception = new("Should not intercept sent notifications");
462+
clientReceived.TrySetException(exception);
463+
return clientReceived.Task;
464+
}
465+
}), cancellationToken: token);
454466
RequestId expectedRequestId = new("abc");
455467
var expectedReason = "Cancelled";
456468

457469
// Act
458-
await client.NotifyCancelAsync(
459-
requestId: expectedRequestId,
460-
reason: expectedReason,
461-
cancellationToken: token);
470+
await client.SendNotificationAsync(
471+
method: NotificationMethods.CancelledNotification,
472+
parameters: new CancelledNotification
473+
{
474+
RequestId = expectedRequestId,
475+
Reason = expectedReason,
476+
}, cancellationToken: token);
462477

463478
// Assert
464-
var result = await _clientToServerPipe.Reader.ReadAsync(token);
465-
var copyBytes = new byte[result.Buffer.Length];
466-
result.Buffer.CopyTo(copyBytes);
467-
Utf8JsonReader reader = new(copyBytes);
468-
var jsonRpcNotification = JsonSerializer.Deserialize<JsonRpcNotification>(ref reader);
469-
Assert.NotNull(jsonRpcNotification);
470-
Assert.Equal(NotificationMethods.CancelledNotification, jsonRpcNotification.Method);
471-
472-
var parameters = jsonRpcNotification.Params;
473-
Assert.NotNull(parameters);
474-
var cancelledNotification = JsonSerializer.Deserialize<CancelledNotification>(parameters.ToString());
475-
Assert.NotNull(cancelledNotification);
476-
Assert.Equal(expectedRequestId.ToString(), cancelledNotification.RequestId.ToString());
477-
Assert.Equal(expectedReason, cancelledNotification.Reason);
479+
await Assert.ThrowsAsync<TimeoutException>(
480+
async () => await clientReceived.Task
481+
.WaitAsync(TimeSpan.FromSeconds(3), token));
478482
}
479483

480-
private McpClientOptions CreateClientOptions(
484+
private static McpClientOptions CreateClientOptions(
481485
IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? notificationHandlers = null)
482486
=> new()
483487
{

tests/ModelContextProtocol.Tests/Server/McpServerTests.cs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.Extensions.AI;
22
using Microsoft.Extensions.DependencyInjection;
3+
using Microsoft.Extensions.Logging;
34
using ModelContextProtocol.Protocol.Messages;
45
using ModelContextProtocol.Protocol.Types;
56
using ModelContextProtocol.Server;
@@ -692,10 +693,11 @@ public async Task NotifyCancel_Should_Be_Handled()
692693
// Arrange
693694
TaskCompletionSource<JsonRpcNotification> notificationReceived = new();
694695
TaskCompletionSource notificationIntercepted = new();
695-
await using TestServerTransport transport = new();
696+
await using TestServerTransport transport = new(LoggerFactory);
696697
transport.OnMessageSent = (message) =>
697698
{
698-
if (message is JsonRpcNotification notification && notification.Method == NotificationMethods.CancelledNotification)
699+
if (message is JsonRpcNotification notification
700+
&& notification.Method == NotificationMethods.CancelledNotification)
699701
notificationReceived.TrySetResult(notification);
700702
};
701703
var options = CreateOptions(new()
@@ -712,10 +714,13 @@ public async Task NotifyCancel_Should_Be_Handled()
712714
// Act
713715
var token = TestContext.Current.CancellationToken;
714716
Task serverTask = server.RunAsync(token);
715-
await server.NotifyCancelAsync(
716-
requestId: new("abc"),
717-
reason: "Cancelled",
718-
cancellationToken: token);
717+
await server.SendNotificationAsync(
718+
NotificationMethods.CancelledNotification,
719+
new CancelledNotification
720+
{
721+
RequestId = new("abc"),
722+
Reason = "Cancelled",
723+
}, cancellationToken: token);
719724
await server.DisposeAsync();
720725
await serverTask.WaitAsync(TimeSpan.FromSeconds(1), token);
721726
var notification = await notificationReceived.Task.WaitAsync(TimeSpan.FromSeconds(1), token);
Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,32 @@
1-
using ModelContextProtocol.Protocol.Messages;
1+
using Microsoft.Extensions.Logging;
2+
using ModelContextProtocol.Protocol.Messages;
23
using ModelContextProtocol.Protocol.Transport;
34
using ModelContextProtocol.Protocol.Types;
45
using System.Text.Json;
5-
using System.Threading.Channels;
66

77
namespace ModelContextProtocol.Tests.Utils;
88

9-
public class TestServerTransport : ITransport
9+
public class TestServerTransport : TransportBase
1010
{
11-
private readonly Channel<IJsonRpcMessage> _messageChannel;
12-
13-
public bool IsConnected { get; set; }
14-
15-
public ChannelReader<IJsonRpcMessage> MessageReader => _messageChannel;
16-
1711
public List<IJsonRpcMessage> SentMessages { get; } = [];
1812

1913
public Action<IJsonRpcMessage>? OnMessageSent { get; set; }
2014

21-
public TestServerTransport()
15+
public TestServerTransport(ILoggerFactory? loggerFactory = null)
16+
: base(loggerFactory)
2217
{
23-
_messageChannel = Channel.CreateUnbounded<IJsonRpcMessage>(new UnboundedChannelOptions
24-
{
25-
SingleReader = true,
26-
SingleWriter = true,
27-
});
28-
IsConnected = true;
18+
SetConnected(true);
2919
}
3020

31-
public ValueTask DisposeAsync()
21+
public override ValueTask DisposeAsync()
3222
{
33-
_messageChannel.Writer.TryComplete();
34-
IsConnected = false;
23+
GC.SuppressFinalize(this);
24+
SentMessages.Clear();
25+
SetConnected(false);
3526
return ValueTask.CompletedTask;
3627
}
3728

38-
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
29+
public async override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
3930
{
4031
SentMessages.Add(message);
4132
if (message is JsonRpcRequest request)
@@ -55,45 +46,24 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
5546
OnMessageSent?.Invoke(message);
5647
}
5748

58-
private async Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken)
59-
{
60-
await WriteMessageAsync(new JsonRpcResponse
49+
private Task ListRoots(JsonRpcRequest request, CancellationToken cancellationToken)
50+
=> WriteMessageAsync(request.Id, new ListRootsResult
6151
{
62-
Id = request.Id,
63-
Result = JsonSerializer.SerializeToNode(new ListRootsResult
64-
{
65-
Roots = []
66-
}),
52+
Roots = []
6753
}, cancellationToken);
68-
}
6954

70-
private async Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken)
71-
{
72-
await WriteMessageAsync(new JsonRpcResponse
55+
private Task Sampling(JsonRpcRequest request, CancellationToken cancellationToken)
56+
=> WriteMessageAsync(request.Id, new CreateMessageResult
7357
{
74-
Id = request.Id,
75-
Result = JsonSerializer.SerializeToNode(new CreateMessageResult { Content = new(), Model = "model", Role = "role" }),
58+
Content = new(),
59+
Model = "model",
60+
Role = "role"
7661
}, cancellationToken);
77-
}
7862

79-
private async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
80-
{
81-
await _messageChannel.Writer.WriteAsync(message, cancellationToken);
82-
}
83-
84-
public async Task NotifyCancelAsync(
85-
RequestId id, string? reason = null, CancellationToken cancellationToken = default)
86-
{
87-
JsonRpcNotification notification = new()
63+
private Task WriteMessageAsync<T>(RequestId id, T message, CancellationToken cancellationToken)
64+
=> WriteMessageAsync(new JsonRpcResponse
8865
{
89-
Method = NotificationMethods.CancelledNotification,
90-
Params = JsonSerializer.SerializeToNode(new CancelledNotification
91-
{
92-
RequestId = id,
93-
Reason = reason,
94-
}),
95-
};
96-
97-
await WriteMessageAsync(notification, cancellationToken);
98-
}
66+
Id = id,
67+
Result = JsonSerializer.SerializeToNode(message),
68+
}, cancellationToken);
9969
}

0 commit comments

Comments
 (0)