diff --git a/README.md b/README.md index 364354bcf..0bddac3a9 100644 --- a/README.md +++ b/README.md @@ -31,18 +31,15 @@ To get started writing a client, the `McpClientFactory.CreateAsync` method is us to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp -var client = await McpClientFactory.CreateAsync(new() +var clientTransport = new StdioClientTransport(new() { - Id = "everything", Name = "Everything", - TransportType = TransportTypes.StdIo, - TransportOptions = new() - { - ["command"] = "npx", - ["arguments"] = "-y @modelcontextprotocol/server-everything", - } + Command = "npx", + Arguments = ["-y", "@modelcontextprotocol/server-everything"], }); +var client = await McpClientFactory.CreateAsync(clientTransport); + // Print the list of tools available from the server. foreach (var tool in await client.ListToolsAsync()) { diff --git a/samples/ChatWithTools/Program.cs b/samples/ChatWithTools/Program.cs index 493806749..355a802bb 100644 --- a/samples/ChatWithTools/Program.cs +++ b/samples/ChatWithTools/Program.cs @@ -6,16 +6,12 @@ // Connect to an MCP server Console.WriteLine("Connecting client to MCP 'everything' server"); var mcpClient = await McpClientFactory.CreateAsync( - new() + new StdioClientTransport(new() { - Id = "everything", + Command = "npx", + Arguments = ["-y", "--verbose", "@modelcontextprotocol/server-everything"], Name = "Everything", - TransportType = TransportTypes.StdIo, - TransportOptions = new() - { - ["command"] = "npx", ["arguments"] = "-y @modelcontextprotocol/server-everything", - } - }); + })); // Get all available tools Console.WriteLine("Tools available:"); diff --git a/samples/QuickstartClient/Program.cs b/samples/QuickstartClient/Program.cs index 1ecd40c25..99a218bdf 100644 --- a/samples/QuickstartClient/Program.cs +++ b/samples/QuickstartClient/Program.cs @@ -13,18 +13,15 @@ var (command, arguments) = GetCommandAndArguments(args); -await using var mcpClient = await McpClientFactory.CreateAsync(new() +var clientTransport = new StdioClientTransport(new() { - Id = "demo-server", Name = "Demo Server", - TransportType = TransportTypes.StdIo, - TransportOptions = new() - { - ["command"] = command, - ["arguments"] = arguments, - } + Command = command, + Arguments = arguments, }); +await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport); + var tools = await mcpClient.ListToolsAsync(); foreach (var tool in tools) { @@ -86,13 +83,13 @@ static void PromptForInput() /// /// This method would only be required if you're creating a generic client, such as we use for the quickstart. /// -static (string command, string arguments) GetCommandAndArguments(string[] args) +static (string command, string[] arguments) GetCommandAndArguments(string[] args) { return args switch { - [var script] when script.EndsWith(".py") => ("python", script), - [var script] when script.EndsWith(".js") => ("node", script), - [var script] when Directory.Exists(script) || (File.Exists(script) && script.EndsWith(".csproj")) => ("dotnet", $"run --project {script} --no-build"), - _ => ("dotnet", "run --project ../../../../QuickstartWeatherServer --no-build") + [var script] when script.EndsWith(".py") => ("python", args), + [var script] when script.EndsWith(".js") => ("node", args), + [var script] when Directory.Exists(script) || (File.Exists(script) && script.EndsWith(".csproj")) => ("dotnet", ["run", "--project", script, "--no-build"]), + _ => ("dotnet", ["run", "--project", "../../../../QuickstartWeatherServer", "--no-build"]) }; } \ No newline at end of file diff --git a/src/Common/Polyfills/System/PasteArguments.cs b/src/Common/Polyfills/System/PasteArguments.cs new file mode 100644 index 000000000..32eb4c69f --- /dev/null +++ b/src/Common/Polyfills/System/PasteArguments.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Copied from: +// https://github.com/dotnet/runtime/blob/d2650b6ae7023a2d9d2c74c56116f1f18472ab04/src/libraries/System.Private.CoreLib/src/System/PasteArguments.cs +// and changed from using ValueStringBuilder to StringBuilder. + +using System.Text; + +namespace System; + +internal static partial class PasteArguments +{ + internal static void AppendArgument(StringBuilder stringBuilder, string argument) + { + if (stringBuilder.Length != 0) + { + stringBuilder.Append(' '); + } + + // Parsing rules for non-argv[0] arguments: + // - Backslash is a normal character except followed by a quote. + // - 2N backslashes followed by a quote ==> N literal backslashes followed by unescaped quote + // - 2N+1 backslashes followed by a quote ==> N literal backslashes followed by a literal quote + // - Parsing stops at first whitespace outside of quoted region. + // - (post 2008 rule): A closing quote followed by another quote ==> literal quote, and parsing remains in quoting mode. + if (argument.Length != 0 && ContainsNoWhitespaceOrQuotes(argument)) + { + // Simple case - no quoting or changes needed. + stringBuilder.Append(argument); + } + else + { + stringBuilder.Append(Quote); + int idx = 0; + while (idx < argument.Length) + { + char c = argument[idx++]; + if (c == Backslash) + { + int numBackSlash = 1; + while (idx < argument.Length && argument[idx] == Backslash) + { + idx++; + numBackSlash++; + } + + if (idx == argument.Length) + { + // We'll emit an end quote after this so must double the number of backslashes. + stringBuilder.Append(Backslash, numBackSlash * 2); + } + else if (argument[idx] == Quote) + { + // Backslashes will be followed by a quote. Must double the number of backslashes. + stringBuilder.Append(Backslash, numBackSlash * 2 + 1); + stringBuilder.Append(Quote); + idx++; + } + else + { + // Backslash will not be followed by a quote, so emit as normal characters. + stringBuilder.Append(Backslash, numBackSlash); + } + + continue; + } + + if (c == Quote) + { + // Escape the quote so it appears as a literal. This also guarantees that we won't end up generating a closing quote followed + // by another quote (which parses differently pre-2008 vs. post-2008.) + stringBuilder.Append(Backslash); + stringBuilder.Append(Quote); + continue; + } + + stringBuilder.Append(c); + } + + stringBuilder.Append(Quote); + } + } + + private static bool ContainsNoWhitespaceOrQuotes(string s) + { + for (int i = 0; i < s.Length; i++) + { + char c = s[i]; + if (char.IsWhiteSpace(c) || c == Quote) + { + return false; + } + } + + return true; + } + + private const char Quote = '\"'; + private const char Backslash = '\\'; +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index cc43cfa9f..a5636f3d3 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -33,9 +33,8 @@ internal sealed class McpClient : McpEndpoint, IMcpClient /// /// The transport to use for communication with the server. /// Options for the client, defining protocol version and capabilities. - /// The server configuration. /// The logger factory. - public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) + public McpClient(IClientTransport clientTransport, McpClientOptions? options, ILoggerFactory? loggerFactory) : base(loggerFactory) { options ??= new(); @@ -43,7 +42,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, Mc _clientTransport = clientTransport; _options = options; - EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; + EndpointName = clientTransport.Name; if (options.Capabilities is { } capabilities) { diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index 751df190d..ca9fcc4d3 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -12,128 +12,39 @@ namespace ModelContextProtocol.Client; public static class McpClientFactory { /// Creates an , connecting it to the specified server. - /// Configuration for the target server to which the client should connect. + /// The transport instance used to communicate with the server. /// /// A client configuration object which specifies client capabilities and protocol version. /// If , details based on the current process will be employed. /// - /// An optional factory method which returns transport implementations based on a server configuration. /// A logger factory for creating loggers for clients. /// The to monitor for cancellation requests. The default is . /// An that's connected to the specified server. - /// is . + /// is . /// is . - /// contains invalid information. - /// returns an invalid transport. public static async Task CreateAsync( - McpServerConfig serverConfig, + IClientTransport clientTransport, McpClientOptions? clientOptions = null, - Func? createTransportFunc = null, ILoggerFactory? loggerFactory = null, CancellationToken cancellationToken = default) { - Throw.IfNull(serverConfig); - - createTransportFunc ??= CreateTransport; - - string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; + Throw.IfNull(clientTransport); + string endpointName = clientTransport.Name; var logger = loggerFactory?.CreateLogger(typeof(McpClientFactory)) ?? NullLogger.Instance; logger.CreatingClient(endpointName); - var transport = - createTransportFunc(serverConfig, loggerFactory) ?? - throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport."); - + McpClient client = new(clientTransport, clientOptions, loggerFactory); try { - McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); - try - { - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - logger.ClientCreated(endpointName); - return client; - } - catch - { - await client.DisposeAsync().ConfigureAwait(false); - throw; - } + await client.ConnectAsync(cancellationToken).ConfigureAwait(false); + logger.ClientCreated(endpointName); + return client; } catch { - if (transport is IAsyncDisposable asyncDisposableTransport) - { - await asyncDisposableTransport.DisposeAsync().ConfigureAwait(false); - } - else if (transport is IDisposable disposableTransport) - { - disposableTransport.Dispose(); - } + await client.DisposeAsync().ConfigureAwait(false); throw; } } - - private static IClientTransport CreateTransport(McpServerConfig serverConfig, ILoggerFactory? loggerFactory) - { - if (string.Equals(serverConfig.TransportType, TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase)) - { - string? command = serverConfig.TransportOptions?.GetValueOrDefault("command"); - if (string.IsNullOrWhiteSpace(command)) - { - command = serverConfig.Location; - if (string.IsNullOrWhiteSpace(command)) - { - throw new ArgumentException("Command is required for stdio transport.", nameof(serverConfig)); - } - } - - string? arguments = serverConfig.TransportOptions?.GetValueOrDefault("arguments"); - - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && - serverConfig.TransportType.Equals(TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase) && - !string.IsNullOrEmpty(command) && - !string.Equals(Path.GetFileName(command), "cmd.exe", StringComparison.OrdinalIgnoreCase)) - { - // On Windows, for stdio, we need to wrap non-shell commands with cmd.exe /c {command} (usually npx or uvicorn). - // The stdio transport will not work correctly if the command is not run in a shell. - arguments = string.IsNullOrWhiteSpace(arguments) ? - $"/c {command}" : - $"/c {command} {arguments}"; - command = "cmd.exe"; - } - - return new StdioClientTransport(new StdioClientTransportOptions - { - Command = command!, - Arguments = arguments, - WorkingDirectory = serverConfig.TransportOptions?.GetValueOrDefault("workingDirectory"), - EnvironmentVariables = serverConfig.TransportOptions? - .Where(kv => kv.Key.StartsWith("env:", StringComparison.Ordinal)) - .ToDictionary(kv => kv.Key.Substring("env:".Length), kv => kv.Value), - ShutdownTimeout = TimeSpan.TryParse(serverConfig.TransportOptions?.GetValueOrDefault("shutdownTimeout"), CultureInfo.InvariantCulture, out var timespan) ? timespan : StdioClientTransportOptions.DefaultShutdownTimeout - }, serverConfig, loggerFactory); - } - - if (string.Equals(serverConfig.TransportType, TransportTypes.Sse, StringComparison.OrdinalIgnoreCase) || - string.Equals(serverConfig.TransportType, "http", StringComparison.OrdinalIgnoreCase)) - { - return new SseClientTransport(new SseClientTransportOptions - { - ConnectionTimeout = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "connectionTimeout", 30)), - MaxReconnectAttempts = ParseInt32OrDefault(serverConfig.TransportOptions, "maxReconnectAttempts", 3), - ReconnectDelay = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "reconnectDelay", 5)), - AdditionalHeaders = serverConfig.TransportOptions? - .Where(kv => kv.Key.StartsWith("header.", StringComparison.Ordinal)) - .ToDictionary(kv => kv.Key.Substring("header.".Length), kv => kv.Value) - }, serverConfig, loggerFactory); - - static int ParseInt32OrDefault(Dictionary? options, string key, int defaultValue) => - options?.TryGetValue(key, out var value) is not true ? defaultValue : - int.TryParse(value, out var result) ? result : - throw new ArgumentException($"Invalid value '{value}' for option '{key}' in transport options.", nameof(serverConfig)); - } - - throw new ArgumentException($"Unsupported transport type '{serverConfig.TransportType}'.", nameof(serverConfig)); - } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Configuration/McpServerConfig.cs b/src/ModelContextProtocol/Configuration/McpServerConfig.cs deleted file mode 100644 index 27cd39e41..000000000 --- a/src/ModelContextProtocol/Configuration/McpServerConfig.cs +++ /dev/null @@ -1,34 +0,0 @@ -namespace ModelContextProtocol; - -/// -/// Configuration for an MCP server connection. -/// This is passed to the client factory to create a client for a specific server. -/// -public record McpServerConfig -{ - /// - /// Unique identifier for this server configuration. - /// - public required string Id { get; init; } - - /// - /// Display name for the server. - /// - public required string Name { get; init; } - - /// - /// The type of transport to use. - /// - public required string TransportType { get; init; } - - /// - /// For stdio transport: path to the executable - /// For HTTP transport: base URL of the server - /// - public string? Location { get; set; } - - /// - /// Additional transport-specific configuration. - /// - public Dictionary? TransportOptions { get; init; } -} \ No newline at end of file diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index c120269bd..e56aa7390 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -7,6 +7,7 @@ ModelContextProtocol .NET SDK for the Model Context Protocol (MCP) README.md + preview diff --git a/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs index 48ec1dbb2..21024fc75 100644 --- a/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs @@ -5,6 +5,11 @@ /// public interface IClientTransport { + /// + /// Specifies a transport identifier used for logging purposes. + /// + string Name { get; } + /// /// Asynchronously establishes a transport session with an MCP server and returns an interface for the duplex JSON-RPC message stream. /// diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index c7542da8c..b45f9f236 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -16,6 +16,7 @@ namespace ModelContextProtocol.Protocol.Transport; /// internal sealed class SseClientSessionTransport : TransportBase { + private readonly string _endpointName; private readonly HttpClient _httpClient; private readonly SseClientTransportOptions _options; private readonly Uri _sseEndpoint; @@ -23,33 +24,29 @@ internal sealed class SseClientSessionTransport : TransportBase private readonly CancellationTokenSource _connectionCts; private Task? _receiveTask; private readonly ILogger _logger; - private readonly McpServerConfig _serverConfig; private readonly TaskCompletionSource _connectionEstablished; - private string EndpointName => $"Client (SSE) for ({_serverConfig.Id}: {_serverConfig.Name})"; - /// /// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server. /// The HTTP server can be local or remote, and must support the SSE protocol. /// /// Configuration options for the transport. - /// The configuration object indicating which server to connect to. /// The HTTP client instance used for requests. /// Logger factory for creating loggers. - public SseClientSessionTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory) + /// The endpoint name used for logging purposes. + public SseClientSessionTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory, string endpointName) : base(loggerFactory) { Throw.IfNull(transportOptions); - Throw.IfNull(serverConfig); Throw.IfNull(httpClient); _options = transportOptions; - _serverConfig = serverConfig; - _sseEndpoint = new Uri(serverConfig.Location!); + _sseEndpoint = transportOptions.Endpoint; _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _connectionEstablished = new TaskCompletionSource(); + _endpointName = endpointName; } /// @@ -59,14 +56,14 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) { if (IsConnected) { - _logger.TransportAlreadyConnected(EndpointName); + _logger.TransportAlreadyConnected(_endpointName); throw new McpTransportException("Transport is already connected"); } // Start message receiving loop _receiveTask = ReceiveMessagesAsync(_connectionCts.Token); - _logger.TransportReadingMessages(EndpointName); + _logger.TransportReadingMessages(_endpointName); await _connectionEstablished.Task.WaitAsync(_options.ConnectionTimeout, cancellationToken).ConfigureAwait(false); } @@ -77,7 +74,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) } catch (Exception ex) { - _logger.TransportConnectFailed(EndpointName, ex); + _logger.TransportConnectFailed(_endpointName, ex); await CloseAsync().ConfigureAwait(false); throw new McpTransportException("Failed to connect transport", ex); } @@ -120,7 +117,7 @@ public override async Task SendMessageAsync( // If the response is not a JSON-RPC response, it is an SSE message if (responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) { - _logger.SSETransportPostAccepted(EndpointName, messageId); + _logger.SSETransportPostAccepted(_endpointName, messageId); // The response will arrive as an SSE message } else @@ -128,9 +125,9 @@ public override async Task SendMessageAsync( JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ?? throw new McpTransportException("Failed to initialize client"); - _logger.TransportReceivedMessageParsed(EndpointName, messageId); + _logger.TransportReceivedMessageParsed(_endpointName, messageId); await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); + _logger.TransportMessageWritten(_endpointName, messageId); } return; } @@ -138,11 +135,11 @@ public override async Task SendMessageAsync( // Otherwise, check if the response was accepted (the response will come as an SSE message) if (responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) { - _logger.SSETransportPostAccepted(EndpointName, messageId); + _logger.SSETransportPostAccepted(_endpointName, messageId); } else { - _logger.SSETransportPostNotAccepted(EndpointName, messageId, responseContent); + _logger.SSETransportPostNotAccepted(_endpointName, messageId, responseContent); throw new McpTransportException("Failed to send message"); } } @@ -220,17 +217,17 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { - _logger.TransportReadMessagesCancelled(EndpointName); + _logger.TransportReadMessagesCancelled(_endpointName); // Normal shutdown } catch (IOException) when (cancellationToken.IsCancellationRequested) { - _logger.TransportReadMessagesCancelled(EndpointName); + _logger.TransportReadMessagesCancelled(_endpointName); // Normal shutdown } catch (Exception ex) when (!cancellationToken.IsCancellationRequested) { - _logger.TransportConnectionError(EndpointName, ex); + _logger.TransportConnectionError(_endpointName, ex); reconnectAttempts++; if (reconnectAttempts >= _options.MaxReconnectAttempts) @@ -249,7 +246,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation { if (!IsConnected) { - _logger.TransportMessageReceivedBeforeConnected(EndpointName, data); + _logger.TransportMessageReceivedBeforeConnected(_endpointName, data); return; } @@ -258,7 +255,7 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation var message = JsonSerializer.Deserialize(data, McpJsonUtilities.JsonContext.Default.IJsonRpcMessage); if (message == null) { - _logger.TransportMessageParseUnexpectedType(EndpointName, data); + _logger.TransportMessageParseUnexpectedType(_endpointName, data); return; } @@ -268,13 +265,13 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation messageId = messageWithId.Id.ToString(); } - _logger.TransportReceivedMessageParsed(EndpointName, messageId); + _logger.TransportReceivedMessageParsed(_endpointName, messageId); await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false); - _logger.TransportMessageWritten(EndpointName, messageId); + _logger.TransportMessageWritten(_endpointName, messageId); } catch (JsonException ex) { - _logger.TransportMessageParseFailed(EndpointName, data, ex); + _logger.TransportMessageParseFailed(_endpointName, data, ex); } } @@ -284,7 +281,7 @@ private void HandleEndpointEvent(string data) { if (string.IsNullOrEmpty(data)) { - _logger.TransportEndpointEventInvalid(EndpointName, data); + _logger.TransportEndpointEventInvalid(_endpointName, data); return; } @@ -312,7 +309,7 @@ private void HandleEndpointEvent(string data) } catch (JsonException ex) { - _logger.TransportEndpointEventParseFailed(EndpointName, data, ex); + _logger.TransportEndpointEventParseFailed(_endpointName, data, ex); throw new McpTransportException("Failed to parse endpoint event", ex); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index b6e76844e..8da23b12f 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -9,7 +9,6 @@ namespace ModelContextProtocol.Protocol.Transport; public sealed class SseClientTransport : IClientTransport, IAsyncDisposable { private readonly SseClientTransportOptions _options; - private readonly McpServerConfig _serverConfig; private readonly HttpClient _httpClient; private readonly ILoggerFactory? _loggerFactory; private readonly bool _ownsHttpClient; @@ -19,10 +18,9 @@ public sealed class SseClientTransport : IClientTransport, IAsyncDisposable /// The HTTP server can be local or remote, and must support the SSE protocol. /// /// Configuration options for the transport. - /// The configuration object indicating which server to connect to. /// Logger factory for creating loggers. - public SseClientTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) - : this(transportOptions, serverConfig, new HttpClient(), loggerFactory, true) + public SseClientTransport(SseClientTransportOptions transportOptions, ILoggerFactory? loggerFactory = null) + : this(transportOptions, new HttpClient(), loggerFactory, true) { } @@ -31,27 +29,28 @@ public SseClientTransport(SseClientTransportOptions transportOptions, McpServerC /// The HTTP server can be local or remote, and must support the SSE protocol. /// /// Configuration options for the transport. - /// The configuration object indicating which server to connect to. /// The HTTP client instance used for requests. /// Logger factory for creating loggers. /// True to dispose HTTP client on close connection. - public SseClientTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory, bool ownsHttpClient = false) + public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool ownsHttpClient = false) { Throw.IfNull(transportOptions); - Throw.IfNull(serverConfig); Throw.IfNull(httpClient); _options = transportOptions; - _serverConfig = serverConfig; _httpClient = httpClient; _loggerFactory = loggerFactory; _ownsHttpClient = ownsHttpClient; + Name = transportOptions.Name ?? transportOptions.Endpoint.ToString(); } + /// + public string Name { get; } + /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - var sessionTransport = new SseClientSessionTransport(_options, _serverConfig, _httpClient, _loggerFactory); + var sessionTransport = new SseClientSessionTransport(_options, _httpClient, _loggerFactory, Name); try { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index 69a0628a4..0a42068df 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -5,6 +5,36 @@ /// public record SseClientTransportOptions { + /// + /// The base address of the server for SSE connections. + /// + public required Uri Endpoint + { + get; + init + { + if (value is null) + { + throw new ArgumentNullException(nameof(value), "Endpoint cannot be null."); + } + if (!value.IsAbsoluteUri) + { + throw new ArgumentException("Endpoint must be an absolute URI.", nameof(value)); + } + if (value.Scheme != Uri.UriSchemeHttp && value.Scheme != Uri.UriSchemeHttps) + { + throw new ArgumentException("Endpoint must use HTTP or HTTPS scheme.", nameof(value)); + } + + field = value; + } + } + + /// + /// Specifies a transport identifier used for logging purposes. + /// + public string? Name { get; init; } + /// /// Timeout for initial connection and endpoint event. /// diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index 11a0c1371..a9ed6cbff 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -3,7 +3,9 @@ using ModelContextProtocol.Logging; using ModelContextProtocol.Utils; using System.Diagnostics; +using System.Runtime.InteropServices; using System.Text; +using System.Text.RegularExpressions; #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously @@ -15,33 +17,44 @@ namespace ModelContextProtocol.Protocol.Transport; public sealed class StdioClientTransport : IClientTransport { private readonly StdioClientTransportOptions _options; - private readonly McpServerConfig _serverConfig; private readonly ILoggerFactory? _loggerFactory; /// /// Initializes a new instance of the class. /// /// Configuration options for the transport. - /// The server configuration for the transport. /// A logger factory for creating loggers. - public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null) + public StdioClientTransport(StdioClientTransportOptions options, ILoggerFactory? loggerFactory = null) { Throw.IfNull(options); - Throw.IfNull(serverConfig); _options = options; - _serverConfig = serverConfig; _loggerFactory = loggerFactory; + Name = options.Name ?? $"stdio-{Regex.Replace(Path.GetFileName(options.Command), @"[\s\.]+", "-")}"; } + /// + public string Name { get; } + /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { - string endpointName = $"Client (stdio) for ({_serverConfig.Id}: {_serverConfig.Name})"; + string endpointName = Name; Process? process = null; bool processStarted = false; + string command = _options.Command; + IList? arguments = _options.Arguments; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && + !string.Equals(Path.GetFileName(command), "cmd.exe", StringComparison.OrdinalIgnoreCase)) + { + // On Windows, for stdio, we need to wrap non-shell commands with cmd.exe /c {command} (usually npx or uvicorn). + // The stdio transport will not work correctly if the command is not run in a shell. + arguments = arguments is null or [] ? ["/c", command] : ["/c", command, ..arguments]; + command = "cmd.exe"; + } + ILogger logger = (ILogger?)_loggerFactory?.CreateLogger() ?? NullLogger.Instance; try { @@ -51,7 +64,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = ProcessStartInfo startInfo = new() { - FileName = _options.Command, + FileName = command, RedirectStandardInput = true, RedirectStandardOutput = true, RedirectStandardError = true, @@ -65,9 +78,22 @@ public async Task ConnectAsync(CancellationToken cancellationToken = #endif }; - if (!string.IsNullOrWhiteSpace(_options.Arguments)) + if (arguments is not null) { - startInfo.Arguments = _options.Arguments; +#if NET + foreach (string arg in arguments) + { + startInfo.ArgumentList.Add(arg); + } +#else + StringBuilder argsBuilder = new(); + foreach (string arg in arguments) + { + PasteArguments.AppendArgument(argsBuilder, arg); + } + + startInfo.Arguments = argsBuilder.ToString(); +#endif } if (_options.EnvironmentVariables != null) diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransportOptions.cs index 0b59555ba..6101adee6 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransportOptions.cs @@ -1,5 +1,4 @@ -// Protocol/Transport/StdioTransport.cs -namespace ModelContextProtocol.Protocol.Transport; +namespace ModelContextProtocol.Protocol.Transport; /// /// Represents configuration options for the stdio transport. @@ -14,12 +13,29 @@ public record StdioClientTransportOptions /// /// The command to execute to start the server process. /// - public required string Command { get; set; } + public required string Command + { + get; + set + { + if (string.IsNullOrWhiteSpace(value)) + { + throw new ArgumentException("Command cannot be null or empty.", nameof(value)); + } + + field = value; + } + } /// /// Arguments to pass to the server process. /// - public string? Arguments { get; set; } + public IList? Arguments { get; set; } + + /// + /// Specifies a transport identifier used for logging purposes. + /// + public string? Name { get; set; } /// /// The working directory for the server process. diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs index 80bd61df5..371036a9d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientTransport.cs @@ -35,6 +35,10 @@ public StreamClientTransport( _loggerFactory = loggerFactory; } + /// + public string Name => $"in-memory-stream"; + + /// public Task ConnectAsync(CancellationToken cancellationToken = default) { diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs b/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs deleted file mode 100644 index 02de0a6b5..000000000 --- a/src/ModelContextProtocol/Protocol/Transport/TransportTypes.cs +++ /dev/null @@ -1,17 +0,0 @@ -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// List all transport types -/// -public static class TransportTypes -{ - /// - /// The name of the Standard IO transport. - /// - public const string StdIo = "stdio"; - - /// - /// The name of the ServerSideEvents transport. - /// - public const string Sse = "sse"; -} diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index cc2277786..d0ad20572 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -5,7 +5,6 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Reflection; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index aa538fa04..be60c90f4 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -234,13 +234,7 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer() { return await McpClientFactory.CreateAsync( - new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }, - createTransportFunc: (_, _) => new StreamClientTransport( + new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), serverOutput: _serverToClientPipe.Reader.AsStream(), LoggerFactory), diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index aa2f773eb..1414f6563 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -14,186 +14,26 @@ public class McpClientFactoryTests [Fact] public async Task CreateAsync_WithInvalidArgs_Throws() { - await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, cancellationToken: TestContext.Current.CancellationToken)); - - await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync(new McpServerConfig() - { - Name = "name", - Id = "id", - TransportType = "somethingunsupported", - }, cancellationToken: TestContext.Current.CancellationToken)); - - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(new McpServerConfig() - { - Name = "name", - Id = "id", - TransportType = TransportTypes.StdIo, - }, createTransportFunc: (_, __) => null!, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync("clientTransport", () => McpClientFactory.CreateAsync(null!, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] - public async Task CreateAsync_NullOptions_EntryAssemblyInferred() + public async Task CreateAsync_NopTransport_ReturnsClient() { - // Arrange - var serverConfig = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.StdIo, - Location = "/path/to/server", - }; - // Act await using var client = await McpClientFactory.CreateAsync( - serverConfig, - null, - (_, __) => new NopTransport(), + new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(client); } - [Fact] - public async Task CreateAsync_WithValidStdioConfig_CreatesNewClient() - { - // Arrange - var serverConfig = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.StdIo, - Location = "/path/to/server", - TransportOptions = new Dictionary - { - ["arguments"] = "--test arg", - ["workingDirectory"] = "/working/dir" - } - }; - - // Act - await using var client = await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, __) => new NopTransport(), - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(client); - // We could add more assertions here about the client's configuration - } - - [Fact] - public async Task CreateAsync_WithNoTransportOptions_CreatesNewClient() - { - // Arrange - var serverConfig = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.StdIo, - Location = "/path/to/server", - }; - - // Act - await using var client = await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, __) => new NopTransport(), - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(client); - // We could add more assertions here about the client's configuration - } - - [Fact] - public async Task CreateAsync_WithValidSseConfig_CreatesNewClient() - { - // Arrange - var serverConfig = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080" - }; - - // Act - await using var client = await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, __) => new NopTransport(), - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(client); - // We could add more assertions here about the client's configuration - } - - [Fact] - public async Task CreateAsync_WithSse_CreatesCorrectTransportOptions() - { - // Arrange - var serverConfig = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080", - TransportOptions = new Dictionary - { - ["connectionTimeout"] = "10", - ["maxReconnectAttempts"] = "2", - ["reconnectDelay"] = "5", - ["header.test"] = "the_header_value" - } - }; - - // Act - await using var client = await McpClientFactory.CreateAsync( - serverConfig, - createTransportFunc: (_, __) => new NopTransport(), - cancellationToken: TestContext.Current.CancellationToken); - - // Assert - Assert.NotNull(client); - // We could add more assertions here about the client's configuration - } - - [Theory] - [InlineData("connectionTimeout", "not_a_number")] - [InlineData("maxReconnectAttempts", "invalid")] - [InlineData("reconnectDelay", "bad_value")] - public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(string key, string value) - { - // arrange - var config = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080", - TransportOptions = new Dictionary - { - [key] = value - } - }; - - // act & assert - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, cancellationToken: TestContext.Current.CancellationToken)); - } - [Theory] [InlineData(typeof(NopTransport))] [InlineData(typeof(FailureTransport))] public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) { // Arrange - var serverConfig = new McpServerConfig - { - Id = "TestServer", - Name = "TestServer", - TransportType = "stdio", - Location = "test-location" - }; - var clientOptions = new McpClientOptions { Capabilities = new ClientCapabilities @@ -216,10 +56,10 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) } }; - var clientTransport = (IClientTransport?)Activator.CreateInstance(transportType); + var clientTransport = (IClientTransport)Activator.CreateInstance(transportType)!; IMcpClient? client = null; - var actionTask = McpClientFactory.CreateAsync(serverConfig, clientOptions, (config, logger) => clientTransport ?? new NopTransport(), new Mock().Object, CancellationToken.None); + var actionTask = McpClientFactory.CreateAsync(clientTransport, clientOptions, new Mock().Object, CancellationToken.None); // Act if (clientTransport is FailureTransport) @@ -248,6 +88,8 @@ private class NopTransport : ITransport, IClientTransport public ValueTask DisposeAsync() => default; + public string Name => "Test Nop Transport"; + public virtual Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { switch (message) diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index 400caade7..5cd78da97 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -8,41 +8,31 @@ public class ClientIntegrationTestFixture { private ILoggerFactory? _loggerFactory; - public McpServerConfig EverythingServerConfig { get; } - public McpServerConfig TestServerConfig { get; } + public StdioClientTransportOptions EverythingServerTransportOptions { get; } + public StdioClientTransportOptions TestServerTransportOptions { get; } public static IEnumerable ClientIds => ["everything", "test_server"]; public ClientIntegrationTestFixture() { - EverythingServerConfig = new() + EverythingServerTransportOptions = new() { - Id = "everything", + Command = "npx", + // Change to Arguments = ["mcp-server-everything"] if you want to run the server locally after creating a symlink + Arguments = ["-y", "--verbose", "@modelcontextprotocol/server-everything"], Name = "Everything", - TransportType = TransportTypes.StdIo, - TransportOptions = new Dictionary - { - ["command"] = "npx", - // Change to ["arguments"] = "mcp-server-everything" if you want to run the server locally after creating a symlink - ["arguments"] = "-y --verbose @modelcontextprotocol/server-everything" - } }; - TestServerConfig = new() + TestServerTransportOptions = new() { - Id = "test_server", + Command = OperatingSystem.IsWindows() ? "TestServer.exe" : "dotnet", Name = "TestServer", - TransportType = TransportTypes.StdIo, - TransportOptions = new Dictionary - { - ["command"] = OperatingSystem.IsWindows() ? "TestServer.exe" : "dotnet", - // Change to ["arguments"] = "mcp-server-everything" if you want to run the server locally after creating a symlink - } }; if (!OperatingSystem.IsWindows()) { - TestServerConfig.TransportOptions["arguments"] = "TestServer.dll"; + // Change to Arguments to "mcp-server-everything" if you want to run the server locally after creating a symlink + TestServerTransportOptions.Arguments = ["TestServer.dll"]; } } @@ -52,10 +42,10 @@ public void Initialize(ILoggerFactory loggerFactory) } public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => - McpClientFactory.CreateAsync(clientId switch + McpClientFactory.CreateAsync(new StdioClientTransport(clientId switch { - "everything" => EverythingServerConfig, - "test_server" => TestServerConfig, + "everything" => EverythingServerTransportOptions, + "test_server" => TestServerTransportOptions, _ => throw new ArgumentException($"Unknown client ID: {clientId}") - }, clientOptions, loggerFactory: _loggerFactory); + }), clientOptions, loggerFactory: _loggerFactory); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index baa7e55b0..3fbd8f16c 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -453,16 +453,11 @@ public async Task Notifications_Stdio(string clientId) public async Task CallTool_Stdio_MemoryServer() { // arrange - McpServerConfig serverConfig = new() + StdioClientTransportOptions stdioOptions = new() { - Id = "memory", + Command = "npx", + Arguments = ["-y", "@modelcontextprotocol/server-memory"], Name = "memory", - TransportType = TransportTypes.StdIo, - TransportOptions = new Dictionary - { - ["command"] = "npx", - ["arguments"] = "-y @modelcontextprotocol/server-memory" - } }; McpClientOptions clientOptions = new() @@ -471,7 +466,7 @@ public async Task CallTool_Stdio_MemoryServer() }; await using var client = await McpClientFactory.CreateAsync( - serverConfig, + new StdioClientTransport(stdioOptions), clientOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -495,7 +490,7 @@ public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() { // Get the MCP client and tools from it. await using var client = await McpClientFactory.CreateAsync( - _fixture.EverythingServerConfig, + new StdioClientTransport(_fixture.EverythingServerTransportOptions), cancellationToken: TestContext.Current.CancellationToken); var mappedTools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -526,7 +521,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() var samplingHandler = new OpenAIClient(s_openAIKey) .AsChatClient("gpt-4o-mini") .CreateSamplingHandler(); - await using var client = await McpClientFactory.CreateAsync(_fixture.EverythingServerConfig, new() + await using var client = await McpClientFactory.CreateAsync(new StdioClientTransport(_fixture.EverythingServerTransportOptions), new() { Capabilities = new() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 64dc9fb56..a6069f593 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -118,14 +118,7 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer(McpClientOptions? options = null) { return await McpClientFactory.CreateAsync( - new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }, - options, - createTransportFunc: (_, _) => new StreamClientTransport( + new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), serverOutput: _serverToClientPipe.Reader.AsStream(), LoggerFactory), diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 075ab8f80..10d79c7fa 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -142,14 +142,7 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer(McpClientOptions? options = null) { return await McpClientFactory.CreateAsync( - new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }, - options, - createTransportFunc: (_, _) => new StreamClientTransport( + new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), LoggerFactory), @@ -203,13 +196,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); await using (var client = await McpClientFactory.CreateAsync( - new McpServerConfig() - { - Id = $"TestServer_{i}", - Name = $"TestServer_{i}", - TransportType = "ignored", - }, - createTransportFunc: (_, _) => new StreamClientTransport( + new StreamClientTransport( serverInput: stdinPipe.Writer.AsStream(), serverOutput: stdoutPipe.Reader.AsStream(), LoggerFactory), diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index fd93eff6d..199f426c4 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -59,13 +59,8 @@ private static async Task RunConnected(Func action { serverTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (IMcpClient client = await McpClientFactory.CreateAsync(new() - { - Id = "TestServer", - Name = "TestServer", - TransportType = TransportTypes.StdIo, - }, - createTransportFunc: (_, __) => clientTransport, + await using (IMcpClient client = await McpClientFactory.CreateAsync( + clientTransport, cancellationToken: TestContext.Current.CancellationToken)) { await action(client, server); diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs index 3dbcfe7c2..b5947eaba 100644 --- a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -37,18 +37,15 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - var defaultConfig = new McpServerConfig + var defaultConfig = new SseClientTransportOptions { - Id = "everything", + Endpoint = new Uri($"http://localhost:{port}/sse"), Name = "Everything", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{port}/sse" }; // Create client and run tests await using var client = await McpClientFactory.CreateAsync( - defaultConfig, + new SseClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -67,13 +64,10 @@ public async Task Sampling_Sse_EverythingServer() await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); - var defaultConfig = new McpServerConfig + var defaultConfig = new SseClientTransportOptions { - Id = "everything", + Endpoint = new Uri($"http://localhost:{port}/sse"), Name = "Everything", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost:{port}/sse" }; int samplingHandlerCalls = 0; @@ -102,7 +96,7 @@ public async Task Sampling_Sse_EverythingServer() }; await using var client = await McpClientFactory.CreateAsync( - defaultConfig, + new SseClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 27ad60371..5e2500ac9 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -50,14 +50,7 @@ public async ValueTask DisposeAsync() private async Task CreateMcpClientForServer(McpClientOptions? options = null) { return await McpClientFactory.CreateAsync( - new McpServerConfig() - { - Id = "TestServer", - Name = "TestServer", - TransportType = "ignored", - }, - options, - createTransportFunc: (_, _) => new StreamClientTransport( + new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), LoggerFactory), diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 41c3d343f..b769d2c08 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -16,20 +16,16 @@ namespace ModelContextProtocol.Tests; public class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) { - private McpServerConfig DefaultServerConfig = new() + private SseClientTransportOptions DefaultTransportOptions = new() { - Id = "test_server", + Endpoint = new Uri("http://localhost/sse"), Name = "In-memory Test Server", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost/sse" }; private Task ConnectMcpClient(HttpClient httpClient, McpClientOptions? clientOptions = null) => McpClientFactory.CreateAsync( - DefaultServerConfig, + new SseClientTransport(DefaultTransportOptions, httpClient, LoggerFactory), clientOptions, - (_, _) => new SseClientTransport(new(), DefaultServerConfig, httpClient, LoggerFactory), LoggerFactory, TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 6b7b474ef..286ae3cfa 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -18,13 +18,10 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable // multiple tests, so this dispatches the output to the current test. private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - private McpServerConfig DefaultServerConfig { get; } = new McpServerConfig + private SseClientTransportOptions DefaultTransportOptions { get; } = new() { - Id = "test_server", + Endpoint = new Uri("http://localhost/sse"), Name = "TestServer", - TransportType = TransportTypes.Sse, - TransportOptions = [], - Location = $"http://localhost/sse" }; public SseServerIntegrationTestFixture() @@ -40,7 +37,7 @@ public SseServerIntegrationTestFixture() HttpClient = new HttpClient(socketsHttpHandler) { - BaseAddress = new Uri(DefaultServerConfig.Location), + BaseAddress = DefaultTransportOptions.Endpoint, }; _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token); } @@ -50,9 +47,8 @@ public SseServerIntegrationTestFixture() public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) { return McpClientFactory.CreateAsync( - DefaultServerConfig, + new SseClientTransport(DefaultTransportOptions, HttpClient, loggerFactory), options, - (_, _) => new SseClientTransport(new(), DefaultServerConfig, HttpClient, loggerFactory), loggerFactory, TestContext.Current.CancellationToken); } diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 63b86a4a6..db753c6b1 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -8,25 +8,18 @@ namespace ModelContextProtocol.Tests.Transport; public class SseClientTransportTests : LoggedTest { - private readonly McpServerConfig _serverConfig; private readonly SseClientTransportOptions _transportOptions; public SseClientTransportTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - _serverConfig = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080" - }; - _transportOptions = new SseClientTransportOptions { + Endpoint = new Uri("http://localhost:8080"), ConnectionTimeout = TimeSpan.FromSeconds(2), MaxReconnectAttempts = 3, ReconnectDelay = TimeSpan.FromMilliseconds(50), + Name = "Test Server", AdditionalHeaders = new Dictionary { ["test"] = "header" @@ -37,21 +30,14 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper) [Fact] public void Constructor_Throws_For_Null_Options() { - var exception = Assert.Throws(() => new SseClientTransport(null!, _serverConfig, LoggerFactory)); + var exception = Assert.Throws(() => new SseClientTransport(null!, LoggerFactory)); Assert.Equal("transportOptions", exception.ParamName); } [Fact] - public void Constructor_Throws_For_Null_Config() + public void Constructor_Throws_For_Null_HttpClient() { var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, null!, LoggerFactory)); - Assert.Equal("serverConfig", exception.ParamName); - } - - [Fact] - public void Constructor_Throws_For_Null_HttpClientg() - { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, _serverConfig, null!, LoggerFactory)); Assert.Equal("httpClient", exception.ParamName); } @@ -60,7 +46,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); bool firstCall = true; @@ -84,7 +70,7 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); var retries = 0; mockHttpHandler.RequestHandler = (request) => @@ -106,7 +92,7 @@ public async Task SendMessageAsync_Handles_Accepted_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -144,7 +130,7 @@ public async Task SendMessageAsync_Handles_Accepted_Json_RPC_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); var eventSourcePipe = new Pipe(); var eventSourceData = "event: endpoint\r\ndata: /sseendpoint\r\n\r\n"u8; @@ -190,7 +176,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); var callIndex = 0; mockHttpHandler.RequestHandler = (request) => @@ -230,7 +216,7 @@ public async Task DisposeAsync_Should_Dispose_Resources() }); }; - await using var transport = new SseClientTransport(_transportOptions, _serverConfig, httpClient, LoggerFactory); + await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); await session.DisposeAsync();