Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,40 +87,17 @@ public override async Task SendMessageAsync(
messageId = messageWithId.Id.ToString();
}

var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint)
using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint)
{
Content = content,
};
CopyAdditionalHeaders(httpRequestMessage.Headers);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders);
var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);

response.EnsureSuccessStatusCode();

var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);

// Check if the message was an initialize request
if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize)
{
// If the response is not a JSON-RPC response, it is an SSE message
if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase))
{
LogAcceptedPost(Name, messageId);
// The response will arrive as an SSE message
}
else
{
JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ??
throw new InvalidOperationException("Failed to initialize client");

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}

return;
}

// Otherwise, check if the response was accepted (the response will come as an SSE message)
if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase))
{
LogAcceptedPost(Name, messageId);
Expand Down Expand Up @@ -177,17 +154,13 @@ public override async ValueTask DisposeAsync()
}
}

internal Uri? MessageEndpoint => _messageEndpoint;

internal SseClientTransportOptions Options => _options;

private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)
{
try
{
using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint);
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
CopyAdditionalHeaders(request.Headers);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders);

using var response = await _httpClient.SendAsync(
request,
Expand Down Expand Up @@ -251,15 +224,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation
return;
}

string messageId = "(no id)";
if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}
catch (JsonException ex)
{
Expand Down Expand Up @@ -290,20 +255,6 @@ private void HandleEndpointEvent(string data)
_connectionEstablished.TrySetResult(true);
}

private void CopyAdditionalHeaders(HttpRequestHeaders headers)
{
if (_options.AdditionalHeaders is not null)
{
foreach (var header in _options.AdditionalHeaders)
{
if (!headers.TryAddWithoutValidation(header.Key, header.Value))
{
throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}.");
}
}
}
}

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} accepted SSE transport POST for message ID '{MessageId}'.")]
private partial void LogAcceptedPost(string endpointName, string messageId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient
/// <inheritdoc />
public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default)
{
if (_options.UseStreamableHttp)
{
return new StreamableHttpClientSessionTransport(_options, _httpClient, _loggerFactory, Name);
}

var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name);

try
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using static System.Net.WebRequestMethods;

namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
Expand Down Expand Up @@ -30,13 +32,20 @@ public required Uri Endpoint
}
}

/// <summary>
/// Gets or sets a value indicating whether to use "Streamable HTTP" for the transport rather than "HTTP with SSE". Defaults to false.
/// <see href="https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http">Streamable HTTP transport specification</see>.
/// <see href="https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse">HTTP with SSE transport specification</see>.
/// </summary>
public bool UseStreamableHttp { get; init; }

/// <summary>
/// Gets a transport identifier used for logging purposes.
/// </summary>
public string? Name { get; init; }

/// <summary>
/// Gets or sets a timeout used to establish the initial connection to the SSE server.
/// Gets or sets a timeout used to establish the initial connection to the SSE server. Defaults to 30 seconds.
/// </summary>
/// <remarks>
/// This timeout controls how long the client waits for:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,7 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati
var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
if (message != null)
{
string messageId = "(no id)";
if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,7 @@ private async Task ReadMessagesAsync()
{
if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) is JsonRpcMessage message)
{
string messageId = "(no id)";
if (message is JsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}

LogTransportReceivedMessage(Name, messageId);
await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
LogTransportMessageWritten(Name, messageId);
}
else
{
Expand Down
Loading