Skip to content

Commit 114feac

Browse files
committed
Amend test fix
1 parent e2d3ba0 commit 114feac

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
344344
return default;
345345
});
346346

347-
// Wait for the client's GET SSE stream to be established before sending notifications
348-
await faultingStreamHandler.WaitForStreamAsync(TestContext.Current.CancellationToken);
347+
// Wait for the client's unsolicited message stream to be established before sending notifications
348+
await faultingStreamHandler.WaitForUnsolicitedMessageStreamAsync(TestContext.Current.CancellationToken);
349349

350350
// Send a custom notification to the client on the unsolicited message stream
351351
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken);

tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ internal sealed class FaultingStreamHandler : DelegatingHandler
1111
{
1212
private FaultingStream? _lastStream;
1313
private TaskCompletionSource? _reconnectTcs;
14-
private TaskCompletionSource _streamAvailableTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
14+
private TaskCompletionSource _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
1515

16-
public Task WaitForStreamAsync(CancellationToken cancellationToken = default)
17-
=> _streamAvailableTcs.Task.WaitAsync(cancellationToken);
16+
public Task WaitForUnsolicitedMessageStreamAsync(CancellationToken cancellationToken = default)
17+
=> _unsolicitedMessageStreamReadyTcs.Task.WaitAsync(cancellationToken);
18+
19+
internal void SignalUnsolicitedMessageStreamReady() => _unsolicitedMessageStreamReadyTcs.TrySetResult();
1820

1921
public async Task<ReconnectAttempt> TriggerFaultAsync(CancellationToken cancellationToken)
2022
{
@@ -28,7 +30,9 @@ public async Task<ReconnectAttempt> TriggerFaultAsync(CancellationToken cancella
2830
throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection.");
2931
}
3032

31-
_streamAvailableTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
33+
// Reset the TCS so we can wait for the reconnected unsolicited message stream
34+
_unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
35+
3236
_reconnectTcs = new();
3337
await _lastStream.TriggerFaultAsync(cancellationToken);
3438

@@ -51,6 +55,7 @@ protected override async Task<HttpResponseMessage> SendAsync(
5155
_reconnectTcs = null;
5256
}
5357

58+
var isGetRequest = request.Method == HttpMethod.Get;
5459
var response = await base.SendAsync(request, cancellationToken);
5560

5661
// Only wrap SSE streams (text/event-stream)
@@ -69,7 +74,12 @@ protected override async Task<HttpResponseMessage> SendAsync(
6974

7075
response.Content = newContent;
7176

72-
_streamAvailableTcs.TrySetResult();
77+
// For GET requests (unsolicited message stream), set up the stream to signal
78+
// when first data is read. This ensures the server's transport handler is ready.
79+
if (isGetRequest)
80+
{
81+
_lastStream.SetReadyCallback(SignalUnsolicitedMessageStreamReady);
82+
}
7383
}
7484

7585
return response;
@@ -96,10 +106,14 @@ private sealed class FaultingStream(Stream innerStream) : Stream
96106
{
97107
private readonly CancellationTokenSource _cts = new();
98108
private TaskCompletionSource? _faultTcs;
109+
private Action? _readyCallback;
110+
private bool _readySignaled;
99111
private bool _disposed;
100112

101113
public bool IsDisposed => _disposed;
102114

115+
public void SetReadyCallback(Action callback) => _readyCallback = callback;
116+
103117
public async Task TriggerFaultAsync(CancellationToken cancellationToken)
104118
{
105119
if (_faultTcs is not null)
@@ -138,6 +152,12 @@ public override async ValueTask<int> ReadAsync(Memory<byte> buffer, Cancellation
138152

139153
_cts.Token.ThrowIfCancellationRequested();
140154

155+
if (bytesRead > 0 && !_readySignaled)
156+
{
157+
_readySignaled = true;
158+
_readyCallback?.Invoke();
159+
}
160+
141161
return bytesRead;
142162
}
143163
catch (OperationCanceledException) when (_cts.IsCancellationRequested)

0 commit comments

Comments
 (0)