diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 50601f666..5714affd2 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -18,6 +18,7 @@ internal sealed partial class AutoDetectingClientSessionTransport : ITransport private readonly ILogger _logger; private readonly string _name; private readonly Channel _messageChannel; + private string? _protocolVersion; public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) { @@ -43,6 +44,20 @@ public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOp /// internal ITransport? ActiveTransport { get; private set; } + /// + public string? ProtocolVersion + { + get => ActiveTransport?.ProtocolVersion ?? _protocolVersion; + set + { + _protocolVersion = value; + if (ActiveTransport is { } transport) + { + transport.ProtocolVersion = value; + } + } + } + public ChannelReader MessageReader => _messageChannel.Reader; /// @@ -70,6 +85,10 @@ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken can { LogUsingStreamableHttp(_name); ActiveTransport = streamableHttpTransport; + if (_protocolVersion is { } protocolVersion) + { + ActiveTransport.ProtocolVersion = protocolVersion; + } } else { diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index 5c5ad2fe4..7a8c0fde2 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -118,6 +118,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) // Connect transport _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); InitializeSession(_sessionTransport); + // We don't want the ConnectAsync token to cancel the session after we've successfully connected. // The base class handles cleaning up the session in DisposeAsync without our help. StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); @@ -164,6 +165,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); } + _sessionTransport.ProtocolVersion = initializeResponse.ProtocolVersion; + // Send initialized notification await SendMessageAsync( new JsonRpcNotification { Method = NotificationMethods.InitializedNotification }, diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index fd2466eaf..e3300a991 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -91,7 +91,7 @@ public override async Task SendMessageAsync( { Content = content, }; - StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders); + StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, protocolVersion: ProtocolVersion); var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); if (!response.IsSuccessStatusCode) @@ -152,7 +152,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders); + StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, protocolVersion: ProtocolVersion); using var response = await _httpClient.SendAsync( request, diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index e99aa2ae6..ceb748d4d 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -85,7 +85,7 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes }, }; - CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId); + CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId, ProtocolVersion); var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); @@ -170,7 +170,7 @@ private async Task ReceiveUnsolicitedMessagesAsync() // Send a GET request to handle any unsolicited messages not sent over a POST response. using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); request.Headers.Accept.Add(s_textEventStreamMediaType); - CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId); + CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId, ProtocolVersion); using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false); @@ -245,23 +245,30 @@ private void LogJsonException(JsonException ex, string data) } } - internal static void CopyAdditionalHeaders(HttpRequestHeaders headers, Dictionary? additionalHeaders, string? sessionId = null) + internal static void CopyAdditionalHeaders( + HttpRequestHeaders headers, + Dictionary? additionalHeaders, + string? sessionId = null, + string? protocolVersion = null) { if (sessionId is not null) { headers.Add("mcp-session-id", sessionId); } - if (additionalHeaders is null) + if (protocolVersion is not null) { - return; + headers.Add("MCP-Protocol-Version", protocolVersion); } - foreach (var header in additionalHeaders) + if (additionalHeaders is not null) { - if (!headers.TryAddWithoutValidation(header.Key, header.Value)) + foreach (var header in additionalHeaders) { - throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}."); + if (!headers.TryAddWithoutValidation(header.Key, header.Value)) + { + throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}."); + } } } } diff --git a/src/ModelContextProtocol.Core/Protocol/ITransport.cs b/src/ModelContextProtocol.Core/Protocol/ITransport.cs index 4fc36e96b..38ceb3e1c 100644 --- a/src/ModelContextProtocol.Core/Protocol/ITransport.cs +++ b/src/ModelContextProtocol.Core/Protocol/ITransport.cs @@ -60,4 +60,12 @@ public interface ITransport : IAsyncDisposable /// /// Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); + + /// Gets or sets the protocol version that's in use. + /// + /// Setting the protocol version does not change the protocol version actively employed by the transport. + /// It provides that information to the transport for situations where the transport needs to be able to + /// propagate the version information, for example as part of HTTP headers or for logging and diagnostic purposes. + /// + string? ProtocolVersion { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs index 9be9c6fa5..ed92ca3d1 100644 --- a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs +++ b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs @@ -59,6 +59,9 @@ internal TransportBase(string name, Channel? messageChannel, ILo /// Gets the logger used by this transport. private protected ILogger Logger => _logger; + /// + public string? ProtocolVersion { get; set; } + /// /// Gets the name that identifies this transport endpoint in logs. /// diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs index 82562d9d3..d1d1976fe 100644 --- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs @@ -34,6 +34,9 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? private bool _isConnected; + /// + public string? ProtocolVersion { get; set; } + /// /// Starts the transport and writes the JSON-RPC messages sent via /// to the SSE response stream until cancellation is requested or the transport is disposed. diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 174cf9bd7..e1264b44d 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -18,6 +18,9 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); + /// + public string? ProtocolVersion { get; set; } + /// /// True, if data was written to the respond body. /// False, if nothing was written because the request body did not contain any messages to respond to. diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 6b024ff6f..52b2522b8 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -35,6 +35,9 @@ public sealed class StreamableHttpServerTransport : ITransport private readonly CancellationTokenSource _disposeCts = new(); private int _getRequestStarted; + + /// + public string? ProtocolVersion { get; set; } /// /// Configures whether the transport should be in stateless mode that does not require all requests for a given session diff --git a/tests/Common/Utils/TestServerTransport.cs b/tests/Common/Utils/TestServerTransport.cs index cd12504a0..6f2bb9e2e 100644 --- a/tests/Common/Utils/TestServerTransport.cs +++ b/tests/Common/Utils/TestServerTransport.cs @@ -10,6 +10,8 @@ public class TestServerTransport : ITransport public bool IsConnected { get; set; } + public string? ProtocolVersion { get; set; } + public ChannelReader MessageReader => _messageChannel; public List SentMessages { get; } = []; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index d7f8433b3..ce284ad24 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -28,7 +28,7 @@ private async Task StartAsync() Services = _app.Services, }); - _app.MapPost("/mcp", (JsonRpcMessage message) => + _app.MapPost("/mcp", (HttpContext context, JsonRpcMessage message) => { if (message is not JsonRpcRequest request) { @@ -36,6 +36,8 @@ private async Task StartAsync() return Results.Accepted(); } + const string ExpectedProtocolVersion = "2024-11-05"; + if (request.Method == "initialize") { return Results.Json(new JsonRpcResponse @@ -43,7 +45,7 @@ private async Task StartAsync() Id = request.Id, Result = JsonSerializer.SerializeToNode(new InitializeResult { - ProtocolVersion = "2024-11-05", + ProtocolVersion = ExpectedProtocolVersion, Capabilities = new() { Tools = new(), @@ -57,6 +59,15 @@ private async Task StartAsync() }); } + if (!context.Request.Headers.TryGetValue("MCP-Protocol-Version", out var actualVersion)) + { + throw new Exception("Request headers did not contain MCP-Protocol-Version."); + } + else if (ExpectedProtocolVersion != actualVersion) + { + throw new Exception($"Unexpected protocol version: {actualVersion}"); + } + if (request.Method == "tools/list") { return Results.Json(new JsonRpcResponse diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index 2b8b75616..571073f6e 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -110,6 +110,8 @@ private class NopTransport : ITransport, IClientTransport public bool IsConnected => true; + public string? ProtocolVersion { get; set; } + public ChannelReader MessageReader => _channel.Reader; public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this);