diff --git a/src/ModelContextProtocol/AIContentExtensions.cs b/src/ModelContextProtocol/AIContentExtensions.cs
index 6a3f17733..87f40daaa 100644
--- a/src/ModelContextProtocol/AIContentExtensions.cs
+++ b/src/ModelContextProtocol/AIContentExtensions.cs
@@ -1,7 +1,9 @@
using Microsoft.Extensions.AI;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
+using ModelContextProtocol.Utils.Json;
using System.Runtime.InteropServices;
+using System.Text.Json;
namespace ModelContextProtocol;
@@ -101,4 +103,28 @@ internal static string GetBase64Data(this DataContent dataContent)
Convert.ToBase64String(dataContent.Data.ToArray());
#endif
}
+
+ internal static Content ToContent(this AIContent content) =>
+ content switch
+ {
+ TextContent textContent => new()
+ {
+ Text = textContent.Text,
+ Type = "text",
+ },
+ DataContent dataContent => new()
+ {
+ Data = dataContent.GetBase64Data(),
+ MimeType = dataContent.MediaType,
+ Type =
+ dataContent.HasTopLevelMediaType("image") ? "image" :
+ dataContent.HasTopLevelMediaType("audio") ? "audio" :
+ "resource",
+ },
+ _ => new()
+ {
+ Text = JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))),
+ Type = "text",
+ }
+ };
}
diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
index ff3f92887..56ad40410 100644
--- a/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
+++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
@@ -22,7 +22,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool
public static new AIFunctionMcpServerTool Create(
Delegate method,
string? name,
- string? description,
+ string? description,
IServiceProvider? services)
{
Throw.IfNull(method);
@@ -34,7 +34,7 @@ internal sealed class AIFunctionMcpServerTool : McpServerTool
/// Creates an instance for a method, specified via a instance.
///
public static new AIFunctionMcpServerTool Create(
- MethodInfo method,
+ MethodInfo method,
object? target,
string? name,
string? description,
@@ -195,57 +195,49 @@ public override async Task InvokeAsync(
};
}
- switch (result)
+ return result switch
{
- case null:
- return new()
- {
- Content = []
- };
-
- case string text:
- return new()
- {
- Content = [new() { Text = text, Type = "text" }]
- };
-
- case TextContent textContent:
- return new()
- {
- Content = [new() { Text = textContent.Text, Type = "text" }]
- };
-
- case DataContent dataContent:
- return new()
- {
- Content = [new()
- {
- Data = dataContent.GetBase64Data(),
- MimeType = dataContent.MediaType,
- Type = dataContent.HasTopLevelMediaType("image") ? "image" : "resource",
- }]
- };
-
- case string[] texts:
- return new()
- {
- Content = texts
- .Select(x => new Content() { Type = "text", Text = x ?? string.Empty })
- .ToList()
- };
+ AIContent aiContent => new()
+ {
+ Content = [aiContent.ToContent()]
+ },
+ null => new()
+ {
+ Content = []
+ },
+ string text => new()
+ {
+ Content = [new() { Text = text, Type = "text" }]
+ },
+ Content content => new()
+ {
+ Content = [content]
+ },
+ IEnumerable texts => new()
+ {
+ Content = [.. texts.Select(x => new Content() { Type = "text", Text = x ?? string.Empty })]
+ },
+ IEnumerable contentItems => new()
+ {
+ Content = [.. contentItems.Select(static item => item.ToContent())]
+ },
+ IEnumerable contents => new()
+ {
+ Content = [.. contents]
+ },
+ CallToolResponse callToolResponse => callToolResponse,
// TODO https://github.com/modelcontextprotocol/csharp-sdk/issues/69:
// Add specialization for annotations.
-
- default:
- return new()
- {
- Content = [new()
+ _ => new()
+ {
+ Content = [new()
{
Text = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))),
Type = "text"
}]
- };
- }
+ },
+ };
}
+
}
\ No newline at end of file
diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolReturnTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolReturnTests.cs
new file mode 100644
index 000000000..a263ab9c3
--- /dev/null
+++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolReturnTests.cs
@@ -0,0 +1,194 @@
+using Microsoft.Extensions.AI;
+using ModelContextProtocol.Protocol.Types;
+using ModelContextProtocol.Server;
+using Moq;
+
+namespace ModelContextProtocol.Tests.Server;
+public class McpServerToolReturnTests
+{
+ [Fact]
+ public async Task CanReturnCollectionOfAIContent()
+ {
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return new List() {
+ new TextContent("text"),
+ new DataContent("data:image/png;base64,1234"),
+ new DataContent("data:audio/wav;base64,1234")
+ };
+ });
+
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+
+ Assert.Equal(3, result.Content.Count);
+
+ Assert.Equal("text", result.Content[0].Text);
+ Assert.Equal("text", result.Content[0].Type);
+
+ Assert.Equal("1234", result.Content[1].Data);
+ Assert.Equal("image/png", result.Content[1].MimeType);
+ Assert.Equal("image", result.Content[1].Type);
+
+ Assert.Equal("1234", result.Content[2].Data);
+ Assert.Equal("audio/wav", result.Content[2].MimeType);
+ Assert.Equal("audio", result.Content[2].Type);
+ }
+
+ [Theory]
+ [InlineData("text", "text")]
+ [InlineData("data:image/png;base64,1234", "image")]
+ [InlineData("data:audio/wav;base64,1234", "audio")]
+ public async Task CanReturnSingleAIContent(string data, string type)
+ {
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return type switch
+ {
+ "text" => (AIContent)new TextContent(data),
+ "image" => new DataContent(data),
+ "audio" => new DataContent(data),
+ _ => throw new ArgumentException("Invalid type")
+ };
+ });
+
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+
+ Assert.Single(result.Content);
+ Assert.Equal(type, result.Content[0].Type);
+
+ if (type != "text")
+ {
+ Assert.NotNull(result.Content[0].MimeType);
+ Assert.Equal(data.Split(',').Last(), result.Content[0].Data);
+ }
+ else
+ {
+ Assert.Null(result.Content[0].MimeType);
+ Assert.Equal(data, result.Content[0].Text);
+ }
+ }
+
+ [Fact]
+ public async Task CanReturnNullAIContent()
+ {
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return (string?)null;
+ });
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+ Assert.Empty(result.Content);
+ }
+
+ [Fact]
+ public async Task CanReturnString()
+ {
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return "42";
+ });
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+ Assert.Single(result.Content);
+ Assert.Equal("42", result.Content[0].Text);
+ Assert.Equal("text", result.Content[0].Type);
+ }
+
+ [Fact]
+ public async Task CanReturnCollectionOfStrings()
+ {
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return new List() { "42", "43" };
+ });
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+ Assert.Equal(2, result.Content.Count);
+ Assert.Equal("42", result.Content[0].Text);
+ Assert.Equal("text", result.Content[0].Type);
+ Assert.Equal("43", result.Content[1].Text);
+ Assert.Equal("text", result.Content[1].Type);
+ }
+
+ [Fact]
+ public async Task CanReturnMcpContent()
+ {
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return new Content { Text = "42", Type = "text" };
+ });
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+ Assert.Single(result.Content);
+ Assert.Equal("42", result.Content[0].Text);
+ Assert.Equal("text", result.Content[0].Type);
+ }
+
+ [Fact]
+ public async Task CanReturnCollectionOfMcpContent()
+ {
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return new List() { new() { Text = "42", Type = "text" }, new() { Data = "1234", Type = "image", MimeType = "image/png" } };
+ });
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+ Assert.Equal(2, result.Content.Count);
+ Assert.Equal("42", result.Content[0].Text);
+ Assert.Equal("text", result.Content[0].Type);
+ Assert.Equal("1234", result.Content[1].Data);
+ Assert.Equal("image", result.Content[1].Type);
+ Assert.Equal("image/png", result.Content[1].MimeType);
+ Assert.Null(result.Content[1].Text);
+ }
+
+ [Fact]
+ public async Task CanReturnCallToolResponse()
+ {
+ CallToolResponse response = new()
+ {
+ Content = [new() { Text = "text", Type = "text" }, new() { Data = "1234", Type = "image" }]
+ };
+
+ Mock mockServer = new();
+ McpServerTool tool = McpServerTool.Create((IMcpServer server) =>
+ {
+ Assert.Same(mockServer.Object, server);
+ return response;
+ });
+ var result = await tool.InvokeAsync(
+ new RequestContext(mockServer.Object, null),
+ TestContext.Current.CancellationToken);
+
+ Assert.Same(response, result);
+
+ Assert.Equal(2, result.Content.Count);
+ Assert.Equal("text", result.Content[0].Text);
+ Assert.Equal("text", result.Content[0].Type);
+ Assert.Equal("1234", result.Content[1].Data);
+ Assert.Equal("image", result.Content[1].Type);
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs
index 3f066dd5c..1ce4b36cb 100644
--- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs
+++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs
@@ -1,4 +1,5 @@
-using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.AI;
+using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using Moq;