From 74c64be7d051916dd72a0cdee20caecd8753345d Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 31 Mar 2025 18:27:50 -0400 Subject: [PATCH] Split StdioClient/ServerTransports into Stdio : Stream --- .../Transport/StdioClientSessionTransport.cs | 40 +++ .../Transport/StdioClientStreamTransport.cs | 326 ------------------ .../Transport/StdioClientTransport.cs | 126 ++++++- .../Transport/StdioServerTransport.cs | 223 +----------- .../Transport/StreamClientSessionTransport.cs | 187 ++++++++++ .../Transport/StreamClientTransport.cs | 47 +++ .../Transport/StreamServerTransport.cs | 209 +++++++++++ .../Client/McpClientExtensionsTests.cs | 24 +- .../McpServerBuilderExtensionsPromptsTests.cs | 24 +- .../McpServerBuilderExtensionsToolsTests.cs | 48 ++- .../Transport/StdioServerTransportTests.cs | 21 +- .../Transport/StreamClientTransport.cs | 79 ----- 12 files changed, 667 insertions(+), 687 deletions(-) create mode 100644 src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs delete mode 100644 src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs create mode 100644 src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs create mode 100644 src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs create mode 100644 src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs delete mode 100644 tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs new file mode 100644 index 000000000..af304c897 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -0,0 +1,40 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using System.Diagnostics; + +namespace ModelContextProtocol.Protocol.Transport; + +/// Provides the client side of a stdio-based session transport. +internal sealed class StdioClientSessionTransport : StreamClientSessionTransport +{ + private readonly StdioClientTransportOptions _options; + private readonly Process _process; + + public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, ILoggerFactory? loggerFactory) + : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory) + { + _process = process; + _options = options; + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (_process.HasExited) + { + Logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + } + + /// + protected override ValueTask CleanupAsync(CancellationToken cancellationToken) + { + StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName); + + return base.CleanupAsync(cancellationToken); + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs deleted file mode 100644 index 35c957e52..000000000 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs +++ /dev/null @@ -1,326 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using System.Diagnostics; -using System.Text; -using System.Text.Json; - -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// Implements the MCP transport protocol over standard input/output streams. -/// -internal sealed class StdioClientStreamTransport : TransportBase -{ - private readonly StdioClientTransportOptions _options; - private readonly McpServerConfig _serverConfig; - private readonly ILogger _logger; - private readonly JsonSerializerOptions _jsonOptions; - private readonly DataReceivedEventHandler _logProcessErrors; - private readonly SemaphoreSlim _sendLock = new(1, 1); - private Process? _process; - private Task? _readTask; - private CancellationTokenSource? _shutdownCts; - private bool _processStarted; - - private string EndpointName => $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})"; - - /// - /// Initializes a new instance of the StdioTransport class. - /// - /// Configuration options for the transport. - /// The server configuration for the transport. - /// A logger factory for creating loggers. - public StdioClientStreamTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null) - : base(loggerFactory) - { - Throw.IfNull(options); - Throw.IfNull(serverConfig); - - _options = options; - _serverConfig = serverConfig; - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - _logProcessErrors = (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); - _jsonOptions = McpJsonUtilities.DefaultOptions; - } - - /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) - { - if (IsConnected) - { - _logger.TransportAlreadyConnected(EndpointName); - throw new McpTransportException("Transport is already connected"); - } - - try - { - _logger.TransportConnecting(EndpointName); - - _shutdownCts = new CancellationTokenSource(); - - UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); - - var startInfo = new ProcessStartInfo - { - FileName = _options.Command, - RedirectStandardInput = true, - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false, - CreateNoWindow = true, - WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, - StandardOutputEncoding = noBomUTF8, - StandardErrorEncoding = noBomUTF8, -#if NET - StandardInputEncoding = noBomUTF8, -#endif - }; - - if (!string.IsNullOrWhiteSpace(_options.Arguments)) - { - startInfo.Arguments = _options.Arguments; - } - - if (_options.EnvironmentVariables != null) - { - foreach (var entry in _options.EnvironmentVariables) - { - startInfo.Environment[entry.Key] = entry.Value; - } - } - - _logger.CreateProcessForTransport(EndpointName, _options.Command, - startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)), - startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString()); - - _process = new Process { StartInfo = startInfo }; - - // Set up error logging - _process.ErrorDataReceived += _logProcessErrors; - - // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, - // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but - // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks - // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core, - // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start - // call, to ensure it picks up the correct encoding. -#if NET - _processStarted = _process.Start(); -#else - Encoding originalInputEncoding = Console.InputEncoding; - try - { - Console.InputEncoding = noBomUTF8; - _processStarted = _process.Start(); - } - finally - { - Console.InputEncoding = originalInputEncoding; - } -#endif - - if (!_processStarted) - { - _logger.TransportProcessStartFailed(EndpointName); - throw new McpTransportException("Failed to start MCP server process"); - } - - _logger.TransportProcessStarted(EndpointName, _process.Id); - - _process.BeginErrorReadLine(); - - // Start reading messages in the background - _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); - _logger.TransportReadingMessages(EndpointName); - - SetConnected(true); - } - catch (Exception ex) - { - _logger.TransportConnectFailed(EndpointName, ex); - await CleanupAsync(cancellationToken).ConfigureAwait(false); - throw new McpTransportException("Failed to connect transport", ex); - } - } - - /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); - - if (!IsConnected || _process?.HasExited == true) - { - _logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); - _logger.TransportSendingMessage(EndpointName, id, json); - _logger.TransportMessageBytesUtf8(EndpointName, json); - - // Write the message followed by a newline using our UTF-8 writer - await _process!.StandardInput.WriteLineAsync(json).ConfigureAwait(false); - await _process.StandardInput.FlushAsync(cancellationToken).ConfigureAwait(false); - - _logger.TransportSentMessage(EndpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(EndpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); - } - } - - /// - public override async ValueTask DisposeAsync() - { - await CleanupAsync(CancellationToken.None).ConfigureAwait(false); - } - - private async Task ReadMessagesAsync(CancellationToken cancellationToken) - { - try - { - _logger.TransportEnteringReadMessagesLoop(EndpointName); - - while (!cancellationToken.IsCancellationRequested && !_process!.HasExited) - { - _logger.TransportWaitingForMessage(EndpointName); - var line = await _process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false); - if (line == null) - { - _logger.TransportEndOfStream(EndpointName); - break; - } - - if (string.IsNullOrWhiteSpace(line)) - { - continue; - } - - _logger.TransportReceivedMessage(EndpointName, line); - _logger.TransportMessageBytesUtf8(EndpointName, line); - - await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); - } - _logger.TransportExitingReadMessagesLoop(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportReadMessagesCancelled(EndpointName); - // Normal shutdown - } - catch (Exception ex) - { - _logger.TransportReadMessagesFailed(EndpointName, ex); - } - finally - { - await CleanupAsync(cancellationToken).ConfigureAwait(false); - } - } - - private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) - { - try - { - line=line.Trim();//Fixes an error when the service prefixes nonprintable characters - var message = JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()); - if (message != null) - { - string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); - } - else - { - _logger.TransportMessageParseUnexpectedType(EndpointName, line); - } - } - catch (JsonException ex) - { - _logger.TransportMessageParseFailed(EndpointName, line, ex); - } - } - - private async Task CleanupAsync(CancellationToken cancellationToken) - { - _logger.TransportCleaningUp(EndpointName); - - if (_process is Process process && _processStarted && !process.HasExited) - { - try - { - // Wait for the process to exit - _logger.TransportWaitingForShutdown(EndpointName); - - // Kill the while process tree because the process may spawn child processes - // and Node.js does not kill its children when it exits properly - process.KillTree(_options.ShutdownTimeout); - } - catch (Exception ex) - { - _logger.TransportShutdownFailed(EndpointName, ex); - } - finally - { - process.ErrorDataReceived -= _logProcessErrors; - process.Dispose(); - _process = null; - } - } - - if (_shutdownCts is { } shutdownCts) - { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - _shutdownCts = null; - } - - if (_readTask is Task readTask) - { - try - { - _logger.TransportWaitingForReadTask(EndpointName); - await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); - } - catch (TimeoutException) - { - _logger.TransportCleanupReadTaskTimeout(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportCleanupReadTaskCancelled(EndpointName); - } - catch (Exception ex) - { - _logger.TransportCleanupReadTaskFailed(EndpointName, ex); - } - finally - { - _logger.TransportReadTaskCleanedUp(EndpointName); - _readTask = null; - } - } - - SetConnected(false); - _logger.TransportCleanedUp(EndpointName); - } -} diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index d2b51b950..0fca07e27 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -1,10 +1,16 @@ using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; using ModelContextProtocol.Utils; +using System.Diagnostics; +using System.Text; + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously namespace ModelContextProtocol.Protocol.Transport; /// -/// Implements the MCP transport protocol over standard input/output streams. +/// Provides a client MCP transport implemented via "stdio" (standard input/output). /// public sealed class StdioClientTransport : IClientTransport { @@ -13,7 +19,7 @@ public sealed class StdioClientTransport : IClientTransport private readonly ILoggerFactory? _loggerFactory; /// - /// Initializes a new instance of the StdioTransport class. + /// Initializes a new instance of the class. /// /// Configuration options for the transport. /// The server configuration for the transport. @@ -31,17 +37,121 @@ public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - var streamTransport = new StdioClientStreamTransport(_options, _serverConfig, _loggerFactory); + string endpointName = $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})"; + + Process? process = null; + bool processStarted = false; + ILogger logger = (ILogger?)_loggerFactory?.CreateLogger() ?? NullLogger.Instance; try { - await streamTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - return streamTransport; + logger.TransportConnecting(endpointName); + + UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); + + ProcessStartInfo startInfo = new() + { + FileName = _options.Command, + RedirectStandardInput = true, + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true, + WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, + StandardOutputEncoding = noBomUTF8, + StandardErrorEncoding = noBomUTF8, +#if NET + StandardInputEncoding = noBomUTF8, +#endif + }; + + if (!string.IsNullOrWhiteSpace(_options.Arguments)) + { + startInfo.Arguments = _options.Arguments; + } + + if (_options.EnvironmentVariables != null) + { + foreach (var entry in _options.EnvironmentVariables) + { + startInfo.Environment[entry.Key] = entry.Value; + } + } + + logger.CreateProcessForTransport(endpointName, _options.Command, + startInfo.Arguments, string.Join(", ", startInfo.Environment.Select(kvp => kvp.Key + "=" + kvp.Value)), + startInfo.WorkingDirectory, _options.ShutdownTimeout.ToString()); + + process = new() { StartInfo = startInfo }; + + // Set up error logging + process.ErrorDataReceived += (sender, args) => logger.TransportError(endpointName, args.Data ?? "(no data)"); + + // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, + // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but + // StandardInputEncoding doesn't exist on .NET Framework; instead, it always picks + // up the encoding from Console.InputEncoding. As such, when not targeting .NET Core, + // we temporarily change Console.InputEncoding to no-BOM UTF-8 around the Process.Start + // call, to ensure it picks up the correct encoding. +#if NET + processStarted = process.Start(); +#else + Encoding originalInputEncoding = Console.InputEncoding; + try + { + Console.InputEncoding = noBomUTF8; + processStarted = process.Start(); + } + finally + { + Console.InputEncoding = originalInputEncoding; + } +#endif + + if (!processStarted) + { + logger.TransportProcessStartFailed(endpointName); + throw new McpTransportException("Failed to start MCP server process"); + } + + logger.TransportProcessStarted(endpointName, process.Id); + + process.BeginErrorReadLine(); + + return new StdioClientSessionTransport(_options, process, endpointName, _loggerFactory); } - catch + catch (Exception ex) + { + logger.TransportConnectFailed(endpointName, ex); + DisposeProcess(process, processStarted, logger, _options.ShutdownTimeout, endpointName); + throw new McpTransportException("Failed to connect transport", ex); + } + } + + internal static void DisposeProcess( + Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName) + { + if (process is not null) { - await streamTransport.DisposeAsync().ConfigureAwait(false); - throw; + try + { + if (processStarted && !process.HasExited) + { + // Wait for the process to exit. + // Kill the while process tree because the process may spawn child processes + // and Node.js does not kill its children when it exits properly. + logger.TransportWaitingForShutdown(endpointName); + process.KillTree(shutdownTimeout); + } + } + catch (Exception ex) + { + logger.TransportShutdownFailed(endpointName, ex); + } + finally + { + process.Dispose(); + } } } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs index 7779edc99..58077dbb2 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs @@ -1,38 +1,15 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol.Logging; -using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; -using System.Text; -using System.Text.Json; namespace ModelContextProtocol.Protocol.Transport; /// -/// Provides an implementation of the MCP transport protocol over standard input/output streams. +/// Provides a server MCP transport implemented via "stdio" (standard input/output). /// -public sealed class StdioServerTransport : TransportBase, ITransport +public sealed class StdioServerTransport : StreamServerTransport { - private static readonly byte[] s_newlineBytes = "\n"u8.ToArray(); - - private readonly string _serverName; - private readonly ILogger _logger; - - private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions; - private readonly TextReader _stdInReader; - private readonly Stream _stdOutStream; - - private readonly SemaphoreSlim _sendLock = new(1, 1); - private readonly CancellationTokenSource _shutdownCts = new(); - - private readonly Task _readLoopCompleted; - private int _disposed = 0; - - private string EndpointName => $"Server (stdio) ({_serverName})"; - /// /// Initializes a new instance of the class, using /// and for input and output streams. @@ -69,14 +46,6 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg { } - private static string GetServerName(McpServerOptions serverOptions) - { - Throw.IfNull(serverOptions); - Throw.IfNull(serverOptions.ServerInfo); - Throw.IfNull(serverOptions.ServerInfo.Name); - return serverOptions.ServerInfo.Name; - } - /// /// Initializes a new instance of the class, using /// and for input and output streams. @@ -90,186 +59,20 @@ private static string GetServerName(McpServerOptions serverOptions) /// to , as that will interfere with the transport's output. /// /// - public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory) - : this(serverName, stdinStream: null, stdoutStream: null, loggerFactory) - { - } - - /// - /// Initializes a new instance of the class with explicit input/output streams. - /// - /// The name of the server. - /// The input to use as standard input. If , will be used. - /// The output to use as standard output. If , will be used. - /// Optional logger factory used for logging employed by the transport. - /// is . - /// - /// - /// This constructor is useful for testing scenarios where you want to redirect input/output. - /// - /// - public StdioServerTransport(string serverName, Stream? stdinStream = null, Stream? stdoutStream = null, ILoggerFactory? loggerFactory = null) - : base(loggerFactory) - { - Throw.IfNull(serverName); - - _serverName = serverName; - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - - _stdInReader = new StreamReader(stdinStream ?? Console.OpenStandardInput(), Encoding.UTF8); - _stdOutStream = stdoutStream ?? new BufferedStream(Console.OpenStandardOutput()); - - SetConnected(true); - _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); - } - - /// - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null) + : base(Console.OpenStandardInput(), + new BufferedStream(Console.OpenStandardOutput()), + serverName ?? throw new ArgumentNullException(nameof(serverName)), + loggerFactory) { - using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); - - if (!IsConnected) - { - _logger.TransportNotConnected(EndpointName); - throw new McpTransportException("Transport is not connected"); - } - - string id = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - id = messageWithId.Id.ToString(); - } - - try - { - _logger.TransportSendingMessage(EndpointName, id); - - await JsonSerializer.SerializeAsync(_stdOutStream, message, _jsonOptions.GetTypeInfo(), cancellationToken).ConfigureAwait(false); - await _stdOutStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); - await _stdOutStream.FlushAsync(cancellationToken).ConfigureAwait(false);; - - _logger.TransportSentMessage(EndpointName, id); - } - catch (Exception ex) - { - _logger.TransportSendFailed(EndpointName, id, ex); - throw new McpTransportException("Failed to send message", ex); - } } - private async Task ReadMessagesAsync() - { - CancellationToken shutdownToken = _shutdownCts.Token; - try - { - _logger.TransportEnteringReadMessagesLoop(EndpointName); - - while (!shutdownToken.IsCancellationRequested) - { - _logger.TransportWaitingForMessage(EndpointName); - - var line = await _stdInReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); - if (string.IsNullOrWhiteSpace(line)) - { - if (line is null) - { - _logger.TransportEndOfStream(EndpointName); - break; - } - - continue; - } - - _logger.TransportReceivedMessage(EndpointName, line); - _logger.TransportMessageBytesUtf8(EndpointName, line); - - try - { - if (JsonSerializer.Deserialize(line, _jsonOptions.GetTypeInfo()) is { } message) - { - string messageId = "(no id)"; - if (message is IJsonRpcMessageWithId messageWithId) - { - messageId = messageWithId.Id.ToString(); - } - _logger.TransportReceivedMessageParsed(EndpointName, messageId); - - await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); - } - else - { - _logger.TransportMessageParseUnexpectedType(EndpointName, line); - } - } - catch (JsonException ex) - { - _logger.TransportMessageParseFailed(EndpointName, line, ex); - // Continue reading even if we fail to parse a message - } - } - - _logger.TransportExitingReadMessagesLoop(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportReadMessagesCancelled(EndpointName); - } - catch (Exception ex) - { - _logger.TransportReadMessagesFailed(EndpointName, ex); - } - finally - { - SetConnected(false); - } - } - - /// - public override async ValueTask DisposeAsync() + private static string GetServerName(McpServerOptions serverOptions) { - if (Interlocked.Exchange(ref _disposed, 1) != 0) - { - return; - } - - try - { - _logger.TransportCleaningUp(EndpointName); - - // Signal to the stdin reading loop to stop. - await _shutdownCts.CancelAsync().ConfigureAwait(false); - _shutdownCts.Dispose(); - - // Dispose of stdin/out. Cancellation may not be able to wake up operations - // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor. - _stdInReader?.Dispose(); - _stdOutStream?.Dispose(); + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); + Throw.IfNull(serverOptions.ServerInfo.Name); - // Make sure the work has quiesced. - try - { - _logger.TransportWaitingForReadTask(EndpointName); - await _readLoopCompleted.ConfigureAwait(false); - _logger.TransportReadTaskCleanedUp(EndpointName); - } - catch (TimeoutException) - { - _logger.TransportCleanupReadTaskTimeout(EndpointName); - } - catch (OperationCanceledException) - { - _logger.TransportCleanupReadTaskCancelled(EndpointName); - } - catch (Exception ex) - { - _logger.TransportCleanupReadTaskFailed(EndpointName, ex); - } - } - finally - { - SetConnected(false); - _logger.TransportCleanedUp(EndpointName); - } + return serverOptions.ServerInfo.Name; } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs new file mode 100644 index 000000000..8359afa65 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -0,0 +1,187 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// Provides the client side of a stream-based session transport. +internal class StreamClientSessionTransport : TransportBase +{ + private readonly TextReader _serverOutput; + private readonly TextWriter _serverInput; + private readonly SemaphoreSlim _sendLock = new(1, 1); + private CancellationTokenSource? _shutdownCts = new(); + private Task? _readTask; + + /// + /// Initializes a new instance of the class. + /// + public StreamClientSessionTransport( + TextWriter serverInput, TextReader serverOutput, string endpointName, ILoggerFactory? loggerFactory) + : base(loggerFactory) + { + Logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _serverOutput = serverOutput; + _serverInput = serverInput; + EndpointName = endpointName; + + // Start reading messages in the background + Logger.TransportReadingMessages(endpointName); + _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); + + SetConnected(true); + } + + protected ILogger Logger { get; private set; } + + protected string EndpointName { get; } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (!IsConnected) + { + Logger.TransportNotConnected(EndpointName); + throw new McpTransportException("Transport is not connected"); + } + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))); + + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + try + { + Logger.TransportSendingMessage(EndpointName, id, json); + Logger.TransportMessageBytesUtf8(EndpointName, json); + + // Write the message followed by a newline using our UTF-8 writer + await _serverInput.WriteLineAsync(json).ConfigureAwait(false); + await _serverInput.FlushAsync(cancellationToken).ConfigureAwait(false); + + Logger.TransportSentMessage(EndpointName, id); + } + catch (Exception ex) + { + Logger.TransportSendFailed(EndpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + /// + public override ValueTask DisposeAsync() => + CleanupAsync(CancellationToken.None); + + private async Task ReadMessagesAsync(CancellationToken cancellationToken) + { + try + { + Logger.TransportEnteringReadMessagesLoop(EndpointName); + + while (!cancellationToken.IsCancellationRequested) + { + Logger.TransportWaitingForMessage(EndpointName); + if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line) + { + Logger.TransportEndOfStream(EndpointName); + break; + } + + if (string.IsNullOrWhiteSpace(line)) + { + continue; + } + + Logger.TransportReceivedMessage(EndpointName, line); + Logger.TransportMessageBytesUtf8(EndpointName, line); + + await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false); + } + Logger.TransportExitingReadMessagesLoop(EndpointName); + } + catch (OperationCanceledException) + { + Logger.TransportReadMessagesCancelled(EndpointName); + } + catch (Exception ex) + { + Logger.TransportReadMessagesFailed(EndpointName, ex); + } + finally + { + await CleanupAsync(cancellationToken).ConfigureAwait(false); + } + } + + private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken) + { + try + { + var message = (IJsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))); + if (message != null) + { + string messageId = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + + Logger.TransportReceivedMessageParsed(EndpointName, messageId); + await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); + Logger.TransportMessageWritten(EndpointName, messageId); + } + else + { + Logger.TransportMessageParseUnexpectedType(EndpointName, line); + } + } + catch (JsonException ex) + { + Logger.TransportMessageParseFailed(EndpointName, line, ex); + } + } + + protected virtual async ValueTask CleanupAsync(CancellationToken cancellationToken) + { + Logger.TransportCleaningUp(EndpointName); + + if (Interlocked.Exchange(ref _shutdownCts, null) is { } shutdownCts) + { + await shutdownCts.CancelAsync().ConfigureAwait(false); + shutdownCts.Dispose(); + } + + if (Interlocked.Exchange(ref _readTask, null) is Task readTask) + { + try + { + Logger.TransportWaitingForReadTask(EndpointName); + await readTask.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false); + Logger.TransportReadTaskCleanedUp(EndpointName); + } + catch (TimeoutException) + { + Logger.TransportCleanupReadTaskTimeout(EndpointName); + } + catch (OperationCanceledException) + { + Logger.TransportCleanupReadTaskCancelled(EndpointName); + } + catch (Exception ex) + { + Logger.TransportCleanupReadTaskFailed(EndpointName, ex); + } + } + + SetConnected(false); + Logger.TransportCleanedUp(EndpointName); + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs new file mode 100644 index 000000000..80bd61df5 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs @@ -0,0 +1,47 @@ +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides a client MCP transport implemented around a pair of input/output streams. +/// +public sealed class StreamClientTransport : IClientTransport +{ + private readonly Stream _serverInput; + private readonly Stream _serverOutput; + private readonly ILoggerFactory? _loggerFactory; + + /// + /// Initializes a new instance of the class. + /// + /// + /// The stream representing the connected server's input. + /// Writes to this stream will be sent to the server. + /// + /// + /// The stream representing the connected server's output. + /// Reads from this stream will receive messages from the server. + /// + /// A logger factory for creating loggers. + public StreamClientTransport( + Stream serverInput, Stream serverOutput, ILoggerFactory? loggerFactory = null) + { + Throw.IfNull(serverInput); + Throw.IfNull(serverOutput); + + _serverInput = serverInput; + _serverOutput = serverOutput; + _loggerFactory = loggerFactory; + } + + /// + public Task ConnectAsync(CancellationToken cancellationToken = default) + { + return Task.FromResult(new StreamClientSessionTransport( + new StreamWriter(_serverInput), + new StreamReader(_serverOutput), + "Client (stream)", + _loggerFactory)); + } +} diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs new file mode 100644 index 000000000..ebdf36350 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -0,0 +1,209 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Logging; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.IO.Pipelines; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Transport; + +/// +/// Provides a server MCP transport implemented around a pair of input/output streams. +/// +public class StreamServerTransport : TransportBase, ITransport +{ + private static readonly byte[] s_newlineBytes = "\n"u8.ToArray(); + + private readonly ILogger _logger; + + private readonly TextReader _inputReader; + private readonly Stream _outputStream; + private readonly string _endpointName; + + private readonly SemaphoreSlim _sendLock = new(1, 1); + private readonly CancellationTokenSource _shutdownCts = new(); + + private readonly Task _readLoopCompleted; + private int _disposed = 0; + + /// + /// Initializes a new instance of the class with explicit input/output streams. + /// + /// The input to use as standard input. + /// The output to use as standard output. + /// Optional name of the server, used for diagnostic purposes, like logging. + /// Optional logger factory used for logging employed by the transport. + /// is . + /// is . + public StreamServerTransport(Stream inputStream, Stream outputStream, string? serverName = null, ILoggerFactory? loggerFactory = null) + : base(loggerFactory) + { + Throw.IfNull(inputStream); + Throw.IfNull(outputStream); + + _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; + + _inputReader = new StreamReader(inputStream, Encoding.UTF8); + _outputStream = outputStream; + + SetConnected(true); + _readLoopCompleted = Task.Run(ReadMessagesAsync, _shutdownCts.Token); + + _endpointName = serverName is not null ? $"Server (stream) ({serverName})" : "Server (stream)"; + } + + /// + public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + if (!IsConnected) + { + _logger.TransportNotConnected(_endpointName); + throw new McpTransportException("Transport is not connected"); + } + + using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); + + string id = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + id = messageWithId.Id.ToString(); + } + + try + { + _logger.TransportSendingMessage(_endpointName, id); + + await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), cancellationToken).ConfigureAwait(false); + await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false); + await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);; + + _logger.TransportSentMessage(_endpointName, id); + } + catch (Exception ex) + { + _logger.TransportSendFailed(_endpointName, id, ex); + throw new McpTransportException("Failed to send message", ex); + } + } + + private async Task ReadMessagesAsync() + { + CancellationToken shutdownToken = _shutdownCts.Token; + try + { + _logger.TransportEnteringReadMessagesLoop(_endpointName); + + while (!shutdownToken.IsCancellationRequested) + { + _logger.TransportWaitingForMessage(_endpointName); + + var line = await _inputReader.ReadLineAsync(shutdownToken).ConfigureAwait(false); + if (string.IsNullOrWhiteSpace(line)) + { + if (line is null) + { + _logger.TransportEndOfStream(_endpointName); + break; + } + + continue; + } + + _logger.TransportReceivedMessage(_endpointName, line); + _logger.TransportMessageBytesUtf8(_endpointName, line); + + try + { + if (JsonSerializer.Deserialize(line, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage))) is IJsonRpcMessage message) + { + string messageId = "(no id)"; + if (message is IJsonRpcMessageWithId messageWithId) + { + messageId = messageWithId.Id.ToString(); + } + _logger.TransportReceivedMessageParsed(_endpointName, messageId); + + await WriteMessageAsync(message, shutdownToken).ConfigureAwait(false); + _logger.TransportMessageWritten(_endpointName, messageId); + } + else + { + _logger.TransportMessageParseUnexpectedType(_endpointName, line); + } + } + catch (JsonException ex) + { + _logger.TransportMessageParseFailed(_endpointName, line, ex); + // Continue reading even if we fail to parse a message + } + } + + _logger.TransportExitingReadMessagesLoop(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportReadMessagesCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportReadMessagesFailed(_endpointName, ex); + } + finally + { + SetConnected(false); + } + } + + /// + public override async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + return; + } + + try + { + _logger.TransportCleaningUp(_endpointName); + + // Signal to the stdin reading loop to stop. + await _shutdownCts.CancelAsync().ConfigureAwait(false); + _shutdownCts.Dispose(); + + // Dispose of stdin/out. Cancellation may not be able to wake up operations + // synchronously blocked in a syscall; we need to forcefully close the handle / file descriptor. + _inputReader?.Dispose(); + _outputStream?.Dispose(); + + // Make sure the work has quiesced. + try + { + _logger.TransportWaitingForReadTask(_endpointName); + await _readLoopCompleted.ConfigureAwait(false); + _logger.TransportReadTaskCleanedUp(_endpointName); + } + catch (TimeoutException) + { + _logger.TransportCleanupReadTaskTimeout(_endpointName); + } + catch (OperationCanceledException) + { + _logger.TransportCleanupReadTaskCancelled(_endpointName); + } + catch (Exception ex) + { + _logger.TransportCleanupReadTaskFailed(_endpointName, ex); + } + } + finally + { + SetConnected(false); + _logger.TransportCleanedUp(_endpointName); + } + + GC.SuppressFinalize(this); + } +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 710679e9f..3a2f2ab77 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -23,7 +23,7 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper) sc.AddSingleton(LoggerFactory); sc.AddMcpServer().WithStdioServerTransport(); // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); for (int f = 0; f < 10; f++) { string name = $"Method{f}"; @@ -53,19 +53,17 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer() { - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }; - return await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + serverOutput: _serverToClientPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 846ceebcc..6168c6648 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -30,7 +30,7 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper sc.AddSingleton(LoggerFactory); _builder = sc.AddMcpServer().WithStdioServerTransport().WithPrompts(); // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory)); + sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); _serviceProvider = sc.BuildServiceProvider(); @@ -55,19 +55,17 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer() { - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }; - return await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + serverOutput: _serverToClientPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 73fdeadef..c615e12f1 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -34,7 +34,7 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) sc.AddSingleton(LoggerFactory); _builder = sc.AddMcpServer().WithStdioServerTransport().WithTools(); // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory)); + sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); _serviceProvider = sc.BuildServiceProvider(); @@ -59,19 +59,17 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer() { - var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); - var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }; - return await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + _serverToClientPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -117,23 +115,21 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdinPipe = new Pipe(); var stdoutPipe = new Pipe(); - await using var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); + await using var transport = new StreamServerTransport(stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); await using var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider); var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); - using var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream()); - using var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream()); - - var serverConfig = new McpServerConfig() - { - Id = $"TestServer_{i}", - Name = $"TestServer_{i}", - TransportType = "ignored", - }; - await using (var client = await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + new McpServerConfig() + { + Id = $"TestServer_{i}", + Name = $"TestServer_{i}", + TransportType = "ignored", + }, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: stdinPipe.Writer.AsStream(), + serverOutput: stdoutPipe.Reader.AsStream(), + LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) { diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 5857f3c4d..6a0bd98a1 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -55,7 +55,7 @@ public void Constructor_Throws_For_Null_Options() [Fact] public async Task Should_Start_In_Connected_State() { - await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, new Pipe().Reader.AsStream(), Stream.Null, LoggerFactory); + await using var transport = new StreamServerTransport(new Pipe().Reader.AsStream(), Stream.Null, loggerFactory: LoggerFactory); Assert.True(transport.IsConnected); } @@ -65,11 +65,10 @@ public async Task SendMessageAsync_Should_Send_Message() { using var output = new MemoryStream(); - await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + await using var transport = new StreamServerTransport( new Pipe().Reader.AsStream(), output, - LoggerFactory); + loggerFactory: LoggerFactory); // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -87,7 +86,7 @@ public async Task SendMessageAsync_Should_Send_Message() [Fact] public async Task DisposeAsync_Should_Dispose_Resources() { - await using var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); + await using var transport = new StreamServerTransport(Stream.Null, Stream.Null, loggerFactory: LoggerFactory); await transport.DisposeAsync(); @@ -104,11 +103,10 @@ public async Task ReadMessagesAsync_Should_Read_Messages() Pipe pipe = new(); using var input = pipe.Reader.AsStream(); - await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + await using var transport = new StreamServerTransport( input, Stream.Null, - LoggerFactory); + loggerFactory: LoggerFactory); // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); @@ -128,7 +126,7 @@ public async Task ReadMessagesAsync_Should_Read_Messages() [Fact] public async Task CleanupAsync_Should_Cleanup_Resources() { - var transport = new StdioServerTransport(_serverOptions.ServerInfo.Name, Stream.Null, Stream.Null, LoggerFactory); + var transport = new StreamServerTransport(Stream.Null, Stream.Null, loggerFactory: LoggerFactory); await transport.DisposeAsync(); @@ -141,11 +139,10 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() // Use a reader that won't terminate using var output = new MemoryStream(); - await using var transport = new StdioServerTransport( - _serverOptions.ServerInfo.Name, + await using var transport = new StreamServerTransport( new Pipe().Reader.AsStream(), output, - LoggerFactory); + loggerFactory: LoggerFactory); // Verify transport is connected Assert.True(transport.IsConnected, "Transport should be connected after StartListeningAsync"); diff --git a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs b/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs deleted file mode 100644 index d41f0b979..000000000 --- a/tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs +++ /dev/null @@ -1,79 +0,0 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol.Messages; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Utils.Json; -using System.Text.Json; - -namespace ModelContextProtocol.Tests.Transport; - -internal sealed class StreamClientTransport : TransportBase, IClientTransport -{ - private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions; - private readonly Task? _readTask; - private readonly CancellationTokenSource _shutdownCts = new CancellationTokenSource(); - private readonly TextReader _serverStdoutReader; - private readonly TextWriter _serverStdinWriter; - - public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdoutReader, ILoggerFactory loggerFactory) - : base(loggerFactory) - { - _serverStdoutReader = serverStdoutReader; - _serverStdinWriter = serverStdinWriter; - _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); - SetConnected(true); - } - - public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this); - - public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) - { - string id = message is IJsonRpcMessageWithId messageWithId ? - messageWithId.Id.ToString() : - "(no id)"; - - await _serverStdinWriter.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false); - await _serverStdinWriter.FlushAsync(cancellationToken).ConfigureAwait(false); - } - - private async Task ReadMessagesAsync(CancellationToken cancellationToken) - { - try - { - while (await _serverStdoutReader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line) - { - if (!string.IsNullOrWhiteSpace(line)) - { - try - { - if (JsonSerializer.Deserialize(line.Trim(), _jsonOptions) is { } message) - { - await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - } - } - catch (JsonException) - { - } - } - } - } - catch (OperationCanceledException) - { - } - } - - public override async ValueTask DisposeAsync() - { - if (_shutdownCts is { } shutdownCts) - { - await shutdownCts.CancelAsync().ConfigureAwait(false); - shutdownCts.Dispose(); - } - - if (_readTask is Task readTask) - { - await readTask.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(false); - } - - SetConnected(false); - } -}