Skip to content

Commit 1e0755c

Browse files
Refactor MCP interfaces and implement session management; remove IMcpEndpoint and enhance notification handling
1 parent c5e37fb commit 1e0755c

File tree

10 files changed

+155
-103
lines changed

10 files changed

+155
-103
lines changed

src/ModelContextProtocol/Client/IMcpClient.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
using ModelContextProtocol.Protocol.Types;
2+
using ModelContextProtocol.Shared;
3+
using ModelContextProtocol.Protocol.Messages;
24

35
namespace ModelContextProtocol.Client;
46

57
/// <summary>
68
/// Represents an instance of an MCP client connecting to a specific server.
79
/// </summary>
8-
public interface IMcpClient : IMcpEndpoint
10+
public interface IMcpClient : IMcpSession
911
{
1012
/// <summary>
1113
/// Gets the capabilities supported by the server.
@@ -23,4 +25,21 @@ public interface IMcpClient : IMcpEndpoint
2325
/// It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt.
2426
/// </summary>
2527
string? ServerInstructions { get; }
28+
29+
30+
/// <summary>
31+
/// Adds a handler for server notifications of a specific method.
32+
/// </summary>
33+
/// <param name="method">The notification method to handle.</param>
34+
/// <param name="handler">The async handler function to process notifications.</param>
35+
/// <remarks>
36+
/// <para>
37+
/// Each method may have multiple handlers. Adding a handler for a method that already has one
38+
/// will not replace the existing handler.
39+
/// </para>
40+
/// <para>
41+
/// <see cref="NotificationMethods"> provides constants for common notification methods.</see>
42+
/// </para>
43+
/// </remarks>
44+
void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler);
2645
}

src/ModelContextProtocol/IMcpEndpoint.cs

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/ModelContextProtocol/Server/IMcpServer.cs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
using ModelContextProtocol.Protocol.Types;
2+
using ModelContextProtocol.Shared;
3+
using ModelContextProtocol.Protocol.Messages;
24

35
namespace ModelContextProtocol.Server;
46

57
/// <summary>
68
/// Represents a server that can communicate with a client using the MCP protocol.
79
/// </summary>
8-
public interface IMcpServer : IMcpEndpoint
10+
public interface IMcpServer : IMcpSession
911
{
1012
/// <summary>
1113
/// Gets the capabilities supported by the client.
@@ -29,4 +31,20 @@ public interface IMcpServer : IMcpEndpoint
2931
/// Runs the server, listening for and handling client requests.
3032
/// </summary>
3133
Task RunAsync(CancellationToken cancellationToken = default);
34+
35+
/// <summary>
36+
/// Adds a handler for server notifications of a specific method.
37+
/// </summary>
38+
/// <param name="method">The notification method to handle.</param>
39+
/// <param name="handler">The async handler function to process notifications.</param>
40+
/// <remarks>
41+
/// <para>
42+
/// Each method may have multiple handlers. Adding a handler for a method that already has one
43+
/// will not replace the existing handler.
44+
/// </para>
45+
/// <para>
46+
/// <see cref="NotificationMethods"> provides constants for common notification methods.</see>
47+
/// </para>
48+
/// </remarks>
49+
void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler);
3250
}

src/ModelContextProtocol/Shared/IMcpSession.cs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,8 @@ namespace ModelContextProtocol.Shared;
55
/// <summary>
66
/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers.
77
/// </summary>
8-
public interface IMcpSession : IDisposable
8+
public interface IMcpSession : IAsyncDisposable
99
{
10-
/// <summary>
11-
/// The name of the endpoint for logging and debug purposes.
12-
/// </summary>
13-
string EndpointName { get; set; }
14-
15-
/// <summary>
16-
/// Starts processing messages from the transport. This method will block until the transport is disconnected.
17-
/// This is generally started in a background task or thread from the initialization logic of the derived class.
18-
/// </summary>
19-
Task ProcessMessagesAsync(CancellationToken cancellationToken);
20-
2110
/// <summary>
2211
/// Sends a generic JSON-RPC request to the server.
2312
/// </summary>

src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace ModelContextProtocol.Shared;
1515
/// This is especially true as a client represents a connection to one and only one server, and vice versa.
1616
/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction.
1717
/// </summary>
18-
public abstract class McpJsonRpcEndpoint : IMcpEndpoint, IAsyncDisposable
18+
internal abstract class McpJsonRpcEndpoint : IMcpSession
1919
{
2020
private readonly RequestHandlers _requestHandlers = [];
2121
private readonly NotificationHandlers _notificationHandlers = [];
@@ -66,21 +66,8 @@ public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task
6666
/// <param name="request">The request instance</param>
6767
/// <param name="cancellationToken">The token for cancellation.</param>
6868
/// <returns>The MCP response.</returns>
69-
public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class
70-
{
71-
using var registration = cancellationToken.Register(async () =>
72-
{
73-
try
74-
{
75-
await this.NotifyCancelAsync(request.Id).ConfigureAwait(false);
76-
}
77-
catch (Exception ex)
78-
{
79-
_logger.LogError(ex, "An error occurred while notifying cancellation for request {RequestId}.", request.Id);
80-
}
81-
});
82-
return await GetSessionOrThrow().SendRequestAsync<TResult>(request, cancellationToken);
83-
}
69+
public Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken = default) where TResult : class
70+
=> GetSessionOrThrow().SendRequestAsync<TResult>(request, cancellationToken);
8471

8572
/// <summary>
8673
/// Sends a notification over the protocol.
@@ -166,7 +153,8 @@ public virtual async ValueTask DisposeUnsynchronizedAsync()
166153
}
167154
finally
168155
{
169-
_session?.Dispose();
156+
var valueTask = _session?.DisposeAsync().ConfigureAwait(false);
157+
if (valueTask is not null) await valueTask.Value;
170158
_sessionCts?.Dispose();
171159
}
172160

@@ -178,6 +166,6 @@ public virtual async ValueTask DisposeUnsynchronizedAsync()
178166
/// </summary>
179167
/// <returns>The current session.</returns>
180168
/// <exception cref="InvalidOperationException">Thrown if the session is not started.</exception>
181-
protected IMcpSession GetSessionOrThrow()
169+
protected McpSession GetSessionOrThrow()
182170
=> _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(StartSession)} before sending messages.");
183171
}

src/ModelContextProtocol/Shared/McpSession.cs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace ModelContextProtocol.Shared;
1414
/// <summary>
1515
/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers.
1616
/// </summary>
17-
internal sealed class McpSession : IMcpSession
17+
internal sealed class McpSession : IMcpSession, IAsyncDisposable
1818
{
1919
private readonly ITransport _transport;
2020
private readonly RequestHandlers _requestHandlers;
@@ -258,6 +258,18 @@ await _transport.SendMessageAsync(new JsonRpcResponse
258258
/// <returns>A task containing the server's response.</returns>
259259
public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, CancellationToken cancellationToken) where TResult : class
260260
{
261+
using var registration = cancellationToken.Register(async () =>
262+
{
263+
try
264+
{
265+
await this.NotifyCancelAsync(request.Id).ConfigureAwait(false);
266+
}
267+
catch (Exception ex)
268+
{
269+
_logger.LogError(ex, "An error occurred while notifying cancellation for request {RequestId}.", request.Id);
270+
}
271+
});
272+
261273
if (!_transport.IsConnected)
262274
{
263275
_logger.EndpointNotConnected(EndpointName);
@@ -380,13 +392,39 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
380392
}
381393
}
382394

383-
public void Dispose()
395+
/// <summary>
396+
/// Asynchronously releases resources used by the session.
397+
/// </summary>
398+
/// <returns>A task that represents the asynchronous dispose operation.</returns>
399+
public async ValueTask DisposeAsync()
384400
{
385401
// Complete all pending requests with cancellation
386402
foreach (var entry in _pendingRequests)
387403
{
388404
entry.Value.TrySetCanceled();
389405
}
390406
_pendingRequests.Clear();
407+
408+
// Dispose any remaining cancellation token sources
409+
foreach (var cts in _handlingRequests.Values)
410+
{
411+
cts.Dispose();
412+
}
413+
_handlingRequests.Clear();
414+
415+
// Asynchronously dispose the transport if it's IAsyncDisposable
416+
if (_transport is IAsyncDisposable asyncDisposableTransport)
417+
{
418+
await asyncDisposableTransport.DisposeAsync().ConfigureAwait(false);
419+
}
420+
else if (_transport is IDisposable disposableTransport)
421+
{
422+
disposableTransport.Dispose();
423+
}
424+
}
425+
426+
public void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler)
427+
{
428+
_notificationHandlers.Add(method, handler);
391429
}
392430
}

src/ModelContextProtocol/McpEndpointExtensions.cs renamed to src/ModelContextProtocol/Shared/McpSessionExtensions.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
using ModelContextProtocol.Protocol.Messages;
2+
using ModelContextProtocol.Shared;
23
using ModelContextProtocol.Utils;
34

45
namespace ModelContextProtocol;
56

6-
/// <summary>Provides extension methods for interacting with an <see cref="IMcpEndpoint"/>.</summary>
7+
/// <summary>Provides extension methods for interacting with an <see cref="IMcpSession"/>.</summary>
78
public static class McpEndpointExtensions
89
{
910
/// <summary>
@@ -15,7 +16,7 @@ public static class McpEndpointExtensions
1516
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
1617
/// <returns>A task representing the completion of the operation.</returns>
1718
public static Task NotifyAsync(
18-
this IMcpEndpoint endpoint,
19+
this IMcpSession endpoint,
1920
JsonRpcNotification notification,
2021
CancellationToken cancellationToken = default)
2122
{
@@ -34,7 +35,7 @@ public static Task NotifyAsync(
3435
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
3536
/// <returns>A task representing the completion of the operation.</returns>
3637
public static Task NotifyAsync(
37-
this IMcpEndpoint endpoint,
38+
this IMcpSession endpoint,
3839
string method,
3940
object? parameters = null,
4041
CancellationToken cancellationToken = default)
@@ -56,7 +57,7 @@ public static Task NotifyAsync(
5657
/// <returns>A task representing the completion of the operation.</returns>
5758
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
5859
public static Task NotifyProgressAsync(
59-
this IMcpEndpoint endpoint,
60+
this IMcpSession endpoint,
6061
ProgressToken progressToken,
6162
ProgressNotificationValue progress,
6263
CancellationToken cancellationToken = default)
@@ -82,7 +83,7 @@ public static Task NotifyProgressAsync(
8283
/// <returns>A task representing the completion of the operation.</returns>
8384
/// <exception cref="ArgumentNullException"><paramref name="endpoint"/> is <see langword="null"/>.</exception>
8485
public static Task NotifyCancelAsync(
85-
this IMcpEndpoint endpoint,
86+
this IMcpSession endpoint,
8687
RequestId requestId,
8788
string? reason = null,
8889
CancellationToken cancellationToken = default)

src/ModelContextProtocol/TokenProgress.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
using ModelContextProtocol.Protocol.Messages;
2+
using ModelContextProtocol.Shared;
23

34
namespace ModelContextProtocol;
45

56
/// <summary>
67
/// Provides an <see cref="IProgress{ProgressNotificationValue}"/> tied to a specific progress token and that will issue
78
/// progress notifications on the supplied endpoint.
89
/// </summary>
9-
internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress<ProgressNotificationValue>
10+
internal sealed class TokenProgress(IMcpSession endpoint, ProgressToken progressToken) : IProgress<ProgressNotificationValue>
1011
{
1112
/// <inheritdoc />
1213
public void Report(ProgressNotificationValue value)

tests/ModelContextProtocol.Tests/CancelledNotificationTests.cs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@ public class CancelledNotificationTests(
1515
public async Task NotifyCancelAsync_SendsCorrectNotification()
1616
{
1717
// Arrange
18-
await using var endpoint = fixture.CreateEndpoint();
19-
await using var transport = fixture.CreateTransport();
20-
var cancellationToken = TestContext.Current.CancellationToken;
21-
endpoint.Start(transport, cancellationToken);
18+
var token = TestContext.Current.CancellationToken;
19+
var clientTransport = fixture.CreateClientTransport();
20+
await using var endpoint = await fixture.CreateClientEndpointAsync(clientTransport);
21+
var transport = await clientTransport.ConnectAsync(token);
2222

2323
var requestId = new RequestId("test-request-id-123");
2424
const string reason = "Operation was cancelled by the user";
2525

2626
// Act
27-
await endpoint.NotifyCancelAsync(requestId, reason, cancellationToken);
27+
await endpoint.NotifyCancelAsync(requestId, reason, token);
2828

2929
// Assert
30-
Assert.Single(transport.SentMessages);
31-
var notification = Assert.IsType<JsonRpcNotification>(transport.SentMessages[0]);
30+
Assert.Equal(1, transport.MessageReader.Count);
31+
var message = await transport.MessageReader.ReadAsync(token);
32+
Assert.NotNull(message);
33+
34+
var notification = Assert.IsType<JsonRpcNotification>(message);
3235
Assert.Equal(NotificationMethods.CancelledNotification, notification.Method);
3336

3437
var cancelParams = Assert.IsType<CancelledNotification>(notification.Params);
@@ -40,10 +43,10 @@ public async Task NotifyCancelAsync_SendsCorrectNotification()
4043
public async Task SendRequestAsync_Cancellation_SendsNotification()
4144
{
4245
// Arrange
43-
await using var endpoint = fixture.CreateEndpoint();
44-
await using var transport = fixture.CreateTransport();
45-
endpoint.Start(transport, CancellationToken.None);
46-
46+
var token = TestContext.Current.CancellationToken;
47+
var clientTransport = fixture.CreateClientTransport();
48+
await using var endpoint = await fixture.CreateClientEndpointAsync(clientTransport);
49+
var transport = await clientTransport.ConnectAsync(token);
4750
var requestId = new RequestId("test-request-id-123");
4851
JsonRpcRequest request = new()
4952
{
@@ -68,15 +71,19 @@ public async Task SendRequestAsync_Cancellation_SendsNotification()
6871
}
6972

7073
// Assert
71-
Assert.NotEmpty(transport.SentMessages);
72-
Assert.Equal(2, transport.SentMessages.Count);
73-
var notification = Assert.IsType<JsonRpcNotification>(transport.SentMessages[0]);
74+
Assert.Equal(2, transport.MessageReader.Count);
75+
var message = await transport.MessageReader.ReadAsync(token);
76+
Assert.NotNull(message);
77+
78+
var notification = Assert.IsType<JsonRpcNotification>(message);
7479
Assert.Equal(NotificationMethods.CancelledNotification, notification.Method);
7580

7681
var cancelParams = Assert.IsType<CancelledNotification>(notification.Params);
7782
Assert.Equal(requestId, cancelParams.RequestId);
7883

79-
var requestMessage = Assert.IsType<JsonRpcRequest>(transport.SentMessages[1]);
84+
message = await transport.MessageReader.ReadAsync(token);
85+
Assert.NotNull(message);
86+
var requestMessage = Assert.IsType<JsonRpcRequest>(message);
8087
Assert.Equal(request.Id, requestMessage.Id);
8188
Assert.Equal(request.Method, requestMessage.Method);
8289
Assert.Equal(request.Params, requestMessage.Params);

0 commit comments

Comments
 (0)