diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs
new file mode 100644
index 000000000..af304c897
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs
@@ -0,0 +1,40 @@
+using Microsoft.Extensions.Logging;
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+using System.Diagnostics;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+/// Provides the client side of a stdio-based session transport.
+internal sealed class StdioClientSessionTransport : StreamClientSessionTransport
+{
+ private readonly StdioClientTransportOptions _options;
+ private readonly Process _process;
+
+ public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, ILoggerFactory? loggerFactory)
+ : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory)
+ {
+ _process = process;
+ _options = options;
+ }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ if (_process.HasExited)
+ {
+ Logger.TransportNotConnected(EndpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ protected override ValueTask CleanupAsync(CancellationToken cancellationToken)
+ {
+ StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName);
+
+ return base.CleanupAsync(cancellationToken);
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs
deleted file mode 100644
index 35c957e52..000000000
--- a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs
+++ /dev/null
@@ -1,326 +0,0 @@
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Abstractions;
-using ModelContextProtocol.Logging;
-using ModelContextProtocol.Protocol.Messages;
-using ModelContextProtocol.Utils;
-using ModelContextProtocol.Utils.Json;
-using System.Diagnostics;
-using System.Text;
-using System.Text.Json;
-
-namespace ModelContextProtocol.Protocol.Transport;
-
-///
-/// Implements the MCP transport protocol over standard input/output streams.
-///
-internal sealed class StdioClientStreamTransport : TransportBase
-{
- private readonly StdioClientTransportOptions _options;
- private readonly McpServerConfig _serverConfig;
- private readonly ILogger _logger;
- private readonly JsonSerializerOptions _jsonOptions;
- private readonly DataReceivedEventHandler _logProcessErrors;
- private readonly SemaphoreSlim _sendLock = new(1, 1);
- private Process? _process;
- private Task? _readTask;
- private CancellationTokenSource? _shutdownCts;
- private bool _processStarted;
-
- private string EndpointName => $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})";
-
- ///
- /// Initializes a new instance of the StdioTransport class.
- ///
- /// Configuration options for the transport.
- /// The server configuration for the transport.
- /// A logger factory for creating loggers.
- public StdioClientStreamTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null)
- : base(loggerFactory)
- {
- Throw.IfNull(options);
- Throw.IfNull(serverConfig);
-
- _options = options;
- _serverConfig = serverConfig;
- _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
- _logProcessErrors = (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)");
- _jsonOptions = McpJsonUtilities.DefaultOptions;
- }
-
- ///
- public async Task ConnectAsync(CancellationToken cancellationToken = default)
- {
- if (IsConnected)
- {
- _logger.TransportAlreadyConnected(EndpointName);
- throw new McpTransportException("Transport is already connected");
- }
-
- try
- {
- _logger.TransportConnecting(EndpointName);
-
- _shutdownCts = new CancellationTokenSource();
-
- UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false);
-
- var startInfo = new ProcessStartInfo
- {
- FileName = _options.Command,
- RedirectStandardInput = true,
- RedirectStandardOutput = true,
- RedirectStandardError = true,
- UseShellExecute = false,
- CreateNoWindow = true,
- WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory,
- StandardOutputEncoding = noBomUTF8,
- StandardErrorEncoding = noBomUTF8,
-#if NET
- StandardInputEncoding = noBomUTF8,
-#endif
- };
-
- if (!string.IsNullOrWhiteSpace(_options.Arguments))
- {
- startInfo.Arguments = _options.Arguments;
- }
-
- if (_options.EnvironmentVariables != null)
- {
- foreach (var entry in _options.EnvironmentVariables)
- {
- startInfo.Environment[entry.Key] = entry.Value;
- }
- }
-
- _logger.CreateProcessForTransport(EndpointName, _options.Command,
- startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)),
- startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString());
-
- _process = new Process { StartInfo = startInfo };
-
- // Set up error logging
- _process.ErrorDataReceived += _logProcessErrors;
-
- // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core,
- // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but
- // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks
- // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core,
- // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start
- // call, to ensure it picks up the correct encoding.
-#if NET
- _processStarted = _process.Start();
-#else
- Encoding originalInputEncoding = Console.InputEncoding;
- try
- {
- Console.InputEncoding = noBomUTF8;
- _processStarted = _process.Start();
- }
- finally
- {
- Console.InputEncoding = originalInputEncoding;
- }
-#endif
-
- if (!_processStarted)
- {
- _logger.TransportProcessStartFailed(EndpointName);
- throw new McpTransportException("Failed to start MCP server process");
- }
-
- _logger.TransportProcessStarted(EndpointName, _process.Id);
-
- _process.BeginErrorReadLine();
-
- // Start reading messages in the background
- _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None);
- _logger.TransportReadingMessages(EndpointName);
-
- SetConnected(true);
- }
- catch (Exception ex)
- {
- _logger.TransportConnectFailed(EndpointName, ex);
- await CleanupAsync(cancellationToken).ConfigureAwait(false);
- throw new McpTransportException("Failed to connect transport", ex);
- }
- }
-
- ///
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
- {
- using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
-
- if (!IsConnected || _process?.HasExited == true)
- {
- _logger.TransportNotConnected(EndpointName);
- throw new McpTransportException("Transport is not connected");
- }
-
- string id = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- id = messageWithId.Id.ToString();
- }
-
- try
- {
- var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo());
- _logger.TransportSendingMessage(EndpointName, id, json);
- _logger.TransportMessageBytesUtf8(EndpointName, json);
-
- // Write the message followed by a newline using our UTF-8 writer
- await _process!.StandardInput.WriteLineAsync(json).ConfigureAwait(false);
- await _process.StandardInput.FlushAsync(cancellationToken).ConfigureAwait(false);
-
- _logger.TransportSentMessage(EndpointName, id);
- }
- catch (Exception ex)
- {
- _logger.TransportSendFailed(EndpointName, id, ex);
- throw new McpTransportException("Failed to send message", ex);
- }
- }
-
- ///
- public override async ValueTask DisposeAsync()
- {
- await CleanupAsync(CancellationToken.None).ConfigureAwait(false);
- }
-
- private async Task ReadMessagesAsync(CancellationToken cancellationToken)
- {
- try
- {
- _logger.TransportEnteringReadMessagesLoop(EndpointName);
-
- while (!cancellationToken.IsCancellationRequested && !_process!.HasExited)
- {
- _logger.TransportWaitingForMessage(EndpointName);
- var line = await _process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false);
- if (line == null)
- {
- _logger.TransportEndOfStream(EndpointName);
- break;
- }
-
- if (string.IsNullOrWhiteSpace(line))
- {
- continue;
- }
-
- _logger.TransportReceivedMessage(EndpointName, line);
- _logger.TransportMessageBytesUtf8(EndpointName, line);
-
- await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
- }
- _logger.TransportExitingReadMessagesLoop(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportReadMessagesCancelled(EndpointName);
- // Normal shutdown
- }
- catch (Exception ex)
- {
- _logger.TransportReadMessagesFailed(EndpointName, ex);
- }
- finally
- {
- await CleanupAsync(cancellationToken).ConfigureAwait(false);
- }
- }
-
- private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
- {
- try
- {
- line=line.Trim();//Fixes an error when the service prefixes nonprintable characters
- var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo());
- if (message != null)
- {
- string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- messageId = messageWithId.Id.ToString();
- }
- _logger.TransportReceivedMessageParsed(EndpointName, messageId);
- await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
- _logger.TransportMessageWritten(EndpointName, messageId);
- }
- else
- {
- _logger.TransportMessageParseUnexpectedType(EndpointName, line);
- }
- }
- catch (JsonException ex)
- {
- _logger.TransportMessageParseFailed(EndpointName, line, ex);
- }
- }
-
- private async Task CleanupAsync(CancellationToken cancellationToken)
- {
- _logger.TransportCleaningUp(EndpointName);
-
- if (_process is Process process && _processStarted && !process.HasExited)
- {
- try
- {
- // Wait for the process to exit
- _logger.TransportWaitingForShutdown(EndpointName);
-
- // Kill the while process tree because the process may spawn child processes
- // and Node.js does not kill its children when it exits properly
- process.KillTree(_options.ShutdownTimeout);
- }
- catch (Exception ex)
- {
- _logger.TransportShutdownFailed(EndpointName, ex);
- }
- finally
- {
- process.ErrorDataReceived -= _logProcessErrors;
- process.Dispose();
- _process = null;
- }
- }
-
- if (_shutdownCts is { } shutdownCts)
- {
- await shutdownCts.CancelAsync().ConfigureAwait(false);
- shutdownCts.Dispose();
- _shutdownCts = null;
- }
-
- if (_readTask is Task readTask)
- {
- try
- {
- _logger.TransportWaitingForReadTask(EndpointName);
- await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
- }
- catch (TimeoutException)
- {
- _logger.TransportCleanupReadTaskTimeout(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportCleanupReadTaskCancelled(EndpointName);
- }
- catch (Exception ex)
- {
- _logger.TransportCleanupReadTaskFailed(EndpointName, ex);
- }
- finally
- {
- _logger.TransportReadTaskCleanedUp(EndpointName);
- _readTask = null;
- }
- }
-
- SetConnected(false);
- _logger.TransportCleanedUp(EndpointName);
- }
-}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs
index d2b51b950..0fca07e27 100644
--- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs
@@ -1,10 +1,16 @@
using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using ModelContextProtocol.Logging;
using ModelContextProtocol.Utils;
+using System.Diagnostics;
+using System.Text;
+
+#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
namespace ModelContextProtocol.Protocol.Transport;
///
-/// Implements the MCP transport protocol over standard input/output streams.
+/// Provides a client MCP transport implemented via "stdio" (standard input/output).
///
public sealed class StdioClientTransport : IClientTransport
{
@@ -13,7 +19,7 @@ public sealed class StdioClientTransport : IClientTransport
private readonly ILoggerFactory? _loggerFactory;
///
- /// Initializes a new instance of the StdioTransport class.
+ /// Initializes a new instance of the class.
///
/// Configuration options for the transport.
/// The server configuration for the transport.
@@ -31,17 +37,121 @@ public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig
///
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
- var streamTransport = new StdioClientStreamTransport(_options, _serverConfig, _loggerFactory);
+ string endpointName = $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})";
+
+ Process? process = null;
+ bool processStarted = false;
+ ILogger logger = (ILogger?)_loggerFactory?.CreateLogger() ?? NullLogger.Instance;
try
{
- await streamTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
- return streamTransport;
+ logger.TransportConnecting(endpointName);
+
+ UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false);
+
+ ProcessStartInfo startInfo = new()
+ {
+ FileName = _options.Command,
+ RedirectStandardInput = true,
+ RedirectStandardOutput = true,
+ RedirectStandardError = true,
+ UseShellExecute = false,
+ CreateNoWindow = true,
+ WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory,
+ StandardOutputEncoding = noBomUTF8,
+ StandardErrorEncoding = noBomUTF8,
+#if NET
+ StandardInputEncoding = noBomUTF8,
+#endif
+ };
+
+ if (!string.IsNullOrWhiteSpace(_options.Arguments))
+ {
+ startInfo.Arguments = _options.Arguments;
+ }
+
+ if (_options.EnvironmentVariables != null)
+ {
+ foreach (var entry in _options.EnvironmentVariables)
+ {
+ startInfo.Environment[entry.Key] = entry.Value;
+ }
+ }
+
+ logger.CreateProcessForTransport(endpointName, _options.Command,
+ startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)),
+ startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString());
+
+ process = new() { StartInfo = startInfo };
+
+ // Set up error logging
+ process.ErrorDataReceived += (sender, args) => logger.TransportError(endpointName, args.Data ?? "(no data)");
+
+ // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core,
+ // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but
+ // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks
+ // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core,
+ // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start
+ // call, to ensure it picks up the correct encoding.
+#if NET
+ processStarted = process.Start();
+#else
+ Encoding originalInputEncoding = Console.InputEncoding;
+ try
+ {
+ Console.InputEncoding = noBomUTF8;
+ processStarted = process.Start();
+ }
+ finally
+ {
+ Console.InputEncoding = originalInputEncoding;
+ }
+#endif
+
+ if (!processStarted)
+ {
+ logger.TransportProcessStartFailed(endpointName);
+ throw new McpTransportException("Failed to start MCP server process");
+ }
+
+ logger.TransportProcessStarted(endpointName, process.Id);
+
+ process.BeginErrorReadLine();
+
+ return new StdioClientSessionTransport(_options, process, endpointName, _loggerFactory);
}
- catch
+ catch (Exception ex)
+ {
+ logger.TransportConnectFailed(endpointName, ex);
+ DisposeProcess(process, processStarted, logger, _options.ShutdownTimeout, endpointName);
+ throw new McpTransportException("Failed to connect transport", ex);
+ }
+ }
+
+ internal static void DisposeProcess(
+ Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
+ {
+ if (process is not null)
{
- await streamTransport.DisposeAsync().ConfigureAwait(false);
- throw;
+ try
+ {
+ if (processStarted && !process.HasExited)
+ {
+ // Wait for the process to exit.
+ // Kill the while process tree because the process may spawn child processes
+ // and Node.js does not kill its children when it exits properly.
+ logger.TransportWaitingForShutdown(endpointName);
+ process.KillTree(shutdownTimeout);
+ }
+ }
+ catch (Exception ex)
+ {
+ logger.TransportShutdownFailed(endpointName, ex);
+ }
+ finally
+ {
+ process.Dispose();
+ }
}
}
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs
index 7779edc99..58077dbb2 100644
--- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs
+++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs
@@ -1,38 +1,15 @@
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
-using ModelContextProtocol.Logging;
-using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;
-using ModelContextProtocol.Utils.Json;
-using System.Text;
-using System.Text.Json;
namespace ModelContextProtocol.Protocol.Transport;
///
-/// Provides an implementation of the MCP transport protocol over standard input/output streams.
+/// Provides a server MCP transport implemented via "stdio" (standard input/output).
///
-public sealed class StdioServerTransport : TransportBase, ITransport
+public sealed class StdioServerTransport : StreamServerTransport
{
- private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();
-
- private readonly string _serverName;
- private readonly ILogger _logger;
-
- private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions;
- private readonly TextReader _stdInReader;
- private readonly Stream _stdOutStream;
-
- private readonly SemaphoreSlim _sendLock = new(1, 1);
- private readonly CancellationTokenSource _shutdownCts = new();
-
- private readonly Task _readLoopCompleted;
- private int _disposed = 0;
-
- private string EndpointName => $"Server (stdio) ({_serverName})";
-
///
/// Initializes a new instance of the class, using
/// and for input and output streams.
@@ -69,14 +46,6 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg
{
}
- private static string GetServerName(McpServerOptions serverOptions)
- {
- Throw.IfNull(serverOptions);
- Throw.IfNull(serverOptions.ServerInfo);
- Throw.IfNull(serverOptions.ServerInfo.Name);
- return serverOptions.ServerInfo.Name;
- }
-
///
/// Initializes a new instance of the class, using
/// and for input and output streams.
@@ -90,186 +59,20 @@ private static string GetServerName(McpServerOptions serverOptions)
/// to , as that will interfere with the transport's output.
///
///
- public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory)
- : this(serverName, stdinStream: null, stdoutStream: null, loggerFactory)
- {
- }
-
- ///
- /// Initializes a new instance of the class with explicit input/output streams.
- ///
- /// The name of the server.
- /// The input to use as standard input. If , will be used.
- /// The output to use as standard output. If , will be used.
- /// Optional logger factory used for logging employed by the transport.
- /// is .
- ///
- ///
- /// This constructor is useful for testing scenarios where you want to redirect input/output.
- ///
- ///
- public StdioServerTransport(string serverName, Stream? stdinStream = null, Stream? stdoutStream = null, ILoggerFactory? loggerFactory = null)
- : base(loggerFactory)
- {
- Throw.IfNull(serverName);
-
- _serverName = serverName;
- _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
-
- _stdInReader = new StreamReader(stdinStream ?? Console.OpenStandardInput(), Encoding.UTF8);
- _stdOutStream = stdoutStream ?? new BufferedStream(Console.OpenStandardOutput());
-
- SetConnected(true);
- _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token);
- }
-
- ///
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null)
+ : base(Console.OpenStandardInput(),
+ new BufferedStream(Console.OpenStandardOutput()),
+ serverName ?? throw new ArgumentNullException(nameof(serverName)),
+ loggerFactory)
{
- using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
-
- if (!IsConnected)
- {
- _logger.TransportNotConnected(EndpointName);
- throw new McpTransportException("Transport is not connected");
- }
-
- string id = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- id = messageWithId.Id.ToString();
- }
-
- try
- {
- _logger.TransportSendingMessage(EndpointName, id);
-
- await JsonSerializer.SerializeAsync(_stdOutStream, message, _jsonOptions.GetTypeInfo(), cancellationToken).ConfigureAwait(false);
- await _stdOutStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
- await _stdOutStream.FlushAsync(cancellationToken).ConfigureAwait(false);;
-
- _logger.TransportSentMessage(EndpointName, id);
- }
- catch (Exception ex)
- {
- _logger.TransportSendFailed(EndpointName, id, ex);
- throw new McpTransportException("Failed to send message", ex);
- }
}
- private async Task ReadMessagesAsync()
- {
- CancellationToken shutdownToken = _shutdownCts.Token;
- try
- {
- _logger.TransportEnteringReadMessagesLoop(EndpointName);
-
- while (!shutdownToken.IsCancellationRequested)
- {
- _logger.TransportWaitingForMessage(EndpointName);
-
- var line = await _stdInReader.ReadLineAsync(shutdownToken).ConfigureAwait(false);
- if (string.IsNullOrWhiteSpace(line))
- {
- if (line is null)
- {
- _logger.TransportEndOfStream(EndpointName);
- break;
- }
-
- continue;
- }
-
- _logger.TransportReceivedMessage(EndpointName, line);
- _logger.TransportMessageBytesUtf8(EndpointName, line);
-
- try
- {
- if (JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()) is { } message)
- {
- string messageId = "(no id)";
- if (message is IJsonRpcMessageWithId messageWithId)
- {
- messageId = messageWithId.Id.ToString();
- }
- _logger.TransportReceivedMessageParsed(EndpointName, messageId);
-
- await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
- _logger.TransportMessageWritten(EndpointName, messageId);
- }
- else
- {
- _logger.TransportMessageParseUnexpectedType(EndpointName, line);
- }
- }
- catch (JsonException ex)
- {
- _logger.TransportMessageParseFailed(EndpointName, line, ex);
- // Continue reading even if we fail to parse a message
- }
- }
-
- _logger.TransportExitingReadMessagesLoop(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportReadMessagesCancelled(EndpointName);
- }
- catch (Exception ex)
- {
- _logger.TransportReadMessagesFailed(EndpointName, ex);
- }
- finally
- {
- SetConnected(false);
- }
- }
-
- ///
- public override async ValueTask DisposeAsync()
+ private static string GetServerName(McpServerOptions serverOptions)
{
- if (Interlocked.Exchange(ref _disposed, 1) != 0)
- {
- return;
- }
-
- try
- {
- _logger.TransportCleaningUp(EndpointName);
-
- // Signal to the stdin reading loop to stop.
- await _shutdownCts.CancelAsync().ConfigureAwait(false);
- _shutdownCts.Dispose();
-
- // Dispose of stdin/out. Cancellation may not be able to wake up operations
- // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor.
- _stdInReader?.Dispose();
- _stdOutStream?.Dispose();
+ Throw.IfNull(serverOptions);
+ Throw.IfNull(serverOptions.ServerInfo);
+ Throw.IfNull(serverOptions.ServerInfo.Name);
- // Make sure the work has quiesced.
- try
- {
- _logger.TransportWaitingForReadTask(EndpointName);
- await _readLoopCompleted.ConfigureAwait(false);
- _logger.TransportReadTaskCleanedUp(EndpointName);
- }
- catch (TimeoutException)
- {
- _logger.TransportCleanupReadTaskTimeout(EndpointName);
- }
- catch (OperationCanceledException)
- {
- _logger.TransportCleanupReadTaskCancelled(EndpointName);
- }
- catch (Exception ex)
- {
- _logger.TransportCleanupReadTaskFailed(EndpointName, ex);
- }
- }
- finally
- {
- SetConnected(false);
- _logger.TransportCleanedUp(EndpointName);
- }
+ return serverOptions.ServerInfo.Name;
}
}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs
new file mode 100644
index 000000000..8359afa65
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs
@@ -0,0 +1,187 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
+using System.Text.Json;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+/// Provides the client side of a stream-based session transport.
+internal class StreamClientSessionTransport : TransportBase
+{
+ private readonly TextReader _serverOutput;
+ private readonly TextWriter _serverInput;
+ private readonly SemaphoreSlim _sendLock = new(1, 1);
+ private CancellationTokenSource? _shutdownCts = new();
+ private Task? _readTask;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ public StreamClientSessionTransport(
+ TextWriter serverInput, TextReader serverOutput, string endpointName, ILoggerFactory? loggerFactory)
+ : base(loggerFactory)
+ {
+ Logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
+ _serverOutput = serverOutput;
+ _serverInput = serverInput;
+ EndpointName = endpointName;
+
+ // Start reading messages in the background
+ Logger.TransportReadingMessages(endpointName);
+ _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None);
+
+ SetConnected(true);
+ }
+
+ protected ILogger Logger { get; private set; }
+
+ protected string EndpointName { get; }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ if (!IsConnected)
+ {
+ Logger.TransportNotConnected(EndpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ string id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)));
+
+ using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ Logger.TransportSendingMessage(EndpointName, id, json);
+ Logger.TransportMessageBytesUtf8(EndpointName, json);
+
+ // Write the message followed by a newline using our UTF-8 writer
+ await _serverInput.WriteLineAsync(json).ConfigureAwait(false);
+ await _serverInput.FlushAsync(cancellationToken).ConfigureAwait(false);
+
+ Logger.TransportSentMessage(EndpointName, id);
+ }
+ catch (Exception ex)
+ {
+ Logger.TransportSendFailed(EndpointName, id, ex);
+ throw new McpTransportException("Failed to send message", ex);
+ }
+ }
+
+ ///
+ public override ValueTask DisposeAsync() =>
+ CleanupAsync(CancellationToken.None);
+
+ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
+ {
+ try
+ {
+ Logger.TransportEnteringReadMessagesLoop(EndpointName);
+
+ while (!cancellationToken.IsCancellationRequested)
+ {
+ Logger.TransportWaitingForMessage(EndpointName);
+ if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line)
+ {
+ Logger.TransportEndOfStream(EndpointName);
+ break;
+ }
+
+ if (string.IsNullOrWhiteSpace(line))
+ {
+ continue;
+ }
+
+ Logger.TransportReceivedMessage(EndpointName, line);
+ Logger.TransportMessageBytesUtf8(EndpointName, line);
+
+ await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
+ }
+ Logger.TransportExitingReadMessagesLoop(EndpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ Logger.TransportReadMessagesCancelled(EndpointName);
+ }
+ catch (Exception ex)
+ {
+ Logger.TransportReadMessagesFailed(EndpointName, ex);
+ }
+ finally
+ {
+ await CleanupAsync(cancellationToken).ConfigureAwait(false);
+ }
+ }
+
+ private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
+ {
+ try
+ {
+ var message = (IJsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)));
+ if (message != null)
+ {
+ string messageId = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ messageId = messageWithId.Id.ToString();
+ }
+
+ Logger.TransportReceivedMessageParsed(EndpointName, messageId);
+ await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
+ Logger.TransportMessageWritten(EndpointName, messageId);
+ }
+ else
+ {
+ Logger.TransportMessageParseUnexpectedType(EndpointName, line);
+ }
+ }
+ catch (JsonException ex)
+ {
+ Logger.TransportMessageParseFailed(EndpointName, line, ex);
+ }
+ }
+
+ protected virtual async ValueTask CleanupAsync(CancellationToken cancellationToken)
+ {
+ Logger.TransportCleaningUp(EndpointName);
+
+ if (Interlocked.Exchange(ref _shutdownCts, null) is { } shutdownCts)
+ {
+ await shutdownCts.CancelAsync().ConfigureAwait(false);
+ shutdownCts.Dispose();
+ }
+
+ if (Interlocked.Exchange(ref _readTask, null) is Task readTask)
+ {
+ try
+ {
+ Logger.TransportWaitingForReadTask(EndpointName);
+ await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
+ Logger.TransportReadTaskCleanedUp(EndpointName);
+ }
+ catch (TimeoutException)
+ {
+ Logger.TransportCleanupReadTaskTimeout(EndpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ Logger.TransportCleanupReadTaskCancelled(EndpointName);
+ }
+ catch (Exception ex)
+ {
+ Logger.TransportCleanupReadTaskFailed(EndpointName, ex);
+ }
+ }
+
+ SetConnected(false);
+ Logger.TransportCleanedUp(EndpointName);
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs
new file mode 100644
index 000000000..80bd61df5
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs
@@ -0,0 +1,47 @@
+using Microsoft.Extensions.Logging;
+using ModelContextProtocol.Utils;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Provides a client MCP transport implemented around a pair of input/output streams.
+///
+public sealed class StreamClientTransport : IClientTransport
+{
+ private readonly Stream _serverInput;
+ private readonly Stream _serverOutput;
+ private readonly ILoggerFactory? _loggerFactory;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ ///
+ /// The stream representing the connected server's input.
+ /// Writes to this stream will be sent to the server.
+ ///
+ ///
+ /// The stream representing the connected server's output.
+ /// Reads from this stream will receive messages from the server.
+ ///
+ /// A logger factory for creating loggers.
+ public StreamClientTransport(
+ Stream serverInput, Stream serverOutput, ILoggerFactory? loggerFactory = null)
+ {
+ Throw.IfNull(serverInput);
+ Throw.IfNull(serverOutput);
+
+ _serverInput = serverInput;
+ _serverOutput = serverOutput;
+ _loggerFactory = loggerFactory;
+ }
+
+ ///
+ public Task ConnectAsync(CancellationToken cancellationToken = default)
+ {
+ return Task.FromResult(new StreamClientSessionTransport(
+ new StreamWriter(_serverInput),
+ new StreamReader(_serverOutput),
+ "Client (stream)",
+ _loggerFactory));
+ }
+}
diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs
new file mode 100644
index 000000000..ebdf36350
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs
@@ -0,0 +1,209 @@
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using ModelContextProtocol.Logging;
+using ModelContextProtocol.Protocol.Messages;
+using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
+using System.IO.Pipelines;
+using System.Text;
+using System.Text.Json;
+
+namespace ModelContextProtocol.Protocol.Transport;
+
+///
+/// Provides a server MCP transport implemented around a pair of input/output streams.
+///
+public class StreamServerTransport : TransportBase, ITransport
+{
+ private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();
+
+ private readonly ILogger _logger;
+
+ private readonly TextReader _inputReader;
+ private readonly Stream _outputStream;
+ private readonly string _endpointName;
+
+ private readonly SemaphoreSlim _sendLock = new(1, 1);
+ private readonly CancellationTokenSource _shutdownCts = new();
+
+ private readonly Task _readLoopCompleted;
+ private int _disposed = 0;
+
+ ///
+ /// Initializes a new instance of the class with explicit input/output streams.
+ ///
+ /// The input to use as standard input.
+ /// The output to use as standard output.
+ /// Optional name of the server, used for diagnostic purposes, like logging.
+ /// Optional logger factory used for logging employed by the transport.
+ /// is .
+ /// is .
+ public StreamServerTransport(Stream inputStream, Stream outputStream, string? serverName = null, ILoggerFactory? loggerFactory = null)
+ : base(loggerFactory)
+ {
+ Throw.IfNull(inputStream);
+ Throw.IfNull(outputStream);
+
+ _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance;
+
+ _inputReader = new StreamReader(inputStream, Encoding.UTF8);
+ _outputStream = outputStream;
+
+ SetConnected(true);
+ _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token);
+
+ _endpointName = serverName is not null ? $"Server (stream) ({serverName})" : "Server (stream)";
+ }
+
+ ///
+ public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ {
+ if (!IsConnected)
+ {
+ _logger.TransportNotConnected(_endpointName);
+ throw new McpTransportException("Transport is not connected");
+ }
+
+ using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
+
+ string id = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ id = messageWithId.Id.ToString();
+ }
+
+ try
+ {
+ _logger.TransportSendingMessage(_endpointName, id);
+
+ await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false);
+ await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
+ await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);;
+
+ _logger.TransportSentMessage(_endpointName, id);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportSendFailed(_endpointName, id, ex);
+ throw new McpTransportException("Failed to send message", ex);
+ }
+ }
+
+ private async Task ReadMessagesAsync()
+ {
+ CancellationToken shutdownToken = _shutdownCts.Token;
+ try
+ {
+ _logger.TransportEnteringReadMessagesLoop(_endpointName);
+
+ while (!shutdownToken.IsCancellationRequested)
+ {
+ _logger.TransportWaitingForMessage(_endpointName);
+
+ var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false);
+ if (string.IsNullOrWhiteSpace(line))
+ {
+ if (line is null)
+ {
+ _logger.TransportEndOfStream(_endpointName);
+ break;
+ }
+
+ continue;
+ }
+
+ _logger.TransportReceivedMessage(_endpointName, line);
+ _logger.TransportMessageBytesUtf8(_endpointName, line);
+
+ try
+ {
+ if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message)
+ {
+ string messageId = "(no id)";
+ if (message is IJsonRpcMessageWithId messageWithId)
+ {
+ messageId = messageWithId.Id.ToString();
+ }
+ _logger.TransportReceivedMessageParsed(_endpointName, messageId);
+
+ await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false);
+ _logger.TransportMessageWritten(_endpointName, messageId);
+ }
+ else
+ {
+ _logger.TransportMessageParseUnexpectedType(_endpointName, line);
+ }
+ }
+ catch (JsonException ex)
+ {
+ _logger.TransportMessageParseFailed(_endpointName, line, ex);
+ // Continue reading even if we fail to parse a message
+ }
+ }
+
+ _logger.TransportExitingReadMessagesLoop(_endpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ _logger.TransportReadMessagesCancelled(_endpointName);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportReadMessagesFailed(_endpointName, ex);
+ }
+ finally
+ {
+ SetConnected(false);
+ }
+ }
+
+ ///
+ public override async ValueTask DisposeAsync()
+ {
+ if (Interlocked.Exchange(ref _disposed, 1) != 0)
+ {
+ return;
+ }
+
+ try
+ {
+ _logger.TransportCleaningUp(_endpointName);
+
+ // Signal to the stdin reading loop to stop.
+ await _shutdownCts.CancelAsync().ConfigureAwait(false);
+ _shutdownCts.Dispose();
+
+ // Dispose of stdin/out. Cancellation may not be able to wake up operations
+ // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor.
+ _inputReader?.Dispose();
+ _outputStream?.Dispose();
+
+ // Make sure the work has quiesced.
+ try
+ {
+ _logger.TransportWaitingForReadTask(_endpointName);
+ await _readLoopCompleted.ConfigureAwait(false);
+ _logger.TransportReadTaskCleanedUp(_endpointName);
+ }
+ catch (TimeoutException)
+ {
+ _logger.TransportCleanupReadTaskTimeout(_endpointName);
+ }
+ catch (OperationCanceledException)
+ {
+ _logger.TransportCleanupReadTaskCancelled(_endpointName);
+ }
+ catch (Exception ex)
+ {
+ _logger.TransportCleanupReadTaskFailed(_endpointName, ex);
+ }
+ }
+ finally
+ {
+ SetConnected(false);
+ _logger.TransportCleanedUp(_endpointName);
+ }
+
+ GC.SuppressFinalize(this);
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs
index 710679e9f..3a2f2ab77 100644
--- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs
+++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs
@@ -23,7 +23,7 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper)
sc.AddSingleton(LoggerFactory);
sc.AddMcpServer().WithStdioServerTransport();
// Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport.
- sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()));
+ sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()));
for (int f = 0; f < 10; f++)
{
string name = $"Method{f}";
@@ -53,19 +53,17 @@ public async ValueTask DisposeAsync()
private async Task CreateMcpClientForServer()
{
- var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream());
- var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream());
-
- var serverConfig = new McpServerConfig()
- {
- Id = "TestServer",
- Name = "TestServer",
- TransportType = "ignored",
- };
-
return await McpClientFactory.CreateAsync(
- serverConfig,
- createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory),
+ new McpServerConfig()
+ {
+ Id = "TestServer",
+ Name = "TestServer",
+ TransportType = "ignored",
+ },
+ createTransportFunc: (_, _) => new StreamClientTransport(
+ serverInput: _clientToServerPipe.Writer.AsStream(),
+ serverOutput: _serverToClientPipe.Reader.AsStream(),
+ LoggerFactory),
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);
}
diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs
index 846ceebcc..6168c6648 100644
--- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs
+++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs
@@ -30,7 +30,7 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper
sc.AddSingleton(LoggerFactory);
_builder = sc.AddMcpServer().WithStdioServerTransport().WithPrompts();
// Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport.
- sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory));
+ sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory));
sc.AddSingleton(new ObjectWithId());
_serviceProvider = sc.BuildServiceProvider();
@@ -55,19 +55,17 @@ public async ValueTask DisposeAsync()
private async Task CreateMcpClientForServer()
{
- var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream());
- var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream());
-
- var serverConfig = new McpServerConfig()
- {
- Id = "TestServer",
- Name = "TestServer",
- TransportType = "ignored",
- };
-
return await McpClientFactory.CreateAsync(
- serverConfig,
- createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory),
+ new McpServerConfig()
+ {
+ Id = "TestServer",
+ Name = "TestServer",
+ TransportType = "ignored",
+ },
+ createTransportFunc: (_, _) => new StreamClientTransport(
+ serverInput: _clientToServerPipe.Writer.AsStream(),
+ serverOutput: _serverToClientPipe.Reader.AsStream(),
+ LoggerFactory),
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);
}
diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs
index 73fdeadef..c615e12f1 100644
--- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs
+++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs
@@ -34,7 +34,7 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper)
sc.AddSingleton(LoggerFactory);
_builder = sc.AddMcpServer().WithStdioServerTransport().WithTools();
// Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport.
- sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory));
+ sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory));
sc.AddSingleton(new ObjectWithId());
_serviceProvider = sc.BuildServiceProvider();
@@ -59,19 +59,17 @@ public async ValueTask DisposeAsync()
private async Task CreateMcpClientForServer()
{
- var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream());
- var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream());
-
- var serverConfig = new McpServerConfig()
- {
- Id = "TestServer",
- Name = "TestServer",
- TransportType = "ignored",
- };
-
return await McpClientFactory.CreateAsync(
- serverConfig,
- createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory),
+ new McpServerConfig()
+ {
+ Id = "TestServer",
+ Name = "TestServer",
+ TransportType = "ignored",
+ },
+ createTransportFunc: (_, _) => new StreamClientTransport(
+ serverInput: _clientToServerPipe.Writer.AsStream(),
+ _serverToClientPipe.Reader.AsStream(),
+ LoggerFactory),
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);
}
@@ -117,23 +115,21 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
var stdinPipe = new Pipe();
var stdoutPipe = new Pipe();
- await using var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream());
+ await using var transport = new StreamServerTransport(stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream());
await using var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider);
var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken);
- using var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream());
- using var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream());
-
- var serverConfig = new McpServerConfig()
- {
- Id = $"TestServer_{i}",
- Name = $"TestServer_{i}",
- TransportType = "ignored",
- };
-
await using (var client = await McpClientFactory.CreateAsync(
- serverConfig,
- createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory),
+ new McpServerConfig()
+ {
+ Id = $"TestServer_{i}",
+ Name = $"TestServer_{i}",
+ TransportType = "ignored",
+ },
+ createTransportFunc: (_, _) => new StreamClientTransport(
+ serverInput: stdinPipe.Writer.AsStream(),
+ serverOutput: stdoutPipe.Reader.AsStream(),
+ LoggerFactory),
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken))
{
diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs
index 5857f3c4d..6a0bd98a1 100644
--- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs
+++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs
@@ -55,7 +55,7 @@ public void Constructor_Throws_For_Null_Options()
[Fact]
public async Task Should_Start_In_Connected_State()
{
- await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), Stream.Null, LoggerFactory);
+ await using var transport = new StreamServerTransport(new Pipe().Reader.AsStream(), Stream.Null, loggerFactory: LoggerFactory);
Assert.True(transport.IsConnected);
}
@@ -65,11 +65,10 @@ public async Task SendMessageAsync_Should_Send_Message()
{
using var output = new MemoryStream();
- await using var transport = new StdioServerTransport(
- _serverOptions.ServerInfo.Name,
+ await using var transport = new StreamServerTransport(
new Pipe().Reader.AsStream(),
output,
- LoggerFactory);
+ loggerFactory: LoggerFactory);
// Verify transport is connected
Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync");
@@ -87,7 +86,7 @@ public async Task SendMessageAsync_Should_Send_Message()
[Fact]
public async Task DisposeAsync_Should_Dispose_Resources()
{
- await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory);
+ await using var transport = new StreamServerTransport(Stream.Null, Stream.Null, loggerFactory: LoggerFactory);
await transport.DisposeAsync();
@@ -104,11 +103,10 @@ public async Task ReadMessagesAsync_Should_Read_Messages()
Pipe pipe = new();
using var input = pipe.Reader.AsStream();
- await using var transport = new StdioServerTransport(
- _serverOptions.ServerInfo.Name,
+ await using var transport = new StreamServerTransport(
input,
Stream.Null,
- LoggerFactory);
+ loggerFactory: LoggerFactory);
// Verify transport is connected
Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync");
@@ -128,7 +126,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages()
[Fact]
public async Task CleanupAsync_Should_Cleanup_Resources()
{
- var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory);
+ var transport = new StreamServerTransport(Stream.Null, Stream.Null, loggerFactory: LoggerFactory);
await transport.DisposeAsync();
@@ -141,11 +139,10 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters()
// Use a reader that won't terminate
using var output = new MemoryStream();
- await using var transport = new StdioServerTransport(
- _serverOptions.ServerInfo.Name,
+ await using var transport = new StreamServerTransport(
new Pipe().Reader.AsStream(),
output,
- LoggerFactory);
+ loggerFactory: LoggerFactory);
// Verify transport is connected
Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync");
diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs
deleted file mode 100644
index d41f0b979..000000000
--- a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs
+++ /dev/null
@@ -1,79 +0,0 @@
-using Microsoft.Extensions.Logging;
-using ModelContextProtocol.Protocol.Messages;
-using ModelContextProtocol.Protocol.Transport;
-using ModelContextProtocol.Utils.Json;
-using System.Text.Json;
-
-namespace ModelContextProtocol.Tests.Transport;
-
-internal sealed class StreamClientTransport : TransportBase, IClientTransport
-{
- private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions;
- private readonly Task? _readTask;
- private readonly CancellationTokenSource _shutdownCts = new CancellationTokenSource();
- private readonly TextReader _serverStdoutReader;
- private readonly TextWriter _serverStdinWriter;
-
- public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdoutReader, ILoggerFactory loggerFactory)
- : base(loggerFactory)
- {
- _serverStdoutReader = serverStdoutReader;
- _serverStdinWriter = serverStdinWriter;
- _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None);
- SetConnected(true);
- }
-
- public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this);
-
- public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
- {
- string id = message is IJsonRpcMessageWithId messageWithId ?
- messageWithId.Id.ToString() :
- "(no id)";
-
- await _serverStdinWriter.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false);
- await _serverStdinWriter.FlushAsync(cancellationToken).ConfigureAwait(false);
- }
-
- private async Task ReadMessagesAsync(CancellationToken cancellationToken)
- {
- try
- {
- while (await _serverStdoutReader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line)
- {
- if (!string.IsNullOrWhiteSpace(line))
- {
- try
- {
- if (JsonSerializer.Deserialize(line.Trim(), _jsonOptions) is { } message)
- {
- await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
- }
- }
- catch (JsonException)
- {
- }
- }
- }
- }
- catch (OperationCanceledException)
- {
- }
- }
-
- public override async ValueTask DisposeAsync()
- {
- if (_shutdownCts is { } shutdownCts)
- {
- await shutdownCts.CancelAsync().ConfigureAwait(false);
- shutdownCts.Dispose();
- }
-
- if (_readTask is Task readTask)
- {
- await readTask.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(false);
- }
-
- SetConnected(false);
- }
-}