Skip to content

Commit e01c4f4

Browse files
committed
Clean up tests, throw on expired cache entries
1 parent 6160ab9 commit e01c4f4

File tree

3 files changed

+117
-66
lines changed

3 files changed

+117
-66
lines changed

src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// This is a shared source file included in both ModelContextProtocol.Core and the test project.
55
// Do not reference symbols internal to the core project, as they won't be available in tests.
66

7-
using System;
7+
using System.Text;
88

99
namespace ModelContextProtocol.Server;
1010

@@ -13,8 +13,6 @@ namespace ModelContextProtocol.Server;
1313
/// </summary>
1414
/// <remarks>
1515
/// Event IDs are formatted as "{base64(sessionId)}:{base64(streamId)}:{sequence}".
16-
/// Base64 encoding is used because the MCP specification allows session IDs to contain
17-
/// any visible ASCII character (0x21-0x7E), including the ':' separator character.
1816
/// </remarks>
1917
internal static class DistributedCacheEventIdFormatter
2018
{
@@ -27,8 +25,8 @@ public static string Format(string sessionId, string streamId, long sequence)
2725
{
2826
// Base64-encode session and stream IDs so the event ID can be parsed
2927
// even if the original IDs contain the ':' separator character
30-
var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId));
31-
var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(streamId));
28+
var sessionBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(sessionId));
29+
var streamBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(streamId));
3230
return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}";
3331
}
3432

@@ -49,8 +47,8 @@ public static bool TryParse(string eventId, out string sessionId, out string str
4947

5048
try
5149
{
52-
sessionId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[0]));
53-
streamId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[1]));
50+
sessionId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[0]));
51+
streamId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[1]));
5452
return long.TryParse(parts[2], out sequence);
5553
}
5654
catch

src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,8 @@ public DistributedCacheEventStreamReader(
255255

256256
var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, currentSequence);
257257
var eventKey = CacheKeys.Event(eventId);
258-
var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false);
259-
260-
if (eventBytes is null)
261-
{
262-
// Event may have expired; skip to next
263-
continue;
264-
}
258+
var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false)
259+
?? throw new McpException($"SSE event with ID '{eventId}' was not found in the cache. The event may have expired.");
265260

266261
var storedEvent = JsonSerializer.Deserialize(eventBytes, McpJsonUtilities.JsonContext.Default.StoredEvent);
267262
if (storedEvent is not null)
@@ -290,19 +285,11 @@ public DistributedCacheEventStreamReader(
290285

291286
// Refresh metadata to get the latest sequence and completion status
292287
var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId);
293-
var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false);
294-
295-
if (metadataBytes is null)
296-
{
297-
// Metadata expired - treat stream as complete to avoid infinite loop
298-
yield break;
299-
}
288+
var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false)
289+
?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' was not found in the cache. The metadata may have expired.");
300290

301-
var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata);
302-
if (currentMetadata is null)
303-
{
304-
yield break;
305-
}
291+
var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata)
292+
?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' could not be deserialized.");
306293

307294
lastSequence = currentMetadata.LastSequence;
308295
isCompleted = currentMetadata.IsCompleted;

tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs

Lines changed: 106 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,8 @@ public async Task ReadEventsAsync_InStreamingMode_WaitsForNewEvents()
756756
}
757757
}, CancellationToken);
758758

759-
// Wait a bit, then write a new event
760-
await Task.Delay(100, CancellationToken);
759+
// Write a new event - the reader should pick it up since it's in streaming mode
760+
// and won't complete until cancelled or the stream is disposed
761761
var newEvent = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
762762

763763
// Wait for read to complete (either event received or timeout)
@@ -812,8 +812,7 @@ public async Task ReadEventsAsync_InStreamingMode_YieldsNewlyWrittenEvents()
812812
}
813813
}, CancellationToken);
814814

815-
// Write 3 new events
816-
await Task.Delay(100, CancellationToken);
815+
// Write 3 new events - the reader should pick them up since it's in streaming mode
817816
var event1 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
818817
var event2 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
819818
var event3 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
@@ -856,25 +855,20 @@ public async Task ReadEventsAsync_InStreamingMode_CompletesWhenStreamIsDisposed(
856855
Assert.NotNull(reader);
857856

858857
// Act - Start reading, then dispose the stream
859-
var events = new List<SseItem<JsonRpcMessage?>>();
860858
var readTask = Task.Run(async () =>
861859
{
862860
await foreach (var evt in reader.ReadEventsAsync(CancellationToken))
863861
{
864-
events.Add(evt);
865862
}
866863
}, CancellationToken);
867864

868-
// Wait a bit, then dispose the writer
869-
await Task.Delay(100, CancellationToken);
865+
// Dispose the writer - the reader should detect this and exit gracefully
870866
await writer.DisposeAsync();
871867

872-
// Wait for read to complete with a timeout
873-
var timeoutTask = Task.Delay(TimeSpan.FromSeconds(2), CancellationToken);
874-
var completedTask = await Task.WhenAny(readTask, timeoutTask);
875-
876-
// Assert - The read should complete gracefully (not timeout)
877-
Assert.Same(readTask, completedTask);
868+
// Assert - The read should complete gracefully within timeout
869+
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
870+
timeoutCts.CancelAfter(TimeSpan.FromSeconds(2));
871+
await readTask.WaitAsync(timeoutCts.Token);
878872
}
879873

880874
[Fact]
@@ -900,8 +894,9 @@ public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation()
900894

901895
// Act - Start reading and then cancel
902896
using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
903-
var stopwatch = System.Diagnostics.Stopwatch.StartNew();
904897
var events = new List<SseItem<JsonRpcMessage?>>();
898+
var messageReceivedTcs = new TaskCompletionSource<bool>();
899+
var continueReadingTcs = new TaskCompletionSource<bool>();
905900
OperationCanceledException? capturedException = null;
906901

907902
var readTask = Task.Run(async () =>
@@ -911,6 +906,8 @@ public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation()
911906
await foreach (var evt in reader.ReadEventsAsync(cts.Token))
912907
{
913908
events.Add(evt);
909+
messageReceivedTcs.SetResult(true);
910+
await continueReadingTcs.Task;
914911
}
915912
}
916913
catch (OperationCanceledException ex)
@@ -919,17 +916,23 @@ public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation()
919916
}
920917
}, CancellationToken);
921918

922-
// Cancel after a short delay
923-
await Task.Delay(100, CancellationToken);
919+
// Write a message for the reader to consume
920+
await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(null), CancellationToken);
921+
922+
// Wait for the first message to be received
923+
await messageReceivedTcs.Task;
924+
925+
// Cancel so that ReadEventsAsync throws before reading the next message
924926
await cts.CancelAsync();
925927

928+
// Allow the message reader to continue
929+
continueReadingTcs.SetResult(true);
930+
926931
// Wait for read task to complete
927932
await readTask;
928-
stopwatch.Stop();
929933

930-
// Assert - Either cancelled exception or graceful exit, but should complete quickly
931-
Assert.Empty(events); // No events should have been received
932-
Assert.True(stopwatch.ElapsedMilliseconds < 1000, $"Should complete quickly after cancellation, took {stopwatch.ElapsedMilliseconds}ms");
934+
Assert.Single(events);
935+
Assert.NotNull(capturedException);
933936
}
934937

935938
[Fact]
@@ -953,7 +956,7 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling()
953956
var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken);
954957
Assert.NotNull(reader);
955958

956-
// Start reading in default mode (will wait for new events)
959+
// Start reading in streaming mode (will wait for new events)
957960
using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
958961
cts.CancelAfter(TimeSpan.FromSeconds(3));
959962
var events = new List<SseItem<JsonRpcMessage?>>();
@@ -968,16 +971,13 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling()
968971
readCompleted = true;
969972
}, CancellationToken);
970973

971-
// Wait a bit, then switch to polling mode
972-
await Task.Delay(100, CancellationToken);
974+
// Switch to polling mode - the reader should detect this and exit
973975
await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken);
974976

975-
// Wait for read task to complete (should complete quickly after mode switch)
976-
var timeoutTask = Task.Delay(TimeSpan.FromSeconds(1), CancellationToken);
977-
var completedTask = await Task.WhenAny(readTask, timeoutTask);
978-
979-
// Assert - Read should have completed after switching to polling mode
980-
Assert.Same(readTask, completedTask);
977+
// Assert - Read should complete within timeout after switching to polling mode
978+
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken);
979+
timeoutCts.CancelAfter(TimeSpan.FromSeconds(1));
980+
await readTask.WaitAsync(timeoutCts.Token);
981981
Assert.True(readCompleted);
982982
Assert.Empty(events); // No new events were written after the one we used to create the reader
983983
}
@@ -1271,7 +1271,7 @@ public void DefaultOptions_HaveReasonableDefaults()
12711271
}
12721272

12731273
[Fact]
1274-
public async Task ReadEventsAsync_Completes_WhenMetadataExpires()
1274+
public async Task ReadEventsAsync_ThrowsMcpException_WhenMetadataExpires()
12751275
{
12761276
// Arrange - Use a cache that allows us to simulate metadata expiration
12771277
var trackingCache = new TestDistributedCache();
@@ -1299,16 +1299,59 @@ public async Task ReadEventsAsync_Completes_WhenMetadataExpires()
12991299
// Now simulate metadata expiration
13001300
trackingCache.ExpireMetadata();
13011301

1302-
// Act - Read events; the reader should complete gracefully when metadata expires
1303-
// instead of looping indefinitely with the stale initial metadata
1304-
var events = new List<SseItem<JsonRpcMessage?>>();
1305-
await foreach (var evt in reader.ReadEventsAsync(CancellationToken))
1302+
// Act & Assert - Reader should throw McpException when metadata expires
1303+
var exception = await Assert.ThrowsAsync<McpException>(async () =>
13061304
{
1307-
events.Add(evt);
1308-
}
1305+
await foreach (var evt in reader.ReadEventsAsync(CancellationToken))
1306+
{
1307+
// Should not yield any events before throwing
1308+
}
1309+
});
13091310

1310-
// If we reach here without timeout, the reader correctly handled metadata expiration
1311-
Assert.Empty(events); // No new events after the initial one used to create the reader
1311+
Assert.Contains("session-1", exception.Message);
1312+
Assert.Contains("stream-1", exception.Message);
1313+
Assert.Contains("metadata", exception.Message, StringComparison.OrdinalIgnoreCase);
1314+
}
1315+
1316+
[Fact]
1317+
public async Task ReadEventsAsync_ThrowsMcpException_WhenEventExpires()
1318+
{
1319+
// Arrange - Use a cache that allows us to simulate event expiration
1320+
var trackingCache = new TestDistributedCache();
1321+
var store = new DistributedCacheEventStreamStore(trackingCache);
1322+
1323+
// Create a stream and write multiple events
1324+
var writer = await store.CreateStreamAsync(new SseEventStreamOptions
1325+
{
1326+
SessionId = "session-1",
1327+
StreamId = "stream-1",
1328+
Mode = SseEventStreamMode.Polling
1329+
}, CancellationToken);
1330+
1331+
var event1 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(new JsonRpcNotification { Method = "method1" }), CancellationToken);
1332+
var event2 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(new JsonRpcNotification { Method = "method2" }), CancellationToken);
1333+
var event3 = await writer.WriteEventAsync(new SseItem<JsonRpcMessage?>(new JsonRpcNotification { Method = "method3" }), CancellationToken);
1334+
1335+
// Create a reader starting from before the first event
1336+
var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0);
1337+
var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken);
1338+
Assert.NotNull(reader);
1339+
1340+
// Simulate event2 expiring from the cache
1341+
trackingCache.ExpireEvent(event2.EventId!);
1342+
1343+
// Act & Assert - Reader should throw McpException when an event is missing
1344+
var exception = await Assert.ThrowsAsync<McpException>(async () =>
1345+
{
1346+
var events = new List<SseItem<JsonRpcMessage?>>();
1347+
await foreach (var evt in reader.ReadEventsAsync(CancellationToken))
1348+
{
1349+
events.Add(evt);
1350+
}
1351+
});
1352+
1353+
Assert.Contains(event2.EventId!, exception.Message);
1354+
Assert.Contains("not found", exception.Message, StringComparison.OrdinalIgnoreCase);
13121355
}
13131356

13141357
[Fact]
@@ -1521,18 +1564,20 @@ public void EventIdFormatter_TryParse_ReturnsFalse_ForNonNumericSequence()
15211564

15221565
/// <summary>
15231566
/// A distributed cache that tracks all operations for verification in tests.
1524-
/// Supports tracking Set calls, counting metadata reads, and simulating metadata expiration.
1567+
/// Supports tracking Set calls, counting metadata reads, and simulating metadata/event expiration.
15251568
/// </summary>
15261569
private sealed class TestDistributedCache : IDistributedCache
15271570
{
15281571
private readonly MemoryDistributedCache _innerCache = new(Options.Create(new MemoryDistributedCacheOptions()));
15291572
private int _metadataReadCount;
15301573
private bool _metadataExpired;
1574+
private readonly HashSet<string> _expiredEventIds = [];
15311575

15321576
public List<(string Key, DistributedCacheEntryOptions Options)> SetCalls { get; } = [];
15331577
public int MetadataReadCount => _metadataReadCount;
15341578

15351579
public void ExpireMetadata() => _metadataExpired = true;
1580+
public void ExpireEvent(string eventId) => _expiredEventIds.Add(eventId);
15361581

15371582
public byte[]? Get(string key)
15381583
{
@@ -1544,6 +1589,10 @@ private sealed class TestDistributedCache : IDistributedCache
15441589
return null;
15451590
}
15461591
}
1592+
if (IsExpiredEvent(key))
1593+
{
1594+
return null;
1595+
}
15471596
return _innerCache.Get(key);
15481597
}
15491598

@@ -1557,9 +1606,26 @@ private sealed class TestDistributedCache : IDistributedCache
15571606
return Task.FromResult<byte[]?>(null);
15581607
}
15591608
}
1609+
if (IsExpiredEvent(key))
1610+
{
1611+
return Task.FromResult<byte[]?>(null);
1612+
}
15601613
return _innerCache.GetAsync(key, token);
15611614
}
15621615

1616+
private bool IsExpiredEvent(string key)
1617+
{
1618+
// Cache key format is "mcp:sse:event:{eventId}"
1619+
foreach (var expiredEventId in _expiredEventIds)
1620+
{
1621+
if (key.EndsWith(expiredEventId))
1622+
{
1623+
return true;
1624+
}
1625+
}
1626+
return false;
1627+
}
1628+
15631629
public void Refresh(string key) => _innerCache.Refresh(key);
15641630
public Task RefreshAsync(string key, CancellationToken token = default) => _innerCache.RefreshAsync(key, token);
15651631
public void Remove(string key) => _innerCache.Remove(key);

0 commit comments

Comments
 (0)