From 80e63a75321d9f7b8b9f9ce6d2e789ba684f2ebf Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 7 Apr 2025 11:22:48 -0400 Subject: [PATCH] Use strong-typing of params in most remaining McpClientExtensions methods --- .../Client/McpClientExtensions.cs | 116 +++++++++--------- .../Protocol/Types/CallToolRequestParams.cs | 7 +- .../Protocol/Types/GetPromptRequestParams.cs | 11 +- .../Server/AIFunctionMcpServerPrompt.cs | 5 +- .../Program.cs | 4 +- .../Program.cs | 4 +- 6 files changed, 74 insertions(+), 73 deletions(-) diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index db3801d17..2850bd1d2 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -22,7 +22,7 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat Throw.IfNull(client); return client.SendRequestAsync( - RequestMethods.Ping, + RequestMethods.Ping, parameters: null, McpJsonUtilities.JsonContext.Default.Object!, McpJsonUtilities.JsonContext.Default.Object, @@ -51,9 +51,9 @@ public static async Task> ListToolsAsync( do { var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -95,9 +95,9 @@ public static async IAsyncEnumerable EnumerateToolsAsync( do { var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -127,9 +127,9 @@ public static async Task> ListPromptsAsync( do { var promptResults = await client.SendRequestAsync( - RequestMethods.PromptsList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.PromptsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, McpJsonUtilities.JsonContext.Default.ListPromptsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -166,8 +166,8 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( { var promptResults = await client.SendRequestAsync( RequestMethods.PromptsList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, McpJsonUtilities.JsonContext.Default.ListPromptsResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -202,12 +202,10 @@ public static Task GetPromptAsync( serializerOptions ??= McpJsonUtilities.DefaultOptions; serializerOptions.MakeReadOnly(); - var parametersTypeInfo = serializerOptions.GetTypeInfo>(); - return client.SendRequestAsync( RequestMethods.PromptsGet, - CreateParametersDictionary(name, arguments), - parametersTypeInfo, + new() { Name = name, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, McpJsonUtilities.JsonContext.Default.GetPromptResult, cancellationToken: cancellationToken); } @@ -229,9 +227,9 @@ public static async Task> ListResourceTemplatesAsync( do { var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -270,9 +268,9 @@ public static async IAsyncEnumerable EnumerateResourceTemplate do { var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -303,9 +301,9 @@ public static async Task> ListResourcesAsync( do { var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourcesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -344,9 +342,9 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( do { var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - CreateCursorDictionary(cursor)!, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourcesResult, cancellationToken: cancellationToken).ConfigureAwait(false); @@ -373,9 +371,9 @@ public static Task ReadResourceAsync( Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - RequestMethods.ResourcesRead, - new Dictionary { ["uri"] = uri }, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ResourcesRead, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, McpJsonUtilities.JsonContext.Default.ReadResourceResult, cancellationToken: cancellationToken); } @@ -400,13 +398,13 @@ public static Task GetCompletionAsync(this IMcpClient client, Re } return client.SendRequestAsync( - RequestMethods.CompletionComplete, - new Dictionary + RequestMethods.CompletionComplete, + new() { - ["ref"] = reference, - ["argument"] = new Argument { Name = argumentName, Value = argumentValue } + Ref = reference, + Argument = new Argument { Name = argumentName, Value = argumentValue } }, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, McpJsonUtilities.JsonContext.Default.CompleteResult, cancellationToken: cancellationToken); } @@ -423,9 +421,9 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - RequestMethods.ResourcesSubscribe, - new Dictionary { ["uri"] = uri }, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + RequestMethods.ResourcesSubscribe, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult, cancellationToken: cancellationToken); } @@ -443,8 +441,8 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u return client.SendRequestAsync( RequestMethods.ResourcesUnsubscribe, - new Dictionary { ["uri"] = uri }, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult, cancellationToken: cancellationToken); } @@ -470,12 +468,10 @@ public static Task CallToolAsync( serializerOptions ??= McpJsonUtilities.DefaultOptions; serializerOptions.MakeReadOnly(); - var parametersTypeInfo = serializerOptions.GetTypeInfo>(); - return client.SendRequestAsync( - RequestMethods.ToolsCall, - CreateParametersDictionary(toolName, arguments), - parametersTypeInfo, + RequestMethods.ToolsCall, + new() { Name = toolName, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, McpJsonUtilities.JsonContext.Default.CallToolResponse, cancellationToken: cancellationToken); } @@ -629,28 +625,28 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C return client.SendRequestAsync( RequestMethods.LoggingSetLevel, - new Dictionary { ["level"] = level }, - McpJsonUtilities.JsonContext.Default.DictionaryStringObject, + new() { Level = level }, + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult, cancellationToken: cancellationToken); } - private static Dictionary? CreateCursorDictionary(string? cursor) => - cursor != null ? new() { ["cursor"] = cursor } : null; - - private static Dictionary CreateParametersDictionary( - string nameParameter, IReadOnlyDictionary? arguments) + /// Convers a dictionary with values to a dictionary with values. + private static IReadOnlyDictionary? ToArgumentsDictionary( + IReadOnlyDictionary? arguments, JsonSerializerOptions options) { - Dictionary parameters = new() - { - ["name"] = nameParameter - }; + var typeInfo = options.GetTypeInfo(); - if (arguments != null) + Dictionary? result = null; + if (arguments is not null) { - parameters["arguments"] = arguments; + result = new(arguments.Count); + foreach (var kvp in arguments) + { + result.Add(kvp.Key, kvp.Value is JsonElement je ? je : JsonSerializer.SerializeToElement(kvp.Value, typeInfo)); + } } - return parameters; + return result; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs index 08c6e5770..d73f73feb 100644 --- a/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/CallToolRequestParams.cs @@ -1,4 +1,5 @@ using System.Text.Json; +using System.Text.Json.Serialization; namespace ModelContextProtocol.Protocol.Types; @@ -11,12 +12,12 @@ public class CallToolRequestParams : RequestParams /// /// Tool name. /// - [System.Text.Json.Serialization.JsonPropertyName("name")] + [JsonPropertyName("name")] public required string Name { get; init; } /// /// Optional arguments to pass to the tool. /// - [System.Text.Json.Serialization.JsonPropertyName("arguments")] - public Dictionary? Arguments { get; init; } + [JsonPropertyName("arguments")] + public IReadOnlyDictionary? Arguments { get; init; } } diff --git a/src/ModelContextProtocol/Protocol/Types/GetPromptRequestParams.cs b/src/ModelContextProtocol/Protocol/Types/GetPromptRequestParams.cs index 08207667b..902cc6bf7 100644 --- a/src/ModelContextProtocol/Protocol/Types/GetPromptRequestParams.cs +++ b/src/ModelContextProtocol/Protocol/Types/GetPromptRequestParams.cs @@ -1,4 +1,7 @@ -namespace ModelContextProtocol.Protocol.Types; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Types; /// /// Used by the client to get a prompt provided by the server. @@ -9,12 +12,12 @@ public class GetPromptRequestParams : RequestParams /// /// he name of the prompt or prompt template. /// - [System.Text.Json.Serialization.JsonPropertyName("name")] + [JsonPropertyName("name")] public required string Name { get; init; } /// /// Arguments to use for templating the prompt. /// - [System.Text.Json.Serialization.JsonPropertyName("arguments")] - public Dictionary? Arguments { get; init; } + [JsonPropertyName("arguments")] + public IReadOnlyDictionary? Arguments { get; init; } } diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs index cc2bae5d4..3e560dbfe 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; using System.Diagnostics.CodeAnalysis; using System.Reflection; using System.Text.Json; @@ -207,8 +208,8 @@ public override async Task GetAsync( cancellationToken.ThrowIfCancellationRequested(); // TODO: Once we shift to the real AIFunctionFactory, the request should be passed via AIFunctionArguments.Context. - Dictionary arguments = request.Params?.Arguments is IDictionary existingArgs ? - new(existingArgs) : + Dictionary arguments = request.Params?.Arguments is { } paramArgs ? + paramArgs.ToDictionary(entry => entry.Key, entry => entry.Value.AsObject()) : []; arguments[RequestContextKey] = request; diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index ba6cca21e..b728e263e 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -249,8 +249,8 @@ private static PromptsCapability ConfigurePrompts() } else if (request.Params?.Name == "complex_prompt") { - string temperature = request.Params.Arguments?["temperature"]?.ToString() ?? "unknown"; - string style = request.Params.Arguments?["style"]?.ToString() ?? "unknown"; + string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; + string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; messages.Add(new PromptMessage() { Role = Role.User, diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index bc3398999..0bae52e6a 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -328,8 +328,8 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st } else if (request.Params.Name == "complex_prompt") { - string temperature = request.Params.Arguments?["temperature"]?.ToString() ?? "unknown"; - string style = request.Params.Arguments?["style"]?.ToString() ?? "unknown"; + string temperature = request.Params.Arguments?["temperature"].ToString() ?? "unknown"; + string style = request.Params.Arguments?["style"].ToString() ?? "unknown"; messages.Add(new PromptMessage() { Role = Role.User,