diff --git a/src/ModelContextProtocol/AIContentExtensions.cs b/src/ModelContextProtocol/AIContentExtensions.cs index 6a3f17733..87f40daaa 100644 --- a/src/ModelContextProtocol/AIContentExtensions.cs +++ b/src/ModelContextProtocol/AIContentExtensions.cs @@ -1,7 +1,9 @@ using Microsoft.Extensions.AI; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; using System.Runtime.InteropServices; +using System.Text.Json; namespace ModelContextProtocol; @@ -101,4 +103,28 @@ internal static string GetBase64Data(this DataContent dataContent) Convert.ToBase64String(dataContent.Data.ToArray()); #endif } + + internal static Content ToContent(this AIContent content) => + content switch + { + TextContent textContent => new() + { + Text = textContent.Text, + Type = "text", + }, + DataContent dataContent => new() + { + Data = dataContent.GetBase64Data(), + MimeType = dataContent.MediaType, + Type = + dataContent.HasTopLevelMediaType("image") ? "image" : + dataContent.HasTopLevelMediaType("audio") ? "audio" : + "resource", + }, + _ => new() + { + Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))), + Type = "text", + } + }; } diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs index ff3f92887..56ad40410 100644 --- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs @@ -22,7 +22,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool public static new AIFunctionMcpServerTool Create( Delegate method, string? name, - string? description, + string? description, IServiceProvider? services) { Throw.IfNull(method); @@ -34,7 +34,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool /// Creates an instance for a method, specified via a instance. /// public static new AIFunctionMcpServerTool Create( - MethodInfo method, + MethodInfo method, object? target, string? name, string? description, @@ -195,57 +195,49 @@ public override async Task InvokeAsync( }; } - switch (result) + return result switch { - case null: - return new() - { - Content = [] - }; - - case string text: - return new() - { - Content = [new() { Text = text, Type = "text" }] - }; - - case TextContent textContent: - return new() - { - Content = [new() { Text = textContent.Text, Type = "text" }] - }; - - case DataContent dataContent: - return new() - { - Content = [new() - { - Data = dataContent.GetBase64Data(), - MimeType = dataContent.MediaType, - Type = dataContent.HasTopLevelMediaType("image") ? "image" : "resource", - }] - }; - - case string[] texts: - return new() - { - Content = texts - .Select(x => new Content() { Type = "text", Text = x ?? string.Empty }) - .ToList() - }; + AIContent aiContent => new() + { + Content = [aiContent.ToContent()] + }, + null => new() + { + Content = [] + }, + string text => new() + { + Content = [new() { Text = text, Type = "text" }] + }, + Content content => new() + { + Content = [content] + }, + IEnumerable texts => new() + { + Content = [.. texts.Select(x => new Content() { Type = "text", Text = x ?? string.Empty })] + }, + IEnumerable contentItems => new() + { + Content = [.. contentItems.Select(static item => item.ToContent())] + }, + IEnumerable contents => new() + { + Content = [.. contents] + }, + CallToolResponse callToolResponse => callToolResponse, // TODO https://github.com/modelcontextprotocol/csharp-sdk/issues/69: // Add specialization for annotations. - - default: - return new() - { - Content = [new() + _ => new() + { + Content = [new() { Text = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))), Type = "text" }] - }; - } + }, + }; } + } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolReturnTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolReturnTests.cs new file mode 100644 index 000000000..a263ab9c3 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolReturnTests.cs @@ -0,0 +1,194 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using Moq; + +namespace ModelContextProtocol.Tests.Server; +public class McpServerToolReturnTests +{ + [Fact] + public async Task CanReturnCollectionOfAIContent() + { + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return new List() { + new TextContent("text"), + new DataContent("data:image/png;base64,1234"), + new DataContent("data:audio/wav;base64,1234") + }; + }); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + + Assert.Equal(3, result.Content.Count); + + Assert.Equal("text", result.Content[0].Text); + Assert.Equal("text", result.Content[0].Type); + + Assert.Equal("1234", result.Content[1].Data); + Assert.Equal("image/png", result.Content[1].MimeType); + Assert.Equal("image", result.Content[1].Type); + + Assert.Equal("1234", result.Content[2].Data); + Assert.Equal("audio/wav", result.Content[2].MimeType); + Assert.Equal("audio", result.Content[2].Type); + } + + [Theory] + [InlineData("text", "text")] + [InlineData("data:image/png;base64,1234", "image")] + [InlineData("data:audio/wav;base64,1234", "audio")] + public async Task CanReturnSingleAIContent(string data, string type) + { + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return type switch + { + "text" => (AIContent)new TextContent(data), + "image" => new DataContent(data), + "audio" => new DataContent(data), + _ => throw new ArgumentException("Invalid type") + }; + }); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + + Assert.Single(result.Content); + Assert.Equal(type, result.Content[0].Type); + + if (type != "text") + { + Assert.NotNull(result.Content[0].MimeType); + Assert.Equal(data.Split(',').Last(), result.Content[0].Data); + } + else + { + Assert.Null(result.Content[0].MimeType); + Assert.Equal(data, result.Content[0].Text); + } + } + + [Fact] + public async Task CanReturnNullAIContent() + { + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return (string?)null; + }); + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Empty(result.Content); + } + + [Fact] + public async Task CanReturnString() + { + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return "42"; + }); + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Single(result.Content); + Assert.Equal("42", result.Content[0].Text); + Assert.Equal("text", result.Content[0].Type); + } + + [Fact] + public async Task CanReturnCollectionOfStrings() + { + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return new List() { "42", "43" }; + }); + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Equal(2, result.Content.Count); + Assert.Equal("42", result.Content[0].Text); + Assert.Equal("text", result.Content[0].Type); + Assert.Equal("43", result.Content[1].Text); + Assert.Equal("text", result.Content[1].Type); + } + + [Fact] + public async Task CanReturnMcpContent() + { + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return new Content { Text = "42", Type = "text" }; + }); + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Single(result.Content); + Assert.Equal("42", result.Content[0].Text); + Assert.Equal("text", result.Content[0].Type); + } + + [Fact] + public async Task CanReturnCollectionOfMcpContent() + { + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return new List() { new() { Text = "42", Type = "text" }, new() { Data = "1234", Type = "image", MimeType = "image/png" } }; + }); + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Equal(2, result.Content.Count); + Assert.Equal("42", result.Content[0].Text); + Assert.Equal("text", result.Content[0].Type); + Assert.Equal("1234", result.Content[1].Data); + Assert.Equal("image", result.Content[1].Type); + Assert.Equal("image/png", result.Content[1].MimeType); + Assert.Null(result.Content[1].Text); + } + + [Fact] + public async Task CanReturnCallToolResponse() + { + CallToolResponse response = new() + { + Content = [new() { Text = "text", Type = "text" }, new() { Data = "1234", Type = "image" }] + }; + + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return response; + }); + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + + Assert.Same(response, result); + + Assert.Equal(2, result.Content.Count); + Assert.Equal("text", result.Content[0].Text); + Assert.Equal("text", result.Content[0].Type); + Assert.Equal("1234", result.Content[1].Data); + Assert.Equal("image", result.Content[1].Type); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 3f066dd5c..1ce4b36cb 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Moq;