Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelContextProtocol.AspNetCore/SseHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public async Task HandleSseRequestAsync(HttpContext context)

var requestPath = (context.Request.PathBase + context.Request.Path).ToString();
var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)];
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}");
await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId);

var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User);
await using var httpMcpSession = new HttpMcpSession<SseResponseStreamTransport>(sessionId, transport, userIdClaim, httpMcpServerOptions.Value.TimeProvider);
Expand Down
21 changes: 11 additions & 10 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
using ModelContextProtocol.AspNetCore.Stateless;
using ModelContextProtocol.Protocol;
Expand Down Expand Up @@ -136,6 +135,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
var transport = new StreamableHttpServerTransport
{
Stateless = true,
SessionId = sessionId,
};
session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId);
}
Expand Down Expand Up @@ -184,7 +184,10 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> StartNewS
if (!HttpServerTransportOptions.Stateless)
{
sessionId = MakeNewSessionId();
transport = new();
transport = new()
{
SessionId = sessionId,
};
context.Response.Headers["mcp-session-id"] = sessionId;
}
else
Expand Down Expand Up @@ -286,21 +289,19 @@ internal static string MakeNewSessionId()

private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport)
{
context.Response.OnStarting(() =>
transport.OnInitRequestReceived = initRequestParams =>
{
var statelessId = new StatelessSessionId
{
ClientInfo = transport?.InitializeRequest?.ClientInfo,
ClientInfo = initRequestParams?.ClientInfo,
UserIdClaim = GetUserIdClaim(context.User),
};

var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId);
var sessionId = Protector.Protect(sessionJson);

context.Response.Headers["mcp-session-id"] = sessionId;

return Task.CompletedTask;
});
transport.SessionId = Protector.Protect(sessionJson);
context.Response.Headers["mcp-session-id"] = transport.SessionId;
return ValueTask.CompletedTask;
};
}

internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOp

public ChannelReader<JsonRpcMessage> MessageReader => _messageChannel.Reader;

string? ITransport.SessionId => ActiveTransport?.SessionId;

/// <inheritdoc/>
public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
Expand Down
15 changes: 15 additions & 0 deletions src/ModelContextProtocol.Core/Client/McpClient.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using System.Diagnostics;
using System.Text.Json;

namespace ModelContextProtocol.Client;
Expand Down Expand Up @@ -93,6 +94,20 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, IL
}
}

/// <inheritdoc/>
public string? SessionId
{
get
{
if (_sessionTransport is null)
{
throw new InvalidOperationException("Must have already initialized a session when invoking this property.");
}

return _sessionTransport.SessionId;
}
}

/// <inheritdoc/>
public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected.");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa
private readonly CancellationTokenSource _connectionCts;
private readonly ILogger _logger;

private string? _mcpSessionId;
private Task? _getReceiveTask;

public StreamableHttpClientSessionTransport(
Expand Down Expand Up @@ -85,7 +84,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
},
};

CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId);
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, SessionId);

var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -124,7 +123,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
// We've successfully initialized! Copy session-id and start GET request if any.
if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues))
{
_mcpSessionId = sessionIdValues.FirstOrDefault();
SessionId = sessionIdValues.FirstOrDefault();
}

_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
Expand Down Expand Up @@ -170,7 +169,7 @@ private async Task ReceiveUnsolicitedMessagesAsync()
// Send a GET request to handle any unsolicited messages not sent over a POST response.
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
request.Headers.Accept.Add(s_textEventStreamMediaType);
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId);
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId);

using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false);

Expand Down
8 changes: 8 additions & 0 deletions src/ModelContextProtocol.Core/IMcpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ namespace ModelContextProtocol;
/// </remarks>
public interface IMcpEndpoint : IAsyncDisposable
{
/// <summary>Gets an identifier associated with the current MCP session.</summary>
/// <remarks>
/// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE.
/// Can return <see langword="null"/> if the session hasn't initialized or if the transport doesn't
/// support multiple sessions (as is the case with STDIO).
/// </remarks>
string? SessionId { get; }

/// <summary>
/// Sends a JSON-RPC request to the connected endpoint and waits for a response.
/// </summary>
Expand Down
8 changes: 8 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/ITransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ namespace ModelContextProtocol.Protocol;
/// </remarks>
public interface ITransport : IAsyncDisposable
{
/// <summary>Gets an identifier associated with the current MCP session.</summary>
/// <remarks>
/// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE.
/// Can return <see langword="null"/> if the session hasn't initialized or if the transport doesn't
/// support multiple sessions (as is the case with STDIO).
/// </remarks>
string? SessionId { get; }

/// <summary>
/// Gets a channel reader for receiving messages from the transport.
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/TransportBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ internal TransportBase(string name, Channel<JsonRpcMessage>? messageChannel, ILo
/// <summary>Gets the logger used by this transport.</summary>
private protected ILogger Logger => _logger;

/// <inheritdoc/>
public virtual string? SessionId { get; protected set; }

/// <summary>
/// Gets the name that identifies this transport endpoint in logs.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace ModelContextProtocol.Server;
internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer
{
public string EndpointName => server.EndpointName;
public string? SessionId => transport?.SessionId ?? server.SessionId;
public ClientCapabilities? ClientCapabilities => server.ClientCapabilities;
public Implementation? ClientInfo => server.ClientInfo;
public McpServerOptions ServerOptions => server.ServerOptions;
Expand Down
4 changes: 4 additions & 0 deletions src/ModelContextProtocol.Core/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ void Register<TPrimitive>(McpServerPrimitiveCollection<TPrimitive>? collection,
InitializeSession(transport);
}

/// <inheritdoc/>
public string? SessionId => _sessionTransport.SessionId;

/// <inheritdoc/>
public ServerCapabilities ServerCapabilities { get; } = new();

/// <inheritdoc />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace ModelContextProtocol.Server;
/// These messages should be passed to <see cref="OnMessageReceivedAsync(JsonRpcMessage, CancellationToken)"/>.
/// Defaults to "/message".
/// </param>
public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message") : ITransport
/// <param name="sessionId">The identifier corresponding to the current MCP session.</param>
public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message", string? sessionId = null) : ITransport
{
private readonly SseWriter _sseWriter = new(messageEndpoint);
private readonly Channel<JsonRpcMessage> _incomingChannel = Channel.CreateBounded<JsonRpcMessage>(new BoundedChannelOptions(1)
Expand All @@ -49,6 +50,9 @@ public async Task RunAsync(CancellationToken cancellationToken)
/// <inheritdoc/>
public ChannelReader<JsonRpcMessage> MessageReader => _incomingChannel.Reader;

/// <inheritdoc/>
public string? SessionId { get; } = sessionId;

/// <inheritdoc/>
public async ValueTask DisposeAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport

public ChannelReader<JsonRpcMessage> MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages.");

string? ITransport.SessionId => parentTransport.SessionId;

/// <returns>
/// True, if data was written to the respond body.
/// False, if nothing was written because the request body did not contain any <see cref="JsonRpcRequest"/> messages to respond to.
Expand Down Expand Up @@ -79,10 +81,11 @@ private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, Cancella
{
_pendingRequest = request.Id;

// Store client capabilities so they can be serialized by "stateless" callers for use in later requests.
if (parentTransport.Stateless && request.Method == RequestMethods.Initialize)
// Invoke the initialize request callback if applicable.
if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize)
{
parentTransport.InitializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams);
var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams);
await onInitRequest(initializeRequest).ConfigureAwait(false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@ public sealed class StreamableHttpServerTransport : ITransport
public bool Stateless { get; init; }

/// <summary>
/// Gets the initialize request if it was received by <see cref="HandlePostRequest(IDuplexPipe, CancellationToken)"/> and <see cref="Stateless"/> is set to <see langword="true"/>.
/// Gets or sets a callback to be invoked before handling the initialize request.
/// </summary>
public InitializeRequestParams? InitializeRequest { get; internal set; }
public Func<InitializeRequestParams?, ValueTask>? OnInitRequestReceived { get; set; }

/// <inheritdoc/>
public ChannelReader<JsonRpcMessage> MessageReader => _incomingChannel.Reader;

internal ChannelWriter<JsonRpcMessage> MessageWriter => _incomingChannel.Writer;

/// <inheritdoc/>
public string? SessionId { get; set; }

/// <summary>
/// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by
/// writing any unsolicited JSON-RPC messages sent via <see cref="SendMessageAsync"/>
Expand Down
2 changes: 2 additions & 0 deletions tests/Common/Utils/TestServerTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public class TestServerTransport : ITransport

public Action<JsonRpcMessage>? OnMessageSent { get; set; }

public string? SessionId => null;

public TestServerTransport()
{
_messageChannel = Channel.CreateUnbounded<JsonRpcMessage>(new UnboundedChannelOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ public async Task Connect_TestServer_ShouldProvideServerFields()
// Assert
Assert.NotNull(client.ServerCapabilities);
Assert.NotNull(client.ServerInfo);

if (ClientTransportOptions.Endpoint.AbsolutePath.EndsWith("/sse"))
{
Assert.Null(client.SessionId);
}
else
{
Assert.NotNull(client.SessionId);
}
}

[Fact]
Expand Down Expand Up @@ -90,6 +99,35 @@ public async Task CallTool_Sse_EchoServer()
Assert.Equal("Echo: Hello MCP!", textContent.Text);
}

[Fact]
public async Task CallTool_EchoSessionId_ReturnsTheSameSessionId()
{
// arrange

// act
await using var client = await GetClientAsync();
var result1 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken);
var result2 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken);
var result3 = await client.CallToolAsync("echoSessionId", cancellationToken: TestContext.Current.CancellationToken);

// assert
Assert.NotNull(result1);
Assert.NotNull(result2);
Assert.NotNull(result3);

Assert.False(result1.IsError);
Assert.False(result2.IsError);
Assert.False(result3.IsError);

var textContent1 = Assert.Single(result1.Content);
var textContent2 = Assert.Single(result2.Content);
var textContent3 = Assert.Single(result3.Content);

Assert.NotNull(textContent1.Text);
Assert.Equal(textContent1.Text, textContent2.Text);
Assert.Equal(textContent1.Text, textContent3.Text);
}

[Fact]
public async Task ListResources_Sse_TestServer()
{
Expand Down
17 changes: 17 additions & 0 deletions tests/ModelContextProtocol.TestServer/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ private static ToolsCapability ConfigureTools()
"""),
},
new Tool()
{
Name = "echoSessionId",
Description = "Echoes the session id back to the client.",
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
{
"type": "object"
}
""", McpJsonUtilities.DefaultOptions),
},
new Tool()
{
Name = "sampleLLM",
Description = "Samples from an LLM using MCP's sampling feature.",
Expand Down Expand Up @@ -170,6 +180,13 @@ private static ToolsCapability ConfigureTools()
Content = [new Content() { Text = "Echo: " + message.ToString(), Type = "text" }]
};
}
else if (request.Params?.Name == "echoSessionId")
{
return new CallToolResponse()
{
Content = [new Content() { Text = request.Server.SessionId, Type = "text" }]
};
}
else if (request.Params?.Name == "sampleLLM")
{
if (request.Params?.Arguments is null ||
Expand Down
17 changes: 17 additions & 0 deletions tests/ModelContextProtocol.TestSseServer/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
""", McpJsonUtilities.DefaultOptions),
},
new Tool()
{
Name = "echoSessionId",
Description = "Echoes the session id back to the client.",
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
{
"type": "object"
}
""", McpJsonUtilities.DefaultOptions),
},
new Tool()
{
Name = "sampleLLM",
Description = "Samples from an LLM using MCP's sampling feature.",
Expand Down Expand Up @@ -168,6 +178,13 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
Content = [new Content() { Text = "Echo: " + message.ToString(), Type = "text" }]
};
}
else if (request.Params.Name == "echoSessionId")
{
return new CallToolResponse()
{
Content = [new Content() { Text = request.Server.SessionId, Type = "text" }]
};
}
else if (request.Params.Name == "sampleLLM")
{
if (request.Params.Arguments is null ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ private class NopTransport : ITransport, IClientTransport
private readonly Channel<JsonRpcMessage> _channel = Channel.CreateUnbounded<JsonRpcMessage>();

public bool IsConnected => true;
public string? SessionId => null;

public ChannelReader<JsonRpcMessage> MessageReader => _channel.Reader;

Expand Down
Loading
Loading