From 2b8bd7b0a7ffdb6cd18d5da8294e99e520d6e187 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Thu, 29 May 2025 21:08:28 +0300 Subject: [PATCH 1/5] Pass session id's to MCP endpoints. --- .../SseHandler.cs | 2 +- .../StreamableHttpHandler.cs | 20 ++++++---- .../AutoDetectingClientSessionTransport.cs | 2 + .../Client/McpClient.cs | 11 ++++++ .../StreamableHttpClientSessionTransport.cs | 7 ++-- src/ModelContextProtocol.Core/IMcpEndpoint.cs | 5 +++ .../Protocol/ITransport.cs | 8 ++++ .../Protocol/TransportBase.cs | 3 ++ .../Server/DestinationBoundMcpServer.cs | 1 + .../Server/McpServer.cs | 4 ++ .../Server/SseResponseStreamTransport.cs | 6 ++- .../Server/StreamableHttpPostTransport.cs | 9 +++-- .../Server/StreamableHttpServerTransport.cs | 7 +++- tests/Common/Utils/TestServerTransport.cs | 2 + .../HttpServerIntegrationTests.cs | 38 +++++++++++++++++++ .../Program.cs | 17 +++++++++ .../Program.cs | 17 +++++++++ .../Client/McpClientFactoryTests.cs | 1 + .../ClientIntegrationTests.cs | 18 +++++++++ .../Server/McpServerTests.cs | 1 + 20 files changed, 161 insertions(+), 18 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index 251e3bf4f..c5ac5a948 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -31,7 +31,7 @@ public async Task HandleSseRequestAsync(HttpContext context) var requestPath = (context.Request.PathBase + context.Request.Path).ToString(); var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)]; - await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}"); + await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId); var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User); await using var httpMcpSession = new HttpMcpSession(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider); diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 6077efa10..a3afe20c0 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -4,7 +4,6 @@ using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; @@ -136,6 +135,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context) var transport = new StreamableHttpServerTransport { Stateless = true, + SessionId = sessionId, }; session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId); } @@ -184,7 +184,10 @@ private async ValueTask> StartNewS if (!HttpServerTransportOptions.Stateless) { sessionId = MakeNewSessionId(); - transport = new(); + transport = new() + { + SessionId = sessionId, + }; context.Response.Headers["mcp-session-id"] = sessionId; } else @@ -286,19 +289,22 @@ internal static string MakeNewSessionId() private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport) { - context.Response.OnStarting(() => + transport.OnInitRequestReceived = initRequestParams => { var statelessId = new StatelessSessionId { - ClientInfo = transport?.InitializeRequest?.ClientInfo, + ClientInfo = initRequestParams.ClientInfo, UserIdClaim = GetUserIdClaim(context.User), }; var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId); - var sessionId = Protector.Protect(sessionJson); - - context.Response.Headers["mcp-session-id"] = sessionId; + transport.SessionId = Protector.Protect(sessionJson); + }; + context.Response.OnStarting(() => + { + Debug.Assert(transport.SessionId is not null); + context.Response.Headers["mcp-session-id"] = transport.SessionId; return Task.CompletedTask; }); } diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 50601f666..39ae7e81d 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -45,6 +45,8 @@ public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOp public ChannelReader MessageReader => _messageChannel.Reader; + string? ITransport.SessionId => ActiveTransport?.SessionId; + /// public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index 5c5ad2fe4..1d1ffe50f 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using System.Diagnostics; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -93,6 +94,16 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL } } + /// + public ITransport Transport + { + get + { + Debug.Assert(_sessionTransport is not null, "Must have already initialized a session when invoking this property."); + return _sessionTransport!; + } + } + /// public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index e99aa2ae6..771f0cfd1 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -27,7 +27,6 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private readonly CancellationTokenSource _connectionCts; private readonly ILogger _logger; - private string? _mcpSessionId; private Task? _getReceiveTask; public StreamableHttpClientSessionTransport( @@ -85,7 +84,7 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes }, }; - CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId); + CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId); var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); @@ -124,7 +123,7 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes // We've successfully initialized! Copy session-id and start GET request if any. if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues)) { - _mcpSessionId = sessionIdValues.FirstOrDefault(); + SessionId = sessionIdValues.FirstOrDefault(); } _getReceiveTask = ReceiveUnsolicitedMessagesAsync(); @@ -170,7 +169,7 @@ private async Task ReceiveUnsolicitedMessagesAsync() // Send a GET request to handle any unsolicited messages not sent over a POST response. using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); request.Headers.Accept.Add(s_textEventStreamMediaType); - CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId); + CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId); using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false); diff --git a/src/ModelContextProtocol.Core/IMcpEndpoint.cs b/src/ModelContextProtocol.Core/IMcpEndpoint.cs index a38b8c97c..46e6164f6 100644 --- a/src/ModelContextProtocol.Core/IMcpEndpoint.cs +++ b/src/ModelContextProtocol.Core/IMcpEndpoint.cs @@ -28,6 +28,11 @@ namespace ModelContextProtocol; /// public interface IMcpEndpoint : IAsyncDisposable { + /// + /// Gets the underlying transport driving the current MCP endpoint. + /// + ITransport Transport { get; } + /// /// Sends a JSON-RPC request to the connected endpoint and waits for a response. /// diff --git a/src/ModelContextProtocol.Core/Protocol/ITransport.cs b/src/ModelContextProtocol.Core/Protocol/ITransport.cs index 4fc36e96b..e35b3a6fb 100644 --- a/src/ModelContextProtocol.Core/Protocol/ITransport.cs +++ b/src/ModelContextProtocol.Core/Protocol/ITransport.cs @@ -25,6 +25,14 @@ namespace ModelContextProtocol.Protocol; /// public interface ITransport : IAsyncDisposable { + /// Gets an identifier associated with the current MCP session. + /// + /// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE. + /// Can return if the session hasn't initialized or if the transport doesn't + /// support multiple sessions (as is the case with STDIO). + /// + string? SessionId { get; } + /// /// Gets a channel reader for receiving messages from the transport. /// diff --git a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs index 9be9c6fa5..a5ed1d3a0 100644 --- a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs +++ b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs @@ -59,6 +59,9 @@ internal TransportBase(string name, Channel? messageChannel, ILo /// Gets the logger used by this transport. private protected ILogger Logger => _logger; + /// + public virtual string? SessionId { get; protected set; } + /// /// Gets the name that identifies this transport endpoint in logs. /// diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index db594da8e..029baae24 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -6,6 +6,7 @@ namespace ModelContextProtocol.Server; internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer { public string EndpointName => server.EndpointName; + public ITransport Transport => transport ?? server.Transport; public ClientCapabilities? ClientCapabilities => server.ClientCapabilities; public Implementation? ClientInfo => server.ClientInfo; public McpServerOptions ServerOptions => server.ServerOptions; diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 808300f05..dd16d6f0b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -96,6 +96,10 @@ void Register(McpServerPrimitiveCollection? collection, InitializeSession(transport); } + /// + public ITransport Transport => _sessionTransport; + + /// public ServerCapabilities ServerCapabilities { get; } = new(); /// diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs index 82562d9d3..438421f28 100644 --- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs @@ -23,7 +23,8 @@ namespace ModelContextProtocol.Server; /// These messages should be passed to . /// Defaults to "/message". /// -public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message") : ITransport +/// The identifier corresponding to the current MCP session. +public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message", string? sessionId = null) : ITransport { private readonly SseWriter _sseWriter = new(messageEndpoint); private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1) @@ -49,6 +50,9 @@ public async Task RunAsync(CancellationToken cancellationToken) /// public ChannelReader MessageReader => _incomingChannel.Reader; + /// + public string? SessionId { get; } = sessionId; + /// public async ValueTask DisposeAsync() { diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 174cf9bd7..e98393622 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -18,6 +18,8 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); + string? ITransport.SessionId => parentTransport.SessionId; + /// /// True, if data was written to the respond body. /// False, if nothing was written because the request body did not contain any messages to respond to. @@ -79,10 +81,11 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella { _pendingRequest = request.Id; - // Store client capabilities so they can be serialized by "stateless" callers for use in later requests. - if (parentTransport.Stateless && request.Method == RequestMethods.Initialize) + // Invoke the initialize request callback if applicable. + if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) { - parentTransport.InitializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); + var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); + onInitRequest(initializeRequest!); } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 6b024ff6f..3d637da7c 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -46,15 +46,18 @@ public sealed class StreamableHttpServerTransport : ITransport public bool Stateless { get; init; } /// - /// Gets the initialize request if it was received by and is set to . + /// Gets or sets a callback to be invoked before handling the initialize request. /// - public InitializeRequestParams? InitializeRequest { get; internal set; } + public Action? OnInitRequestReceived { get; set; } /// public ChannelReader MessageReader => _incomingChannel.Reader; internal ChannelWriter MessageWriter => _incomingChannel.Writer; + /// + public string? SessionId { get; set; } + /// /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by /// writing any unsolicited JSON-RPC messages sent via diff --git a/tests/Common/Utils/TestServerTransport.cs b/tests/Common/Utils/TestServerTransport.cs index cd12504a0..1679fee65 100644 --- a/tests/Common/Utils/TestServerTransport.cs +++ b/tests/Common/Utils/TestServerTransport.cs @@ -16,6 +16,8 @@ public class TestServerTransport : ITransport public Action? OnMessageSent { get; set; } + public string? SessionId => null; + public TestServerTransport() { _messageChannel = Channel.CreateUnbounded(new UnboundedChannelOptions diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 30187faad..13bbb470b 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -52,6 +52,15 @@ public async Task Connect_TestServer_ShouldProvideServerFields() // Assert Assert.NotNull(client.ServerCapabilities); Assert.NotNull(client.ServerInfo); + + if (ClientTransportOptions.Endpoint.AbsolutePath.EndsWith("/sse")) + { + Assert.Null(client.Transport.SessionId); + } + else + { + Assert.NotNull(client.Transport.SessionId); + } } [Fact] @@ -90,6 +99,35 @@ public async Task CallTool_Sse_EchoServer() Assert.Equal("Echo: Hello MCP!", textContent.Text); } + [Fact] + public async Task CallTool_EchoSessionId_ReturnsTheSameSessionId() + { + // arrange + + // act + await using var client = await GetClientAsync(); + var result1 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken); + var result2 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken); + var result3 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(result1); + Assert.NotNull(result2); + Assert.NotNull(result3); + + Assert.False(result1.IsError); + Assert.False(result2.IsError); + Assert.False(result3.IsError); + + var textContent1 = Assert.Single(result1.Content); + var textContent2 = Assert.Single(result2.Content); + var textContent3 = Assert.Single(result3.Content); + + Assert.NotNull(textContent1.Text); + Assert.Equal(textContent1.Text, textContent2.Text); + Assert.Equal(textContent1.Text, textContent3.Text); + } + [Fact] public async Task ListResources_Sse_TestServer() { diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index dbecbf481..b7e9295b9 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -133,6 +133,16 @@ private static ToolsCapability ConfigureTools() """), }, new Tool() + { + Name = "echoSessionId", + Description = "Echoes the session id back to the client.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object" + } + """, McpJsonUtilities.DefaultOptions), + }, + new Tool() { Name = "sampleLLM", Description = "Samples from an LLM using MCP's sampling feature.", @@ -170,6 +180,13 @@ private static ToolsCapability ConfigureTools() Content = [new Content() { Text = "Echo: " + message.ToString(), Type = "text" }] }; } + else if (request.Params?.Name == "echoSessionId") + { + return new CallToolResponse() + { + Content = [new Content() { Text = request.Server.Transport.SessionId, Type = "text" }] + }; + } else if (request.Params?.Name == "sampleLLM") { if (request.Params?.Arguments is null || diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 2e37c1d7b..6fe9cc7ea 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -128,6 +128,16 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st """, McpJsonUtilities.DefaultOptions), }, new Tool() + { + Name = "echoSessionId", + Description = "Echoes the session id back to the client.", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object" + } + """, McpJsonUtilities.DefaultOptions), + }, + new Tool() { Name = "sampleLLM", Description = "Samples from an LLM using MCP's sampling feature.", @@ -168,6 +178,13 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st Content = [new Content() { Text = "Echo: " + message.ToString(), Type = "text" }] }; } + else if (request.Params.Name == "echoSessionId") + { + return new CallToolResponse() + { + Content = [new Content() { Text = request.Server.Transport.SessionId, Type = "text" }] + }; + } else if (request.Params.Name == "sampleLLM") { if (request.Params.Arguments is null || diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index 2b8b75616..9b7f44569 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -109,6 +109,7 @@ private class NopTransport : ITransport, IClientTransport private readonly Channel _channel = Channel.CreateUnbounded(); public bool IsConnected => true; + public string? SessionId => null; public ChannelReader MessageReader => _channel.Reader; diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 14fd02e98..1fa1d5592 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -54,6 +54,8 @@ public async Task Connect_ShouldProvideServerFields(string clientId) Assert.NotNull(client.ServerInfo); if (clientId != "everything") // Note: Comment the below assertion back when the everything server is updated to provide instructions Assert.NotNull(client.ServerInstructions); + + Assert.Null(client.Transport.SessionId); } [Theory] @@ -94,6 +96,22 @@ public async Task CallTool_Stdio_EchoServer(string clientId) Assert.Equal("Echo: Hello MCP!", textContent.Text); } + [Fact] + public async Task CallTool_Stdio_EchoSessionId_ReturnsNull() + { + // arrange + + // act + await using var client = await _fixture.CreateClientAsync("test_server"); + var result = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken); + + // assert + Assert.NotNull(result); + Assert.False(result.IsError); + var textContent = Assert.Single(result.Content, c => c.Type == "text"); + Assert.Null(textContent.Text); + } + [Theory] [MemberData(nameof(GetClients))] public async Task CallTool_Stdio_ViaAIFunction_EchoServer(string clientId) diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 75966b6eb..7d039777d 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -650,6 +650,7 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public ValueTask DisposeAsync() => default; + public ITransport Transport => throw new NotImplementedException(); public Implementation? ClientInfo => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); public LoggingLevel? LoggingLevel => throw new NotImplementedException(); From 06b1cdca541ecdc5587ec11e218f8616f4a069b7 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Fri, 30 May 2025 09:40:24 +0300 Subject: [PATCH 2/5] Update src/ModelContextProtocol/Client/McpClient.cs Co-authored-by: Stephen Halter --- src/ModelContextProtocol.Core/Client/McpClient.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index 1d1ffe50f..a43e5f1c7 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -99,8 +99,7 @@ public ITransport Transport { get { - Debug.Assert(_sessionTransport is not null, "Must have already initialized a session when invoking this property."); - return _sessionTransport!; + return _sessionTransport ?? throw new InvalidOperationException("Must have already initialized a session when invoking this property."); } } From 4450743158a098b6454e30ccc6c33f120c644bb9 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Fri, 30 May 2025 09:40:35 +0300 Subject: [PATCH 3/5] Update src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs Co-authored-by: Stephen Halter --- .../StreamableHttpHandler.cs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index a3afe20c0..3bbbc451e 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -299,14 +299,8 @@ private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttp var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId); transport.SessionId = Protector.Protect(sessionJson); - }; - - context.Response.OnStarting(() => - { - Debug.Assert(transport.SessionId is not null); context.Response.Headers["mcp-session-id"] = transport.SessionId; - return Task.CompletedTask; - }); + }; } internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) From e6182e79c98ec206170c4ecd5313c17fe574a42d Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Fri, 30 May 2025 09:57:01 +0300 Subject: [PATCH 4/5] Make init callback asynchronous. --- src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs | 3 ++- .../Server/StreamableHttpPostTransport.cs | 2 +- .../Server/StreamableHttpServerTransport.cs | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 3bbbc451e..a38fa7c6d 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -293,13 +293,14 @@ private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttp { var statelessId = new StatelessSessionId { - ClientInfo = initRequestParams.ClientInfo, + ClientInfo = initRequestParams?.ClientInfo, UserIdClaim = GetUserIdClaim(context.User), }; var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId); transport.SessionId = Protector.Protect(sessionJson); context.Response.Headers["mcp-session-id"] = transport.SessionId; + return ValueTask.CompletedTask; }; } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index e98393622..343b57485 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -85,7 +85,7 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) { var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - onInitRequest(initializeRequest!); + await onInitRequest(initializeRequest).ConfigureAwait(false); } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 3d637da7c..1f5775e66 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -48,7 +48,7 @@ public sealed class StreamableHttpServerTransport : ITransport /// /// Gets or sets a callback to be invoked before handling the initialize request. /// - public Action? OnInitRequestReceived { get; set; } + public Func? OnInitRequestReceived { get; set; } /// public ChannelReader MessageReader => _incomingChannel.Reader; From 0176c6c27b5c5500f32d8be6815885e7c6d71d0f Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Thu, 12 Jun 2025 17:38:25 +0300 Subject: [PATCH 5/5] Address feedback. --- src/ModelContextProtocol.Core/Client/McpClient.cs | 9 +++++++-- src/ModelContextProtocol.Core/IMcpEndpoint.cs | 11 +++++++---- .../Server/DestinationBoundMcpServer.cs | 2 +- src/ModelContextProtocol.Core/Server/McpServer.cs | 2 +- .../HttpServerIntegrationTests.cs | 4 ++-- tests/ModelContextProtocol.TestServer/Program.cs | 2 +- tests/ModelContextProtocol.TestSseServer/Program.cs | 2 +- .../ClientIntegrationTests.cs | 2 +- .../Server/McpServerTests.cs | 2 +- 9 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index a43e5f1c7..43639db2c 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -95,11 +95,16 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL } /// - public ITransport Transport + public string? SessionId { get { - return _sessionTransport ?? throw new InvalidOperationException("Must have already initialized a session when invoking this property."); + if (_sessionTransport is null) + { + throw new InvalidOperationException("Must have already initialized a session when invoking this property."); + } + + return _sessionTransport.SessionId; } } diff --git a/src/ModelContextProtocol.Core/IMcpEndpoint.cs b/src/ModelContextProtocol.Core/IMcpEndpoint.cs index 46e6164f6..ea825e682 100644 --- a/src/ModelContextProtocol.Core/IMcpEndpoint.cs +++ b/src/ModelContextProtocol.Core/IMcpEndpoint.cs @@ -28,10 +28,13 @@ namespace ModelContextProtocol; /// public interface IMcpEndpoint : IAsyncDisposable { - /// - /// Gets the underlying transport driving the current MCP endpoint. - /// - ITransport Transport { get; } + /// Gets an identifier associated with the current MCP session. + /// + /// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE. + /// Can return if the session hasn't initialized or if the transport doesn't + /// support multiple sessions (as is the case with STDIO). + /// + string? SessionId { get; } /// /// Sends a JSON-RPC request to the connected endpoint and waits for a response. diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index 029baae24..d286d1ef4 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Server; internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer { public string EndpointName => server.EndpointName; - public ITransport Transport => transport ?? server.Transport; + public string? SessionId => transport?.SessionId ?? server.SessionId; public ClientCapabilities? ClientCapabilities => server.ClientCapabilities; public Implementation? ClientInfo => server.ClientInfo; public McpServerOptions ServerOptions => server.ServerOptions; diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index dd16d6f0b..b715edda9 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -97,7 +97,7 @@ void Register(McpServerPrimitiveCollection? collection, } /// - public ITransport Transport => _sessionTransport; + public string? SessionId => _sessionTransport.SessionId; /// public ServerCapabilities ServerCapabilities { get; } = new(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 13bbb470b..8fc7fb3dd 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -55,11 +55,11 @@ public async Task Connect_TestServer_ShouldProvideServerFields() if (ClientTransportOptions.Endpoint.AbsolutePath.EndsWith("/sse")) { - Assert.Null(client.Transport.SessionId); + Assert.Null(client.SessionId); } else { - Assert.NotNull(client.Transport.SessionId); + Assert.NotNull(client.SessionId); } } diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index b7e9295b9..312a1823b 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -184,7 +184,7 @@ private static ToolsCapability ConfigureTools() { return new CallToolResponse() { - Content = [new Content() { Text = request.Server.Transport.SessionId, Type = "text" }] + Content = [new Content() { Text = request.Server.SessionId, Type = "text" }] }; } else if (request.Params?.Name == "sampleLLM") diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index 6fe9cc7ea..cf078c25c 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -182,7 +182,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { return new CallToolResponse() { - Content = [new Content() { Text = request.Server.Transport.SessionId, Type = "text" }] + Content = [new Content() { Text = request.Server.SessionId, Type = "text" }] }; } else if (request.Params.Name == "sampleLLM") diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 1fa1d5592..dfd8d767d 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -55,7 +55,7 @@ public async Task Connect_ShouldProvideServerFields(string clientId) if (clientId != "everything") // Note: Comment the below assertion back when the everything server is updated to provide instructions Assert.NotNull(client.ServerInstructions); - Assert.Null(client.Transport.SessionId); + Assert.Null(client.SessionId); } [Theory] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 7d039777d..0e1bb429a 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -650,7 +650,7 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public ValueTask DisposeAsync() => default; - public ITransport Transport => throw new NotImplementedException(); + public string? SessionId => throw new NotImplementedException(); public Implementation? ClientInfo => throw new NotImplementedException(); public IServiceProvider? Services => throw new NotImplementedException(); public LoggingLevel? LoggingLevel => throw new NotImplementedException();