diff --git a/.github/workflows/ci-build-test.yml b/.github/workflows/ci-build-test.yml index 2bcbb4d1b..7bdfd9258 100644 --- a/.github/workflows/ci-build-test.yml +++ b/.github/workflows/ci-build-test.yml @@ -62,7 +62,7 @@ jobs: node-version: '20' - name: 📦 Install dependencies for tests - run: npm install @modelcontextprotocol/server-everything + run: npm install @modelcontextprotocol/server-everything@2025.12.18 - name: 📦 Install dependencies for tests run: npm install @modelcontextprotocol/server-memory diff --git a/src/Common/ServerSentEvents/ArrayBuffer.cs b/src/Common/ServerSentEvents/ArrayBuffer.cs new file mode 100644 index 000000000..bc5191d3a --- /dev/null +++ b/src/Common/ServerSentEvents/ArrayBuffer.cs @@ -0,0 +1,198 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Copied from https://github.com/dotnet/runtime/blob/dcbf3413c5f7ae431a68fd0d3f09af095b525887/src/libraries/Common/src/System/Net/ArrayBuffer.cs + +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Net.ServerSentEvents; + +// Warning: Mutable struct! +// The purpose of this struct is to simplify buffer management. +// It manages a sliding buffer where bytes can be added at the end and removed at the beginning. +// [ActiveSpan/Memory] contains the current buffer contents; these bytes will be preserved +// (copied, if necessary) on any call to EnsureAvailableBytes. +// [AvailableSpan/Memory] contains the available bytes past the end of the current content, +// and can be written to in order to add data to the end of the buffer. +// Commit(byteCount) will extend the ActiveSpan by [byteCount] bytes into the AvailableSpan. +// Discard(byteCount) will discard [byteCount] bytes as the beginning of the ActiveSpan. + +[StructLayout(LayoutKind.Auto)] +internal struct ArrayBuffer : IDisposable +{ +#if NET + private static int ArrayMaxLength => Array.MaxLength; +#else + private const int ArrayMaxLength = 0X7FFFFFC7; +#endif + + private readonly bool _usePool; + private byte[] _bytes; + private int _activeStart; + private int _availableStart; + + // Invariants: + // 0 <= _activeStart <= _availableStart <= bytes.Length + + public ArrayBuffer(int initialSize, bool usePool = false) + { + Debug.Assert(initialSize > 0 || usePool); + + _usePool = usePool; + _bytes = initialSize == 0 + ? Array.Empty() + : usePool ? ArrayPool.Shared.Rent(initialSize) : new byte[initialSize]; + _activeStart = 0; + _availableStart = 0; + } + + public ArrayBuffer(byte[] buffer) + { + Debug.Assert(buffer.Length > 0); + + _usePool = false; + _bytes = buffer; + _activeStart = 0; + _availableStart = 0; + } + + public void Dispose() + { + _activeStart = 0; + _availableStart = 0; + + byte[] array = _bytes; + _bytes = null!; + + if (array is not null) + { + ReturnBufferIfPooled(array); + } + } + + // This is different from Dispose as the instance remains usable afterwards (_bytes will not be null). + public void ClearAndReturnBuffer() + { + Debug.Assert(_usePool); + Debug.Assert(_bytes is not null); + + _activeStart = 0; + _availableStart = 0; + + byte[] bufferToReturn = _bytes!; + _bytes = Array.Empty(); + ReturnBufferIfPooled(bufferToReturn); + } + + public int ActiveLength => _availableStart - _activeStart; + public Span ActiveSpan => new Span(_bytes, _activeStart, _availableStart - _activeStart); + public ReadOnlySpan ActiveReadOnlySpan => new ReadOnlySpan(_bytes, _activeStart, _availableStart - _activeStart); + public Memory ActiveMemory => new Memory(_bytes, _activeStart, _availableStart - _activeStart); + + public int AvailableLength => _bytes.Length - _availableStart; + public Span AvailableSpan => _bytes.AsSpan(_availableStart); + public Memory AvailableMemory => _bytes.AsMemory(_availableStart); + public Memory AvailableMemorySliced(int length) => new Memory(_bytes, _availableStart, length); + + public int Capacity => _bytes.Length; + public int ActiveStartOffset => _activeStart; + + public byte[] DangerousGetUnderlyingBuffer() => _bytes; + + public void Discard(int byteCount) + { + Debug.Assert(byteCount <= ActiveLength, $"Expected {byteCount} <= {ActiveLength}"); + _activeStart += byteCount; + + if (_activeStart == _availableStart) + { + _activeStart = 0; + _availableStart = 0; + } + } + + public void Commit(int byteCount) + { + Debug.Assert(byteCount <= AvailableLength); + _availableStart += byteCount; + } + + // Ensure at least [byteCount] bytes to write to. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void EnsureAvailableSpace(int byteCount) + { + if (byteCount > AvailableLength) + { + EnsureAvailableSpaceCore(byteCount); + } + } + + private void EnsureAvailableSpaceCore(int byteCount) + { + Debug.Assert(AvailableLength < byteCount); + + if (_bytes.Length == 0) + { + Debug.Assert(_usePool && _activeStart == 0 && _availableStart == 0); + _bytes = ArrayPool.Shared.Rent(byteCount); + return; + } + + int totalFree = _activeStart + AvailableLength; + if (byteCount <= totalFree) + { + // We can free up enough space by just shifting the bytes down, so do so. + Buffer.BlockCopy(_bytes, _activeStart, _bytes, 0, ActiveLength); + _availableStart = ActiveLength; + _activeStart = 0; + Debug.Assert(byteCount <= AvailableLength); + return; + } + + int desiredSize = ActiveLength + byteCount; + + if ((uint)desiredSize > ArrayMaxLength) + { + throw new OutOfMemoryException(); + } + + // Double the existing buffer size (capped at Array.MaxLength). + int newSize = Math.Max(desiredSize, (int)Math.Min(ArrayMaxLength, 2 * (uint)_bytes.Length)); + + byte[] newBytes = _usePool ? + ArrayPool.Shared.Rent(newSize) : + new byte[newSize]; + byte[] oldBytes = _bytes; + + if (ActiveLength != 0) + { + Buffer.BlockCopy(oldBytes, _activeStart, newBytes, 0, ActiveLength); + } + + _availableStart = ActiveLength; + _activeStart = 0; + + _bytes = newBytes; + ReturnBufferIfPooled(oldBytes); + + Debug.Assert(byteCount <= AvailableLength); + } + + public void Grow() + { + EnsureAvailableSpaceCore(AvailableLength + 1); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ReturnBufferIfPooled(byte[] buffer) + { + // The buffer may be Array.Empty() + if (_usePool && buffer.Length > 0) + { + ArrayPool.Shared.Return(buffer); + } + } +} diff --git a/src/Common/ServerSentEvents/PooledByteBufferWriter.cs b/src/Common/ServerSentEvents/PooledByteBufferWriter.cs new file mode 100644 index 000000000..d8928d6dd --- /dev/null +++ b/src/Common/ServerSentEvents/PooledByteBufferWriter.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Copied from https://github.com/dotnet/runtime/blob/dcbf3413c5f7ae431a68fd0d3f09af095b525887/src/libraries/System.Net.ServerSentEvents/src/System/Net/ServerSentEvents/PooledByteBufferWriter.cs + +using System.Buffers; +using System.Diagnostics; + +namespace System.Net.ServerSentEvents; + +internal sealed class PooledByteBufferWriter : IBufferWriter, IDisposable +{ + private const int MinimumBufferSize = 256; + private ArrayBuffer _buffer = new(initialSize: 256, usePool: true); + + public void Advance(int count) => _buffer.Commit(count); + + public Memory GetMemory(int sizeHint = 0) + { + _buffer.EnsureAvailableSpace(Math.Max(sizeHint, MinimumBufferSize)); + return _buffer.AvailableMemory; + } + + public Span GetSpan(int sizeHint = 0) + { + _buffer.EnsureAvailableSpace(Math.Max(sizeHint, MinimumBufferSize)); + return _buffer.AvailableSpan; + } + + public ReadOnlyMemory WrittenMemory => _buffer.ActiveMemory; + public int Capacity => _buffer.Capacity; + public int WrittenCount => _buffer.ActiveLength; + public void Reset() => _buffer.Discard(_buffer.ActiveLength); + public void Dispose() => _buffer.Dispose(); +} diff --git a/src/Common/ServerSentEvents/SseEventWriter.cs b/src/Common/ServerSentEvents/SseEventWriter.cs new file mode 100644 index 000000000..bf61e73af --- /dev/null +++ b/src/Common/ServerSentEvents/SseEventWriter.cs @@ -0,0 +1,136 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Based on https://github.com/dotnet/runtime/blob/dcbf3413c5f7ae431a68fd0d3f09af095b525887/src/libraries/System.Net.ServerSentEvents/src/System/Net/ServerSentEvents/SseFormatter.cs + +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Net.ServerSentEvents; + +/// +/// Provides methods for writing SSE events to a stream. +/// +internal sealed class SseEventWriter : IDisposable +{ + private static readonly byte[] s_newLine = "\n"u8.ToArray(); + + private readonly Stream _destination; + private readonly PooledByteBufferWriter _bufferWriter = new(); + private readonly PooledByteBufferWriter _userDataBufferWriter = new(); + + /// + /// Initializes a new instance of the class with the specified destination stream and item formatter. + /// + /// The stream to write SSE events to. + /// is . + public SseEventWriter(Stream destination) + { + _destination = destination ?? throw new ArgumentNullException(nameof(destination)); + } + + /// + /// Writes an SSE item to the destination stream. + /// + /// The SSE item to write. + /// + /// The token to monitor for cancellation requests. + /// A task representing the asynchronous write operation. + public async ValueTask WriteAsync(SseItem item, Action, IBufferWriter> itemFormatter, CancellationToken cancellationToken = default) + { + itemFormatter(item, _userDataBufferWriter); + + FormatSseEvent( + _bufferWriter, + eventType: item.EventType, + data: _userDataBufferWriter.WrittenMemory.Span, + eventId: item.EventId, + reconnectionInterval: item.ReconnectionInterval); + + await _destination.WriteAsync(_bufferWriter.WrittenMemory, cancellationToken).ConfigureAwait(false); + await _destination.FlushAsync(cancellationToken).ConfigureAwait(false); + + _userDataBufferWriter.Reset(); + _bufferWriter.Reset(); + } + + private static void FormatSseEvent( + IBufferWriter bufferWriter, + string? eventType, + ReadOnlySpan data, + string? eventId, + TimeSpan? reconnectionInterval) + { + if (eventType is not null) + { + Debug.Assert(!eventType.ContainsLineBreaks()); + + bufferWriter.WriteUtf8String("event: "u8); + bufferWriter.WriteUtf8String(eventType); + bufferWriter.WriteUtf8String(s_newLine); + } + + WriteLinesWithPrefix(bufferWriter, prefix: "data: "u8, data); + bufferWriter.Write(s_newLine); + + if (eventId is not null) + { + Debug.Assert(!eventId.ContainsLineBreaks()); + + bufferWriter.WriteUtf8String("id: "u8); + bufferWriter.WriteUtf8String(eventId); + bufferWriter.WriteUtf8String(s_newLine); + } + + if (reconnectionInterval is { } retry) + { + Debug.Assert(retry >= TimeSpan.Zero); + + bufferWriter.WriteUtf8String("retry: "u8); + bufferWriter.WriteUtf8Number((long)retry.TotalMilliseconds); + bufferWriter.WriteUtf8String(s_newLine); + } + + bufferWriter.WriteUtf8String(s_newLine); + } + + private static void WriteLinesWithPrefix(IBufferWriter writer, ReadOnlySpan prefix, ReadOnlySpan data) + { + // Writes a potentially multi-line string, prefixing each line with the given prefix. + // Both \n and \r\n sequences are normalized to \n. + + while (true) + { + writer.WriteUtf8String(prefix); + + int i = data.IndexOfAny((byte)'\r', (byte)'\n'); + if (i < 0) + { + writer.WriteUtf8String(data); + return; + } + + int lineLength = i; + if (data[i++] == '\r' && i < data.Length && data[i] == '\n') + { + i++; + } + + ReadOnlySpan nextLine = data.Slice(0, lineLength); + data = data.Slice(i); + + writer.WriteUtf8String(nextLine); + writer.WriteUtf8String(s_newLine); + } + } + + /// + public void Dispose() + { + _bufferWriter.Dispose(); + _userDataBufferWriter.Dispose(); + } +} diff --git a/src/Common/ServerSentEvents/SseEventWriterHelpers.cs b/src/Common/ServerSentEvents/SseEventWriterHelpers.cs new file mode 100644 index 000000000..53e315df2 --- /dev/null +++ b/src/Common/ServerSentEvents/SseEventWriterHelpers.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Copied from https://github.com/dotnet/runtime/blob/dcbf3413c5f7ae431a68fd0d3f09af095b525887/src/libraries/System.Net.ServerSentEvents/src/System/Net/ServerSentEvents/Helpers.cs + +using System.Buffers; +using System.Diagnostics; +using System.Globalization; +using System.Text; + +namespace System.Net.ServerSentEvents; + +internal static class SseEventWriterHelpers +{ + public static void WriteUtf8Number(this IBufferWriter writer, long value) + { +#if NET + const int MaxDecimalDigits = 20; + Span buffer = writer.GetSpan(MaxDecimalDigits); + Debug.Assert(MaxDecimalDigits <= buffer.Length); + + bool success = value.TryFormat(buffer, out int bytesWritten, provider: CultureInfo.InvariantCulture); + Debug.Assert(success); + writer.Advance(bytesWritten); +#else + writer.WriteUtf8String(value.ToString(CultureInfo.InvariantCulture)); +#endif + } + + public static void WriteUtf8String(this IBufferWriter writer, ReadOnlySpan value) + { + if (value.IsEmpty) + { + return; + } + + Span buffer = writer.GetSpan(value.Length); + Debug.Assert(value.Length <= buffer.Length); + value.CopyTo(buffer); + writer.Advance(value.Length); + } + + public static void WriteUtf8String(this IBufferWriter writer, ReadOnlySpan value) + { + if (value.IsEmpty) + { + return; + } + +#if NET + int maxByteCount = Encoding.UTF8.GetMaxByteCount(value.Length); + Span buffer = writer.GetSpan(maxByteCount); + Debug.Assert(buffer.Length >= maxByteCount); + + int bytesWritten = Encoding.UTF8.GetBytes(value, buffer); + writer.Advance(bytesWritten); +#else + // netstandard2.0 doesn't have the Span overload of GetBytes + byte[] bytes = Encoding.UTF8.GetBytes(value.ToString()); + Span buffer = writer.GetSpan(bytes.Length); + bytes.AsSpan().CopyTo(buffer); + writer.Advance(bytes.Length); +#endif + } + + public static bool ContainsLineBreaks(this ReadOnlySpan text) => + text.IndexOfAny('\r', '\n') >= 0; + + public static bool ContainsLineBreaks(this string? text) => + text is not null && text.AsSpan().ContainsLineBreaks(); +} diff --git a/src/Common/ServerSentEvents/SseItem.cs b/src/Common/ServerSentEvents/SseItem.cs new file mode 100644 index 000000000..566e08eb4 --- /dev/null +++ b/src/Common/ServerSentEvents/SseItem.cs @@ -0,0 +1,31 @@ +namespace System.Net.ServerSentEvents; + +/// +/// Provides factory methods for creating server-sent event (SSE) items with specific event types and data payloads. +/// +internal static class SseItem +{ + /// + /// Creates a new server-sent event (SSE) message containing the specified data and the default event type. + /// + /// The type of the data to include in the SSE message. + /// The data to include in the SSE message. Can be null. + /// An representing an SSE message with the specified data and the default event type. + public static SseItem Message(T? data) + => new(data: data, SseParser.EventTypeDefault); + + /// + /// Creates a new Server-Sent Events (SSE) item representing a 'prime' event with no data. + /// + /// An instance representing a 'prime' event with no data. + public static SseItem Prime() + => new(data: default, eventType: "prime"); + + /// + /// Creates a server-sent event (SSE) item representing the specified endpoint. + /// + /// The endpoint string to include in the SSE item. Cannot be null. + /// An containing the specified endpoint value. + public static SseItem Endpoint(string endpoint) + => new(endpoint, "endpoint"); +} diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 67f4f4e1d..1ede12e7f 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -43,6 +43,21 @@ public class HttpServerTransportOptions /// public bool Stateless { get; set; } + /// + /// Gets or sets the event store for resumability support. + /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header. + /// + /// + /// When configured, the server will: + /// + /// Generate unique event IDs for each SSE message + /// Store events for later replay + /// Replay missed events when a client reconnects with a Last-Event-ID header + /// Send priming events to establish resumability before any actual messages + /// + /// + public ISseEventStreamStore? EventStreamStore { get; set; } + /// /// Gets or sets a value that indicates whether the server uses a single execution context for the entire session. /// diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 8c78d7516..ad87f7a4c 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -41,10 +41,12 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo if (!streamableHttpHandler.HttpServerTransportOptions.Stateless) { - // The GET and DELETE endpoints are not mapped in Stateless mode since there's no way to send unsolicited messages - // for the GET to handle, and there is no server-side state for the DELETE to clean up. + // The GET endpoint is not mapped in Stateless mode since there's no way to send unsolicited messages. + // Resuming streams via GET is currently not supported in Stateless mode. streamableHttpGroup.MapGet("", streamableHttpHandler.HandleGetRequestAsync) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); + + // The DELETE endpoint is not mapped in Stateless mode since there is no server-side state for the DELETE to clean up. streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); // Map legacy HTTP with SSE endpoints only if not in Stateless mode, because we cannot guarantee the /message requests diff --git a/src/ModelContextProtocol.AspNetCore/SseEventStreamReaderExtensions.cs b/src/ModelContextProtocol.AspNetCore/SseEventStreamReaderExtensions.cs new file mode 100644 index 000000000..7c6970c70 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/SseEventStreamReaderExtensions.cs @@ -0,0 +1,53 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Buffers; +using System.Net.ServerSentEvents; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Provides extension methods for . +/// +internal static class SseEventStreamReaderExtensions +{ + /// + /// Copies all events from the reader to the destination stream in SSE format. + /// + /// The event stream reader to copy events from. + /// The destination stream to write SSE-formatted events to. + /// A token to cancel the operation. + /// A task that represents the asynchronous copy operation. + /// Thrown when or is null. + public static async Task CopyToAsync(this ISseEventStreamReader reader, Stream destination, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(reader); + ArgumentNullException.ThrowIfNull(destination); + + Utf8JsonWriter? jsonWriter = null; + var jsonTypeInfo = (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)); + + var events = reader.ReadEventsAsync(cancellationToken); + await SseFormatter.WriteAsync(events, destination, FormatEvent, cancellationToken).ConfigureAwait(false); + + void FormatEvent(SseItem item, IBufferWriter writer) + { + if (item.Data is null) + { + return; + } + + if (jsonWriter is null) + { + jsonWriter = new Utf8JsonWriter(writer); + } + else + { + jsonWriter.Reset(writer); + } + + JsonSerializer.Serialize(jsonWriter, item.Data, jsonTypeInfo); + } + } +} diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index c0f59363a..6e11b9b86 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -23,6 +23,7 @@ internal sealed class StreamableHttpHandler( ILoggerFactory loggerFactory) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; + private const string LastEventIdHeaderName = "Last-Event-ID"; private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); @@ -88,10 +89,57 @@ await WriteJsonRpcErrorAsync(context, return; } + var lastEventId = context.Request.Headers[LastEventIdHeaderName].ToString(); + if (!string.IsNullOrEmpty(lastEventId)) + { + await HandleResumedStreamAsync(context, session, lastEventId); + } + else + { + await HandleUnsolicitedMessageStreamAsync(context, session); + } + } + + private async Task HandleResumedStreamAsync(HttpContext context, StreamableHttpSession session, string lastEventId) + { + if (HttpServerTransportOptions.Stateless) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The Last-Event-ID header is not supported in stateless mode.", + StatusCodes.Status400BadRequest); + return; + } + + var eventStreamReader = await GetEventStreamReaderAsync(context, lastEventId); + if (eventStreamReader is null) + { + // There was an error obtaining the event stream; consider the request failed. + return; + } + + if (!string.Equals(session.Id, eventStreamReader.SessionId, StringComparison.Ordinal)) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The Last-Event-ID header refers to a session with a different session ID.", + StatusCodes.Status400BadRequest); + return; + } + + using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping); + var cancellationToken = sseCts.Token; + + await using var _ = await session.AcquireReferenceAsync(cancellationToken); + + InitializeSseResponse(context); + await eventStreamReader.CopyToAsync(context.Response.Body, context.RequestAborted); + } + + private async Task HandleUnsolicitedMessageStreamAsync(HttpContext context, StreamableHttpSession session) + { if (!session.TryStartGetRequest()) { await WriteJsonRpcErrorAsync(context, - "Bad Request: This server does not support multiple GET requests. Start a new session to get a new GET SSE response.", + "Bad Request: This server does not support multiple GET requests. Start a new session or use Last-Event-ID header to resume.", StatusCodes.Status400BadRequest); return; } @@ -120,6 +168,12 @@ await WriteJsonRpcErrorAsync(context, } } + private static async Task HandleResumePostResponseStreamAsync(HttpContext context, ISseEventStreamReader eventStreamReader) + { + InitializeSseResponse(context); + await eventStreamReader.CopyToAsync(context.Response.Body, context.RequestAborted); + } + public async Task HandleDeleteRequestAsync(HttpContext context) { var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString(); @@ -131,14 +185,13 @@ public async Task HandleDeleteRequestAsync(HttpContext context) private async ValueTask GetSessionAsync(HttpContext context, string sessionId) { - StreamableHttpSession? session; - if (string.IsNullOrEmpty(sessionId)) { await WriteJsonRpcErrorAsync(context, "Bad Request: Mcp-Session-Id header is required", StatusCodes.Status400BadRequest); return null; } - else if (!sessionManager.TryGetValue(sessionId, out session)) + + if (!sessionManager.TryGetValue(sessionId, out var session)) { // -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does. // One of the few other usages I found was from some Ethereum JSON-RPC documentation and this @@ -194,12 +247,15 @@ private async ValueTask StartNewSessionAsync(HttpContext { SessionId = sessionId, FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, + EventStreamStore = HttpServerTransportOptions.EventStreamStore, }; context.Response.Headers[McpSessionIdHeaderName] = sessionId; } else { // In stateless mode, each request is independent. Don't set any session ID on the transport. + // If in the future we support resuming stateless requests, we should populate + // the event stream store and retry interval here as well. sessionId = ""; transport = new() { @@ -246,6 +302,28 @@ private async ValueTask CreateSessionAsync( return session; } + private async ValueTask GetEventStreamReaderAsync(HttpContext context, string lastEventId) + { + if (HttpServerTransportOptions.EventStreamStore is not { } eventStreamStore) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: This server does not support resuming streams.", + StatusCodes.Status400BadRequest); + return null; + } + + var eventStreamReader = await eventStreamStore.GetStreamReaderAsync(lastEventId, context.RequestAborted); + if (eventStreamReader is null) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The specified Last-Event-ID is either invalid or expired.", + StatusCodes.Status400BadRequest); + return null; + } + + return eventStreamReader; + } + private static Task WriteJsonRpcErrorAsync(HttpContext context, string errorMessage, int statusCode, int errorCode = -32000) { var jsonRpcError = new JsonRpcError diff --git a/src/ModelContextProtocol.Core/.editorconfig b/src/ModelContextProtocol.Core/.editorconfig new file mode 100644 index 000000000..3a5001118 --- /dev/null +++ b/src/ModelContextProtocol.Core/.editorconfig @@ -0,0 +1,2 @@ +[*.cs] +dotnet_diagnostic.CA2007.severity = error # CA2007: Do not directly await a Task without ConfigureAwait diff --git a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs index 43b6ef30d..0a8776ad6 100644 --- a/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs @@ -108,4 +108,31 @@ public required Uri Endpoint /// Gets sor sets the authorization provider to use for authentication. /// public ClientOAuthOptions? OAuth { get; set; } + + /// + /// Gets or sets the maximum number of consecutive reconnection attempts when an SSE stream is disconnected. + /// + /// + /// The maximum number of reconnection attempts. The default is 2. + /// + /// + /// When an SSE stream is disconnected (e.g., due to a network issue), the client will attempt to + /// reconnect using the Last-Event-ID header to resume from where it left off. This property controls + /// how many reconnection attempts are made before giving up. + /// + public int MaxReconnectionAttempts { get; set; } = 2; + + /// + /// Gets or sets the default interval at which the client attempts reconnection after an SSE stream is disconnected. + /// + /// + ///

+ /// The default value is 1 second. + ///

+ ///

+ /// If the server sends a message specifying a different reconnection interval, that new value will be used for all + /// subsequent reconnection attempts for that stream. + ///

+ ///
+ public TimeSpan DefaultReconnectionInterval { get; set; } = TimeSpan.FromSeconds(1); } diff --git a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs index a9c228d43..78b0d6f4a 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs @@ -19,7 +19,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation { try { - await base.SendMessageAsync(message, cancellationToken); + await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } catch (IOException) { @@ -58,7 +58,7 @@ protected override async ValueTask CleanupAsync(Exception? error = null, Cancell } // And handle cleanup in the base type. - await base.CleanupAsync(error, cancellationToken); + await base.CleanupAsync(error, cancellationToken).ConfigureAwait(false); } private async ValueTask GetUnexpectedExitExceptionAsync(CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index 534249038..017512589 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -5,6 +5,7 @@ using System.Text.Json; using ModelContextProtocol.Protocol; using System.Threading.Channels; +using System.Net; namespace ModelContextProtocol.Client; @@ -105,8 +106,18 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes } else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream") { - using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken); - rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false); + var sseState = new SseStreamState(); + using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var sseResponse = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, sseState, cancellationToken).ConfigureAwait(false); + rpcResponseOrError = sseResponse.Response; + + // Resumability: If POST SSE stream ended without a response but we have a Last-Event-ID (from priming), + // attempt to resume by sending a GET request with Last-Event-ID header. The server will replay + // events from the event store, allowing us to receive the pending response. + if (rpcResponseOrError is null && rpcRequest is not null && sseState.LastEventId is not null) + { + rpcResponseOrError = await SendGetSseRequestWithRetriesAsync(rpcRequest, sseState, cancellationToken).ConfigureAwait(false); + } } if (rpcRequest is null) @@ -155,7 +166,7 @@ public override async ValueTask DisposeAsync() // Send DELETE request to terminate the session. Only send if we have a session ID, per MCP spec. if (_options.OwnsSession && !string.IsNullOrEmpty(SessionId)) { - await SendDeleteRequest(); + await SendDeleteRequest().ConfigureAwait(false); } if (_getReceiveTask != null) @@ -188,56 +199,147 @@ public override async ValueTask DisposeAsync() private async Task ReceiveUnsolicitedMessagesAsync() { - // Send a GET request to handle any unsolicited messages not sent over a POST response. - using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); - request.Headers.Accept.Add(s_textEventStreamMediaType); - CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion); + var state = new SseStreamState(); - // Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages. - HttpResponseMessage response; - try - { - response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false); - } - catch (HttpRequestException) + // Continuously receive unsolicited messages until canceled + while (!_connectionCts.Token.IsCancellationRequested) { - return; - } + await SendGetSseRequestWithRetriesAsync( + relatedRpcRequest: null, + state, + _connectionCts.Token).ConfigureAwait(false); - using (response) - { - if (!response.IsSuccessStatusCode) + // If we exhausted retries without receiving any events, stop trying + if (state.LastEventId is null) { return; } - - using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false); - await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false); } } - private async Task ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) + /// + /// Sends a GET request for SSE with retry logic and resumability support. + /// + private async Task SendGetSseRequestWithRetriesAsync( + JsonRpcRequest? relatedRpcRequest, + SseStreamState state, + CancellationToken cancellationToken) { - await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + int attempt = 0; + + // Delay before first attempt if we're reconnecting (have a Last-Event-ID) + bool shouldDelay = state.LastEventId is not null; + + while (attempt < _options.MaxReconnectionAttempts) { - if (sseEvent.EventType != "message") + cancellationToken.ThrowIfCancellationRequested(); + + if (shouldDelay) { - continue; + var delay = state.RetryInterval ?? _options.DefaultReconnectionInterval; + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); } + shouldDelay = true; - var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false); + using var request = new HttpRequestMessage(HttpMethod.Get, _options.Endpoint); + request.Headers.Accept.Add(s_textEventStreamMediaType); + CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion, state.LastEventId); - // The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes - // a GET request for any notifications that might need to be sent after the completion of each POST. - if (rpcResponseOrError is not null) + HttpResponseMessage response; + try { - return rpcResponseOrError; + response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); + } + catch (HttpRequestException) + { + attempt++; + continue; + } + + using (response) + { + if (response.StatusCode >= HttpStatusCode.InternalServerError) + { + // Server error; retry. + attempt++; + continue; + } + + if (!response.IsSuccessStatusCode) + { + // If the server could be reached but returned a non-success status code, + // retrying likely won't change that. + return null; + } + + using var responseStream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var sseResponse = await ProcessSseResponseAsync(responseStream, relatedRpcRequest, state, cancellationToken).ConfigureAwait(false); + + if (sseResponse.Response is { } rpcResponseOrError) + { + return rpcResponseOrError; + } + + // If we reach here, then the stream closed without the response. + + if (sseResponse.IsNetworkError || state.LastEventId is null) + { + // No event ID means server may not support resumability; don't retry indefinitely. + attempt++; + } + else + { + // We have an event ID, so we continue polling to receive more events. + // The server should eventually send a response or return an error. + attempt = 0; + } } } return null; } + private async Task ProcessSseResponseAsync( + Stream responseStream, + JsonRpcRequest? relatedRpcRequest, + SseStreamState state, + CancellationToken cancellationToken) + { + try + { + await foreach (SseItem sseEvent in SseParser.Create(responseStream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + { + // Track event ID and retry interval for resumability + if (!string.IsNullOrEmpty(sseEvent.EventId)) + { + state.LastEventId = sseEvent.EventId; + } + if (sseEvent.ReconnectionInterval.HasValue) + { + state.RetryInterval = sseEvent.ReconnectionInterval.Value; + } + + // Skip events with empty data + if (string.IsNullOrEmpty(sseEvent.Data)) + { + continue; + } + + var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false); + if (rpcResponseOrError is not null) + { + return new() { Response = rpcResponseOrError }; + } + } + } + catch (Exception ex) when (ex is IOException or HttpRequestException) + { + return new() { IsNetworkError = true }; + } + + return default; + } + private async Task ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken) { LogTransportReceivedMessageSensitive(Name, data); @@ -292,7 +394,8 @@ internal static void CopyAdditionalHeaders( HttpRequestHeaders headers, IDictionary? additionalHeaders, string? sessionId, - string? protocolVersion) + string? protocolVersion, + string? lastEventId = null) { if (sessionId is not null) { @@ -304,6 +407,11 @@ internal static void CopyAdditionalHeaders( headers.Add("MCP-Protocol-Version", protocolVersion); } + if (lastEventId is not null) + { + headers.Add("Last-Event-ID", lastEventId); + } + if (additionalHeaders is null) { return; @@ -317,4 +425,22 @@ internal static void CopyAdditionalHeaders( } } } + + /// + /// Tracks state across SSE stream connections. + /// + private sealed class SseStreamState + { + public string? LastEventId { get; set; } + public TimeSpan? RetryInterval { get; set; } + } + + /// + /// Represents the result of processing an SSE response. + /// + private readonly struct SseResponse + { + public JsonRpcMessageWithId? Response { get; init; } + public bool IsNetworkError { get; init; } + } } diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 0b915c9f1..648072691 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -29,16 +29,38 @@ internal sealed partial class McpSessionHandler : IAsyncDisposable "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); /// The latest version of the protocol supported by this implementation. - internal const string LatestProtocolVersion = "2025-06-18"; + internal const string LatestProtocolVersion = "2025-11-25"; /// All protocol versions supported by this implementation. internal static readonly string[] SupportedProtocolVersions = [ "2024-11-05", "2025-03-26", + "2025-06-18", LatestProtocolVersion, ]; + /// + /// Checks if the given protocol version supports priming events. + /// + /// The protocol version to check. + /// True if the protocol version supports resumability. + /// + /// Priming events are only supported in protocol version >= 2025-11-25. + /// Older clients may crash when receiving SSE events with empty data. + /// + internal static bool SupportsPrimingEvent(string? protocolVersion) + { + const string MinResumabilityProtocolVersion = "2025-11-25"; + + if (protocolVersion is null) + { + return false; + } + + return string.Compare(protocolVersion, MinResumabilityProtocolVersion, StringComparison.Ordinal) >= 0; + } + private readonly bool _isServer; private readonly string _transportKind; private readonly ITransport _transport; diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index 5cd8339bf..9e22a5c0e 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -23,6 +23,7 @@ + diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs new file mode 100644 index 000000000..01c642355 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamReader.cs @@ -0,0 +1,36 @@ +using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.Server; + +/// +/// Provides read access to an SSE event stream, allowing events to be consumed asynchronously. +/// +public interface ISseEventStreamReader +{ + /// + /// Gets the session ID associated with the stream being read. + /// + string SessionId { get; } + + /// + /// Gets the ID of the stream. + /// + /// + /// This value is guaranteed to be unique on a per-session basis. + /// + string StreamId { get; } + + /// + /// Gets the messages from the stream as an . + /// + /// A token to cancel the operation. + /// An of containing JSON-RPC messages. + /// + /// If the stream's mode is set to , the returned + /// messages will only include the currently-available events starting at the last event ID specified + /// when the reader was created. Otherwise, the returned messages will continue until the associated + /// is disposed. + /// + IAsyncEnumerable> ReadEventsAsync(CancellationToken cancellationToken = default); +} diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs new file mode 100644 index 000000000..3d9d9b948 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamStore.cs @@ -0,0 +1,23 @@ +namespace ModelContextProtocol.Server; + +/// +/// Provides storage and retrieval of SSE event streams, enabling resumability and redelivery of events. +/// +public interface ISseEventStreamStore +{ + /// + /// Creates a new SSE event stream with the specified options. + /// + /// The configuration options for the new stream. + /// A token to cancel the operation. + /// A writer for the newly created event stream. + ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default); + + /// + /// Gets a reader for an existing event stream based on the last event ID. + /// + /// The ID of the last event received by the client, used to resume from that point. + /// A token to cancel the operation. + /// A reader for the event stream, or null if no matching stream is found. + ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default); +} diff --git a/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs b/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs new file mode 100644 index 000000000..43ddb2361 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/ISseEventStreamWriter.cs @@ -0,0 +1,30 @@ +using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.Server; + +/// +/// Provides write access to an SSE event stream, allowing events to be written and tracked with unique IDs. +/// +public interface ISseEventStreamWriter : IAsyncDisposable +{ + /// + /// Sets the mode of the event stream. + /// + /// The new mode to set for the event stream. + /// A token to cancel the operation. + /// A task that represents the asynchronous operation. + ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default); + + /// + /// Writes an event to the stream. + /// + /// The original . + /// A token to cancel the operation. + /// A new with a populated event ID. + /// + /// If the provided already has an event ID, this method skips writing the event. + /// Otherwise, an event ID unique to all sessions and streams is generated and assigned to the event. + /// + ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default); +} diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index 7cea77f6c..c8cab6eeb 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -558,7 +558,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) try { - return await handler(request, cancellationToken); + return await handler(request, cancellationToken).ConfigureAwait(false); } catch (Exception e) when (e is not OperationCanceledException and not McpProtocolException) { diff --git a/src/ModelContextProtocol.Core/Server/McpSseEventWriterExtensions.cs b/src/ModelContextProtocol.Core/Server/McpSseEventWriterExtensions.cs new file mode 100644 index 000000000..3f62021f0 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpSseEventWriterExtensions.cs @@ -0,0 +1,70 @@ +using ModelContextProtocol.Protocol; +using System.Buffers; +using System.Net.ServerSentEvents; +using System.Text.Json; + +namespace ModelContextProtocol.Server; + +/// +/// Provides MCP extension methods for . +/// +internal static class McpSseEventWriterExtensions +{ + [ThreadStatic] + private static Utf8JsonWriter? _jsonWriter; + + /// + /// Writes an SSE item containing a . + /// + /// The . + /// The SSE item containing the . + /// The token to monitor for cancellation requests. + /// A task representing the asynchronous write operation. + public static ValueTask WriteAsync(this SseEventWriter writer, SseItem item, CancellationToken cancellationToken = default) + => writer.WriteAsync(item, FormatJsonRpcMessage, cancellationToken); + + /// + /// Writes an SSE item containing a . + /// + /// The . + /// The SSE item containing the string. + /// The token to monitor for cancellation requests. + /// A task representing the asynchronous write operation. + public static ValueTask WriteAsync(this SseEventWriter writer, SseItem item, CancellationToken cancellationToken = default) + => writer.WriteAsync(item, FormatString, cancellationToken); + + /// + /// Formats a message by writing it as JSON to the buffer writer. + /// + private static void FormatJsonRpcMessage(SseItem item, IBufferWriter writer) + { + if (item.Data is null) + { + return; + } + + if (_jsonWriter is null) + { + _jsonWriter = new Utf8JsonWriter(writer); + } + else + { + _jsonWriter.Reset(writer); + } + + JsonSerializer.Serialize(_jsonWriter, item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!); + } + + /// + /// Formats a string by writing it as UTF-8 to the buffer writer. + /// + private static void FormatString(SseItem item, IBufferWriter writer) + { + if (item.Data is null) + { + return; + } + + writer.WriteUtf8String(item.Data); + } +} diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index a8d1f66c9..d27da571a 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -82,4 +82,21 @@ public McpServer Server /// including the method name, parameters, request ID, and associated transport and user information. /// public JsonRpcRequest JsonRpcRequest { get; } + + /// + /// Ends the current response and enables polling for updates from the server. + /// + /// The interval at which the client should poll for updates. + /// The cancellation token. + /// A that completes when polling has been enabled. + /// Thrown when the transport does not support polling. + public async ValueTask EnablePollingAsync(TimeSpan retryInterval, CancellationToken cancellationToken = default) + { + if (JsonRpcRequest.Context?.RelatedTransport is not StreamableHttpPostTransport transport) + { + throw new InvalidOperationException("Polling is only supported for Streamable HTTP transports."); + } + + await transport.EnablePollingAsync(retryInterval, cancellationToken).ConfigureAwait(false); + } } diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs new file mode 100644 index 000000000..2b7704d3d --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/SseEventStreamMode.cs @@ -0,0 +1,20 @@ +namespace ModelContextProtocol.Server; + +/// +/// Represents the mode of an SSE event stream. +/// +public enum SseEventStreamMode +{ + /// + /// Causes the event stream returned by to only end when + /// the associated gets disposed. + /// + Streaming = 0, + + /// + /// Causes the event stream returned by to end + /// after the most recent event has been consumed. This forces clients to keep making new requests in order to receive + /// the latest messages. + /// + Polling = 1, +} diff --git a/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs b/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs new file mode 100644 index 000000000..6d5be24ef --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/SseEventStreamOptions.cs @@ -0,0 +1,22 @@ +namespace ModelContextProtocol.Server; + +/// +/// Configuration options for creating an SSE event stream. +/// +public sealed class SseEventStreamOptions +{ + /// + /// Gets or sets the session ID associated with the event stream. + /// + public required string SessionId { get; set; } + + /// + /// Gets or sets the stream ID that uniquely identifies this stream within a session. + /// + public required string StreamId { get; set; } + + /// + /// Gets or sets the mode of the event stream. Defaults to . + /// + public SseEventStreamMode Mode { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs index afdf29943..315e4819e 100644 --- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; using System.Security.Claims; using System.Threading.Channels; @@ -27,14 +28,17 @@ namespace ModelContextProtocol.Server; /// The identifier corresponding to the current MCP session. public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message", string? sessionId = null) : ITransport { - private readonly SseWriter _sseWriter = new(messageEndpoint); private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1) { SingleReader = true, SingleWriter = false, }); + private readonly SemaphoreSlim _lock = new(1, 1); + private readonly TaskCompletionSource _completedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly SseEventWriter _sseWriter = new(sseResponseStream); private bool _isConnected; + private bool _disposed; /// /// Starts the transport and writes the JSON-RPC messages sent via @@ -45,7 +49,14 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? public async Task RunAsync(CancellationToken cancellationToken = default) { _isConnected = true; - await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); + + // Write the endpoint event first + if (messageEndpoint is not null) + { + await _sseWriter.WriteAsync(SseItem.Endpoint(messageEndpoint), cancellationToken).ConfigureAwait(false); + } + + await _completedTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); } /// @@ -57,17 +68,34 @@ public async Task RunAsync(CancellationToken cancellationToken = default) /// public async ValueTask DisposeAsync() { + using var _ = await _lock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + return; + } + + _disposed = true; _isConnected = false; _incomingChannel.Writer.TryComplete(); - await _sseWriter.DisposeAsync().ConfigureAwait(false); + _completedTcs.TrySetResult(true); + _sseWriter.Dispose(); } /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { Throw.IfNull(message); - // If the underlying writer has been disposed, just drop the message. - await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + + using var _ = await _lock.LockAsync(cancellationToken).ConfigureAwait(false); + + // If disposed, just drop the message. + if (_disposed) + { + return; + } + + await _sseWriter.WriteAsync(SseItem.Message(message), cancellationToken).ConfigureAwait(false); } /// diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs deleted file mode 100644 index a2314e623..000000000 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ /dev/null @@ -1,120 +0,0 @@ -using ModelContextProtocol.Protocol; -using System.Buffers; -using System.Net.ServerSentEvents; -using System.Text; -using System.Text.Json; -using System.Threading.Channels; - -namespace ModelContextProtocol.Server; - -internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOptions? channelOptions = null) : IAsyncDisposable -{ - private readonly Channel> _messages = Channel.CreateBounded>(channelOptions ?? new BoundedChannelOptions(1) - { - SingleReader = true, - SingleWriter = false, - }); - - private Utf8JsonWriter? _jsonWriter; - private Task? _writeTask; - private CancellationToken? _writeCancellationToken; - - private readonly SemaphoreSlim _disposeLock = new(1, 1); - private bool _disposed; - - public Func>, CancellationToken, IAsyncEnumerable>>? MessageFilter { get; set; } - - public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) - { - Throw.IfNull(sseResponseStream); - - // When messageEndpoint is set, the very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single - // item of a different type, so we fib and special-case the "endpoint" event type in the formatter. - if (messageEndpoint is not null && !_messages.Writer.TryWrite(new SseItem(null, "endpoint"))) - { - throw new InvalidOperationException("You must call RunAsync before calling SendMessageAsync."); - } - - _writeCancellationToken = cancellationToken; - - var messages = _messages.Reader.ReadAllAsync(cancellationToken); - if (MessageFilter is not null) - { - messages = MessageFilter(messages, cancellationToken); - } - - _writeTask = SseFormatter.WriteAsync(messages, sseResponseStream, WriteJsonRpcMessageToBuffer, cancellationToken); - return _writeTask; - } - - public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - { - Throw.IfNull(message); - - using var _ = await _disposeLock.LockAsync(cancellationToken).ConfigureAwait(false); - - if (_disposed) - { - // Don't throw ObjectDisposedException here; just return false to indicate the message wasn't sent. - // The calling transport can determine what to do in this case (drop the message, or fall back to another transport). - return false; - } - - // Emit redundant "event: message" lines for better compatibility with other SDKs. - await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false); - return true; - } - - public async ValueTask DisposeAsync() - { - using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); - - if (_disposed) - { - return; - } - - _messages.Writer.Complete(); - try - { - if (_writeTask is not null) - { - await _writeTask.ConfigureAwait(false); - } - } - catch (OperationCanceledException) when (_writeCancellationToken?.IsCancellationRequested == true) - { - // Ignore exceptions caused by intentional cancellation during shutdown. - } - finally - { - _jsonWriter?.Dispose(); - _disposed = true; - } - } - - private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter writer) - { - if (item.EventType == "endpoint" && messageEndpoint is not null) - { - writer.Write(Encoding.UTF8.GetBytes(messageEndpoint)); - return; - } - - JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!); - } - - private Utf8JsonWriter GetUtf8JsonWriter(IBufferWriter writer) - { - if (_jsonWriter is null) - { - _jsonWriter = new Utf8JsonWriter(writer); - } - else - { - _jsonWriter.Reset(writer); - } - - return _jsonWriter; - } -} diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 1109c2b2b..f0f94c270 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,9 +1,6 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; using System.Diagnostics; -using System.IO.Pipelines; using System.Net.ServerSentEvents; -using System.Runtime.CompilerServices; -using System.Security.Claims; using System.Text.Json; using System.Threading.Channels; @@ -13,17 +10,25 @@ namespace ModelContextProtocol.Server; /// Handles processing the request/response body pairs for the Streamable HTTP transport. /// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream) : ITransport +internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream, CancellationToken sessionCancellationToken) : ITransport { - private readonly SseWriter _sseWriter = new(); + private readonly SemaphoreSlim _messageLock = new(1, 1); + private readonly TaskCompletionSource _httpResponseTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly SseEventWriter _httpSseWriter = new(responseStream); + + private TaskCompletionSource? _storeStreamTcs; + private ISseEventStreamWriter? _storeSseWriter; + private RequestId _pendingRequest; + private bool _finalResponseMessageSent; + private bool _httpResponseCompleted; public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.Context.RelatedTransport should only be used for sending messages."); string? ITransport.SessionId => parentTransport.SessionId; /// - /// True, if data was written to the respond body. + /// True, if data was written to the response body. /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// @@ -35,11 +40,11 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio { _pendingRequest = request.Id; - // Invoke the initialize request callback if applicable. - if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) + // Invoke the initialize request handler if applicable. + if (request.Method == RequestMethods.Initialize) { var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - await onInitRequest(initializeRequest).ConfigureAwait(false); + await parentTransport.HandleInitRequestAsync(initializeRequest).ConfigureAwait(false); } } @@ -51,15 +56,28 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio message.Context.ExecutionContext = ExecutionContext.Capture(); } - await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); - if (_pendingRequest.Id is null) { + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); return false; } - _sseWriter.MessageFilter = StopOnFinalResponseFilter; - await _sseWriter.WriteAllAsync(responseStream, cancellationToken).ConfigureAwait(false); + using (await _messageLock.LockAsync(cancellationToken).ConfigureAwait(false)) + { + var primingItem = await TryStartSseEventStreamAsync(_pendingRequest).ConfigureAwait(false); + if (primingItem.HasValue) + { + await _httpSseWriter.WriteAsync(primingItem.Value, cancellationToken).ConfigureAwait(false); + } + + // Ensure that we've sent the priming event before processing the incoming request. + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); + } + + // Wait for the response to be written before returning from the handler. + // This keeps the HTTP response open until the final response message is sent. + await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + return true; } @@ -72,31 +90,136 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can throw new InvalidOperationException("Server to client requests are not supported in stateless mode."); } - bool isAccepted = await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); - if (!isAccepted) + using var _ = await _messageLock.LockAsync().ConfigureAwait(false); + + try { - // The underlying writer didn't accept the message because the underlying request has completed. - // Rather than drop the message, fall back to sending it via the parent transport. - await parentTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + + if (_finalResponseMessageSent) + { + // The final response message has already been sent. + // Rather than drop the message, fall back to sending it via the parent transport. + await parentTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + return; + } + + var item = new SseItem(message, SseParser.EventTypeDefault); + + if (_storeSseWriter is not null) + { + item = await _storeSseWriter.WriteEventAsync(item, cancellationToken).ConfigureAwait(false); + } + + if (!_httpResponseCompleted) + { + // Only write the message to the response if the response has not completed. + + try + { + await _httpSseWriter.WriteAsync(item, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + _httpResponseTcs.TrySetException(ex); + } + } + } + finally + { + // Complete the response if this is the final message. + if ((message is JsonRpcResponse or JsonRpcError) && ((JsonRpcMessageWithId)message).Id == _pendingRequest) + { + _finalResponseMessageSent = true; + _httpResponseTcs.TrySetResult(true); + _storeStreamTcs?.TrySetResult(true); + } } } - public async ValueTask DisposeAsync() + public async ValueTask EnablePollingAsync(TimeSpan retryInterval, CancellationToken cancellationToken) { - await _sseWriter.DisposeAsync().ConfigureAwait(false); + if (parentTransport.Stateless) + { + throw new InvalidOperationException("Polling is not supported in stateless mode."); + } + + using var _ = await _messageLock.LockAsync(cancellationToken).ConfigureAwait(false); + + if (_storeSseWriter is null) + { + throw new InvalidOperationException($"Polling requires an event stream store to be configured."); + } + + // Send the priming event with the new retry interval. + var primingItem = await _storeSseWriter.WriteEventAsync( + sseItem: new SseItem() { ReconnectionInterval = retryInterval }, + cancellationToken) + .ConfigureAwait(false); + + // Write to the response stream if it still exists. + if (!_httpResponseCompleted) + { + await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false); + } + + // Set the mode to 'Polling' so that the replay stream ends as soon as all available messages have been sent. + // This prevents the client from immediately establishing another long-lived connection. + await _storeSseWriter.SetModeAsync(SseEventStreamMode.Polling, cancellationToken).ConfigureAwait(false); + + // Signal completion so HandlePostAsync can return. + _httpResponseTcs.TrySetResult(true); } - private async IAsyncEnumerable> StopOnFinalResponseFilter(IAsyncEnumerable> messages, [EnumeratorCancellation] CancellationToken cancellationToken) + private async ValueTask?> TryStartSseEventStreamAsync(RequestId requestId) { - await foreach (var message in messages.WithCancellation(cancellationToken)) + Debug.Assert(_storeSseWriter is null); + + _storeSseWriter = await parentTransport.TryCreateEventStreamAsync( + streamId: requestId.Id!.ToString()!, + cancellationToken: sessionCancellationToken) + .ConfigureAwait(false); + + if (_storeSseWriter is null) { - yield return message; + return null; + } - if (message.Data is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message.Data).Id == _pendingRequest) + _storeStreamTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _ = HandleStoreStreamDisposalAsync(_storeStreamTcs.Task); + + return await _storeSseWriter.WriteEventAsync(SseItem.Prime(), sessionCancellationToken).ConfigureAwait(false); + + async Task HandleStoreStreamDisposalAsync(Task streamTask) + { + try { - // Complete the SSE response stream now that all pending requests have been processed. - break; + await streamTask.WaitAsync(sessionCancellationToken).ConfigureAwait(false); } + finally + { + using var _ = await _messageLock.LockAsync().ConfigureAwait(false); + + await _storeSseWriter!.DisposeAsync().ConfigureAwait(false); + } + } + } + + public async ValueTask DisposeAsync() + { + using var _ = await _messageLock.LockAsync().ConfigureAwait(false); + + if (_httpResponseCompleted) + { + return; } + + _httpResponseCompleted = true; + + _httpResponseTcs.TrySetResult(true); + + _httpSseWriter.Dispose(); + + // Don't dispose the event stream writer here, as we may continue to write to the event store + // after disposal if there are pending messages. } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index c99b1fa39..6307726d0 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,5 +1,7 @@ using ModelContextProtocol.Protocol; -using System.IO.Pipelines; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Net.ServerSentEvents; using System.Security.Claims; using System.Threading.Channels; @@ -21,24 +23,27 @@ namespace ModelContextProtocol.Server; /// public sealed class StreamableHttpServerTransport : ITransport { - // For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages. - private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1) - { - SingleReader = true, - SingleWriter = false, - FullMode = BoundedChannelFullMode.DropOldest, - }); + /// + /// The stream ID used for unsolicited messages sent via the standalone GET SSE stream. + /// + public static readonly string UnsolicitedMessageStreamId = "__get__"; + private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1) { SingleReader = true, SingleWriter = false, }); - private readonly CancellationTokenSource _disposeCts = new(); + private readonly CancellationTokenSource _transportDisposedCts = new(); + private readonly SemaphoreSlim _unsolicitedMessageLock = new(1, 1); - private int _getRequestStarted; + private SseEventWriter? _httpSseWriter; + private ISseEventStreamWriter? _storeSseWriter; + private TaskCompletionSource? _httpResponseTcs; + private bool _getHttpRequestStarted; + private bool _getHttpResponseCompleted; /// - public string? SessionId { get; set; } + public string? SessionId { get; init; } /// /// Gets or initializes a value that indicates whether the transport should be in stateless mode that does not require all requests for a given session @@ -59,15 +64,30 @@ public sealed class StreamableHttpServerTransport : ITransport public bool FlowExecutionContextFromRequests { get; init; } /// - /// Gets or sets a callback to be invoked before handling the initialize request. + /// Gets or sets the event store for resumability support. + /// When set, events are stored and can be replayed when clients reconnect with a Last-Event-ID header. /// - public Func? OnInitRequestReceived { get; set; } + public ISseEventStreamStore? EventStreamStore { get; init; } + + /// + /// Gets or sets the negotiated protocol version for this session. + /// + internal string? NegotiatedProtocolVersion { get; private set; } /// public ChannelReader MessageReader => _incomingChannel.Reader; internal ChannelWriter MessageWriter => _incomingChannel.Writer; + /// + /// Handles the initialize request by capturing the protocol version and invoking the user callback. + /// + internal async ValueTask HandleInitRequestAsync(InitializeRequestParams? initParams) + { + // Capture the negotiated protocol version for resumability checks + NegotiatedProtocolVersion = initParams?.ProtocolVersion; + } + /// /// Handles an optional SSE GET request a client using the Streamable HTTP transport might make by /// writing any unsolicited JSON-RPC messages sent via @@ -78,8 +98,7 @@ public sealed class StreamableHttpServerTransport : ITransport /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. /// is . /// - /// is and GET requests are not supported in stateless mode, - /// or a GET request has already been started for this session. + /// is and GET requests are not supported in stateless mode. /// public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default) { @@ -90,13 +109,27 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo throw new InvalidOperationException("GET requests are not supported in stateless mode."); } - if (Interlocked.Exchange(ref _getRequestStarted, 1) == 1) + using (await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false)) { - throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); + if (_getHttpRequestStarted) + { + throw new InvalidOperationException("Session resumption is not yet supported. Please start a new session."); + } + + _getHttpRequestStarted = true; + _httpSseWriter = new SseEventWriter(sseResponseStream); + _httpResponseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _storeSseWriter = await TryCreateEventStreamAsync(streamId: UnsolicitedMessageStreamId, cancellationToken).ConfigureAwait(false); + if (_storeSseWriter is not null) + { + var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime(), cancellationToken).ConfigureAwait(false); + await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false); + } } - // We do not need to reference _disposeCts like in HandlePostRequest, because the session ending completes the _sseWriter gracefully. - await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); + // Wait for the response to be written before returning from the handler. + // This keeps the HTTP response open until the final response message is sent. + await _httpResponseTcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); } /// @@ -122,9 +155,15 @@ public async Task HandlePostRequestAsync(JsonRpcMessage message, Stream re Throw.IfNull(message); Throw.IfNull(responseStream); - using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await using var postTransport = new StreamableHttpPostTransport(this, responseStream); - return await postTransport.HandlePostAsync(message, postCts.Token).ConfigureAwait(false); + var postTransport = new StreamableHttpPostTransport(this, responseStream, _transportDisposedCts.Token); + using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_transportDisposedCts.Token, cancellationToken); + await using (postTransport.ConfigureAwait(false)) + { + return await postTransport.HandlePostAsync( + message, + cancellationToken: postCts.Token) + .ConfigureAwait(false); + } } /// @@ -137,28 +176,94 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can throw new InvalidOperationException("Unsolicited server to client messages are not supported in stateless mode."); } - // If the underlying writer has been disposed, just drop the message. - await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + using var _ = await _unsolicitedMessageLock.LockAsync(cancellationToken).ConfigureAwait(false); + + if (!_getHttpRequestStarted) + { + // Clients are not required to make a GET request for unsolicited messages. + // If no GET request has been made, drop the message. + return; + } + + Debug.Assert(_httpSseWriter is not null); + Debug.Assert(_httpResponseTcs is not null); + + var item = SseItem.Message(message); + + if (_storeSseWriter is not null) + { + item = await _storeSseWriter.WriteEventAsync(item, cancellationToken).ConfigureAwait(false); + } + + if (!_getHttpResponseCompleted) + { + // Only write the message to the response if the response has not completed. + + try + { + await _httpSseWriter!.WriteAsync(item, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) when (!cancellationToken.IsCancellationRequested) + { + _httpResponseTcs!.TrySetException(ex); + } + } } /// public async ValueTask DisposeAsync() { + using var _ = await _unsolicitedMessageLock.LockAsync().ConfigureAwait(false); + + if (_getHttpResponseCompleted) + { + return; + } + + _getHttpResponseCompleted = true; + try { _incomingChannel.Writer.TryComplete(); - await _disposeCts.CancelAsync(); + await _transportDisposedCts.CancelAsync().ConfigureAwait(false); } finally { try { - await _sseWriter.DisposeAsync().ConfigureAwait(false); + _httpResponseTcs?.TrySetResult(true); + _httpSseWriter?.Dispose(); + + if (_storeSseWriter is not null) + { + await _storeSseWriter.DisposeAsync().ConfigureAwait(false); + } } finally { - _disposeCts.Dispose(); + _transportDisposedCts.Dispose(); } } } + + internal async ValueTask TryCreateEventStreamAsync(string streamId, CancellationToken cancellationToken) + { + if (EventStreamStore is null || !McpSessionHandler.SupportsPrimingEvent(NegotiatedProtocolVersion)) + { + return null; + } + + // We use the 'Streaming' stream mode so that in the case of an unexpected network disconnection, + // the client can continue reading the remaining messages in a single, streamed response. + const SseEventStreamMode Mode = SseEventStreamMode.Streaming; + + var sseEventStreamWriter = await EventStreamStore.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = SessionId ?? Guid.NewGuid().ToString("N"), + StreamId = streamId, + Mode = Mode, + }, cancellationToken).ConfigureAwait(false); + + return sseEventStreamWriter; + } } diff --git a/src/ModelContextProtocol/.editorconfig b/src/ModelContextProtocol/.editorconfig new file mode 100644 index 000000000..3a5001118 --- /dev/null +++ b/src/ModelContextProtocol/.editorconfig @@ -0,0 +1,2 @@ +[*.cs] +dotnet_diagnostic.CA2007.severity = error # CA2007: Do not directly await a Task without ConfigureAwait diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index 728304070..46edd23f6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -1,5 +1,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; namespace ModelContextProtocol.AspNetCore.Tests; @@ -56,4 +58,37 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name); } + + [Fact] + public async Task EnablePollingAsync_ThrowsInvalidOperationException_InSseMode() + { + InvalidOperationException? capturedException = null; + var pollingTool = McpServerTool.Create(async (RequestContext context) => + { + try + { + await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); + } + catch (InvalidOperationException ex) + { + capturedException = ex; + } + + return "Complete"; + }, options: new() { Name = "polling_tool" }); + + Builder.Services.AddMcpServer().WithHttpTransport().WithTools([pollingTool]); + + await using var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync(); + + await mcpClient.CallToolAsync("polling_tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(capturedException); + Assert.Contains("Streamable HTTP", capturedException.Message, StringComparison.OrdinalIgnoreCase); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs index cce2e4f0f..9e48dd473 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs @@ -283,4 +283,75 @@ public async Task CanResumeSessionWithMapMcpAndRunSessionHandler() Assert.Equal(1, runSessionCount); } + + [Fact] + public async Task EnablePollingAsync_ThrowsInvalidOperationException_InStatelessMode() + { + Assert.SkipUnless(Stateless, "This test only applies to stateless mode."); + + InvalidOperationException? capturedException = null; + var pollingTool = McpServerTool.Create(async (RequestContext context) => + { + try + { + await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); + } + catch (InvalidOperationException ex) + { + capturedException = ex; + } + + return "Complete"; + }, options: new() { Name = "polling_tool" }); + + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools([pollingTool]); + + await using var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync(); + + await mcpClient.CallToolAsync("polling_tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(capturedException); + Assert.Contains("stateless", capturedException.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task EnablePollingAsync_ThrowsInvalidOperationException_WhenNoEventStreamStoreConfigured() + { + Assert.SkipWhen(Stateless, "This test only applies to stateful mode without an event stream store."); + + InvalidOperationException? capturedException = null; + var pollingTool = McpServerTool.Create(async (RequestContext context) => + { + try + { + await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); + } + catch (InvalidOperationException ex) + { + capturedException = ex; + } + + return "Complete"; + }, options: new() { Name = "polling_tool" }); + + // Configure without EventStreamStore + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools([pollingTool]); + + await using var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var mcpClient = await ConnectAsync(); + + await mcpClient.CallToolAsync("polling_tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(capturedException); + Assert.Contains("event stream store", capturedException.Message, StringComparison.OrdinalIgnoreCase); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs new file mode 100644 index 000000000..e3425b253 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -0,0 +1,676 @@ +using System.ComponentModel; +using System.Diagnostics; +using System.Net; +using System.Net.ServerSentEvents; +using System.Text; +using System.Text.Json.Nodes; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Integration tests for SSE resumability with full client-server flow. +/// These tests use McpClient for end-to-end testing and only use raw HTTP +/// for SSE format verification where McpClient abstracts away the details. +/// +public class ResumabilityIntegrationTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + private const string InitializeRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"TestClient","version":"1.0.0"}}} + """; + + [Fact] + public async Task Server_StoresEvents_WhenEventStoreConfigured() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act - Make a tool call which generates events + var result = await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // Assert - Events were stored + Assert.NotNull(result); + Assert.True(eventStreamStore.StoreEventCallCount > 0, "Expected events to be stored when EventStore is configured"); + } + + [Fact] + public async Task Server_StoresMultipleEvents_ForMultipleToolCalls() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act - Make multiple tool calls + var initialCount = eventStreamStore.StoreEventCallCount; + + await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test1" }, + cancellationToken: TestContext.Current.CancellationToken); + + var countAfterFirst = eventStreamStore.StoreEventCallCount; + + await client.CallToolAsync("echo", + new Dictionary { ["message"] = "test2" }, + cancellationToken: TestContext.Current.CancellationToken); + + var countAfterSecond = eventStreamStore.StoreEventCallCount; + + // Assert - More events were stored for each call + Assert.True(countAfterFirst > initialCount, "Expected more events after first call"); + Assert.True(countAfterSecond > countAfterFirst, "Expected more events after second call"); + } + + [Fact] + public async Task Client_CanMakeMultipleRequests_WithResumabilityEnabled() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act - Make many requests to verify stability + for (int i = 0; i < 5; i++) + { + var result = await client.CallToolAsync("echo", + new Dictionary { ["message"] = $"test{i}" }, + cancellationToken: TestContext.Current.CancellationToken); + + var textContent = Assert.Single(result.Content.OfType()); + Assert.Equal($"Echo: test{i}", textContent.Text); + } + + // Assert - All requests succeeded and events were stored + Assert.True(eventStreamStore.StoreEventCallCount >= 5, "Expected events to be stored for each request"); + } + + [Fact] + public async Task Ping_WorksWithResumabilityEnabled() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act & Assert - Ping should work + await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task ListTools_WorksWithResumabilityEnabled() + { + // Arrange + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + await using var client = await ConnectClientAsync(); + + // Act + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(tools); + Assert.Single(tools); + } + + [Fact] + public async Task Server_WithoutEventStore_DoesNotIncludeEventId() + { + // Arrange - Server without event store + await using var app = await CreateServerAsync(); + + // Act + var sseResponse = await SendInitializeAndReadSseResponseAsync(InitializeRequest); + + // Assert - No event IDs or retry field when EventStore is not configured + Assert.True(sseResponse.LastEventId is null, "Did not expect event IDs when EventStore is not configured"); + } + + [Fact] + public async Task Server_DoesNotSendPrimingEvents_ToOlderProtocolVersionClients() + { + // Arrange - Server with resumability enabled + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + + // Use an older protocol version that doesn't support resumability + const string OldProtocolInitRequest = """ + {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"OldClient","version":"1.0.0"}}} + """; + + var sseResponse = await SendInitializeAndReadSseResponseAsync(OldProtocolInitRequest); + + // Assert - Old clients should not receive event IDs or retry fields (no priming events) + Assert.True(sseResponse.LastEventId is null, "Old protocol clients should not receive event IDs"); + + // Event store should not have been called for old clients + Assert.Equal(0, eventStreamStore.StoreEventCallCount); + } + + [Fact] + public async Task Client_CanPollResponse_FromServer() + { + const string ProgressToolName = "progress_tool"; + var clientReceivedInitialValueTcs = new TaskCompletionSource(); + var clientReceivedPolledValueTcs = new TaskCompletionSource(); + var progressTool = McpServerTool.Create(async (RequestContext context, IProgress progress) => + { + progress.Report(new() { Progress = 0, Message = "Initial value" }); + + await clientReceivedInitialValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + await context.EnablePollingAsync(retryInterval: TimeSpan.FromSeconds(1)); + + progress.Report(new() { Progress = 50, Message = "Polled value" }); + + await clientReceivedPolledValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken); ; + + return "Complete"; + }, options: new() { Name = ProgressToolName }); + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, configureServer: builder => + { + builder.WithTools([progressTool]); + }); + await using var client = await ConnectClientAsync(); + + var progressHandler = new Progress(value => + { + switch (value.Message) + { + case "Initial value": + Assert.True(clientReceivedInitialValueTcs.TrySetResult(), "Received the initial value more than once."); + break; + case "Polled value": + Assert.True(clientReceivedPolledValueTcs.TrySetResult(), "Received the polled value more than once."); + break; + default: + throw new UnreachableException($"Unknown progress message '{value.Message}'"); + } + }); + + var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError is true); + Assert.Equal("Complete", result.Content.OfType().Single().Text); + } + + [Fact] + public async Task Client_CanResumePostResponseStream_AfterDisconnection() + { + using var faultingStreamHandler = new FaultingStreamHandler() + { + InnerHandler = SocketsHttpHandler, + }; + + HttpClient = new(faultingStreamHandler); + ConfigureHttpClient(HttpClient); + + const string ProgressToolName = "progress_tool"; + const string InitialMessage = "Initial notification"; + const string ReplayedMessage = "Replayed notification"; + const string ResultMessage = "Complete"; + + var clientReceivedInitialValueTcs = new TaskCompletionSource(); + var clientReceivedReconnectValueTcs = new TaskCompletionSource(); + var progressTool = McpServerTool.Create(async (RequestContext context, IProgress progress, CancellationToken cancellationToken) => + { + progress.Report(new() { Progress = 0, Message = InitialMessage }); + + // Make sure the client receives one message before we disconnect. + await clientReceivedInitialValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Simulate a network disconnection by faulting the response stream. + var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken); + + // Send another message that the client should receive after reconnecting. + progress.Report(new() { Progress = 50, Message = ReplayedMessage }); + + reconnectAttempt.Continue(); + + // Wait for the client to receive the message via replay. + await clientReceivedReconnectValueTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Return the final result with the client still connected. + return ResultMessage; + }, options: new() { Name = ProgressToolName }); + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, configureServer: builder => + { + builder.WithTools([progressTool]); + }); + await using var client = await ConnectClientAsync(); + + var initialNotificationReceivedCount = 0; + var replayedNotificationReceivedCount = 0; + var progressHandler = new Progress(value => + { + switch (value.Message) + { + case InitialMessage: + initialNotificationReceivedCount++; + clientReceivedInitialValueTcs.TrySetResult(); + break; + case ReplayedMessage: + replayedNotificationReceivedCount++; + clientReceivedReconnectValueTcs.TrySetResult(); + break; + default: + throw new UnreachableException($"Unknown progress message '{value.Message}'"); + } + }); + + var result = await client.CallToolAsync(ProgressToolName, progress: progressHandler, cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError is true); + Assert.Equal(1, initialNotificationReceivedCount); + Assert.Equal(1, replayedNotificationReceivedCount); + Assert.Equal(ResultMessage, result.Content.OfType().Single().Text); + } + + [Fact] + public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() + { + using var faultingStreamHandler = new FaultingStreamHandler() + { + InnerHandler = SocketsHttpHandler, + }; + + HttpClient = new(faultingStreamHandler); + ConfigureHttpClient(HttpClient); + + var eventStreamStore = new TestSseEventStreamStore(); + + // Capture the server instance via RunSessionHandler + var serverTcs = new TaskCompletionSource(); + + await using var app = await CreateServerAsync(eventStreamStore, configureTransport: options => + { + options.RunSessionHandler = (httpContext, mcpServer, cancellationToken) => + { + serverTcs.TrySetResult(mcpServer); + return mcpServer.RunAsync(cancellationToken); + }; + }); + + await using var client = await ConnectClientAsync(); + + // Get the server instance + var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Set up notification tracking with unique messages + var clientReceivedInitialNotificationTcs = new TaskCompletionSource(); + var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(); + var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(); + + const string CustomNotificationMethod = "test/custom_notification"; + const string InitialMessage = "Initial notification"; + const string ReplayedMessage = "Replayed notification"; + const string ReconnectMessage = "Reconnect notification"; + + var initialNotificationReceivedCount = 0; + var replayedNotificationReceivedCount = 0; + var reconnectNotificationReceivedCount = 0; + + await using var _ = client.RegisterNotificationHandler(CustomNotificationMethod, (notification, cancellationToken) => + { + var message = notification.Params?["message"]?.GetValue(); + switch (message) + { + case InitialMessage: + initialNotificationReceivedCount++; + clientReceivedInitialNotificationTcs.TrySetResult(); + break; + case ReplayedMessage: + replayedNotificationReceivedCount++; + clientReceivedReplayedNotificationTcs.TrySetResult(); + break; + case ReconnectMessage: + reconnectNotificationReceivedCount++; + clientReceivedReconnectNotificationTcs.TrySetResult(); + break; + default: + throw new UnreachableException($"Unknown notification message '{message}'"); + } + return default; + }); + + // Send a custom notification to the client on the unsolicited message stream + await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken); + + // Wait for client to receive the first notification + await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Fault the unsolicited message stream (GET SSE) + var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken); + + // Send another notification while the client is disconnected - this should be stored + await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReplayedMessage }, cancellationToken: TestContext.Current.CancellationToken); + + // Allow the client to reconnect + reconnectAttempt.Continue(); + + // Wait for client to receive the notification via replay + await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Send a final notification while the client has reconnected - this should be handled by the transport + await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReconnectMessage }, cancellationToken: TestContext.Current.CancellationToken); + + // Wait for the client to receive the final notification + await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + + // Assert each notification was received exactly once + Assert.Equal(1, initialNotificationReceivedCount); + Assert.Equal(1, replayedNotificationReceivedCount); + Assert.Equal(1, reconnectNotificationReceivedCount); + } + + [Fact] + public async Task Server_Returns400_WhenLastEventIdRefersToWrongSession() + { + // Arrange - Create server with event store + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore); + + // First, initialize a session and make a call to generate some events + using var initRequest = new HttpRequestMessage(HttpMethod.Post, "/") + { + Headers = + { + Accept = { new("application/json"), new("text/event-stream") } + }, + Content = new StringContent(InitializeRequest, Encoding.UTF8, "application/json"), + }; + var initResponse = await HttpClient.SendAsync(initRequest, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + initResponse.EnsureSuccessStatusCode(); + + // Get the session ID from the response + var sessionId = initResponse.Headers.GetValues("Mcp-Session-Id").First(); + + // Read the SSE response to get an event ID + await using var initStream = await initResponse.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + string? eventId = null; + await foreach (var sseItem in SseParser.Create(initStream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + if (!string.IsNullOrEmpty(sseItem.EventId)) + { + eventId = sseItem.EventId; + } + } + + Assert.NotNull(eventId); + + // Act - Try to resume with a different session ID but the same event ID + var wrongSessionId = "wrong-session-id"; + using var resumeRequest = new HttpRequestMessage(HttpMethod.Get, "/") + { + Headers = + { + Accept = { new("text/event-stream") }, + } + }; + resumeRequest.Headers.Add("Mcp-Session-Id", wrongSessionId); + resumeRequest.Headers.Add("Last-Event-ID", eventId); + + var resumeResponse = await HttpClient.SendAsync(resumeRequest, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + + // Assert - First we get 404 because the wrong session doesn't exist + Assert.Equal(HttpStatusCode.NotFound, resumeResponse.StatusCode); + + // Now test with an existing session but event ID from a different session + // Create a second session + using var initRequest2 = new HttpRequestMessage(HttpMethod.Post, "/") + { + Headers = + { + Accept = { new("application/json"), new("text/event-stream") } + }, + Content = new StringContent(InitializeRequest, Encoding.UTF8, "application/json"), + }; + var initResponse2 = await HttpClient.SendAsync(initRequest2, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + initResponse2.EnsureSuccessStatusCode(); + + var sessionId2 = initResponse2.Headers.GetValues("Mcp-Session-Id").First(); + Assert.NotEqual(sessionId, sessionId2); + + // Read the second session's response + await using var initStream2 = await initResponse2.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var _ in SseParser.Create(initStream2).EnumerateAsync(TestContext.Current.CancellationToken)) + { + // Consume the stream + } + + // Try to use session 2's ID but with an event ID from session 1 + using var mismatchRequest = new HttpRequestMessage(HttpMethod.Get, "/") + { + Headers = + { + Accept = { new("text/event-stream") }, + } + }; + mismatchRequest.Headers.Add("Mcp-Session-Id", sessionId2); + mismatchRequest.Headers.Add("Last-Event-ID", eventId); // This event ID belongs to session 1 + + var mismatchResponse = await HttpClient.SendAsync(mismatchRequest, HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); + + // Assert - Should get 400 Bad Request because the event ID doesn't match the session + Assert.Equal(HttpStatusCode.BadRequest, mismatchResponse.StatusCode); + + // Verify the error message + var responseBody = await mismatchResponse.Content.ReadAsStringAsync(TestContext.Current.CancellationToken); + var errorResponse = JsonNode.Parse(responseBody); + Assert.NotNull(errorResponse); + var errorMessage = errorResponse["error"]?["message"]?.GetValue(); + Assert.Equal("Bad Request: The Last-Event-ID header refers to a session with a different session ID.", errorMessage); + } + + [Fact] + public async Task EnablePollingAsync_SendsSseItemWithRetryField() + { + // Arrange + const string PollingToolName = "polling_tool"; + var expectedRetryInterval = TimeSpan.FromSeconds(5); + var pollingTool = McpServerTool.Create(async (RequestContext context) => + { + await context.EnablePollingAsync(retryInterval: expectedRetryInterval); + return "Polling enabled"; + }, options: new() { Name = PollingToolName }); + + var eventStreamStore = new TestSseEventStreamStore(); + await using var app = await CreateServerAsync(eventStreamStore, configureServer: builder => + { + builder.WithTools([pollingTool]); + }); + await using var client = await ConnectClientAsync(); + + // Act - Call the tool that enables polling + var result = await client.CallToolAsync(PollingToolName, cancellationToken: TestContext.Current.CancellationToken); + + // Assert - The result should be successful + Assert.False(result.IsError is true); + Assert.Equal("Polling enabled", result.Content.OfType().Single().Text); + + // Verify that the event store received the retry interval + Assert.Contains(expectedRetryInterval, eventStreamStore.StoredReconnectionIntervals); + } + + [Fact] + public async Task PostResponse_EndsAndSseEventStreamWriterIsDisposed_WhenWriteEventAsyncIsCanceled() + { + var blockingStore = new BlockingEventStreamStore(); + await using var app = await CreateServerAsync(blockingStore); + await using var client = await ConnectClientAsync(); + + // Enable blocking now that initialization is complete + blockingStore.EnableBlocking(); + + using var callCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + + // Start calling the tool - this will eventually trigger WriteEventAsync for the response + var callTask = client.CallToolAsync("echo", + new Dictionary { ["message"] = "test" }, + cancellationToken: callCts.Token).AsTask(); + + // Wait for the writer to block on WriteEventAsync for the response message + await blockingStore.WriteEventBlockedTask.WaitAsync(TestContext.Current.CancellationToken); + + // Cancel the token while the writer is blocked - this causes an OCE to bubble up + // to SendMessageAsync + await callCts.CancelAsync(); + + // The call should complete (with an error or cancellation) without hanging + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(10)); + + // The call task should throw an OCE due to cancellation + await Assert.ThrowsAsync(() => callTask).WaitAsync(timeoutCts.Token); + + // Wait for the writer to be disposed + await blockingStore.DisposedTask.WaitAsync(timeoutCts.Token); + } + + [McpServerToolType] + private class ResumabilityTestTools + { + [McpServerTool(Name = "echo"), Description("Echoes the message back")] + public static string Echo(string message) => $"Echo: {message}"; + } + + private async Task CreateServerAsync( + ISseEventStreamStore? eventStreamStore = null, + Action? configureServer = null, + Action? configureTransport = null) + { + var serverBuilder = Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.EventStreamStore = eventStreamStore; + configureTransport?.Invoke(options); + }) + .WithTools(); + + configureServer?.Invoke(serverBuilder); + + var app = Builder.Build(); + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private async Task ConnectClientAsync() + { + var transport = new HttpClientTransport(new HttpClientTransportOptions + { + Endpoint = new Uri("http://localhost:5000/"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + return await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + } + + private async Task SendInitializeAndReadSseResponseAsync(string initializeRequest) + { + using var requestContent = new StringContent(initializeRequest, Encoding.UTF8, "application/json"); + using var request = new HttpRequestMessage(HttpMethod.Post, "/") + { + Headers = + { + Accept = { new("application/json"), new("text/event-stream") } + }, + Content = requestContent, + }; + + var response = await HttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, + TestContext.Current.CancellationToken); + + response.EnsureSuccessStatusCode(); + + var sseResponse = new SseResponse(); + await using var stream = await response.Content.ReadAsStreamAsync(TestContext.Current.CancellationToken); + await foreach (var sseItem in SseParser.Create(stream).EnumerateAsync(TestContext.Current.CancellationToken)) + { + if (!string.IsNullOrEmpty(sseItem.EventId)) + { + sseResponse.LastEventId = sseItem.EventId; + } + } + + return sseResponse; + } + + private struct SseResponse + { + public string? LastEventId { get; set; } + } + + /// + /// A test event stream store that blocks on WriteEventAsync for response messages, + /// allowing the test to cancel the operation and verify proper cleanup. + /// + private sealed class BlockingEventStreamStore : ISseEventStreamStore + { + private readonly TaskCompletionSource _writeEventBlockedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _disposedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private bool _blockingEnabled; + + public Task WriteEventBlockedTask => _writeEventBlockedTcs.Task; + public Task DisposedTask => _disposedTcs.Task; + + public void EnableBlocking() => _blockingEnabled = true; + + public ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default) + => new(new BlockingEventStreamWriter(this)); + + public ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default) + => throw new NotSupportedException("This test store does not support reading streams."); + + private sealed class BlockingEventStreamWriter : ISseEventStreamWriter + { + private readonly BlockingEventStreamStore _store; + + public BlockingEventStreamWriter(BlockingEventStreamStore store) + { + _store = store; + } + + public ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) => default; + + public async ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default) + { + // Skip if already has an event ID (replay) + if (sseItem.EventId is not null) + { + return sseItem; + } + + // Block when we receive a response and blocking is enabled + if (sseItem.Data is JsonRpcResponse && _store._blockingEnabled) + { + // Signal that we're blocked + _store._writeEventBlockedTcs.TrySetResult(); + + // Wait to be canceled + await new TaskCompletionSource().Task.WaitAsync(cancellationToken); + } + + return sseItem with { EventId = "0" }; + } + + public ValueTask DisposeAsync() + { + _store._disposedTcs.TrySetResult(); + return default; + } + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index 7d8bbfd45..a6b0fc70e 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -287,6 +287,35 @@ async Task GetFirstNotificationAsync() Assert.Equal("test-method", await GetFirstNotificationAsync()); } + [Fact] + public async Task SendNotificationAsync_DoesNotThrow_WhenNoGetRequestHasBeenMade() + { + // Clients are not required to make a GET request for unsolicited messages. + // If no GET request has been made, the messages should be dropped rather than throwing. + McpServer? server = null; + + Builder.Services.AddMcpServer() + .WithHttpTransport(options => + { + options.RunSessionHandler = (httpContext, mcpServer, cancellationToken) => + { + server = mcpServer; + return mcpServer.RunAsync(cancellationToken); + }; + }); + + await StartAsync(); + + await CallInitializeAndValidateAsync(); + Assert.NotNull(server); + + // Calling SendNotificationAsync before a GET request should not throw. + // The notification should be silently dropped. + var exception = await Record.ExceptionAsync(() => + server.SendNotificationAsync("test-method", TestContext.Current.CancellationToken)); + Assert.Null(exception); + } + [Fact] public async Task SecondGetRequests_IsRejected_AsBadRequest() { @@ -567,6 +596,7 @@ public async Task McpServer_UsedOutOfScope_CanSendNotifications() SetSessionId(sessionId); // Call the subscribe method to capture the McpServer instance. + using var getResponse = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); using var response = await HttpClient.PostAsync("", JsonContent(SubscribeToResource("file:///test")), TestContext.Current.CancellationToken); var rpcResponse = await AssertSingleSseResponseAsync(response); AssertType(rpcResponse.Result); @@ -574,7 +604,6 @@ public async Task McpServer_UsedOutOfScope_CanSendNotifications() // Check the captured McpServer instance can send a notification. await capturedServer.SendNotificationAsync(NotificationMethods.ResourceUpdatedNotification, TestContext.Current.CancellationToken); - using var getResponse = await HttpClient.GetAsync("", HttpCompletionOption.ResponseHeadersRead, TestContext.Current.CancellationToken); JsonRpcMessage? firstSseMessage = await ReadSseAsync(getResponse.Content) .Select(data => JsonSerializer.Deserialize(data, McpJsonUtilities.DefaultOptions)) .FirstOrDefaultAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs new file mode 100644 index 000000000..cace4d8be --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs @@ -0,0 +1,175 @@ +using System.Diagnostics; +using System.Net; + +namespace ModelContextProtocol.AspNetCore.Tests.Utils; + +/// +/// A message handler that wraps SSE response streams and can trigger faults mid-stream +/// to simulate network disconnections during SSE streaming. +/// +internal sealed class FaultingStreamHandler : DelegatingHandler +{ + private FaultingStream? _lastStream; + private TaskCompletionSource? _reconnectTcs; + + public async Task TriggerFaultAsync(CancellationToken cancellationToken) + { + if (_lastStream is null or { IsDisposed: true }) + { + throw new InvalidOperationException("There is no active response stream to fault."); + } + + if (_reconnectTcs is not null) + { + throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection."); + } + + _reconnectTcs = new(); + await _lastStream.TriggerFaultAsync(cancellationToken); + + return new(_reconnectTcs); + } + + public sealed class ReconnectAttempt(TaskCompletionSource reconnectTcs) + { + public void Continue() + => reconnectTcs.SetResult(); + } + + protected override async Task SendAsync( + HttpRequestMessage request, CancellationToken cancellationToken) + { + if (_reconnectTcs is not null && request.Headers.Accept.Contains(new("text/event-stream"))) + { + // If we're blocking reconnection, wait until we're allowed to continue. + await _reconnectTcs.Task.WaitAsync(cancellationToken); + _reconnectTcs = null; + } + + var response = await base.SendAsync(request, cancellationToken); + + // Only wrap SSE streams (text/event-stream) + if (response.Content.Headers.ContentType?.MediaType == "text/event-stream") + { + var originalStream = await response.Content.ReadAsStreamAsync(cancellationToken); + _lastStream = new FaultingStream(originalStream); + var faultingContent = new FaultingStreamContent(_lastStream); + + // Copy headers from original content + var newContent = faultingContent; + foreach (var header in response.Content.Headers) + { + newContent.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + + response.Content = newContent; + } + + return response; + } + + private sealed class FaultingStreamContent(FaultingStream stream) : HttpContent + { + private readonly FaultingStream _manualStream = new(stream); + + protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context) + => throw new NotSupportedException(); + + protected override Task CreateContentReadStreamAsync() + => Task.FromResult(stream); + + protected override bool TryComputeLength(out long length) + { + length = -1; + return false; + } + } + + private sealed class FaultingStream(Stream innerStream) : Stream + { + private readonly CancellationTokenSource _cts = new(); + private TaskCompletionSource? _faultTcs; + private bool _disposed; + + public bool IsDisposed => _disposed; + + public async Task TriggerFaultAsync(CancellationToken cancellationToken) + { + if (_faultTcs is not null) + { + throw new InvalidOperationException("Only one fault can be triggered per stream."); + } + + _faultTcs = new TaskCompletionSource(); + + await _cts.CancelAsync(); + + // Use a timeout to detect if the fault is not observed by a read operation. + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(30)); + + try + { + await _faultTcs.Task.WaitAsync(timeoutCts.Token); + } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + throw new TimeoutException( + $"TriggerFaultAsync timed out after 30 seconds waiting for a read to observe the cancellation. " + + $"Stream disposed: {_disposed}, CTS cancelled: {_cts.IsCancellationRequested}"); + } + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + try + { + _cts.Token.ThrowIfCancellationRequested(); + + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cts.Token); + var bytesRead = await innerStream.ReadAsync(buffer, linkedCts.Token); + + _cts.Token.ThrowIfCancellationRequested(); + + return bytesRead; + } + catch (OperationCanceledException) when (_cts.IsCancellationRequested) + { + Debug.Assert(_faultTcs is not null); + + if (!_faultTcs.TrySetResult()) + { + throw new InvalidOperationException("Attempted to read an already-faulted stream."); + } + + throw new IOException("Simulated network disconnection."); + } + } + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException("Synchronous reads are not supported."); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + + public override bool CanRead => innerStream.CanRead; + public override bool CanSeek => innerStream.CanSeek; + public override bool CanWrite => innerStream.CanWrite; + public override long Length => innerStream.Length; + public override long Position { get => innerStream.Position; set => innerStream.Position = value; } + public override void Flush() => innerStream.Flush(); + public override long Seek(long offset, SeekOrigin origin) => innerStream.Seek(offset, origin); + public override void SetLength(long value) => innerStream.SetLength(value); + public override void Write(byte[] buffer, int offset, int count) => innerStream.Write(buffer, offset, count); + protected override void Dispose(bool disposing) + { + if (!disposing || _disposed) + { + return; + } + + _disposed = true; + innerStream.Dispose(); + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index fe70c2fa5..c93a27650 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -27,21 +27,24 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) return new(connection.ClientStream); }; - HttpClient = new HttpClient(SocketsHttpHandler) - { - BaseAddress = new Uri("http://localhost:5000/"), - Timeout = TimeSpan.FromSeconds(10), - }; + HttpClient = new HttpClient(SocketsHttpHandler); + ConfigureHttpClient(HttpClient); } public WebApplicationBuilder Builder { get; } - public HttpClient HttpClient { get; } + public HttpClient HttpClient { get; set; } public SocketsHttpHandler SocketsHttpHandler { get; } = new(); public KestrelInMemoryTransport KestrelInMemoryTransport { get; } = new(); + protected static void ConfigureHttpClient(HttpClient httpClient) + { + httpClient.BaseAddress = new Uri("http://localhost:5000/"); + httpClient.Timeout = TimeSpan.FromSeconds(10); + } + public override void Dispose() { HttpClient.Dispose(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs new file mode 100644 index 000000000..1072fbe69 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/TestSseEventStreamStore.cs @@ -0,0 +1,268 @@ +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Net.ServerSentEvents; +using System.Runtime.CompilerServices; + +namespace ModelContextProtocol.AspNetCore.Tests.Utils; + +/// +/// In-memory event store for testing resumability. +/// This is a simple implementation intended for testing, not for production use. +/// +public sealed class TestSseEventStreamStore : ISseEventStreamStore +{ + private readonly ConcurrentDictionary _streams = new(); + private readonly ConcurrentDictionary _eventLookup = new(); + private readonly List _storedEventIds = []; + private readonly List _storedReconnectionIntervals = []; + private readonly object _storedEventIdsLock = new(); + private int _storeEventCallCount; + private long _globalSequence; + + /// + /// Gets the number of times events have been stored. + /// + public int StoreEventCallCount => _storeEventCallCount; + + /// + /// Gets the list of stored event IDs in order. + /// + public IReadOnlyList StoredEventIds + { + get + { + lock (_storedEventIdsLock) + { + return [.. _storedEventIds]; + } + } + } + + /// + /// Gets the list of stored reconnection intervals in order. + /// + public IReadOnlyList StoredReconnectionIntervals + { + get + { + lock (_storedEventIdsLock) + { + return [.. _storedReconnectionIntervals]; + } + } + } + + /// + public ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default) + { + var streamKey = GetStreamKey(options.SessionId, options.StreamId); + var state = new StreamState(options.SessionId, options.StreamId, options.Mode); + if (!_streams.TryAdd(streamKey, state)) + { + throw new InvalidOperationException($"A stream with key '{streamKey}' has already been created."); + } + var writer = new InMemoryEventStreamWriter(this, state); + return new ValueTask(writer); + } + + /// + public ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default) + { + // Look up the event by its ID to find which stream it belongs to + if (!_eventLookup.TryGetValue(lastEventId, out var lookup)) + { + return new ValueTask((ISseEventStreamReader?)null); + } + + var reader = new InMemoryEventStreamReader(lookup.Stream, lookup.Sequence); + return new ValueTask(reader); + } + + private string GenerateEventId() => Interlocked.Increment(ref _globalSequence).ToString(); + + private void TrackEvent(string eventId, StreamState stream, long sequence, TimeSpan? reconnectionInterval = null) + { + _eventLookup[eventId] = (stream, sequence); + lock (_storedEventIdsLock) + { + _storedEventIds.Add(eventId); + if (reconnectionInterval.HasValue) + { + _storedReconnectionIntervals.Add(reconnectionInterval.Value); + } + } + Interlocked.Increment(ref _storeEventCallCount); + } + + private static string GetStreamKey(string sessionId, string streamId) => $"{sessionId}:{streamId}"; + + /// + /// Holds the state for a single stream. + /// + private sealed class StreamState + { + private readonly List<(SseItem Item, long Sequence)> _events = []; + private readonly object _lock = new(); + private TaskCompletionSource _newEventSignal = new(TaskCreationOptions.RunContinuationsAsynchronously); + private long _sequence; + + public StreamState(string sessionId, string streamId, SseEventStreamMode mode) + { + SessionId = sessionId; + StreamId = streamId; + Mode = mode; + } + + public string SessionId { get; } + public string StreamId { get; } + public SseEventStreamMode Mode { get; set; } + public bool IsCompleted { get; private set; } + + public long NextSequence() => Interlocked.Increment(ref _sequence); + + public void AddEvent(SseItem item, long sequence) + { + lock (_lock) + { + if (IsCompleted) + { + throw new InvalidOperationException("Cannot add events to a completed stream."); + } + + _events.Add((item, sequence)); + + var oldSignal = _newEventSignal; + _newEventSignal = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + oldSignal.TrySetResult(); + } + } + + public (List> Events, long LastSequence, Task NewEventSignal) GetEventsAfter(long sequence) + { + lock (_lock) + { + var result = new List>(); + long lastSequence = sequence; + + foreach (var (item, seq) in _events) + { + if (seq > sequence) + { + result.Add(item); + lastSequence = seq; + } + } + + return (result, lastSequence, _newEventSignal.Task); + } + } + + public void Complete() + { + lock (_lock) + { + IsCompleted = true; + _newEventSignal.TrySetResult(); + } + } + } + + private sealed class InMemoryEventStreamWriter : ISseEventStreamWriter + { + private readonly TestSseEventStreamStore _store; + private readonly StreamState _state; + private bool _disposed; + + public InMemoryEventStreamWriter(TestSseEventStreamStore store, StreamState state) + { + _store = store; + _state = state; + } + + public ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) + { + _state.Mode = mode; + return default; + } + + public ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default) + { + // Skip if already has an event ID + if (sseItem.EventId is not null) + { + return new ValueTask>(sseItem); + } + + var sequence = _state.NextSequence(); + var eventId = _store.GenerateEventId(); + var newItem = sseItem with { EventId = eventId }; + + _state.AddEvent(newItem, sequence); + _store.TrackEvent(eventId, _state, sequence, sseItem.ReconnectionInterval); + + return new ValueTask>(newItem); + } + + public ValueTask DisposeAsync() + { + if (_disposed) + { + return default; + } + + _disposed = true; + _state.Complete(); + return default; + } + } + + private sealed class InMemoryEventStreamReader : ISseEventStreamReader + { + private readonly StreamState _state; + private readonly long _startSequence; + + public InMemoryEventStreamReader(StreamState state, long startSequence) + { + _state = state; + _startSequence = startSequence; + } + + public string SessionId => _state.SessionId; + public string StreamId => _state.StreamId; + + public async IAsyncEnumerable> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + long lastSeenSequence = _startSequence; + + while (true) + { + // Get events after the last seen sequence + var (events, lastSequence, newEventSignal) = _state.GetEventsAfter(lastSeenSequence); + + foreach (var evt in events) + { + yield return evt; + } + + // Update to the sequence we actually retrieved + lastSeenSequence = lastSequence; + + // If in polling mode, stop after returning currently available events + if (_state.Mode == SseEventStreamMode.Polling) + { + yield break; + } + + // If the stream is completed, stop + if (_state.IsCompleted) + { + yield break; + } + + // Wait for new events or cancellation + await newEventSignal.WaitAsync(cancellationToken).ConfigureAwait(false); + } + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index 86cefcf10..9aeec6b1e 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -107,7 +107,7 @@ public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperat { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, Content = [new TextContentBlock { Text = "Hello" }] @@ -157,7 +157,7 @@ public async Task CreateSamplingHandler_ShouldHandleImageMessages() { Messages = [ - new SamplingMessage + new SamplingMessage { Role = Role.User, Content = [new ImageContentBlock @@ -492,7 +492,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient() public async Task ReturnsNegotiatedProtocolVersion(string? protocolVersion) { await using McpClient client = await CreateMcpClientForServer(new() { ProtocolVersion = protocolVersion }); - Assert.Equal(protocolVersion ?? "2025-06-18", client.NegotiatedProtocolVersion); + Assert.Equal(protocolVersion ?? "2025-11-25", client.NegotiatedProtocolVersion); } [Fact] @@ -500,7 +500,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn { int getWeatherToolCallCount = 0; int askClientToolCallCount = 0; - + Server.ServerOptions.ToolCollection?.Add(McpServerTool.Create( async (McpServer server, string query, CancellationToken cancellationToken) => { @@ -513,14 +513,14 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn return $"Weather in {location}: sunny, 22°C"; }, "get_weather", "Gets the weather for a location"); - + var response = await server .AsSamplingChatClient() .AsBuilder() .UseFunctionInvocation() .Build() .GetResponseAsync(query, new ChatOptions { Tools = [weatherTool] }, cancellationToken); - + return response.Text ?? "No response"; }, new() { Name = "ask_client", Description = "Asks the client a question using sampling" })); @@ -530,7 +530,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn { int currentCall = samplingCallCount++; var lastMessage = messages.LastOrDefault(); - + // First call: Return a tool call request for get_weather if (currentCall == 0) { @@ -552,7 +552,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn string resultText = toolResult.Result?.ToString() ?? string.Empty; Assert.Contains("Weather in Paris: sunny", resultText); - + return Task.FromResult(new([ new ChatMessage(ChatRole.User, messages.First().Contents), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("call_weather_123", "get_weather", new Dictionary { ["location"] = "Paris" })]), @@ -577,7 +577,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Null(result.IsError); - + var textContent = result.Content.OfType().FirstOrDefault(); Assert.NotNull(textContent); Assert.Contains("Weather in Paris: sunny, 22", textContent.Text); @@ -585,7 +585,7 @@ public async Task EndToEnd_SamplingWithTools_ServerUsesIChatClientWithFunctionIn Assert.Equal(1, askClientToolCallCount); Assert.Equal(2, samplingCallCount); } - + /// Simple test IChatClient implementation for testing. private sealed class TestChatClient(Func, ChatOptions?, CancellationToken, Task> getResponse) : IChatClient { @@ -594,7 +594,7 @@ public Task GetResponseAsync( ChatOptions? options = null, CancellationToken cancellationToken = default) => getResponse(messages, options, cancellationToken); - + async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( IEnumerable messages, ChatOptions? options, @@ -606,7 +606,7 @@ async IAsyncEnumerable IChatClient.GetStreamingResponseAsync yield return update; } } - + object? IChatClient.GetService(Type serviceType, object? serviceKey) => null; void IDisposable.Dispose() { } } diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index 6f625866a..049d72d60 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -15,11 +15,13 @@ public class ClientIntegrationTestFixture public ClientIntegrationTestFixture() { + const string ServerEverythingVersion = "2025.12.18"; + EverythingServerTransportOptions = new() { Command = "npx", // Change to Arguments = ["mcp-server-everything"] if you want to run the server locally after creating a symlink - Arguments = ["-y", "--verbose", "@modelcontextprotocol/server-everything"], + Arguments = ["-y", "--verbose", $"@modelcontextprotocol/server-everything@{ServerEverythingVersion}"], Name = "Everything", }; diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index e1e1011b4..e1d6e3191 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -804,10 +804,10 @@ public void ReturnDescription_NoReturnDescription_NoChange() public void ReturnDescription_StructuredOutputEnabled_WithExplicitDescription_NoSynthesis() { // When UseStructuredContent is true and Description is set, return description goes to output schema - McpServerTool tool = McpServerTool.Create(ToolWithReturnDescription, new() - { - Description = "Custom description", - UseStructuredContent = true + McpServerTool tool = McpServerTool.Create(ToolWithReturnDescription, new() + { + Description = "Custom description", + UseStructuredContent = true }); // Description should not have the return description appended @@ -815,6 +815,23 @@ public void ReturnDescription_StructuredOutputEnabled_WithExplicitDescription_No Assert.NotNull(tool.ProtocolTool.OutputSchema); } + [Fact] + public async Task EnablePollingAsync_ThrowsInvalidOperationException_WhenTransportIsNotStreamableHttpPost() + { + // Arrange + Mock mockServer = new(); + var jsonRpcRequest = CreateTestJsonRpcRequest(); + + // The JsonRpcRequest has no Context, so RelatedTransport will be null + var requestContext = new RequestContext(mockServer.Object, jsonRpcRequest); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => requestContext.EnablePollingAsync(TimeSpan.FromSeconds(1), TestContext.Current.CancellationToken).AsTask()); + + Assert.Contains("Streamable HTTP", exception.Message); + } + [Description("Tool that returns data.")] [return: Description("The computed result")] private static string ToolWithReturnDescription() => "result";