Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup>
<LangVersion>13</LangVersion>
<LangVersion>preview</LangVersion>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa
private readonly ILogger _logger;

private string? _mcpSessionId;
private string? _negotiatedProtocolVersion;
private Task? _getReceiveTask;

public StreamableHttpClientSessionTransport(
Expand Down Expand Up @@ -85,7 +86,7 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
},
};

CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId);
CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, _mcpSessionId, _negotiatedProtocolVersion);

var response = await _httpClient.SendAsync(httpRequestMessage, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -119,14 +120,17 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
throw new McpException($"Streamable HTTP POST response completed without a reply to request with ID: {rpcRequest.Id}");
}

if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseOrError is JsonRpcResponse)
if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseOrError is JsonRpcResponse initResponse)
{
// We've successfully initialized! Copy session-id and start GET request if any.
// We've successfully initialized! Copy session-id and protocol version, then start GET request if any.
if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues))
{
_mcpSessionId = sessionIdValues.FirstOrDefault();
}

var initializeResult = JsonSerializer.Deserialize(initResponse.Result, McpJsonUtilities.JsonContext.Default.InitializeResult);
_negotiatedProtocolVersion = initializeResult?.ProtocolVersion;

_getReceiveTask = ReceiveUnsolicitedMessagesAsync();
}

Expand Down Expand Up @@ -170,7 +174,7 @@ private async Task ReceiveUnsolicitedMessagesAsync()
// Send a GET request to handle any unsolicited messages not sent over a POST response.
using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint);
request.Headers.Accept.Add(s_textEventStreamMediaType);
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId);
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, _mcpSessionId, _negotiatedProtocolVersion);

using var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, _connectionCts.Token).ConfigureAwait(false);

Expand Down Expand Up @@ -245,13 +249,22 @@ private void LogJsonException(JsonException ex, string data)
}
}

internal static void CopyAdditionalHeaders(HttpRequestHeaders headers, Dictionary<string, string>? additionalHeaders, string? sessionId = null)
internal static void CopyAdditionalHeaders(
HttpRequestHeaders headers,
Dictionary<string, string>? additionalHeaders,
string? sessionId = null,
string? protocolVersion = null)
{
if (sessionId is not null)
{
headers.Add("mcp-session-id", sessionId);
}

if (protocolVersion is not null)
{
headers.Add("mcp-protocol-version", protocolVersion);
}

if (additionalHeaders is null)
{
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
<PackageId>ModelContextProtocol.Core</PackageId>
<Description>Core .NET SDK for the Model Context Protocol (MCP)</Description>
<PackageReadmeFile>README.md</PackageReadmeFile>
<LangVersion>preview</LangVersion>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' != 'netstandard2.0'">
Expand Down
1 change: 0 additions & 1 deletion src/ModelContextProtocol/ModelContextProtocol.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
<PackageId>ModelContextProtocol</PackageId>
<Description>.NET SDK for the Model Context Protocol (MCP) with hosting and dependency injection extensions.</Description>
<PackageReadmeFile>README.md</PackageReadmeFile>
<LangVersion>preview</LangVersion>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' != 'netstandard2.0'">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Primitives;
using ModelContextProtocol.Client;

namespace ModelContextProtocol.AspNetCore.Tests;
Expand Down Expand Up @@ -143,4 +144,38 @@ public async Task SseMode_Works_WithSseEndpoint()

Assert.Equal("SseTestServer", mcpClient.ServerInfo.Name);
}

[Fact]
public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitialization()
{
var protocolVersionHeaderValues = new List<string?>();

Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools<EchoHttpContextUserTools>();

await using var app = Builder.Build();

app.Use(next =>
{
return async context =>
{
if (!StringValues.IsNullOrEmpty(context.Request.Headers["mcp-session-id"]))
{
protocolVersionHeaderValues.Add(context.Request.Headers["mcp-protocol-version"]);
}

await next(context);
};
});

app.MapMcp();

await app.StartAsync(TestContext.Current.CancellationToken);

await using var mcpClient = await ConnectAsync();
await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);

// The header should be included in the GET request, the initialized notification, and the tools/list call.
Assert.Equal(3, protocolVersionHeaderValues.Count);
Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v));
}
}
2 changes: 1 addition & 1 deletion tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ private ClaimsPrincipal CreateUser(string name)
"TestAuthType", "name", "role"));

[McpServerToolType]
private class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor)
protected class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor)
{
[McpServerTool, Description("Echoes the input back to the client with their user name.")]
public string EchoWithUserName(string message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
<TargetFrameworks>net9.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>Latest</LangVersion>
<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
<RootNamespace>ModelContextProtocol.AspNetCore.Tests</RootNamespace>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
<TargetFrameworks>net9.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>Latest</LangVersion>

<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
Expand Down Expand Up @@ -45,7 +44,7 @@
<PackageReference Include="Moq" />
<PackageReference Include="OpenTelemetry" />
<PackageReference Include="OpenTelemetry.Exporter.InMemory" />
<PackageReference Include="System.Linq.AsyncEnumerable" />
<PackageReference Include="System.Linq.AsyncEnumerable" />
<PackageReference Include="xunit.v3" />
<PackageReference Include="xunit.runner.visualstudio">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down