Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/ModelContextProtocol/Client/SseClientSessionTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,25 @@ public override async Task SendMessageAsync(
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders);
var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);

response.EnsureSuccessStatusCode();

var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);

if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase))
{
LogAcceptedPost(Name, messageId);
}
else
if (!response.IsSuccessStatusCode)
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogRejectedPostSensitive(Name, messageId, responseContent);
LogRejectedPostSensitive(Name, messageId, await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false));
}
else
{
LogRejectedPost(Name, messageId);
}

throw new InvalidOperationException("Failed to send message");
response.EnsureSuccessStatusCode();
}

if (response.Content.Headers.ContentType?.MediaType is "application/json")
{
// Certain MCP servers implementing SSE may return the response in the current POST request instead of the SSE stream.
// Even though this is not officially part of the SSE protocol, we handle it here.
await ProcessInboundMessage(await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false), cancellationToken).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -179,7 +178,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
break;

case "message":
await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false);
await ProcessInboundMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false);
break;
}
}
Expand All @@ -205,7 +204,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
}
}

private async Task ProcessSseMessage(string data, CancellationToken cancellationToken)
private async Task ProcessInboundMessage(string data, CancellationToken cancellationToken)
{
if (!IsConnected)
{
Expand Down
90 changes: 90 additions & 0 deletions tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using System.Text.Json;
using System.Text.Json.Serialization;
using TestServerWithHosting.Tools;

Expand Down Expand Up @@ -58,6 +59,34 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU
Assert.True(true);
}

[Fact]
public async Task ConnectAndReceiveMessage_ServerReturningResponseInPostRequest_WithFullEndpointEventUri()
{
await using var app = Builder.Build();
MapSseEndpointRespondingInPostRequest(app,
request => new JsonRpcResponse
{
Id = request.Id,
Result = JsonSerializer.SerializeToNode(
new InitializeResult
{
Capabilities = new(),
ServerInfo = new() { Name = "TestServer", Version = "1.0" },
ProtocolVersion = "2024-11-05"
}, McpJsonUtilities.DefaultOptions),
});

await app.StartAsync(TestContext.Current.CancellationToken);

await using var mcpClient = await ConnectMcpClientAsync();

// Send a test message through POST endpoint
JsonRpcRequest request = new() { Method = "TestMethod", Id = new(42), Params = null };
JsonRpcResponse response = await mcpClient.SendRequestAsync(request, cancellationToken: TestContext.Current.CancellationToken);

Assert.Equal(request.Id, response.Id);
}

[Fact]
public async Task ConnectAndReceiveNotification_InMemoryServer()
{
Expand Down Expand Up @@ -280,6 +309,67 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints)
});
}

private static void MapSseEndpointRespondingInPostRequest(IEndpointRouteBuilder endpoints, Func<JsonRpcRequest, JsonRpcResponse> requestHandler)
{
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();

var routeGroup = endpoints.MapGroup("");
SseResponseStreamTransport? session = null;

routeGroup.MapGet("/sse", async context =>
{
var response = context.Response;
var requestAborted = context.RequestAborted;

response.Headers.ContentType = "text/event-stream";

await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost/message");
session = transport;

try
{
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
await using var server = McpServerFactory.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider);

try
{
await server.RunAsync(requestAborted);
}
finally
{
await transport.DisposeAsync();
await transportTask;
}
}
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
}
});

routeGroup.MapPost("/message", async context =>
{
if (session is null)
{
await Results.BadRequest("Session not started.").ExecuteAsync(context);
return;
}
var message = (JsonRpcRequest?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcRequest)), context.RequestAborted);
if (message is null)
{
await Results.BadRequest("No message in request body.").ExecuteAsync(context);
return;
}

var response = requestHandler(message);

context.Response.StatusCode = StatusCodes.Status202Accepted;
await context.Response.WriteAsJsonAsync(response, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcResponse)), cancellationToken: context.RequestAborted);
});
}

public class Envelope
{
public required string Message { get; set; }
Expand Down
Loading