diff --git a/Directory.Packages.props b/Directory.Packages.props index 8a09ce3e1..c0ed13101 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -11,6 +11,8 @@ + + diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index b3d98dd0e..081ce0005 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.AI; using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; @@ -158,6 +159,10 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(BlobResourceContents))] [JsonSerializable(typeof(TextResourceContents))] + // Distributed cache event stream store + [JsonSerializable(typeof(DistributedCacheEventStreamStore.StreamMetadata))] + [JsonSerializable(typeof(DistributedCacheEventStreamStore.StoredEvent))] + // Other MCP Types [JsonSerializable(typeof(IReadOnlyDictionary))] [JsonSerializable(typeof(ProgressToken))] diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index 9e22a5c0e..9e18c7b76 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -49,6 +49,7 @@ + diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs new file mode 100644 index 000000000..5fa8525d3 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs @@ -0,0 +1,118 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// This is a shared source file included in both ModelContextProtocol.Core and the test project. +// Do not reference symbols internal to the core project, as they won't be available in tests. + +#if NET +using System.Buffers; +using System.Buffers.Text; +using System.Diagnostics.CodeAnalysis; + +#endif +using System.Text; + +namespace ModelContextProtocol.Server; + +/// +/// Provides methods for formatting and parsing event IDs used by . +/// +/// +/// Event IDs are formatted as "{base64(sessionId)}:{base64(streamId)}:{sequence}". +/// +internal static class DistributedCacheEventIdFormatter +{ + private const char Separator = ':'; + + /// + /// Formats session ID, stream ID, and sequence number into an event ID string. + /// + public static string Format(string sessionId, string streamId, long sequence) + { + // Base64-encode session and stream IDs so the event ID can be parsed + // even if the original IDs contain the ':' separator character + var sessionBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(sessionId)); + var streamBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(streamId)); + return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}"; + } + + /// + /// Attempts to parse an event ID into its component parts. + /// + public static bool TryParse(string eventId, out string sessionId, out string streamId, out long sequence) + { + sessionId = string.Empty; + streamId = string.Empty; + sequence = 0; + +#if NET + ReadOnlySpan eventIdSpan = eventId.AsSpan(); + Span partRanges = stackalloc Range[4]; + int rangeCount = eventIdSpan.Split(partRanges, Separator); + if (rangeCount != 3) + { + return false; + } + + try + { + ReadOnlySpan sessionBase64 = eventIdSpan[partRanges[0]]; + ReadOnlySpan streamBase64 = eventIdSpan[partRanges[1]]; + ReadOnlySpan sequenceSpan = eventIdSpan[partRanges[2]]; + + if (!TryDecodeBase64ToString(sessionBase64, out sessionId!) || + !TryDecodeBase64ToString(streamBase64, out streamId!)) + { + return false; + } + + return long.TryParse(sequenceSpan, out sequence); + } + catch + { + return false; + } +#else + var parts = eventId.Split(Separator); + if (parts.Length != 3) + { + return false; + } + + try + { + sessionId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[0])); + streamId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[1])); + return long.TryParse(parts[2], out sequence); + } + catch + { + return false; + } +#endif + } + +#if NET + private static bool TryDecodeBase64ToString(ReadOnlySpan base64Chars, [NotNullWhen(true)] out string? result) + { + // Use a single buffer: base64 chars are ASCII (1:1 with UTF8 bytes), + // and decoded data is always smaller than encoded, so we can decode in-place. + int bufferLength = base64Chars.Length; + Span buffer = bufferLength <= 256 + ? stackalloc byte[bufferLength] + : new byte[bufferLength]; + + Encoding.UTF8.GetBytes(base64Chars, buffer); + + OperationStatus status = Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten); + if (status != OperationStatus.Done) + { + result = null; + return false; + } + + result = Encoding.UTF8.GetString(buffer[..bytesWritten]); + return true; + } +#endif +} diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs new file mode 100644 index 000000000..f2a595b2d --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -0,0 +1,384 @@ +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; +using System.Runtime.CompilerServices; +using System.Text.Json; + +namespace ModelContextProtocol.Server; + +/// +/// An implementation backed by . +/// +/// +/// +/// This implementation stores SSE events in a distributed cache, enabling resumability across +/// multiple server instances. Event IDs are encoded with session, stream, and sequence information +/// to allow efficient retrieval of events after a given point. +/// +/// +/// The writer maintains in-memory state for sequence number generation, as there is guaranteed +/// to be only one writer per stream. Readers may be created from separate processes. +/// +/// +public sealed partial class DistributedCacheEventStreamStore : ISseEventStreamStore +{ + private readonly IDistributedCache _cache; + private readonly DistributedCacheEventStreamStoreOptions _options; + private readonly ILogger _logger; + + /// + /// Initializes a new instance of the class. + /// + /// The distributed cache to use for storage. + /// Optional configuration options for the store. + /// Optional logger for diagnostic output. + public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null, ILogger? logger = null) + { + Throw.IfNull(cache); + _cache = cache; + _options = options ?? new(); + _logger = logger ?? NullLogger.Instance; + } + + /// + public ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default) + { + Throw.IfNull(options); + LogStreamCreated(options.SessionId, options.StreamId, options.Mode); + var writer = new DistributedCacheEventStreamWriter(_cache, options.SessionId, options.StreamId, options.Mode, _options, _logger); + return new ValueTask(writer); + } + + /// + public async ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default) + { + Throw.IfNull(lastEventId); + + // Parse the event ID to get session, stream, and sequence information + if (!DistributedCacheEventIdFormatter.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence)) + { + LogEventIdParsingFailed(lastEventId); + return null; + } + + // Check if the stream exists by looking for its metadata + var metadataKey = CacheKeys.StreamMetadata(sessionId, streamId); + var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); + if (metadataBytes is null) + { + LogStreamMetadataNotFound(sessionId, streamId); + return null; + } + + var metadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); + if (metadata is null) + { + LogStreamMetadataDeserializationFailed(sessionId, streamId); + return null; + } + + var startSequence = sequence + 1; + LogStreamReaderCreated(sessionId, streamId, startSequence, metadata.LastSequence); + return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, startSequence, metadata, _options, _logger); + } + + /// + /// Provides methods for generating cache keys. + /// + /// + /// Cache keys are versioned to allow format changes without conflicts with existing entries. + /// When the cache format changes, increment to invalidate old entries. + /// + internal static class CacheKeys + { + /// + /// The current cache key version. Increment this when changing the cache format + /// to ensure old entries are ignored. + /// + private const string Version = "v1"; + private const string Prefix = $"mcp:sse:{Version}:"; + + public static string StreamMetadata(string sessionId, string streamId) => + $"{Prefix}meta:{sessionId}:{streamId}"; + + public static string Event(string eventId) => + $"{Prefix}event:{eventId}"; + + public static string StreamEventCount(string sessionId, string streamId) => + $"{Prefix}count:{sessionId}:{streamId}"; + } + + /// + /// Metadata about a stream stored in the cache. + /// + internal sealed class StreamMetadata + { + public SseEventStreamMode Mode { get; set; } + public bool IsCompleted { get; set; } + public long LastSequence { get; set; } + } + + /// + /// Serialized representation of an SSE event stored in the cache. + /// + internal sealed class StoredEvent + { + public string? EventType { get; set; } + public string? EventId { get; set; } + public int? ReconnectionIntervalMs { get; set; } + public JsonRpcMessage? Data { get; set; } + } + + private sealed partial class DistributedCacheEventStreamWriter : ISseEventStreamWriter + { + private readonly IDistributedCache _cache; + private readonly string _sessionId; + private readonly string _streamId; + private SseEventStreamMode _mode; + private readonly DistributedCacheEventStreamStoreOptions _options; + private readonly ILogger _logger; + private long _sequence; + private bool _disposed; + + public DistributedCacheEventStreamWriter( + IDistributedCache cache, + string sessionId, + string streamId, + SseEventStreamMode mode, + DistributedCacheEventStreamStoreOptions options, + ILogger logger) + { + _cache = cache; + _sessionId = sessionId; + _streamId = streamId; + _mode = mode; + _options = options; + _logger = logger; + } + + public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) + { + LogStreamModeChanged(_sessionId, _streamId, mode); + _mode = mode; + await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); + } + + public async ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default) + { + // Skip if already has an event ID + if (sseItem.EventId is not null) + { + LogEventAlreadyHasId(_sessionId, _streamId, sseItem.EventId); + return sseItem; + } + + // Generate a new sequence number and event ID + var sequence = Interlocked.Increment(ref _sequence); + var eventId = DistributedCacheEventIdFormatter.Format(_sessionId, _streamId, sequence); + var newItem = sseItem with { EventId = eventId }; + + // Store the event in the cache + var storedEvent = new StoredEvent + { + EventType = newItem.EventType, + EventId = eventId, + ReconnectionIntervalMs = newItem.ReconnectionInterval.HasValue + ? (int)newItem.ReconnectionInterval.Value.TotalMilliseconds + : null, + Data = newItem.Data, + }; + + var eventBytes = JsonSerializer.SerializeToUtf8Bytes(storedEvent, McpJsonUtilities.JsonContext.Default.StoredEvent); + var eventKey = CacheKeys.Event(eventId); + + await _cache.SetAsync(eventKey, eventBytes, new DistributedCacheEntryOptions + { + SlidingExpiration = _options.EventSlidingExpiration, + AbsoluteExpirationRelativeToNow = _options.EventAbsoluteExpiration, + }, cancellationToken).ConfigureAwait(false); + + // Update metadata with the latest sequence + await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); + + LogEventWritten(_sessionId, _streamId, eventId, sequence); + return newItem; + } + + private async ValueTask UpdateMetadataAsync(CancellationToken cancellationToken) + { + var metadata = new StreamMetadata + { + Mode = _mode, + IsCompleted = _disposed, + LastSequence = Interlocked.Read(ref _sequence), + }; + + var metadataBytes = JsonSerializer.SerializeToUtf8Bytes(metadata, McpJsonUtilities.JsonContext.Default.StreamMetadata); + var metadataKey = CacheKeys.StreamMetadata(_sessionId, _streamId); + + await _cache.SetAsync(metadataKey, metadataBytes, new DistributedCacheEntryOptions + { + SlidingExpiration = _options.MetadataSlidingExpiration, + AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration, + }, cancellationToken).ConfigureAwait(false); + } + + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + + // Mark the stream as completed in the metadata + await UpdateMetadataAsync(CancellationToken.None).ConfigureAwait(false); + LogStreamWriterDisposed(_sessionId, _streamId, Interlocked.Read(ref _sequence)); + } + + [LoggerMessage(Level = LogLevel.Debug, Message = "Stream mode changed for session '{SessionId}', stream '{StreamId}' to {Mode}.")] + private partial void LogStreamModeChanged(string sessionId, string streamId, SseEventStreamMode mode); + + [LoggerMessage(Level = LogLevel.Trace, Message = "Event already has ID '{EventId}' for session '{SessionId}', stream '{StreamId}'. Skipping ID generation.")] + private partial void LogEventAlreadyHasId(string sessionId, string streamId, string eventId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Event written to session '{SessionId}', stream '{StreamId}' with ID '{EventId}' (sequence {Sequence}).")] + private partial void LogEventWritten(string sessionId, string streamId, string eventId, long sequence); + + [LoggerMessage(Level = LogLevel.Information, Message = "Stream writer disposed for session '{SessionId}', stream '{StreamId}'. Total events written: {TotalEvents}.")] + private partial void LogStreamWriterDisposed(string sessionId, string streamId, long totalEvents); + } + + private sealed partial class DistributedCacheEventStreamReader : ISseEventStreamReader + { + private readonly IDistributedCache _cache; + private readonly long _startSequence; + private readonly StreamMetadata _initialMetadata; + private readonly DistributedCacheEventStreamStoreOptions _options; + private readonly ILogger _logger; + + public DistributedCacheEventStreamReader( + IDistributedCache cache, + string sessionId, + string streamId, + long startSequence, + StreamMetadata initialMetadata, + DistributedCacheEventStreamStoreOptions options, + ILogger logger) + { + _cache = cache; + SessionId = sessionId; + StreamId = streamId; + _startSequence = startSequence; + _initialMetadata = initialMetadata; + _options = options; + _logger = logger; + } + + public string SessionId { get; } + public string StreamId { get; } + + public async IAsyncEnumerable> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Start from the sequence after the last received event + var currentSequence = _startSequence; + + // Use the initial metadata passed to the constructor for the first read. + var lastSequence = _initialMetadata.LastSequence; + var isCompleted = _initialMetadata.IsCompleted; + var mode = _initialMetadata.Mode; + + LogReadingEventsStarted(SessionId, StreamId, _startSequence, lastSequence); + + while (!cancellationToken.IsCancellationRequested) + { + // Read all available events from currentSequence + 1 to lastSequence + for (; currentSequence <= lastSequence; currentSequence++) + { + cancellationToken.ThrowIfCancellationRequested(); + + var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, currentSequence); + var eventKey = CacheKeys.Event(eventId); + var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false) + ?? throw new McpException($"SSE event with ID '{eventId}' was not found in the cache. The event may have expired."); + + var storedEvent = JsonSerializer.Deserialize(eventBytes, McpJsonUtilities.JsonContext.Default.StoredEvent); + if (storedEvent is not null) + { + LogEventRead(SessionId, StreamId, eventId, currentSequence); + yield return new SseItem(storedEvent.Data, storedEvent.EventType) + { + EventId = storedEvent.EventId, + ReconnectionInterval = storedEvent.ReconnectionIntervalMs.HasValue + ? TimeSpan.FromMilliseconds(storedEvent.ReconnectionIntervalMs.Value) + : null, + }; + } + } + + // If in polling mode, stop after returning currently available events + if (mode == SseEventStreamMode.Polling) + { + LogReadingEventsCompletedPolling(SessionId, StreamId, currentSequence - 1); + yield break; + } + + // If the stream is completed and we've read all events, stop + if (isCompleted) + { + LogReadingEventsCompletedStreamEnded(SessionId, StreamId, currentSequence - 1); + yield break; + } + + // Wait before polling again for new events + LogWaitingForNewEvents(SessionId, StreamId, _options.PollingInterval); + await Task.Delay(_options.PollingInterval, cancellationToken).ConfigureAwait(false); + + // Refresh metadata to get the latest sequence and completion status + var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); + var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false) + ?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' was not found in the cache. The metadata may have expired."); + + var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata) + ?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' could not be deserialized."); + + lastSequence = currentMetadata.LastSequence; + isCompleted = currentMetadata.IsCompleted; + mode = currentMetadata.Mode; + } + } + + [LoggerMessage(Level = LogLevel.Debug, Message = "Starting to read events for session '{SessionId}', stream '{StreamId}' starting at sequence {StartSequence}. Last available sequence: {LastSequence}.")] + private partial void LogReadingEventsStarted(string sessionId, string streamId, long startSequence, long lastSequence); + + [LoggerMessage(Level = LogLevel.Trace, Message = "Event read from session '{SessionId}', stream '{StreamId}' with ID '{EventId}' (sequence {Sequence}).")] + private partial void LogEventRead(string sessionId, string streamId, string eventId, long sequence); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Reading events completed for session '{SessionId}', stream '{StreamId}' in polling mode. Last sequence read: {LastSequence}.")] + private partial void LogReadingEventsCompletedPolling(string sessionId, string streamId, long lastSequence); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Reading events completed for session '{SessionId}', stream '{StreamId}' as stream has ended. Last sequence read: {LastSequence}.")] + private partial void LogReadingEventsCompletedStreamEnded(string sessionId, string streamId, long lastSequence); + + [LoggerMessage(Level = LogLevel.Trace, Message = "Waiting for new events on session '{SessionId}', stream '{StreamId}'. Polling interval: {PollingInterval}.")] + private partial void LogWaitingForNewEvents(string sessionId, string streamId, TimeSpan pollingInterval); + } + + [LoggerMessage(Level = LogLevel.Information, Message = "Stream created for session '{SessionId}', stream '{StreamId}' with mode {Mode}.")] + private partial void LogStreamCreated(string sessionId, string streamId, SseEventStreamMode mode); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Stream reader created for session '{SessionId}', stream '{StreamId}' starting at sequence {StartSequence}. Last available sequence: {LastSequence}.")] + private partial void LogStreamReaderCreated(string sessionId, string streamId, long startSequence, long lastSequence); + + [LoggerMessage(Level = LogLevel.Warning, Message = "Failed to parse event ID '{EventId}'. Unable to create stream reader.")] + private partial void LogEventIdParsingFailed(string eventId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Stream metadata not found for session '{SessionId}', stream '{StreamId}'.")] + private partial void LogStreamMetadataNotFound(string sessionId, string streamId); + + [LoggerMessage(Level = LogLevel.Warning, Message = "Failed to deserialize stream metadata for session '{SessionId}', stream '{StreamId}'.")] + private partial void LogStreamMetadataDeserializationFailed(string sessionId, string streamId); +} diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs new file mode 100644 index 000000000..1c8452136 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs @@ -0,0 +1,51 @@ +namespace ModelContextProtocol.Server; + +/// +/// Configuration options for . +/// +public sealed class DistributedCacheEventStreamStoreOptions +{ + /// + /// Gets or sets the sliding expiration for individual events in the cache. + /// + /// + /// Events are refreshed on each access. If an event is not accessed within this + /// time period, it may be evicted from the cache. + /// + public TimeSpan? EventSlidingExpiration { get; set; } = TimeSpan.FromMinutes(30); + + /// + /// Gets or sets the absolute expiration for individual events in the cache. + /// + /// + /// Events will be evicted from the cache after this time period, regardless of access. + /// + public TimeSpan? EventAbsoluteExpiration { get; set; } = TimeSpan.FromHours(2); + + /// + /// Gets or sets the sliding expiration for stream metadata in the cache. + /// + /// + /// Stream metadata includes mode and completion status. This should typically be + /// set to a longer duration than event expiration to allow for resumability. + /// + public TimeSpan? MetadataSlidingExpiration { get; set; } = TimeSpan.FromHours(1); + + /// + /// Gets or sets the absolute expiration for stream metadata in the cache. + /// + /// + /// Stream metadata will be evicted from the cache after this time period, regardless of access. + /// + public TimeSpan? MetadataAbsoluteExpiration { get; set; } = TimeSpan.FromHours(4); + + /// + /// Gets or sets the interval between polling attempts when a reader is waiting for new events + /// in mode. + /// + /// + /// This only affects readers. A shorter interval provides lower latency for new events + /// but increases cache access frequency. + /// + public TimeSpan PollingInterval { get; set; } = TimeSpan.FromMilliseconds(100); +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs index e3425b253..a75906463 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -279,6 +279,7 @@ public async Task Client_CanResumePostResponseStream_AfterDisconnection() [Fact] public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() { + var timeout = TimeSpan.FromSeconds(10); using var faultingStreamHandler = new FaultingStreamHandler() { InnerHandler = SocketsHttpHandler, @@ -304,12 +305,12 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() await using var client = await ConnectClientAsync(); // Get the server instance - var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + var server = await serverTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Set up notification tracking with unique messages - var clientReceivedInitialNotificationTcs = new TaskCompletionSource(); - var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(); - var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(); + var clientReceivedInitialNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); const string CustomNotificationMethod = "test/custom_notification"; const string InitialMessage = "Initial notification"; @@ -343,11 +344,14 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() return default; }); + // Wait for the client's unsolicited message stream to be established before sending notifications + await faultingStreamHandler.WaitForUnsolicitedMessageStreamAsync(TestContext.Current.CancellationToken); + // Send a custom notification to the client on the unsolicited message stream await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken); // Wait for client to receive the first notification - await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedInitialNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Fault the unsolicited message stream (GET SSE) var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken); @@ -359,13 +363,13 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() reconnectAttempt.Continue(); // Wait for client to receive the notification via replay - await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedReplayedNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Send a final notification while the client has reconnected - this should be handled by the transport await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReconnectMessage }, cancellationToken: TestContext.Current.CancellationToken); // Wait for the client to receive the final notification - await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedReconnectNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Assert each notification was received exactly once Assert.Equal(1, initialNotificationReceivedCount); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs index cace4d8be..dc157735f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs @@ -11,6 +11,12 @@ internal sealed class FaultingStreamHandler : DelegatingHandler { private FaultingStream? _lastStream; private TaskCompletionSource? _reconnectTcs; + private TaskCompletionSource _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public Task WaitForUnsolicitedMessageStreamAsync(CancellationToken cancellationToken = default) + => _unsolicitedMessageStreamReadyTcs.Task.WaitAsync(cancellationToken); + + internal void SignalUnsolicitedMessageStreamReady() => _unsolicitedMessageStreamReadyTcs.TrySetResult(); public async Task TriggerFaultAsync(CancellationToken cancellationToken) { @@ -24,6 +30,9 @@ public async Task TriggerFaultAsync(CancellationToken cancella throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection."); } + // Reset the TCS so we can wait for the reconnected unsolicited message stream + _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + _reconnectTcs = new(); await _lastStream.TriggerFaultAsync(cancellationToken); @@ -46,6 +55,7 @@ protected override async Task SendAsync( _reconnectTcs = null; } + var isGetRequest = request.Method == HttpMethod.Get; var response = await base.SendAsync(request, cancellationToken); // Only wrap SSE streams (text/event-stream) @@ -63,6 +73,13 @@ protected override async Task SendAsync( } response.Content = newContent; + + // For GET requests (unsolicited message stream), set up the stream to signal + // when first data is read. This ensures the server's transport handler is ready. + if (isGetRequest) + { + _lastStream.SetReadyCallback(SignalUnsolicitedMessageStreamReady); + } } return response; @@ -89,10 +106,14 @@ private sealed class FaultingStream(Stream innerStream) : Stream { private readonly CancellationTokenSource _cts = new(); private TaskCompletionSource? _faultTcs; + private Action? _readyCallback; + private bool _readySignaled; private bool _disposed; public bool IsDisposed => _disposed; + public void SetReadyCallback(Action callback) => _readyCallback = callback; + public async Task TriggerFaultAsync(CancellationToken cancellationToken) { if (_faultTcs is not null) @@ -131,6 +152,12 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation _cts.Token.ThrowIfCancellationRequested(); + if (bytesRead > 0 && !_readySignaled) + { + _readySignaled = true; + _readyCallback?.Invoke(); + } + return bytesRead; } catch (OperationCanceledException) when (_cts.IsCancellationRequested) diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index e0fb3d1fa..84b0ee994 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -29,6 +29,10 @@ + + + + @@ -41,6 +45,7 @@ + diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs new file mode 100644 index 000000000..34188a694 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -0,0 +1,1752 @@ +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for . +/// +public class DistributedCacheEventStreamStoreTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + private static CancellationToken CancellationToken => TestContext.Current.CancellationToken; + + private static IDistributedCache CreateMemoryCache() + { + var options = Options.Create(new MemoryDistributedCacheOptions()); + return new MemoryDistributedCache(options); + } + + [Fact] + public void Constructor_ThrowsArgumentNullException_WhenCacheIsNull() + { + Assert.Throws("cache", () => new DistributedCacheEventStreamStore(null!)); + } + + [Fact] + public async Task CreateStreamAsync_ThrowsArgumentNullException_WhenOptionsIsNull() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Act & Assert + await Assert.ThrowsAsync("options", + async () => await store.CreateStreamAsync(null!, CancellationToken)); + } + + [Fact] + public async Task WriteEventAsync_AssignsUniqueEventId_WhenItemHasNoEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + var item = new SseItem(null); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.NotNull(result.EventId); + Assert.NotEmpty(result.EventId); + } + + [Fact] + public async Task WriteEventAsync_SkipsAssigningEventId_WhenItemAlreadyHasEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + var existingEventId = "existing-event-id"; + var item = new SseItem(null) { EventId = existingEventId }; + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.Equal(existingEventId, result.EventId); + } + + [Fact] + public async Task WriteEventAsync_PreservesDataProperty_InReturnedItem() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + var message = new JsonRpcNotification { Method = "test/notification" }; + var item = new SseItem(message); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Data should be preserved in the returned item (same reference) + Assert.Same(message, result.Data); + } + + [Fact] + public async Task WriteEventAsync_PreservesEventTypeProperty_InReturnedItem() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + var item = new SseItem(null, "custom-event-type"); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.Equal("custom-event-type", result.EventType); + } + + [Fact] + public async Task WriteEventAsync_PreservesReconnectionIntervalProperty_InStoredEvent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var expectedInterval = TimeSpan.FromSeconds(5); + var item = new SseItem(null) { ReconnectionInterval = expectedInterval }; + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert - ReconnectionInterval should be preserved in returned item + Assert.Equal(expectedInterval, result.ReconnectionInterval); + + // Get a reader and verify ReconnectionInterval is preserved after round-trip + var reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken); + Assert.NotNull(reader); + + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Reader should not return the event we just wrote (it starts after lastEventId) + Assert.Empty(events); + + // Write another event and verify it can be read with correct ReconnectionInterval + var secondItem = new SseItem(null) { ReconnectionInterval = TimeSpan.FromSeconds(10) }; + _ = await writer.WriteEventAsync(secondItem, CancellationToken); + + // Re-fetch reader using the first event ID to get the second event + reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken); + Assert.NotNull(reader); + + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + Assert.Single(events); + Assert.Equal(TimeSpan.FromSeconds(10), events[0].ReconnectionInterval); + } + + [Fact] + public async Task WriteEventAsync_HandlesNullReconnectionInterval_InStoredEvent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write an event WITH a reconnection interval first + var firstItem = new SseItem(null) { ReconnectionInterval = TimeSpan.FromSeconds(5) }; + var firstResult = await writer.WriteEventAsync(firstItem, CancellationToken); + + // Write an event WITHOUT a reconnection interval + var secondItem = new SseItem(null); + var secondResult = await writer.WriteEventAsync(secondItem, CancellationToken); + Assert.Null(secondResult.ReconnectionInterval); + + // Get a reader starting after the first event + var reader = await store.GetStreamReaderAsync(firstResult.EventId!, CancellationToken); + Assert.NotNull(reader); + + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Should get the second event with null ReconnectionInterval + Assert.Single(events); + Assert.Null(events[0].ReconnectionInterval); + } + + [Fact] + public async Task WriteEventAsync_HandlesNullData_AssignsEventIdAndStoresEvent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var item = new SseItem(null); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Event ID should be assigned + Assert.NotNull(result.EventId); + + // Assert - Event should be retrievable + var reader = await store.GetStreamReaderAsync(result.EventId, CancellationToken); + Assert.NotNull(reader); + } + + [Fact] + public async Task WriteEventAsync_StoresEventWithCorrectSlidingExpiration() + { + // Arrange - Use a mock cache to verify expiration options + var mockCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventSlidingExpiration = TimeSpan.FromMinutes(15) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + var item = new SseItem(null); + + // Act + await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Verify at least one call used the expected sliding expiration + Assert.Contains(mockCache.SetCalls, call => + call.Key.Contains("event:") && + call.Options.SlidingExpiration == TimeSpan.FromMinutes(15)); + } + + [Fact] + public async Task WriteEventAsync_StoresEventWithCorrectAbsoluteExpiration() + { + // Arrange + var mockCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventAbsoluteExpiration = TimeSpan.FromHours(3) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + var item = new SseItem(null); + + // Act + await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.Contains(mockCache.SetCalls, call => + call.Key.Contains("event:") && + call.Options.AbsoluteExpirationRelativeToNow == TimeSpan.FromHours(3)); + } + + [Fact] + public async Task WriteEventAsync_UpdatesStreamMetadata_AfterEachWrite() + { + // Arrange + var mockCache = new TestDistributedCache(); + var store = new DistributedCacheEventStreamStore(mockCache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + var item = new SseItem(null); + + // Act + await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Metadata should have been updated + Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); + } + + [Fact] + public async Task SetModeAsync_PersistsModeChangeToMetadata() + { + // Arrange + var mockCache = new TestDistributedCache(); + var store = new DistributedCacheEventStreamStore(mockCache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + mockCache.SetCalls.Clear(); // Clear calls from CreateStreamAsync setup + + // Act + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Assert - Metadata should have been updated with the new mode + Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); + } + + [Fact] + public async Task SetModeAsync_ModeChangeReflectedInReader() + { + // Arrange + var cache = CreateMemoryCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(10) + }; + var store = new DistributedCacheEventStreamStore(cache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write an event to have something to read + var item = new SseItem(new JsonRpcNotification { Method = "test" }); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Get a reader based on the event ID (starting at sequence 1, reader will wait for seq 2+) + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Change mode to Polling while reader exists + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Assert - Reader should complete immediately in polling mode (no new events to read) + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(500)); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + } + + // In polling mode, reader should complete without waiting for new events + Assert.Empty(events); // No events after the one we used to create the reader + } + + [Fact] + public async Task DisposeAsync_MarksStreamAsCompleted() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write an event so we can get a reader + var item = new SseItem(null); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Act + await writer.DisposeAsync(); + + // Assert - Reader should see the stream as completed and exit immediately + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(500)); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + } + + // The reader should complete without waiting for new events because stream is completed + Assert.Empty(events); // No new events after the one we used to create the reader + } + + [Fact] + public async Task DisposeAsync_IsIdempotent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Act - Call DisposeAsync multiple times + await writer.DisposeAsync(); + await writer.DisposeAsync(); + await writer.DisposeAsync(); + + // Assert - No exception thrown, operation is idempotent + // If we got here without exception, the test passes + } + + [Fact] + public async Task DisposeAsync_UpdatesMetadata_WithIsCompletedFlag() + { + // Arrange + var mockCache = new TestDistributedCache(); + var store = new DistributedCacheEventStreamStore(mockCache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + mockCache.SetCalls.Clear(); // Clear calls from CreateStreamAsync + + // Act + await writer.DisposeAsync(); + + // Assert - Metadata should have been updated + Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); + } + + [Fact] + public async Task GetStreamReaderAsync_ThrowsArgumentNullException_WhenLastEventIdIsNull() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Act & Assert + await Assert.ThrowsAsync("lastEventId", + async () => await store.GetStreamReaderAsync(null!, CancellationToken)); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsNull_WhenEventIdIsUnparseable() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Act - Try various invalid event ID formats + var result1 = await store.GetStreamReaderAsync("invalid-format", CancellationToken); + var result2 = await store.GetStreamReaderAsync("only:two:parts:here", CancellationToken); + var result3 = await store.GetStreamReaderAsync("", CancellationToken); + + // Assert + Assert.Null(result1); + Assert.Null(result2); + Assert.Null(result3); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsNull_WhenStreamMetadataDoesNotExist() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Create a valid-looking event ID for a stream that doesn't exist + var fakeEventId = DistributedCacheEventIdFormatter.Format("nonexistent-session", "nonexistent-stream", 1); + + // Act + var reader = await store.GetStreamReaderAsync(fakeEventId, CancellationToken); + + // Assert + Assert.Null(reader); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsReaderWithCorrectSessionIdAndStreamId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "my-session", + StreamId = "my-stream", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write an event to get a valid event ID + var item = new SseItem(null); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Act + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + + // Assert + Assert.NotNull(reader); + Assert.Equal("my-session", reader.SessionId); + Assert.Equal("my-stream", reader.StreamId); + } + + [Fact] + public async Task ReadEventsAsync_ReturnsEventsInOrder() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write multiple events + var event1 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method1" }), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method2" }), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method3" }), CancellationToken); + + // Create a reader starting from before the first event (use a fake event ID with sequence 0) + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert - Events should be in order + Assert.Equal(3, events.Count); + Assert.Equal(event1.EventId, events[0].EventId); + Assert.Equal(event2.EventId, events[1].EventId); + Assert.Equal(event3.EventId, events[2].EventId); + } + + [Fact] + public async Task ReadEventsAsync_ReturnsEmpty_WhenNoNewEventsExist() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write one event + var writtenItem = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader starting from the last event (so there are no new events to read) + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert + Assert.Empty(events); + } + + [Fact] + public async Task ReadEventsAsync_PreservesCorrectDataEventTypeAndEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message = new JsonRpcNotification { Method = "test/method" }; + var writtenItem = await writer.WriteEventAsync(new SseItem(message, "custom-event-type"), CancellationToken); + + // Create a reader starting from before the event + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert + Assert.Single(events); + var readEvent = events[0]; + Assert.Equal(writtenItem.EventId, readEvent.EventId); + Assert.Equal("custom-event-type", readEvent.EventType); + + var readMessage = Assert.IsType(readEvent.Data); + Assert.Equal("test/method", readMessage.Method); + } + + [Fact] + public async Task ReadEventsAsync_HandlesNullData() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writtenItem = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader starting from before the event + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert + Assert.Single(events); + Assert.Null(events[0].Data); + Assert.Equal(writtenItem.EventId, events[0].EventId); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_CompletesImmediatelyAfterReturningAvailableEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write events + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader from sequence 0 + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act - Should complete quickly without waiting for new events + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + stopwatch.Stop(); + + // Assert - Should have returned both events and completed quickly + Assert.Equal(2, events.Count); + Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Polling mode should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_ReturnsOnlyEventsAfterLastEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write 3 events + var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader starting from event2 (should only return event3) + var reader = await store.GetStreamReaderAsync(event2.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert - Only event3 should be returned + Assert.Single(events); + Assert.Equal(event3.EventId, events[0].EventId); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_ReturnsEmptyIfNoNewEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write one event and create a reader from that event (no events after it) + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert - No new events should be returned + Assert.Empty(events); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_DoesNotWaitForNewEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write one event so we have a valid event ID, then create reader from it + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Should complete immediately without waiting (no new events after the one we started from) + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + stopwatch.Stop(); + + // Assert - Should complete quickly with no events + Assert.Empty(events); + Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Polling mode should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); + } + + [Fact] + public async Task ReadEventsAsync_InStreamingMode_WaitsForNewEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write one event so we have a valid event ID + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Start reading and then write a new event + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(2)); + var events = new List>(); + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + if (events.Count >= 1) + { + // Got the event we were waiting for, cancel to stop + await cts.CancelAsync(); + } + } + }, CancellationToken); + + // Write a new event - the reader should pick it up since it's in streaming mode + // and won't complete until cancelled or the stream is disposed + var newEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Wait for read to complete (either event received or timeout) + try + { + await readTask; + } + catch (OperationCanceledException) + { + // Expected when we cancel after receiving event + } + + // Assert - Should have received the new event + Assert.Single(events); + Assert.Equal(newEvent.EventId, events[0].EventId); + } + + [Fact] + public async Task ReadEventsAsync_InStreamingMode_YieldsNewlyWrittenEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write initial event + var initialEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(initialEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Write multiple events while reader is active + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(3)); + var events = new List>(); + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + if (events.Count >= 3) + { + await cts.CancelAsync(); + } + } + }, CancellationToken); + + // Write 3 new events - the reader should pick them up since it's in streaming mode + var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + try + { + await readTask; + } + catch (OperationCanceledException) + { + // Expected + } + + // Assert - Should have received all 3 events in order + Assert.Equal(3, events.Count); + Assert.Equal(event1.EventId, events[0].EventId); + Assert.Equal(event2.EventId, events[1].EventId); + Assert.Equal(event3.EventId, events[2].EventId); + } + + [Fact] + public async Task ReadEventsAsync_InStreamingMode_CompletesWhenStreamIsDisposed() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write event to create a valid reader + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Start reading, then dispose the stream + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + } + }, CancellationToken); + + // Dispose the writer - the reader should detect this and exit gracefully + await writer.DisposeAsync(); + + // Assert - The read should complete gracefully within timeout + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(10)); + await readTask.WaitAsync(timeoutCts.Token); + } + + [Fact] + public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write event to create a valid reader + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Start reading and then cancel + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + var events = new List>(); + var messageReceivedTcs = new TaskCompletionSource(); + var continueReadingTcs = new TaskCompletionSource(); + OperationCanceledException? capturedException = null; + + var readTask = Task.Run(async () => + { + try + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + messageReceivedTcs.SetResult(true); + await continueReadingTcs.Task; + } + } + catch (OperationCanceledException ex) + { + capturedException = ex; + } + }, CancellationToken); + + // Write a message for the reader to consume + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Wait for the first message to be received + await messageReceivedTcs.Task; + + // Cancel so that ReadEventsAsync throws before reading the next message + await cts.CancelAsync(); + + // Allow the message reader to continue + continueReadingTcs.SetResult(true); + + // Wait for read task to complete + await readTask; + + Assert.Single(events); + Assert.NotNull(capturedException); + } + + [Fact] + public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write an event to create a valid reader + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Start reading in streaming mode (will wait for new events) + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(10)); + var events = new List>(); + var readCompleted = false; + + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + } + readCompleted = true; + }, CancellationToken); + + // Switch to polling mode - the reader should detect this and exit + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Assert - Read should complete within timeout after switching to polling mode + await readTask.WaitAsync(cts.Token); + Assert.True(readCompleted); + Assert.Empty(events); // No new events were written after the one we used to create the reader + } + + [Fact] + public async Task ReadEventsAsync_PollingModeReturnsEventsThenCompletes() + { + // Arrange - Start in default mode, write some events, switch to polling, reader should return remaining events + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming + }, CancellationToken); + + // Write initial event and create reader from sequence 0 + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + + // Write events first + var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Switch to polling mode + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Get reader + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act - Read should return events and complete immediately + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + stopwatch.Stop(); + + // Assert + Assert.Equal(2, events.Count); + Assert.Equal(event1.EventId, events[0].EventId); + Assert.Equal(event2.EventId, events[1].EventId); + Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); + } + + [Fact] + public async Task MultipleStreams_AreIsolated_EventsDoNotLeakBetweenStreams() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Create two streams with different session/stream IDs + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-2", + StreamId = "stream-2", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write events to each stream + var event1 = await writer1.WriteEventAsync(new SseItem(null, "event-from-stream1"), CancellationToken); + var event2 = await writer2.WriteEventAsync(new SseItem(null, "event-from-stream2"), CancellationToken); + + // Create readers for each stream from sequence 0 + var start1 = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var start2 = DistributedCacheEventIdFormatter.Format("session-2", "stream-2", 0); + + var reader1 = await store.GetStreamReaderAsync(start1, CancellationToken); + var reader2 = await store.GetStreamReaderAsync(start2, CancellationToken); + Assert.NotNull(reader1); + Assert.NotNull(reader2); + + // Act - Read from each reader + var events1 = new List>(); + await foreach (var evt in reader1.ReadEventsAsync(CancellationToken)) + { + events1.Add(evt); + } + + var events2 = new List>(); + await foreach (var evt in reader2.ReadEventsAsync(CancellationToken)) + { + events2.Add(evt); + } + + // Assert - Each reader should only see its own stream's events + Assert.Single(events1); + Assert.Equal("event-from-stream1", events1[0].EventType); + Assert.Equal(event1.EventId, events1[0].EventId); + + Assert.Single(events2); + Assert.Equal("event-from-stream2", events2[0].EventType); + Assert.Equal(event2.EventId, events2[0].EventId); + } + + [Fact] + public async Task MultipleStreams_SameSession_DifferentStreamIds_AreIsolated() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Create two streams with same session but different stream IDs + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "shared-session", + StreamId = "stream-A", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "shared-session", + StreamId = "stream-B", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write events to each stream + await writer1.WriteEventAsync(new SseItem(null, "from-A"), CancellationToken); + await writer2.WriteEventAsync(new SseItem(null, "from-B"), CancellationToken); + + // Create readers from sequence 0 + var reader1 = await store.GetStreamReaderAsync(DistributedCacheEventIdFormatter.Format("shared-session", "stream-A", 0), CancellationToken); + var reader2 = await store.GetStreamReaderAsync(DistributedCacheEventIdFormatter.Format("shared-session", "stream-B", 0), CancellationToken); + Assert.NotNull(reader1); + Assert.NotNull(reader2); + + // Act + var events1 = new List>(); + await foreach (var evt in reader1.ReadEventsAsync(CancellationToken)) + { + events1.Add(evt); + } + + var events2 = new List>(); + await foreach (var evt in reader2.ReadEventsAsync(CancellationToken)) + { + events2.Add(evt); + } + + // Assert + Assert.Single(events1); + Assert.Equal("from-A", events1[0].EventType); + + Assert.Single(events2); + Assert.Equal("from-B", events2[0].EventType); + } + + [Fact] + public async Task EventIds_AreGloballyUnique_AcrossStreams() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-2", + StreamId = "stream-2", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Act - Write events to each stream + var event1a = await writer1.WriteEventAsync(new SseItem(null), CancellationToken); + var event1b = await writer1.WriteEventAsync(new SseItem(null), CancellationToken); + var event2a = await writer2.WriteEventAsync(new SseItem(null), CancellationToken); + var event2b = await writer2.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert - All event IDs should be unique + var allEventIds = new[] { event1a.EventId, event1b.EventId, event2a.EventId, event2b.EventId }; + Assert.Equal(4, allEventIds.Distinct().Count()); + } + + [Fact] + public async Task WriteEventAsync_UsesConfiguredSlidingExpiration() + { + // Arrange + var mockCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventSlidingExpiration = TimeSpan.FromMinutes(30) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + mockCache.SetCalls.Clear(); + + // Act + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert - Event should be written with the configured sliding expiration + Assert.Contains(mockCache.SetCalls, call => + call.Key.Contains("event:") && + call.Options.SlidingExpiration == TimeSpan.FromMinutes(30)); + } + + [Fact] + public async Task WriteEventAsync_UsesConfiguredAbsoluteExpiration() + { + // Arrange + var mockCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventAbsoluteExpiration = TimeSpan.FromHours(6) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + mockCache.SetCalls.Clear(); + + // Act + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert - Event should be written with the configured absolute expiration (relative to now) + var eventCall = mockCache.SetCalls.FirstOrDefault(call => call.Key.Contains("event:")); + Assert.NotNull(eventCall.Key); + Assert.NotNull(eventCall.Options.AbsoluteExpirationRelativeToNow); + Assert.Equal(TimeSpan.FromHours(6), eventCall.Options.AbsoluteExpirationRelativeToNow); + } + + [Fact] + public async Task WriteEventAsync_UsesConfiguredMetadataExpiration() + { + // Arrange - Metadata is written when events are written + var mockCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + MetadataSlidingExpiration = TimeSpan.FromMinutes(45), + MetadataAbsoluteExpiration = TimeSpan.FromHours(12) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Act - Write an event, which also updates metadata + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert + var metadataCall = mockCache.SetCalls.FirstOrDefault(call => call.Key.Contains("meta:")); + Assert.NotNull(metadataCall.Key); + Assert.Equal(TimeSpan.FromMinutes(45), metadataCall.Options.SlidingExpiration); + Assert.Equal(TimeSpan.FromHours(12), metadataCall.Options.AbsoluteExpirationRelativeToNow); + } + + [Fact] + public void DefaultOptions_HaveReasonableDefaults() + { + // Arrange & Act + var options = new DistributedCacheEventStreamStoreOptions(); + + // Assert - Check that defaults are set reasonably + Assert.True(options.PollingInterval >= TimeSpan.FromMilliseconds(50), "Polling interval should be at least 50ms"); + Assert.True(options.EventSlidingExpiration > TimeSpan.Zero, "Event sliding expiration should be positive"); + Assert.True(options.EventAbsoluteExpiration > TimeSpan.Zero, "Event absolute expiration should be positive"); + Assert.True(options.MetadataSlidingExpiration > TimeSpan.Zero, "Metadata sliding expiration should be positive"); + Assert.True(options.MetadataAbsoluteExpiration > TimeSpan.Zero, "Metadata absolute expiration should be positive"); + } + + [Fact] + public async Task ReadEventsAsync_ThrowsMcpException_WhenMetadataExpires() + { + // Arrange - Use a cache that allows us to simulate metadata expiration + var trackingCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(10) // Fast polling to detect the bug quickly + }; + var store = new DistributedCacheEventStreamStore(trackingCache, customOptions); + + // Create a stream and write an event + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Streaming // Non-polling mode to trigger the waiting loop + }, CancellationToken); + + var item = new SseItem(new JsonRpcNotification { Method = "test" }); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Get a reader starting after the first event (so it will wait for more events) + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Now simulate metadata expiration + trackingCache.ExpireMetadata(); + + // Act & Assert - Reader should throw McpException when metadata expires + var exception = await Assert.ThrowsAsync(async () => + { + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + // Should not yield any events before throwing + } + }); + + Assert.Contains("session-1", exception.Message); + Assert.Contains("stream-1", exception.Message); + Assert.Contains("metadata", exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ReadEventsAsync_ThrowsMcpException_WhenEventExpires() + { + // Arrange - Use a cache that allows us to simulate event expiration + var trackingCache = new TestDistributedCache(); + var store = new DistributedCacheEventStreamStore(trackingCache); + + // Create a stream and write multiple events + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var event1 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method1" }), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method2" }), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method3" }), CancellationToken); + + // Create a reader starting from before the first event + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Simulate event2 expiring from the cache + trackingCache.ExpireEvent(event2.EventId!); + + // Act & Assert - Reader should throw McpException when an event is missing + var exception = await Assert.ThrowsAsync(async () => + { + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + }); + + Assert.Contains(event2.EventId!, exception.Message); + Assert.Contains("not found", exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ReadEventsAsync_DoesNotReadMetadata_InPollingMode() + { + // Arrange - Use a tracking cache to count metadata reads + var trackingCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(10) + }; + var store = new DistributedCacheEventStreamStore(trackingCache, customOptions); + + // Create a stream in POLLING mode - this allows the reader to exit after reading available events + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var item1 = new SseItem(new JsonRpcNotification { Method = "test1" }); + var item2 = new SseItem(new JsonRpcNotification { Method = "test2" }); + await writer.WriteEventAsync(item1, CancellationToken); + await writer.WriteEventAsync(item2, CancellationToken); + + // Get a reader starting before all events (use a fake event ID at sequence 0) + var zeroSequenceEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(zeroSequenceEventId, CancellationToken); + Assert.NotNull(reader); + + // GetStreamReaderAsync should have read metadata exactly once + Assert.Equal(1, trackingCache.MetadataReadCount); + + // Act - Read all events + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert - In polling mode, the reader should: + // 1. Use initial metadata from GetStreamReaderAsync (no additional read needed) + // 2. Read all available events (2 events) + // 3. Exit immediately because mode is Polling + // + // Metadata read count should remain at 1 (only the initial read from GetStreamReaderAsync) + Assert.Equal(2, events.Count); + Assert.Equal(1, trackingCache.MetadataReadCount); + } + + [Fact] + public void EventIdFormatter_Format_CreatesValidEventId() + { + // Act + var eventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 42); + + // Assert + Assert.NotNull(eventId); + Assert.NotEmpty(eventId); + Assert.Contains(":", eventId); // Should contain separators + } + + [Fact] + public void EventIdFormatter_TryParse_RoundTripsSuccessfully() + { + // Arrange + var originalSessionId = "my-session-id"; + var originalStreamId = "my-stream-id"; + var originalSequence = 12345L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesEmptySessionAndStreamIds() + { + // Arrange + var originalSessionId = ""; + var originalStreamId = ""; + var originalSequence = 42L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesSpecialCharactersInSessionId() + { + // Arrange - Session IDs can contain any visible ASCII character per MCP spec + var originalSessionId = "session:with:colons:and|pipes"; + var originalStreamId = "stream-1"; + var originalSequence = 1L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesSpecialCharactersInStreamId() + { + // Arrange + var originalSessionId = "session-1"; + var originalStreamId = "stream:with:colons:and|special!chars@#$%"; + var originalSequence = 1L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesUnicodeCharacters() + { + // Arrange + var originalSessionId = "session-日本語-émojis-🎉"; + var originalStreamId = "stream-中文-العربية"; + var originalSequence = 999L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesZeroSequence() + { + // Act + var eventId = DistributedCacheEventIdFormatter.Format("session", "stream", 0); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out _, out _, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(0, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesLargeSequence() + { + // Act + var eventId = DistributedCacheEventIdFormatter.Format("session", "stream", long.MaxValue); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out _, out _, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(long.MaxValue, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForEmptyString() + { + // Act + var parsed = DistributedCacheEventIdFormatter.TryParse("", out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.False(parsed); + Assert.Equal(string.Empty, sessionId); + Assert.Equal(string.Empty, streamId); + Assert.Equal(0, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForInvalidFormat() + { + // Act & Assert - Various invalid formats + Assert.False(DistributedCacheEventIdFormatter.TryParse("no-separators", out _, out _, out _)); + Assert.False(DistributedCacheEventIdFormatter.TryParse("only:one", out _, out _, out _)); + Assert.False(DistributedCacheEventIdFormatter.TryParse("too:many:parts:here", out _, out _, out _)); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForInvalidBase64() + { + // Act - Invalid base64 in first part + var parsed = DistributedCacheEventIdFormatter.TryParse("!!!invalid!!!:c3RyZWFt:1", out _, out _, out _); + + // Assert + Assert.False(parsed); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForNonNumericSequence() + { + // Arrange - Valid base64 but non-numeric sequence + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream")); + var invalidEventId = $"{sessionBase64}:{streamBase64}:not-a-number"; + + // Act + var parsed = DistributedCacheEventIdFormatter.TryParse(invalidEventId, out _, out _, out _); + + // Assert + Assert.False(parsed); + } + + /// + /// A distributed cache that tracks all operations for verification in tests. + /// Supports tracking Set calls, counting metadata reads, and simulating metadata/event expiration. + /// + private sealed class TestDistributedCache : IDistributedCache + { + private readonly MemoryDistributedCache _innerCache = new(Options.Create(new MemoryDistributedCacheOptions())); + private int _metadataReadCount; + private bool _metadataExpired; + private readonly HashSet _expiredEventIds = []; + + public List<(string Key, DistributedCacheEntryOptions Options)> SetCalls { get; } = []; + public int MetadataReadCount => _metadataReadCount; + + public void ExpireMetadata() => _metadataExpired = true; + public void ExpireEvent(string eventId) => _expiredEventIds.Add(eventId); + + public byte[]? Get(string key) + { + if (key.Contains("meta:")) + { + Interlocked.Increment(ref _metadataReadCount); + if (_metadataExpired) + { + return null; + } + } + if (IsExpiredEvent(key)) + { + return null; + } + return _innerCache.Get(key); + } + + public Task GetAsync(string key, CancellationToken token = default) + { + if (key.Contains("meta:")) + { + Interlocked.Increment(ref _metadataReadCount); + if (_metadataExpired) + { + return Task.FromResult(null); + } + } + if (IsExpiredEvent(key)) + { + return Task.FromResult(null); + } + return _innerCache.GetAsync(key, token); + } + + private bool IsExpiredEvent(string key) + { + // Cache key format is "mcp:sse:event:{eventId}" + foreach (var expiredEventId in _expiredEventIds) + { + if (key.EndsWith(expiredEventId)) + { + return true; + } + } + return false; + } + + public void Refresh(string key) => _innerCache.Refresh(key); + public Task RefreshAsync(string key, CancellationToken token = default) => _innerCache.RefreshAsync(key, token); + public void Remove(string key) => _innerCache.Remove(key); + public Task RemoveAsync(string key, CancellationToken token = default) => _innerCache.RemoveAsync(key, token); + + public void Set(string key, byte[] value, DistributedCacheEntryOptions options) + { + SetCalls.Add((key, options)); + _innerCache.Set(key, value, options); + } + + public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default) + { + SetCalls.Add((key, options)); + return _innerCache.SetAsync(key, value, options, token); + } + } +}