diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs index a89bd29c2..db96a53bb 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs @@ -20,7 +20,7 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream) : ITran private Utf8JsonWriter? _jsonWriter; /// - public bool IsConnected => _sseWriteTask?.IsCompleted == false; + public bool IsConnected { get; private set; } /// /// Starts the transport and writes the JSON-RPC messages sent via @@ -41,6 +41,8 @@ void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter()); } + IsConnected = true; + // The very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single item of a different type, // so we fib and special-case the "endpoint" event type in the formatter. _outgoingSseChannel.Writer.TryWrite(new SseItem(null, "endpoint")); @@ -55,6 +57,7 @@ void WriteJsonRpcMessageToBuffer(SseItem item, IBufferWriter public ValueTask DisposeAsync() { + IsConnected = false; _incomingChannel.Writer.TryComplete(); _outgoingSseChannel.Writer.TryComplete(); return new ValueTask(_sseWriteTask ?? Task.CompletedTask); @@ -63,7 +66,7 @@ public ValueTask DisposeAsync() /// public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { - if (_sseWriteTask is null) + if (!IsConnected) { throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); } @@ -80,7 +83,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca /// Thrown when there is an attempt to process a message before calling . public async Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationToken cancellationToken) { - if (_sseWriteTask is null) + if (!IsConnected) { throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); }