From 20be4c1147a08a0b184ff5f84841f94bbd100d22 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 16 Apr 2025 15:15:57 -0400 Subject: [PATCH 1/3] Clean up (use of) McpException --- samples/EverythingServer/Program.cs | 2 +- src/ModelContextProtocol/Client/McpClient.cs | 4 +- src/ModelContextProtocol/McpErrorCode.cs | 49 +++++++++++++++++++ src/ModelContextProtocol/McpException.cs | 24 +++++---- .../Protocol/Messages/ErrorCodes.cs | 32 ------------ .../Server/AIFunctionMcpServerPrompt.cs | 4 +- .../Server/AIFunctionMcpServerTool.cs | 12 +++-- src/ModelContextProtocol/Server/McpServer.cs | 34 +++++++++---- .../Server/McpServerExtensions.cs | 6 +-- src/ModelContextProtocol/Shared/McpSession.cs | 38 +++++++++----- .../Program.cs | 28 +++++------ .../Program.cs | 22 ++++----- .../McpServerBuilderExtensionsPromptsTests.cs | 30 ++++++------ .../McpServerBuilderExtensionsToolsTests.cs | 4 +- .../DiagnosticTests.cs | 4 +- .../Server/McpServerPromptTests.cs | 2 +- .../Server/McpServerResourceTests.cs | 2 +- .../Server/McpServerTests.cs | 12 ++--- 18 files changed, 181 insertions(+), 128 deletions(-) create mode 100644 src/ModelContextProtocol/McpErrorCode.cs delete mode 100644 src/ModelContextProtocol/Protocol/Messages/ErrorCodes.cs diff --git a/samples/EverythingServer/Program.cs b/samples/EverythingServer/Program.cs index c9bc12729..59d9b845d 100644 --- a/samples/EverythingServer/Program.cs +++ b/samples/EverythingServer/Program.cs @@ -176,7 +176,7 @@ await ctx.Server.RequestSamplingAsync([ { if (ctx.Params?.Level is null) { - throw new McpException("Missing required argument 'level'"); + throw new McpException("Missing required argument 'level'", McpErrorCode.InvalidParams); } _minimumLoggingLevel = ctx.Params.Level; diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index a8b5abd13..50c0a195f 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -155,10 +155,10 @@ await SendMessageAsync( new JsonRpcNotification { Method = NotificationMethods.InitializedNotification }, initializationCts.Token).ConfigureAwait(false); } - catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested) + catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) { LogClientInitializationTimeout(EndpointName); - throw new McpException("Initialization timed out", oce); + throw new TimeoutException("Initialization timed out", oce); } } catch (Exception e) diff --git a/src/ModelContextProtocol/McpErrorCode.cs b/src/ModelContextProtocol/McpErrorCode.cs new file mode 100644 index 000000000..f6cf4f516 --- /dev/null +++ b/src/ModelContextProtocol/McpErrorCode.cs @@ -0,0 +1,49 @@ +namespace ModelContextProtocol; + +/// +/// Represents standard JSON-RPC error codes as defined in the MCP specification. +/// +public enum McpErrorCode +{ + /// + /// Indicates that the JSON received could not be parsed. + /// + /// + /// This error occurs when the input contains malformed JSON or incorrect syntax. + /// + ParseError = -32700, + + /// + /// Indicates that the JSON payload does not conform to the expected Request object structure. + /// + /// + /// The request is considered invalid if it lacks required fields or fails to follow the JSON-RPC protocol. + /// + InvalidRequest = -32600, + + /// + /// Indicates that the requested method does not exist or is not available on the server. + /// + /// + /// This error is returned when the method name specified in the request cannot be found. + /// + MethodNotFound = -32601, + + /// + /// Indicates that one or more parameters provided in the request are invalid. + /// + /// + /// This error is returned when the parameters do not match the expected method signature or constraints. + /// This includes cases where required parameters are missing or not understood, such as when a name for + /// a tool or prompt is not recognized. + /// + InvalidParams = -32602, + + /// + /// Indicates that an internal error occurred while processing the request. + /// + /// + /// This error is used when the endpoint encounters an unexpected condition that prevents it from fulfilling the request. + /// + InternalError = -32603, +} diff --git a/src/ModelContextProtocol/McpException.cs b/src/ModelContextProtocol/McpException.cs index 5b210e815..3831dd688 100644 --- a/src/ModelContextProtocol/McpException.cs +++ b/src/ModelContextProtocol/McpException.cs @@ -1,8 +1,15 @@ namespace ModelContextProtocol; /// -/// Represents an exception that is thrown when a Model Context Protocol (MCP) error occurs. +/// Represents an exception that is thrown when an Model Context Protocol (MCP) error occurs. /// +/// +/// This exception is used to represent failures to do with protocol-level concerns, such as invalid JSON-RPC requests, +/// invalid parameters, or internal errors. It is not intended to be used for application-level errors. +/// or from a may be +/// propagated to the remote endpoint; sensitive information should not be included. If sensitive details need +/// to be included, a different exception type should be used. +/// public class McpException : Exception { /// @@ -33,8 +40,8 @@ public McpException(string message, Exception? innerException) : base(message, i /// Initializes a new instance of the class with a specified error message and JSON-RPC error code. /// /// The message that describes the error. - /// A JSON-RPC error code from class. - public McpException(string message, int? errorCode) : this(message, null, errorCode) + /// A . + public McpException(string message, McpErrorCode errorCode) : this(message, null, errorCode) { } @@ -43,18 +50,15 @@ public McpException(string message, int? errorCode) : this(message, null, errorC /// /// The message that describes the error. /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. - /// A JSON-RPC error code from class. - public McpException(string message, Exception? innerException, int? errorCode) : base(message, innerException) + /// A . + public McpException(string message, Exception? innerException, McpErrorCode errorCode) : base(message, innerException) { ErrorCode = errorCode; } /// - /// Gets the JSON-RPC error code associated with this exception. + /// Gets the error code associated with this exception. /// - /// - /// A standard JSON-RPC error code, or if the exception wasn't caused by a JSON-RPC error. - /// /// /// This property contains a standard JSON-RPC error code as defined in the MCP specification. Common error codes include: /// @@ -65,5 +69,5 @@ public McpException(string message, Exception? innerException, int? errorCode) : /// -32603: Internal error - Internal JSON-RPC error /// /// - public int? ErrorCode { get; } + public McpErrorCode ErrorCode { get; } = McpErrorCode.InternalError; } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Messages/ErrorCodes.cs b/src/ModelContextProtocol/Protocol/Messages/ErrorCodes.cs deleted file mode 100644 index f797ded92..000000000 --- a/src/ModelContextProtocol/Protocol/Messages/ErrorCodes.cs +++ /dev/null @@ -1,32 +0,0 @@ -namespace ModelContextProtocol.Protocol.Messages; - -/// -/// Standard JSON-RPC error codes as defined in the MCP specification. -/// -internal static class ErrorCodes -{ - /// - /// Invalid JSON was received by the server. - /// - public const int ParseError = -32700; - - /// - /// The JSON sent is not a valid Request object. - /// - public const int InvalidRequest = -32600; - - /// - /// The method does not exist / is not available. - /// - public const int MethodNotFound = -32601; - - /// - /// Invalid method parameter(s). - /// - public const int InvalidParams = -32602; - - /// - /// Internal JSON-RPC error. - /// - public const int InternalError = -32603; -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs index d3b48b6c0..b1671b906 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs @@ -101,7 +101,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( BindParameter = (pi, args) => GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), + throw new InvalidOperationException("No service of the requested type was found.")), }; } @@ -113,7 +113,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( BindParameter = (pi, args) => (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), + throw new InvalidOperationException("No service of the requested type was found.")), }; } diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 76f0d6e5a..562a8476c 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -122,7 +122,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( BindParameter = (pi, args) => GetRequestContext(args)?.Services?.GetService(pi.ParameterType) ?? (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), + throw new InvalidOperationException("No service of the requested type was found.")), }; } @@ -134,7 +134,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( BindParameter = (pi, args) => (GetRequestContext(args)?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? (pi.HasDefaultValue ? null : - throw new ArgumentException("No service of the requested type was found.")), + throw new InvalidOperationException("No service of the requested type was found.")), }; } @@ -265,10 +265,14 @@ public override async ValueTask InvokeAsync( } catch (Exception e) when (e is not OperationCanceledException) { - return new CallToolResponse() + string errorMessage = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; + + return new() { IsError = true, - Content = [new() { Text = $"An error occurred invoking '{request.Params?.Name}'.", Type = "text" }], + Content = [new() { Text = errorMessage, Type = "text" }], }; } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 079bcb9b1..2f7b59a90 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -196,7 +196,9 @@ private void SetCompletionHandler(McpServerOptions options) } var completeHandler = completionsCapability.CompleteHandler ?? - throw new McpException("Completions capability was enabled, but Complete handler was not specified."); + throw new InvalidOperationException( + $"{nameof(ServerCapabilities)}.{nameof(ServerCapabilities.Completions)} was enabled, " + + $"but {nameof(CompletionsCapability.CompleteHandler)} was not specified."); // This capability is not optional, so return an empty result if there is no handler. RequestHandlers.Set( @@ -219,7 +221,9 @@ private void SetResourcesHandler(McpServerOptions options) if ((listResourcesHandler is not { } && listResourceTemplatesHandler is not { }) || resourcesCapability.ReadResourceHandler is not { } readResourceHandler) { - throw new McpException("Resources capability was enabled, but ListResources and/or ReadResource handlers were not specified."); + throw new InvalidOperationException( + $"{nameof(ServerCapabilities)}.{nameof(ServerCapabilities.Resources)} was enabled, " + + $"but {nameof(ResourcesCapability.ListResourcesHandler)} or {nameof(ResourcesCapability.ReadResourceHandler)} was not specified."); } listResourcesHandler ??= static async (_, _) => new ListResourcesResult(); @@ -252,7 +256,9 @@ private void SetResourcesHandler(McpServerOptions options) var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler; if (subscribeHandler is null || unsubscribeHandler is null) { - throw new McpException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified."); + throw new InvalidOperationException( + $"{nameof(ServerCapabilities)}.{nameof(ServerCapabilities.Resources)}.{nameof(ResourcesCapability.Subscribe)} is set, " + + $"but {nameof(ResourcesCapability.SubscribeToResourcesHandler)} or {nameof(ResourcesCapability.UnsubscribeFromResourcesHandler)} was not specified."); } RequestHandlers.Set( @@ -277,7 +283,10 @@ private void SetPromptsHandler(McpServerOptions options) if (listPromptsHandler is null != getPromptHandler is null) { - throw new McpException("ListPrompts and GetPrompt handlers should be specified together."); + throw new InvalidOperationException( + $"{nameof(PromptsCapability)}.{nameof(promptsCapability.ListPromptsHandler)} or " + + $"{nameof(PromptsCapability)}.{nameof(promptsCapability.GetPromptHandler)} was specified without the other. " + + $"Both or neither must be provided."); } // Handle prompts provided via DI. @@ -310,7 +319,7 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals return originalGetPromptHandler(request, cancellationToken); } - throw new McpException($"Unknown prompt '{request.Params?.Name}'"); + throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams); } return prompt.GetAsync(request, cancellationToken); @@ -344,7 +353,9 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals // Make sure the handlers are provided if the capability is enabled. if (listPromptsHandler is null || getPromptHandler is null) { - throw new McpException("ListPrompts and/or GetPrompt handlers were not specified but the Prompts capability was enabled."); + throw new InvalidOperationException( + $"{nameof(ServerCapabilities)}.{nameof(ServerCapabilities.Prompts)} was enabled, " + + $"but {nameof(PromptsCapability.ListPromptsHandler)} or {nameof(PromptsCapability.GetPromptHandler)} was not specified."); } } @@ -370,7 +381,10 @@ private void SetToolsHandler(McpServerOptions options) if (listToolsHandler is null != callToolHandler is null) { - throw new McpException("ListTools and CallTool handlers should be specified together."); + throw new InvalidOperationException( + $"{nameof(ToolsCapability)}.{nameof(ToolsCapability.ListToolsHandler)} or " + + $"{nameof(ToolsCapability)}.{nameof(ToolsCapability.CallToolHandler)} was specified without the other. " + + $"Both or neither must be provided."); } // Handle tools provided via DI. @@ -403,7 +417,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) return originalCallToolHandler(request, cancellationToken); } - throw new McpException($"Unknown tool '{request.Params?.Name}'"); + throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams); } return tool.InvokeAsync(request, cancellationToken); @@ -437,7 +451,9 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) // Make sure the handlers are provided if the capability is enabled. if (listToolsHandler is null || callToolHandler is null) { - throw new McpException("ListTools and/or CallTool handlers were not specified but the Tools capability was enabled."); + throw new InvalidOperationException( + $"{nameof(ServerCapabilities)}.{nameof(ServerCapabilities.Tools)} was enabled, " + + $"but {nameof(ToolsCapability.ListToolsHandler)} or {nameof(ToolsCapability.CallToolHandler)} was not specified."); } } diff --git a/src/ModelContextProtocol/Server/McpServerExtensions.cs b/src/ModelContextProtocol/Server/McpServerExtensions.cs index aa103c357..1b4643f47 100644 --- a/src/ModelContextProtocol/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol/Server/McpServerExtensions.cs @@ -36,7 +36,7 @@ public static Task RequestSamplingAsync( if (server.ClientCapabilities?.Sampling is null) { - throw new ArgumentException("Client connected to the server does not support sampling.", nameof(server)); + throw new InvalidOperationException("Client does not support sampling."); } return server.SendRequestAsync( @@ -166,7 +166,7 @@ public static IChatClient AsSamplingChatClient(this IMcpServer server) if (server.ClientCapabilities?.Sampling is null) { - throw new ArgumentException("Client connected to the server does not support sampling.", nameof(server)); + throw new InvalidOperationException("Client does not support sampling."); } return new SamplingChatClient(server); @@ -204,7 +204,7 @@ public static Task RequestRootsAsync( if (server.ClientCapabilities?.Roots is null) { - throw new ArgumentException("Client connected to the server does not support roots.", nameof(server)); + throw new InvalidOperationException("Client does not support roots."); } return server.SendRequestAsync( diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index a15618a2f..1b51fc8fd 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -142,15 +142,24 @@ ex is OperationCanceledException && if (!isUserCancellation && message is JsonRpcRequest request) { LogRequestHandlerException(EndpointName, request.Method, ex); + + JsonRpcErrorDetail detail = ex is McpException mcpe ? + new() + { + Code = (int)mcpe.ErrorCode, + Message = mcpe.Message, + } : + new() + { + Code = (int)McpErrorCode.InternalError, + Message = "An error occurred.", + }; + await _transport.SendMessageAsync(new JsonRpcError { Id = request.Id, JsonRpc = "2.0", - Error = new JsonRpcErrorDetail - { - Code = (ex as McpException)?.ErrorCode ?? ErrorCodes.InternalError, - Message = ex.Message - } + Error = detail, }, cancellationToken).ConfigureAwait(false); } else if (ex is not OperationCanceledException) @@ -287,7 +296,7 @@ private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId if (!_requestHandlers.TryGetValue(request.Method, out var handler)) { LogNoHandlerFoundForRequest(EndpointName, request.Method); - throw new McpException("The method does not exist or is not available.", ErrorCodes.MethodNotFound); + throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); } LogRequestHandlerCalled(EndpointName, request.Method); @@ -342,7 +351,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc { if (!_transport.IsConnected) { - throw new McpException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected"); } cancellationToken.ThrowIfCancellationRequested(); @@ -399,7 +408,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc if (response is JsonRpcError error) { LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); - throw new McpException($"Request failed (remote): {error.Error.Message}", error.Error.Code); + throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); } if (response is JsonRpcResponse success) @@ -443,7 +452,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca if (!_transport.IsConnected) { - throw new McpException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected"); } cancellationToken.ThrowIfCancellationRequested(); @@ -591,13 +600,16 @@ private static void AddExceptionTags(ref TagList tags, Activity? activity, Excep e = ae.InnerException; } - int? intErrorCode = (e as McpException)?.ErrorCode is int errorCode ? errorCode : - e is JsonException ? ErrorCodes.ParseError : null; + int? intErrorCode = + (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : + e is JsonException ? (int)McpErrorCode.ParseError : + null; - tags.Add("error.type", intErrorCode == null ? e.GetType().FullName : intErrorCode.ToString()); + string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; + tags.Add("error.type", errorType); if (intErrorCode is not null) { - tags.Add("rpc.jsonrpc.error_code", intErrorCode.ToString()); + tags.Add("rpc.jsonrpc.error_code", errorType); } if (activity is { IsAllDataRequested: true }) diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 4fe963d26..6be6025c8 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -166,7 +166,7 @@ private static ToolsCapability ConfigureTools() { if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) { - throw new McpException("Missing required argument 'message'"); + throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); } return new CallToolResponse() { @@ -179,7 +179,7 @@ private static ToolsCapability ConfigureTools() !request.Params.Arguments.TryGetValue("prompt", out var prompt) || !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) { - throw new McpException("Missing required arguments 'prompt' and 'maxTokens'"); + throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.GetRawText())), cancellationToken); @@ -191,7 +191,7 @@ private static ToolsCapability ConfigureTools() } else { - throw new McpException($"Unknown tool: {request.Params?.Name}"); + throw new McpException($"Unknown tool: {request.Params?.Name}", McpErrorCode.InvalidParams); } } }; @@ -285,7 +285,7 @@ private static PromptsCapability ConfigurePrompts() } else { - throw new McpException($"Unknown prompt: {request.Params?.Name}"); + throw new McpException($"Unknown prompt: {request.Params?.Name}", McpErrorCode.InvalidParams); } return new GetPromptResult() @@ -306,7 +306,7 @@ private static LoggingCapability ConfigureLogging() { if (request.Params?.Level is null) { - throw new McpException("Missing required argument 'level'"); + throw new McpException("Missing required argument 'level'", McpErrorCode.InvalidParams); } _minimumLoggingLevel = request.Params.Level; @@ -388,7 +388,7 @@ private static ResourcesCapability ConfigureResources() } catch (Exception e) { - throw new McpException("Invalid cursor.", e); + throw new McpException($"Invalid cursor: '{request.Params.Cursor}'", e, McpErrorCode.InvalidParams); } } @@ -410,7 +410,7 @@ private static ResourcesCapability ConfigureResources() { if (request.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'"); + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (request.Params.Uri.StartsWith("test://dynamic/resource/")) @@ -418,7 +418,7 @@ private static ResourcesCapability ConfigureResources() var id = request.Params.Uri.Split('/').LastOrDefault(); if (string.IsNullOrEmpty(id)) { - throw new McpException("Invalid resource URI"); + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } return new ReadResourceResult() @@ -435,7 +435,7 @@ private static ResourcesCapability ConfigureResources() } ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) - ?? throw new McpException("Resource not found"); + ?? throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); return new ReadResourceResult() { @@ -447,12 +447,12 @@ private static ResourcesCapability ConfigureResources() { if (request?.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'"); + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (!request.Params.Uri.StartsWith("test://static/resource/") && !request.Params.Uri.StartsWith("test://dynamic/resource/")) { - throw new McpException("Invalid resource URI"); + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } _subscribedResources.TryAdd(request.Params.Uri, true); @@ -464,12 +464,12 @@ private static ResourcesCapability ConfigureResources() { if (request?.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'"); + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (!request.Params.Uri.StartsWith("test://static/resource/") && !request.Params.Uri.StartsWith("test://dynamic/resource/")) { - throw new McpException("Invalid resource URI"); + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } _subscribedResources.TryRemove(request.Params.Uri, out _); @@ -514,7 +514,7 @@ private static CompletionsCapability ConfigureCompletions() return new CompleteResult() { Completion = new() { Values = values, HasMore = false, Total = values.Length } }; } - throw new McpException($"Unknown reference type: {request.Params?.Ref.Type}"); + throw new McpException($"Unknown reference type: '{request.Params?.Ref.Type}'", McpErrorCode.InvalidParams); }; return new() { CompleteHandler = handler }; diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index f6ab20a7b..72a271cf9 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -157,13 +157,13 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { if (request.Params is null) { - throw new McpException("Missing required parameter 'name'"); + throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); } if (request.Params.Name == "echo") { if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) { - throw new McpException("Missing required argument 'message'"); + throw new McpException("Missing required argument 'message'", McpErrorCode.InvalidParams); } return new CallToolResponse() { @@ -176,7 +176,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st !request.Params.Arguments.TryGetValue("prompt", out var prompt) || !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) { - throw new McpException("Missing required arguments 'prompt' and 'maxTokens'"); + throw new McpException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); } var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt.ToString(), "sampleLLM", Convert.ToInt32(maxTokens.ToString())), cancellationToken); @@ -188,7 +188,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } else { - throw new McpException($"Unknown tool: {request.Params.Name}"); + throw new McpException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); } } }, @@ -220,9 +220,9 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(requestParams.Cursor)); startIndex = Convert.ToInt32(startIndexAsString); } - catch + catch (Exception e) { - throw new McpException("Invalid cursor"); + throw new McpException($"Invalid cursor: '{requestParams.Cursor}'", e, McpErrorCode.InvalidParams); } } @@ -244,7 +244,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { if (request.Params?.Uri is null) { - throw new McpException("Missing required argument 'uri'"); + throw new McpException("Missing required argument 'uri'", McpErrorCode.InvalidParams); } if (request.Params.Uri.StartsWith("test://dynamic/resource/")) @@ -252,7 +252,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st var id = request.Params.Uri.Split('/').LastOrDefault(); if (string.IsNullOrEmpty(id)) { - throw new McpException("Invalid resource URI"); + throw new McpException($"Invalid resource URI: '{request.Params.Uri}'", McpErrorCode.InvalidParams); } return new ReadResourceResult() @@ -269,7 +269,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } ResourceContents? contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) ?? - throw new McpException("Resource not found"); + throw new McpException($"Resource not found: '{request.Params.Uri}'", McpErrorCode.InvalidParams); return new ReadResourceResult() { @@ -316,7 +316,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st { if (request.Params is null) { - throw new McpException("Missing required parameter 'name'"); + throw new McpException("Missing required parameter 'name'", McpErrorCode.InvalidParams); } List messages = new(); if (request.Params.Name == "simple_prompt") @@ -366,7 +366,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } else { - throw new McpException($"Unknown prompt: {request.Params.Name}"); + throw new McpException($"Unknown prompt: {request.Params.Name}", McpErrorCode.InvalidParams); } return new GetPromptResult() diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 1bede93cf..d0cfd0663 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -31,10 +31,10 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer { NextCursor = "abc", Prompts = [new() - { - Name = "FirstCustomPrompt", - Description = "First prompt returned by custom handler", - }], + { + Name = "FirstCustomPrompt", + Description = "First prompt returned by custom handler", + }], }; case "abc": @@ -42,10 +42,10 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer { NextCursor = "def", Prompts = [new() - { - Name = "SecondCustomPrompt", - Description = "Second prompt returned by custom handler", - }], + { + Name = "SecondCustomPrompt", + Description = "Second prompt returned by custom handler", + }], }; case "def": @@ -53,14 +53,14 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer { NextCursor = null, Prompts = [new() - { - Name = "FinalCustomPrompt", - Description = "Final prompt returned by custom handler", - }], + { + Name = "FinalCustomPrompt", + Description = "Final prompt returned by custom handler", + }], }; default: - throw new Exception("Unexpected cursor"); + throw new McpException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); } }) .WithGetPromptHandler(async (request, cancellationToken) => @@ -76,7 +76,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new Exception($"Unknown prompt '{request.Params?.Name}'"); + throw new McpException($"Unknown prompt '{request.Params?.Name}'", McpErrorCode.InvalidParams); } }) .WithPrompts(); @@ -194,7 +194,7 @@ public async Task Throws_Exception_Missing_Parameter() nameof(SimplePrompts.ReturnsChatMessages), cancellationToken: TestContext.Current.CancellationToken)); - Assert.Contains("Missing required parameter", e.Message); + Assert.Equal(McpErrorCode.InternalError, e.ErrorCode); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 16a69cdf0..e0043ff05 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -88,7 +88,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new Exception("Unexpected cursor"); + throw new McpException($"Unexpected cursor: '{cursor}'", McpErrorCode.InvalidParams); } }) .WithCallToolHandler(async (request, cancellationToken) => @@ -104,7 +104,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer }; default: - throw new Exception($"Unknown tool '{request.Params?.Name}'"); + throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams); } }) .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 1d1689140..b6355f798 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -108,7 +108,7 @@ await RunConnected(async (client, server) => a.Kind == ActivityKind.Client); Assert.Equal(ActivityStatusCode.Error, doesNotExistToolClient.Status); - Assert.Equal("-32603", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); + Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); var doesNotExistToolServer = Assert.Single(activities, a => a.Tags.Any(t => t.Key == "mcp.tool.name" && t.Value == "does-not-exist") && @@ -117,7 +117,7 @@ await RunConnected(async (client, server) => a.Kind == ActivityKind.Server); Assert.Equal(ActivityStatusCode.Error, doesNotExistToolServer.Status); - Assert.Equal("-32603", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); + Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } private static async Task RunConnected(Func action) diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index e3ffd9d90..2e251d033 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -59,7 +59,7 @@ public async Task SupportsServiceFromDI() Assert.Contains("something", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); - await Assert.ThrowsAsync(async () => await prompt.GetAsync( + await Assert.ThrowsAsync(async () => await prompt.GetAsync( new RequestContext(new Mock().Object), TestContext.Current.CancellationToken)); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 6182ddc07..8cf4a3c9d 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -95,6 +95,6 @@ public void CreatingReadHandlerWithNoListHandlerFails() }; }); var sp = services.BuildServiceProvider(); - Assert.Throws(() => sp.GetRequiredService()); + Assert.Throws(sp.GetRequiredService); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 34ac9c8a0..d7ccb6e15 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -94,7 +94,7 @@ public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Run } [Fact] - public async Task RequestSamplingAsync_Should_Throw_McpException_If_Client_Does_Not_Support_Sampling() + public async Task RequestSamplingAsync_Should_Throw_Exception_If_Client_Does_Not_Support_Sampling() { // Arrange await using var transport = new TestServerTransport(); @@ -104,7 +104,7 @@ public async Task RequestSamplingAsync_Should_Throw_McpException_If_Client_Does_ var action = () => server.RequestSamplingAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); // Act & Assert - await Assert.ThrowsAsync("server", action); + await Assert.ThrowsAsync(action); } [Fact] @@ -130,7 +130,7 @@ public async Task RequestSamplingAsync_Should_SendRequest() } [Fact] - public async Task RequestRootsAsync_Should_Throw_McpException_If_Client_Does_Not_Support_Roots() + public async Task RequestRootsAsync_Should_Throw_Exception_If_Client_Does_Not_Support_Roots() { // Arrange await using var transport = new TestServerTransport(); @@ -138,7 +138,7 @@ public async Task RequestRootsAsync_Should_Throw_McpException_If_Client_Does_Not SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert - await Assert.ThrowsAsync("server", () => server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None)); + await Assert.ThrowsAsync(() => server.RequestRootsAsync(new ListRootsRequestParams(), CancellationToken.None)); } [Fact] @@ -507,7 +507,7 @@ private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities se await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory)); + Assert.Throws(() => McpServerFactory.Create(transport, options, LoggerFactory)); } [Fact] @@ -515,7 +515,7 @@ public async Task AsSamplingChatClient_NoSamplingSupport_Throws() { await using var server = new TestServerForIChatClient(supportsSampling: false); - Assert.Throws("server", () => server.AsSamplingChatClient()); + Assert.Throws(server.AsSamplingChatClient); } [Fact] From 997ebbc41105d52bc4553f36065ec74e271eea05 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 16 Apr 2025 18:04:22 -0400 Subject: [PATCH 2/3] Replace McpTransportException with InvalidOperationException --- .../Protocol/Transport/IClientTransport.cs | 2 +- .../Transport/McpTransportException.cs | 34 --------------- .../Transport/SseClientSessionTransport.cs | 42 ++++++------------- .../Transport/StdioClientSessionTransport.cs | 6 +-- .../Transport/StdioClientTransport.cs | 4 +- .../Transport/StreamClientSessionTransport.cs | 21 +--------- .../Transport/StreamServerTransport.cs | 4 +- .../Protocol/Transport/TransportBase.cs | 2 +- 8 files changed, 24 insertions(+), 91 deletions(-) delete mode 100644 src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs diff --git a/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs index b0911c8f9..ab608ef94 100644 --- a/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/IClientTransport.cs @@ -42,6 +42,6 @@ public interface IClientTransport /// This method is used by to initialize the connection. /// /// - /// The transport connection could not be established. + /// The transport connection could not be established. Task ConnectAsync(CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs b/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs deleted file mode 100644 index 647b908d4..000000000 --- a/src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs +++ /dev/null @@ -1,34 +0,0 @@ -namespace ModelContextProtocol.Protocol.Transport; - -/// -/// Represents errors that occur in MCP transport operations. -/// -public class McpTransportException : Exception -{ - /// - /// Initializes a new instance of the class. - /// - public McpTransportException() - { - } - - /// - /// Initializes a new instance of the class with a specified error message. - /// - /// The message that describes the error. - public McpTransportException(string message) - : base(message) - { - } - - /// - /// Initializes a new instance of the class with a specified error message - /// and a reference to the inner exception that is the cause of this exception. - /// - /// The message that describes the error. - /// The exception that is the cause of the current exception. - public McpTransportException(string message, Exception? innerException) - : base(message, innerException) - { - } -} diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 921867a20..e2c7cc5cc 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -58,11 +58,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) await _connectionEstablished.Task.WaitAsync(_options.ConnectionTimeout, cancellationToken).ConfigureAwait(false); } - catch (Exception ex) when (ex is not McpTransportException) // propagate transport exceptions + catch (Exception ex) { LogTransportConnectFailed(Name, ex); await CloseAsync().ConfigureAwait(false); - throw new McpTransportException("Failed to connect transport", ex); + throw new InvalidOperationException("Failed to connect transport", ex); } } @@ -110,7 +110,7 @@ public override async Task SendMessageAsync( else { JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ?? - throw new McpTransportException("Failed to initialize client"); + throw new InvalidOperationException("Failed to initialize client"); LogTransportReceivedMessage(Name, messageId); await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false); @@ -136,7 +136,7 @@ public override async Task SendMessageAsync( LogRejectedPost(Name, messageId); } - throw new McpTransportException("Failed to send message"); + throw new InvalidOperationException("Failed to send message"); } } @@ -273,34 +273,18 @@ private async Task ProcessSseMessage(string data, CancellationToken cancellation private void HandleEndpointEvent(string data) { - try + if (string.IsNullOrEmpty(data)) { - if (string.IsNullOrEmpty(data)) - { - LogTransportEndpointEventInvalid(Name); - return; - } - - // If data is an absolute URL, the Uri will be constructed entirely from it and not the _sseEndpoint. - _messageEndpoint = new Uri(_sseEndpoint, data); - - // Set connected state - SetConnected(true); - _connectionEstablished.TrySetResult(true); + LogTransportEndpointEventInvalid(Name); + return; } - catch (JsonException ex) - { - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogTransportEndpointEventParseFailedSensitive(Name, data, ex); - } - else - { - LogTransportEndpointEventParseFailed(Name, ex); - } - throw new McpTransportException("Failed to parse endpoint event", ex); - } + // If data is an absolute URL, the Uri will be constructed entirely from it and not the _sseEndpoint. + _messageEndpoint = new Uri(_sseEndpoint, data); + + // Set connected state + SetConnected(true); + _connectionEstablished.TrySetResult(true); } private void CopyAdditionalHeaders(HttpRequestHeaders headers) diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs index 0f4b80dd4..aa92c9d4b 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs @@ -23,14 +23,14 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process /// /// For stdio-based transports, this implementation first verifies that the underlying process /// is still running before attempting to send the message. If the process has exited or cannot - /// be accessed, a is thrown with details about the failure. + /// be accessed, a is thrown with details about the failure. /// /// /// After verifying the process state, this method delegates to the base class implementation /// to handle the actual message serialization and transmission to the process's standard input stream. /// /// - /// + /// /// Thrown when the underlying process has exited or cannot be accessed. /// public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) @@ -49,7 +49,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio if (hasExited) { - throw new McpTransportException("Transport is not connected", processException); + throw new InvalidOperationException("Transport is not connected", processException); } await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index fff3786a6..3fb00f042 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -154,7 +154,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = if (!processStarted) { LogTransportProcessStartFailed(logger, endpointName); - throw new McpTransportException("Failed to start MCP server process"); + throw new InvalidOperationException("Failed to start MCP server process"); } LogTransportProcessStarted(logger, endpointName, process.Id); @@ -176,7 +176,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = LogTransportShutdownFailed(logger, endpointName, ex2); } - throw new McpTransportException("Failed to connect transport", ex); + throw new InvalidOperationException("Failed to connect transport", ex); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs index 6c6ebcd29..3c24416bb 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamClientSessionTransport.cs @@ -57,28 +57,11 @@ public StreamClientSessionTransport( } /// - /// - /// - /// For stream-based transports, this implementation serializes the JSON-RPC message to the - /// underlying output stream. The specific serialization format includes: - /// - /// A Content-Length header that specifies the byte length of the JSON message - /// A blank line separator - /// The UTF-8 encoded JSON representation of the message - /// - /// - /// - /// This implementation first checks if the transport is connected and throws a - /// if it's not. It then extracts the message ID (if present) for logging purposes, serializes the message, - /// and writes it to the output stream. - /// - /// - /// Thrown when the transport is not connected. public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { if (!IsConnected) { - throw new McpTransportException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected"); } string id = "(no id)"; @@ -99,7 +82,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio catch (Exception ex) { LogTransportSendFailed(Name, id, ex); - throw new McpTransportException("Failed to send message", ex); + throw new InvalidOperationException("Failed to send message", ex); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs index 516eb4041..c94509955 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StreamServerTransport.cs @@ -60,7 +60,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio { if (!IsConnected) { - throw new McpTransportException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected"); } using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false); @@ -80,7 +80,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio catch (Exception ex) { LogTransportSendFailed(Name, id, ex); - throw new McpTransportException("Failed to send message", ex); + throw new InvalidOperationException("Failed to send message", ex); } } diff --git a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs index 9e253072f..0d496d0a6 100644 --- a/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs +++ b/src/ModelContextProtocol/Protocol/Transport/TransportBase.cs @@ -73,7 +73,7 @@ protected async Task WriteMessageAsync(IJsonRpcMessage message, CancellationToke { if (!IsConnected) { - throw new McpTransportException("Transport is not connected"); + throw new InvalidOperationException("Transport is not connected"); } await _messageChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false); From 3b7bbdd7fbc09a2acda955633850fcc55a6f1b35 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 17 Apr 2025 14:04:29 -0400 Subject: [PATCH 3/3] Add cancellation test --- .../Client/McpClientFactoryTests.cs | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index daee177c1..454222e6b 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils.Json; using Moq; +using System.IO.Pipelines; using System.Text.Json; using System.Threading.Channels; @@ -29,6 +30,31 @@ public async Task CreateAsync_NopTransport_ReturnsClient() Assert.NotNull(client); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task Cancellation_ThrowsCancellationException(bool preCanceled) + { + var cts = new CancellationTokenSource(); + + if (preCanceled) + { + cts.Cancel(); + } + + Task t = McpClientFactory.CreateAsync( + new StreamClientTransport(new Pipe().Writer.AsStream(), new Pipe().Reader.AsStream()), + cancellationToken: cts.Token); + Assert.False(t.IsCompleted); + + if (!preCanceled) + { + cts.Cancel(); + } + + await Assert.ThrowsAnyAsync(() => t); + } + [Theory] [InlineData(typeof(NopTransport))] [InlineData(typeof(FailureTransport))]