diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index 47b8514de..03d01ee42 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -1,7 +1,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; -using ModelContextProtocol.Shared; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Diagnostics.CodeAnalysis; diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index b95483196..8aa32965e 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; @@ -214,41 +213,26 @@ private void SetPromptsHandler(McpServerOptions options) throw new McpServerException("ListPrompts and GetPrompt handlers should be specified together."); } - // Handle tools provided via DI. + // Handle prompts provided via DI. if (prompts is { IsEmpty: false }) { + // Synthesize the handlers, making sure a PromptsCapability is specified. var originalListPromptsHandler = listPromptsHandler; - var originalGetPromptHandler = getPromptHandler; - - // Synthesize the handlers, making sure a ToolsCapability is specified. listPromptsHandler = async (request, cancellationToken) => { - ListPromptsResult result = new(); - foreach (McpServerPrompt prompt in prompts) - { - result.Prompts.Add(prompt.ProtocolPrompt); - } + ListPromptsResult result = originalListPromptsHandler is not null ? + await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); - if (originalListPromptsHandler is not null) + if (request.Params?.Cursor is null) { - string? nextCursor = null; - do - { - ListPromptsResult extraResults = await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false); - result.Prompts.AddRange(extraResults.Prompts); - - nextCursor = extraResults.NextCursor; - if (nextCursor is not null) - { - request = request with { Params = new() { Cursor = nextCursor } }; - } - } - while (nextCursor is not null); + result.Prompts.AddRange(prompts.Select(t => t.ProtocolPrompt)); } return result; }; + var originalGetPromptHandler = getPromptHandler; getPromptHandler = (request, cancellationToken) => { if (request.Params is null || @@ -316,38 +300,23 @@ private void SetToolsHandler(McpServerOptions options) // Handle tools provided via DI. if (tools is { IsEmpty: false }) { - var originalListToolsHandler = listToolsHandler; - var originalCallToolHandler = callToolHandler; - // Synthesize the handlers, making sure a ToolsCapability is specified. + var originalListToolsHandler = listToolsHandler; listToolsHandler = async (request, cancellationToken) => { - ListToolsResult result = new(); - foreach (McpServerTool tool in tools) - { - result.Tools.Add(tool.ProtocolTool); - } + ListToolsResult result = originalListToolsHandler is not null ? + await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); - if (originalListToolsHandler is not null) + if (request.Params?.Cursor is null) { - string? nextCursor = null; - do - { - ListToolsResult extraResults = await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false); - result.Tools.AddRange(extraResults.Tools); - - nextCursor = extraResults.NextCursor; - if (nextCursor is not null) - { - request = request with { Params = new() { Cursor = nextCursor } }; - } - } - while (nextCursor is not null); + result.Tools.AddRange(tools.Select(t => t.ProtocolTool)); } return result; }; + var originalCallToolHandler = callToolHandler; callToolHandler = (request, cancellationToken) => { if (request.Params is null || diff --git a/src/ModelContextProtocol/TokenProgress.cs b/src/ModelContextProtocol/TokenProgress.cs index 7cc97236a..46af03f4f 100644 --- a/src/ModelContextProtocol/TokenProgress.cs +++ b/src/ModelContextProtocol/TokenProgress.cs @@ -1,6 +1,5 @@ using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Server; -using ModelContextProtocol.Shared; namespace ModelContextProtocol; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 3a2f2ab77..7daea2074 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -2,7 +2,6 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Transport; using ModelContextProtocol.Tests.Utils; using System.IO.Pipelines; diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 53f7d05c3..c19469afc 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -6,12 +6,13 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Transport; using ModelContextProtocol.Tests.Utils; using System.ComponentModel; using System.IO.Pipelines; using System.Threading.Channels; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.Tests.Configuration; public class McpServerBuilderExtensionsPromptsTests : LoggedTest, IAsyncDisposable @@ -28,7 +29,70 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper { ServiceCollection sc = new(); sc.AddSingleton(LoggerFactory); - _builder = sc.AddMcpServer().WithStdioServerTransport().WithPrompts(); + _builder = sc + .AddMcpServer() + .WithStdioServerTransport() + .WithListPromptsHandler(async (request, cancellationToken) => + { + var cursor = request.Params?.Cursor; + switch (cursor) + { + case null: + return new() + { + NextCursor = "abc", + Prompts = [new() + { + Name = "FirstCustomPrompt", + Description = "First prompt returned by custom handler", + }], + }; + + case "abc": + return new() + { + NextCursor = "def", + Prompts = [new() + { + Name = "SecondCustomPrompt", + Description = "Second prompt returned by custom handler", + }], + }; + + case "def": + return new() + { + NextCursor = null, + Prompts = [new() + { + Name = "FinalCustomPrompt", + Description = "Final prompt returned by custom handler", + }], + }; + + default: + throw new Exception("Unexpected cursor"); + } + }) + .WithGetPromptHandler(async (request, cancellationToken) => + { + switch (request.Params?.Name) + { + case "FirstCustomPrompt": + case "SecondCustomPrompt": + case "FinalCustomPrompt": + return new GetPromptResult() + { + Messages = [new() { Role = Role.User, Content = new() { Text = $"hello from {request.Params.Name}", Type = "text" } }], + }; + + default: + throw new Exception($"Unknown prompt '{request.Params?.Name}'"); + } + }) + .WithPrompts(); + + // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); @@ -85,7 +149,7 @@ public async Task Can_List_And_Call_Registered_Prompts() IMcpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(3, prompts.Count); + Assert.Equal(6, prompts.Count); var prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsChatMessages)); Assert.Equal("Returns chat messages", prompt.Description); @@ -98,6 +162,14 @@ public async Task Can_List_And_Call_Registered_Prompts() Assert.Equal(2, chatMessages.Count); Assert.Equal("The prompt is: hello", chatMessages[0].Text); Assert.Equal("Summarize.", chatMessages[1].Text); + + prompt = prompts.First(t => t.Name == "SecondCustomPrompt"); + Assert.Equal("Second prompt returned by custom handler", prompt.Description); + result = await prompt.GetAsync(cancellationToken: TestContext.Current.CancellationToken); + chatMessages = result.ToChatMessages(); + Assert.NotNull(chatMessages); + Assert.Single(chatMessages); + Assert.Equal("hello from SecondCustomPrompt", chatMessages[0].Text); } [Fact] @@ -106,7 +178,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() IMcpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(3, prompts.Count); + Assert.Equal(6, prompts.Count); Channel listChanged = Channel.CreateUnbounded(); client.AddNotificationHandler("notifications/prompts/list_changed", notification => @@ -127,7 +199,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() await notificationRead; prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(4, prompts.Count); + Assert.Equal(7, prompts.Count); Assert.Contains(prompts, t => t.Name == "NewPrompt"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -136,7 +208,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() await notificationRead; prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(3, prompts.Count); + Assert.Equal(6, prompts.Count); Assert.DoesNotContain(prompts, t => t.Name == "NewPrompt"); } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index c615e12f1..2b160c52e 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -7,7 +7,6 @@ using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Transport; using ModelContextProtocol.Tests.Utils; using System.Collections.Concurrent; using System.ComponentModel; @@ -16,6 +15,8 @@ using System.Text.RegularExpressions; using System.Threading.Channels; +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + namespace ModelContextProtocol.Tests.Configuration; public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable @@ -32,7 +33,90 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { ServiceCollection sc = new(); sc.AddSingleton(LoggerFactory); - _builder = sc.AddMcpServer().WithStdioServerTransport().WithTools(); + _builder = sc + .AddMcpServer() + .WithStdioServerTransport() + .WithListToolsHandler(async (request, cancellationToken) => + { + var cursor = request.Params?.Cursor; + switch (cursor) + { + case null: + return new() + { + NextCursor = "abc", + Tools = [new() + { + Name = "FirstCustomTool", + Description = "First tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], + }; + + case "abc": + return new() + { + NextCursor = "def", + Tools = [new() + { + Name = "SecondCustomTool", + Description = "Second tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], + }; + + case "def": + return new() + { + NextCursor = null, + Tools = [new() + { + Name = "FinalCustomTool", + Description = "Third tool returned by custom handler", + InputSchema = JsonSerializer.Deserialize(""" + { + "type": "object", + "properties": {}, + "required": [] + } + """), + }], + }; + + default: + throw new Exception("Unexpected cursor"); + } + }) + .WithCallToolHandler(async (request, cancellationToken) => + { + switch (request.Params?.Name) + { + case "FirstCustomTool": + case "SecondCustomTool": + case "FinalCustomTool": + return new CallToolResponse() + { + Content = [new Content() { Text = $"{request.Params.Name}Result", Type = "text" }], + }; + + default: + throw new Exception($"Unknown tool '{request.Params?.Name}'"); + } + }) + .WithTools(); + // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); @@ -89,7 +173,7 @@ public async Task Can_List_Registered_Tools() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + Assert.Equal(16, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -134,7 +218,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T cancellationToken: TestContext.Current.CancellationToken)) { var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + Assert.Equal(16, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -161,7 +245,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + Assert.Equal(16, tools.Count); Channel listChanged = Channel.CreateUnbounded(); client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification => @@ -182,7 +266,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(14, tools.Count); + Assert.Equal(17, tools.Count); Assert.Contains(tools, t => t.Name == "NewTool"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -191,7 +275,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + Assert.Equal(16, tools.Count); Assert.DoesNotContain(tools, t => t.Name == "NewTool"); } @@ -225,9 +309,16 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); - Assert.Equal("hello Peter", result.Content[0].Text); Assert.Equal("hello2 Peter", result.Content[1].Text); + + result = await client.CallToolAsync( + "SecondCustomTool", + cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Content); + Assert.NotEmpty(result.Content); + Assert.Equal("SecondCustomToolResult", result.Content[0].Text); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 65ed07dff..a3290049c 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 75aefc086..893ac793b 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -1,5 +1,4 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Test.Utils; using ModelContextProtocol.Tests.Utils;