Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ internal sealed partial class AutoDetectingClientSessionTransport : ITransport
private readonly ILogger _logger;
private readonly string _name;
private readonly Channel<JsonRpcMessage> _messageChannel;
private string? _protocolVersion;

public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName)
{
Expand All @@ -43,6 +44,20 @@ public AutoDetectingClientSessionTransport(SseClientTransportOptions transportOp
/// </summary>
internal ITransport? ActiveTransport { get; private set; }

/// <inheritdoc />
public string? ProtocolVersion
{
get => ActiveTransport?.ProtocolVersion ?? _protocolVersion;
set
{
_protocolVersion = value;
if (ActiveTransport is { } transport)
{
transport.ProtocolVersion = value;
}
}
}

public ChannelReader<JsonRpcMessage> MessageReader => _messageChannel.Reader;

/// <inheritdoc/>
Expand Down Expand Up @@ -70,6 +85,10 @@ private async Task InitializeAsync(JsonRpcMessage message, CancellationToken can
{
LogUsingStreamableHttp(_name);
ActiveTransport = streamableHttpTransport;
if (_protocolVersion is { } protocolVersion)
{
ActiveTransport.ProtocolVersion = protocolVersion;
}
}
else
{
Expand Down
3 changes: 3 additions & 0 deletions src/ModelContextProtocol.Core/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
// Connect transport
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
InitializeSession(_sessionTransport);

// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
// The base class handles cleaning up the session in DisposeAsync without our help.
StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None);
Expand Down Expand Up @@ -164,6 +165,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}");
}

_sessionTransport.ProtocolVersion = initializeResponse.ProtocolVersion;

// Send initialized notification
await SendMessageAsync(
new JsonRpcNotification { Method = NotificationMethods.InitializedNotification },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public override async Task SendMessageAsync(
{
Content = content,
};
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, protocolVersion: ProtocolVersion);
var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);

if (!response.IsSuccessStatusCode)
Expand Down Expand Up @@ -152,7 +152,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
{
using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint);
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, protocolVersion: ProtocolVersion);

using var response = await _httpClient.SendAsync(
request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
},
};

CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId);
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId, ProtocolVersion);

var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -170,7 +170,7 @@ private async Task ReceiveUnsolicitedMessagesAsync()
// Send a GET request to handle any unsolicited messages not sent over a POST response.
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
request.Headers.Accept.Add(s_textEventStreamMediaType);
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId);
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId, ProtocolVersion);

using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false);

Expand Down Expand Up @@ -245,23 +245,30 @@ private void LogJsonException(JsonException ex, string data)
}
}

internal static void CopyAdditionalHeaders(HttpRequestHeaders headers, Dictionary<string, string>? additionalHeaders, string? sessionId = null)
internal static void CopyAdditionalHeaders(
HttpRequestHeaders headers,
Dictionary<string, string>? additionalHeaders,
string? sessionId = null,
string? protocolVersion = null)
{
if (sessionId is not null)
{
headers.Add("mcp-session-id", sessionId);
}

if (additionalHeaders is null)
if (protocolVersion is not null)
{
return;
headers.Add("MCP-Protocol-Version", protocolVersion);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
headers.Add("MCP-Protocol-Version", protocolVersion);
headers.Add("mcp-protocol-version", protocolVersion);

}

foreach (var header in additionalHeaders)
if (additionalHeaders is not null)
{
if (!headers.TryAddWithoutValidation(header.Key, header.Value))
foreach (var header in additionalHeaders)
{
throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}.");
if (!headers.TryAddWithoutValidation(header.Key, header.Value))
{
throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}.");
}
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/ITransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,12 @@ public interface ITransport : IAsyncDisposable
/// </para>
/// </remarks>
Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default);

/// <summary>Gets or sets the protocol version that's in use.</summary>
/// <remarks>
/// Setting the protocol version does not change the protocol version actively employed by the transport.
/// It provides that information to the transport for situations where the transport needs to be able to
/// propagate the version information, for example as part of HTTP headers or for logging and diagnostic purposes.
/// </remarks>
string? ProtocolVersion { get; set; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this as a public settable property is a little weird when, as the remarks note, it doesn't change the behavior of the transport.

I was thinking we should try to minimize the change to the public API surface and do something more similar to what I instructed copilot to do in halter73#11. It does end up double-deserializing the InitializeResult, but I think that's better layering.

If we think it's useful to have a public property, I think it should have an internal setter, and we should call it in McpServer as well. And it should probably be on IMcpEndpoint rather than ITransport. However, I think we should probably just get rid of this property alltogether for now, and contain the product changes to just a few lines in StreamableHttpClientSessionTransport.cs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong preference.

I wasn't aware you were already working on it. I'll just close this one and you can continue with the other approach.

}
3 changes: 3 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/TransportBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ internal TransportBase(string name, Channel<JsonRpcMessage>? messageChannel, ILo
/// <summary>Gets the logger used by this transport.</summary>
private protected ILogger Logger => _logger;

/// <inheritdoc />
public string? ProtocolVersion { get; set; }

/// <summary>
/// Gets the name that identifies this transport endpoint in logs.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string?

private bool _isConnected;

/// <inheritdoc />
public string? ProtocolVersion { get; set; }

/// <summary>
/// Starts the transport and writes the JSON-RPC messages sent via <see cref="SendMessageAsync"/>
/// to the SSE response stream until cancellation is requested or the transport is disposed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport

public ChannelReader<JsonRpcMessage> MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages.");

/// <inheritdoc />
public string? ProtocolVersion { get; set; }

/// <returns>
/// True, if data was written to the respond body.
/// False, if nothing was written because the request body did not contain any <see cref="JsonRpcRequest"/> messages to respond to.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public sealed class StreamableHttpServerTransport : ITransport
private readonly CancellationTokenSource _disposeCts = new();

private int _getRequestStarted;

/// <inheritdoc />
public string? ProtocolVersion { get; set; }

/// <summary>
/// Configures whether the transport should be in stateless mode that does not require all requests for a given session
Expand Down
2 changes: 2 additions & 0 deletions tests/Common/Utils/TestServerTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ public class TestServerTransport : ITransport

public bool IsConnected { get; set; }

public string? ProtocolVersion { get; set; }

public ChannelReader<JsonRpcMessage> MessageReader => _messageChannel;

public List<JsonRpcMessage> SentMessages { get; } = [];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,24 @@ private async Task StartAsync()
Services = _app.Services,
});

_app.MapPost("/mcp", (JsonRpcMessage message) =>
_app.MapPost("/mcp", (HttpContext context, JsonRpcMessage message) =>
{
if (message is not JsonRpcRequest request)
{
// Ignore all non-request notifications.
return Results.Accepted();
}

const string ExpectedProtocolVersion = "2024-11-05";

if (request.Method == "initialize")
{
return Results.Json(new JsonRpcResponse
{
Id = request.Id,
Result = JsonSerializer.SerializeToNode(new InitializeResult
{
ProtocolVersion = "2024-11-05",
ProtocolVersion = ExpectedProtocolVersion,
Capabilities = new()
{
Tools = new(),
Expand All @@ -57,6 +59,15 @@ private async Task StartAsync()
});
}

if (!context.Request.Headers.TryGetValue("MCP-Protocol-Version", out var actualVersion))
{
throw new Exception("Request headers did not contain MCP-Protocol-Version.");
}
else if (ExpectedProtocolVersion != actualVersion)
{
throw new Exception($"Unexpected protocol version: {actualVersion}");
}

if (request.Method == "tools/list")
{
return Results.Json(new JsonRpcResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ private class NopTransport : ITransport, IClientTransport

public bool IsConnected => true;

public string? ProtocolVersion { get; set; }

public ChannelReader<JsonRpcMessage> MessageReader => _channel.Reader;

public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult<ITransport>(this);
Expand Down