Skip to content

Commit 265107f

Browse files
committed
Update HttpListenerSseServerTransport to use SseResponseStreamTransport
1 parent 125db2b commit 265107f

File tree

2 files changed

+39
-89
lines changed

2 files changed

+39
-89
lines changed

src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
using ModelContextProtocol.Server;
2-
using System.Net;
3-
using System.Text;
1+
using System.Net;
2+
using ModelContextProtocol.Server;
43

54
namespace ModelContextProtocol.Protocol.Transport;
65

@@ -17,8 +16,6 @@ internal class HttpListenerServerProvider : IDisposable
1716
private readonly int _port;
1817
private HttpListener? _listener;
1918
private CancellationTokenSource? _cts;
20-
private Func<string, CancellationToken, bool>? _messageHandler;
21-
private StreamWriter? _streamWriter;
2219
private bool _isRunning;
2320

2421
/// <summary>
@@ -30,34 +27,16 @@ public HttpListenerServerProvider(int port)
3027
_port = port;
3128
}
3229

33-
public Task InitializeMessageHandler(Func<string, CancellationToken, bool> messageHandler)
34-
{
35-
_messageHandler = messageHandler;
36-
return Task.CompletedTask;
37-
}
38-
39-
public async Task SendEvent(string data, string eventId)
40-
{
41-
if (_streamWriter == null)
42-
{
43-
throw new McpServerException("Stream writer not initialized");
44-
}
45-
if (eventId != null)
46-
{
47-
await _streamWriter.WriteLineAsync($"id: {eventId}").ConfigureAwait(false);
48-
}
49-
await _streamWriter.WriteLineAsync($"data: {data}").ConfigureAwait(false);
50-
await _streamWriter.WriteLineAsync().ConfigureAwait(false); // Empty line to finish the event
51-
await _streamWriter.FlushAsync().ConfigureAwait(false);
52-
}
30+
public required Func<Stream, CancellationToken, Task> OnSseConnectionAsync { get; set; }
31+
public required Func<string, CancellationToken, Task<bool>> OnMessageAsync { get; set; }
5332

5433
/// <inheritdoc/>
5534
public Task StartAsync(CancellationToken cancellationToken = default)
5635
{
5736
if (_isRunning)
5837
return Task.CompletedTask;
5938

60-
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
39+
_cts = new CancellationTokenSource();
6140
_listener = new HttpListener();
6241
_listener.Prefixes.Add($"http://localhost:{_port}/");
6342
_listener.Start();
@@ -77,8 +56,6 @@ public Task StopAsync(CancellationToken cancellationToken = default)
7756
_cts?.Cancel();
7857
_listener?.Stop();
7958

80-
_streamWriter?.Close();
81-
8259
_isRunning = false;
8360
return Task.CompletedTask;
8461
}
@@ -163,28 +140,10 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
163140
response.Headers.Add("Cache-Control", "no-cache");
164141
response.Headers.Add("Connection", "keep-alive");
165142

166-
// Get the output stream and create a StreamWriter
167-
var outputStream = response.OutputStream;
168-
_streamWriter = new StreamWriter(outputStream, Encoding.UTF8) { AutoFlush = true };
169-
170143
// Keep the connection open until cancelled
171144
try
172145
{
173-
// Immediately send the "endpoint" event with the POST URL
174-
await _streamWriter.WriteLineAsync("event: endpoint").ConfigureAwait(false);
175-
await _streamWriter.WriteLineAsync($"data: {MessageEndpoint}").ConfigureAwait(false);
176-
await _streamWriter.WriteLineAsync().ConfigureAwait(false); // blank line to end an SSE message
177-
await _streamWriter.FlushAsync(cancellationToken).ConfigureAwait(false);
178-
179-
// Keep the connection open by "pinging" or just waiting
180-
// until the client disconnects or the server is canceled.
181-
while (!cancellationToken.IsCancellationRequested && response.OutputStream.CanWrite)
182-
{
183-
// Do a periodic no-op to keep connection alive:
184-
await _streamWriter.WriteLineAsync(": keep-alive").ConfigureAwait(false);
185-
await _streamWriter.FlushAsync(cancellationToken).ConfigureAwait(false);
186-
await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
187-
}
146+
await OnSseConnectionAsync(response.OutputStream, cancellationToken).ConfigureAwait(false);
188147
}
189148
catch (TaskCanceledException)
190149
{
@@ -199,7 +158,6 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
199158
// Remove client on disconnect
200159
try
201160
{
202-
_streamWriter.Close();
203161
response.Close();
204162
}
205163
catch { /* Ignore errors during cleanup */ }
@@ -219,7 +177,7 @@ private async Task HandleMessageAsync(HttpListenerContext context, CancellationT
219177
}
220178

221179
// Process the message asynchronously
222-
if (_messageHandler != null && _messageHandler(requestBody, cancellationToken))
180+
if (await OnMessageAsync(requestBody, cancellationToken))
223181
{
224182
// Return 202 Accepted
225183
response.StatusCode = 202;

src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ public sealed class HttpListenerSseServerTransport : TransportBase, IServerTrans
1717
private readonly string _serverName;
1818
private readonly HttpListenerServerProvider _httpServerProvider;
1919
private readonly ILogger<HttpListenerSseServerTransport> _logger;
20-
private readonly JsonSerializerOptions _jsonOptions;
21-
private CancellationTokenSource? _shutdownCts;
22-
20+
private SseResponseStreamTransport? _sseResponseStreamTransport;
21+
2322
private string EndpointName => $"Server (SSE) ({_serverName})";
2423

2524
/// <summary>
@@ -44,28 +43,23 @@ public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactor
4443
{
4544
_serverName = serverName;
4645
_logger = loggerFactory.CreateLogger<HttpListenerSseServerTransport>();
47-
_jsonOptions = McpJsonUtilities.DefaultOptions;
48-
_httpServerProvider = new HttpListenerServerProvider(port);
46+
_httpServerProvider = new HttpListenerServerProvider(port)
47+
{
48+
OnSseConnectionAsync = OnSseConnectionAsync,
49+
OnMessageAsync = OnMessageAsync,
50+
};
4951
}
5052

5153
/// <inheritdoc/>
5254
public Task StartListeningAsync(CancellationToken cancellationToken = default)
5355
{
54-
_shutdownCts = new CancellationTokenSource();
55-
56-
_httpServerProvider.InitializeMessageHandler(HttpMessageHandler);
57-
_httpServerProvider.StartAsync(cancellationToken);
58-
59-
SetConnected(true);
60-
61-
return Task.CompletedTask;
56+
return _httpServerProvider.StartAsync(cancellationToken);
6257
}
6358

64-
6559
/// <inheritdoc/>
6660
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
6761
{
68-
if (!IsConnected)
62+
if (!IsConnected || _sseResponseStreamTransport is null)
6963
{
7064
_logger.TransportNotConnected(EndpointName);
7165
throw new McpTransportException("Transport is not connected");
@@ -79,10 +73,10 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
7973

8074
try
8175
{
82-
var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
76+
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());
8377
_logger.TransportSendingMessage(EndpointName, id, json);
8478

85-
await _httpServerProvider.SendEvent(json, "message").ConfigureAwait(false);
79+
await _sseResponseStreamTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
8680

8781
_logger.TransportSentMessage(EndpointName, id);
8882
}
@@ -100,49 +94,47 @@ public override async ValueTask DisposeAsync()
10094
GC.SuppressFinalize(this);
10195
}
10296

103-
private async Task CleanupAsync(CancellationToken cancellationToken)
97+
private Task CleanupAsync(CancellationToken cancellationToken)
10498
{
10599
_logger.TransportCleaningUp(EndpointName);
106100

107-
if (_shutdownCts != null)
108-
{
109-
await _shutdownCts.CancelAsync().ConfigureAwait(false);
110-
_shutdownCts.Dispose();
111-
_shutdownCts = null;
112-
}
113-
114101
_httpServerProvider.Dispose();
115-
116102
SetConnected(false);
103+
117104
_logger.TransportCleanedUp(EndpointName);
105+
return Task.CompletedTask;
106+
}
107+
108+
private async Task OnSseConnectionAsync(Stream responseStream, CancellationToken cancellationToken)
109+
{
110+
await using var sseResponseStreamTransport = new SseResponseStreamTransport(responseStream);
111+
_sseResponseStreamTransport = sseResponseStreamTransport;
112+
SetConnected(true);
113+
await sseResponseStreamTransport.RunAsync(cancellationToken);
118114
}
119115

120116
/// <summary>
121117
/// Handles HTTP messages received by the HTTP server provider.
122118
/// </summary>
123119
/// <returns>true if the message was accepted (return 202), false otherwise (return 400)</returns>
124-
private bool HttpMessageHandler(string request, CancellationToken cancellationToken)
120+
private async Task<bool> OnMessageAsync(string request, CancellationToken cancellationToken)
125121
{
126122
_logger.TransportReceivedMessage(EndpointName, request);
127123

128124
try
129125
{
130-
var message = JsonSerializer.Deserialize(request, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
126+
var message = JsonSerializer.Deserialize(request, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());
131127
if (message != null)
132128
{
133-
// Fire-and-forget the message to the message channel
134-
Task.Run(async () =>
129+
string messageId = "(no id)";
130+
if (message is IJsonRpcMessageWithId messageWithId)
135131
{
136-
string messageId = "(no id)";
137-
if (message is IJsonRpcMessageWithId messageWithId)
138-
{
139-
messageId = messageWithId.Id.ToString();
140-
}
141-
142-
_logger.TransportReceivedMessageParsed(EndpointName, messageId);
143-
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
144-
_logger.TransportMessageWritten(EndpointName, messageId);
145-
}, cancellationToken);
132+
messageId = messageWithId.Id.ToString();
133+
}
134+
135+
_logger.TransportReceivedMessageParsed(EndpointName, messageId);
136+
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
137+
_logger.TransportMessageWritten(EndpointName, messageId);
146138

147139
return true;
148140
}

0 commit comments

Comments
 (0)