Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using ModelContextProtocol.Utils;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Text.Json;

namespace Microsoft.Extensions.DependencyInjection;

Expand All @@ -24,6 +25,7 @@ public static partial class McpServerBuilderExtensions
/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <typeparam name="TToolType">The tool type.</typeparam>
/// <param name="builder">The builder instance.</param>
/// <param name="serializerOptions">The serializer options governing tool parameter marshalling.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <remarks>
Expand All @@ -35,7 +37,8 @@ public static partial class McpServerBuilderExtensions
DynamicallyAccessedMemberTypes.PublicMethods |
DynamicallyAccessedMemberTypes.NonPublicMethods |
DynamicallyAccessedMemberTypes.PublicConstructors)] TToolType>(
this IMcpServerBuilder builder)
this IMcpServerBuilder builder,
JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);

Expand All @@ -44,8 +47,8 @@ public static partial class McpServerBuilderExtensions
if (toolMethod.GetCustomAttribute<McpServerToolAttribute>() is not null)
{
builder.Services.AddSingleton((Func<IServiceProvider, McpServerTool>)(toolMethod.IsStatic ?
services => McpServerTool.Create(toolMethod, options: new() { Services = services }) :
services => McpServerTool.Create(toolMethod, typeof(TToolType), new() { Services = services })));
services => McpServerTool.Create(toolMethod, options: new() { Services = services, SerializerOptions = serializerOptions }) :
services => McpServerTool.Create(toolMethod, typeof(TToolType), new() { Services = services, SerializerOptions = serializerOptions })));
}
}

Expand All @@ -55,6 +58,7 @@ public static partial class McpServerBuilderExtensions
/// <summary>Adds <see cref="McpServerTool"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <param name="builder">The builder instance.</param>
/// <param name="toolTypes">Types with marked methods to add as tools to the server.</param>
/// <param name="serializerOptions">The serializer options governing tool parameter marshalling.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="toolTypes"/> is <see langword="null"/>.</exception>
Expand All @@ -64,7 +68,7 @@ public static partial class McpServerBuilderExtensions
/// instance for each. For instance methods, an instance will be constructed for each invocation of the tool.
/// </remarks>
[RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)]
public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params IEnumerable<Type> toolTypes)
public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, IEnumerable<Type> toolTypes, JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);
Throw.IfNull(toolTypes);
Expand All @@ -78,8 +82,8 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params
if (toolMethod.GetCustomAttribute<McpServerToolAttribute>() is not null)
{
builder.Services.AddSingleton((Func<IServiceProvider, McpServerTool>)(toolMethod.IsStatic ?
services => McpServerTool.Create(toolMethod, options: new() { Services = services }) :
services => McpServerTool.Create(toolMethod, toolType, new() { Services = services })));
services => McpServerTool.Create(toolMethod, options: new() { Services = services , SerializerOptions = serializerOptions }) :
services => McpServerTool.Create(toolMethod, toolType, new() { Services = services , SerializerOptions = serializerOptions })));
}
}
}
Expand All @@ -92,6 +96,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params
/// Adds types marked with the <see cref="McpServerToolTypeAttribute"/> attribute from the given assembly as tools to the server.
/// </summary>
/// <param name="builder">The builder instance.</param>
/// <param name="serializerOptions">The serializer options governing tool parameter marshalling.</param>
/// <param name="toolAssembly">The assembly to load the types from. If <see langword="null"/>, the calling assembly will be used.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
Expand All @@ -116,7 +121,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params
/// </para>
/// </remarks>
[RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)]
public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder builder, Assembly? toolAssembly = null)
public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder builder, Assembly? toolAssembly = null, JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);

Expand All @@ -125,7 +130,8 @@ public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder bui
return builder.WithTools(
from t in toolAssembly.GetTypes()
where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
select t);
select t,
serializerOptions);
}
#endregion

Expand Down
3 changes: 2 additions & 1 deletion src/ModelContextProtocol/Server/AIFunctionMcpServerTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
Name = options?.Name ?? method.GetCustomAttribute<McpServerToolAttribute>()?.Name,
Description = options?.Description,
MarshalResult = static (result, _, cancellationToken) => new ValueTask<object?>(result),
SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions,
ConfigureParameterBinding = pi =>
{
if (pi.ParameterType == typeof(RequestContext<CallToolRequestParams>))
Expand Down Expand Up @@ -314,7 +315,7 @@ public override async Task<CallToolResponse> InvokeAsync(
{
Content = [new()
{
Text = JsonSerializer.Serialize(result, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object))),
Text = JsonSerializer.Serialize(result, AIFunction.JsonSerializerOptions.GetTypeInfo(typeof(object))),
Type = "text"
}]
},
Expand Down
15 changes: 13 additions & 2 deletions src/ModelContextProtocol/Server/McpServerToolCreateOptions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using ModelContextProtocol.Utils.Json;
using System.ComponentModel;
using System.Text.Json;

namespace ModelContextProtocol.Server;

Expand Down Expand Up @@ -122,11 +124,19 @@ public sealed class McpServerToolCreateOptions
/// </remarks>
public bool? ReadOnly { get; set; }

/// <summary>
/// Gets or sets the JSON serializer options to use when marshalling data to/from JSON.
/// </summary>
/// <remarks>
/// Defaults to <see cref="McpJsonUtilities.DefaultOptions"/> if left unspecified.
/// </remarks>
public JsonSerializerOptions? SerializerOptions { get; set; }

/// <summary>
/// Creates a shallow clone of the current <see cref="McpServerToolCreateOptions"/> instance.
/// </summary>
internal McpServerToolCreateOptions Clone() =>
new McpServerToolCreateOptions()
new McpServerToolCreateOptions
{
Services = Services,
Name = Name,
Expand All @@ -135,6 +145,7 @@ internal McpServerToolCreateOptions Clone() =>
Destructive = Destructive,
Idempotent = Idempotent,
OpenWorld = OpenWorld,
ReadOnly = ReadOnly
ReadOnly = ReadOnly,
SerializerOptions = SerializerOptions,
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
<RootNamespace>ModelContextProtocol.AspNetCore.Tests</RootNamespace>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' == 'net9.0'">
<!-- For better test coverage, only disable reflection in one of the targets -->
<JsonSerializerIsReflectionEnabledByDefault>false</JsonSerializerIsReflectionEnabledByDefault>
</PropertyGroup>

<PropertyGroup>
<!-- Without this, tests are currently not showing results until all tests complete
https://xunit.net/docs/getting-started/v3/microsoft-testing-platform
Expand Down
20 changes: 15 additions & 5 deletions tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using ModelContextProtocol.Utils.Json;
using System.Text.Json.Serialization;
using TestServerWithHosting.Tools;

namespace ModelContextProtocol.Tests;

public class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper)
public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper)
{
private SseClientTransportOptions DefaultTransportOptions = new()
{
Expand All @@ -41,7 +42,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer()
await using var mcpClient = await ConnectMcpClient(httpClient);

// Send a test message through POST endpoint
await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken);

Assert.True(true);
}
Expand All @@ -57,7 +58,7 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU
await using var mcpClient = await ConnectMcpClient(httpClient);

// Send a test message through POST endpoint
await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
await mcpClient.SendNotificationAsync("test/message", new Envelope { Message = "Hello, SSE!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken);

Assert.True(true);
}
Expand All @@ -73,7 +74,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer()
mcpServer.RegisterNotificationHandler("test/notification", async (notification, cancellationToken) =>
{
Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue<string>());
await mcpServer.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: cancellationToken);
await mcpServer.SendNotificationAsync("test/notification", new Envelope { Message = "Hello from server!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: cancellationToken);
});
return mcpServer.RunAsync(cancellationToken);
});
Expand All @@ -90,7 +91,7 @@ public async Task ConnectAndReceiveNotification_InMemoryServer()
});

// Send a test message through POST endpoint
await mcpClient.SendNotificationAsync("test/notification", new { message = "Hello from client!" }, cancellationToken: TestContext.Current.CancellationToken);
await mcpClient.SendNotificationAsync("test/notification", new Envelope { Message = "Hello from client!" }, serializerOptions: JsonContext.Default.Options, cancellationToken: TestContext.Current.CancellationToken);

var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken);
Assert.Equal("Hello from server!", message);
Expand Down Expand Up @@ -205,4 +206,13 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints)
await context.Response.WriteAsync("Accepted");
});
}

public class Envelope
{
public required string Message { get; set; }
}

[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)]
[JsonSerializable(typeof(Envelope))]
partial class JsonContext : JsonSerializerContext;
}
5 changes: 3 additions & 2 deletions tests/ModelContextProtocol.TestSseServer/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.AspNetCore.Connections;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using Serilog;
using System.Text;
using System.Text.Json;
Expand Down Expand Up @@ -124,7 +125,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
},
"required": ["message"]
}
"""),
""", McpJsonUtilities.DefaultOptions),
},
new Tool()
{
Expand All @@ -145,7 +146,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
},
"required": ["prompt", "maxTokens"]
}
"""),
""", McpJsonUtilities.DefaultOptions),
}
]
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using Moq;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
Expand Down Expand Up @@ -379,7 +380,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient()
await using (client.RegisterNotificationHandler(NotificationMethods.LoggingMessageNotification,
(notification, cancellationToken) =>
{
Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize<LoggingMessageNotificationParams>(notification.Params)));
Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize<LoggingMessageNotificationParams>(notification.Params, McpJsonUtilities.DefaultOptions)));
return Task.CompletedTask;
}))
{
Expand All @@ -398,7 +399,7 @@ public async Task AsClientLoggerProvider_MessagesSentToClient()

Assert.Equal("TestLogger", m.Logger);

string ? s = JsonSerializer.Deserialize<string>(m.Data.Value);
string ? s = JsonSerializer.Deserialize<string>(m.Data.Value, McpJsonUtilities.DefaultOptions);
Assert.NotNull(s);

if (s.Contains("Information"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils.Json;
using Moq;
using System.Text.Json;
using System.Threading.Channels;
Expand Down Expand Up @@ -107,7 +108,7 @@ public virtual Task SendMessageAsync(IJsonRpcMessage message, CancellationToken
Name = "NopTransport",
Version = "1.0.0"
},
}),
}, McpJsonUtilities.DefaultOptions),
});
break;
}
Expand Down
20 changes: 15 additions & 5 deletions tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Tests.Utils;
using ModelContextProtocol.Utils.Json;
using OpenAI;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Tests;

public class ClientIntegrationTests : LoggedTest, IClassFixture<ClientIntegrationTestFixture>
public partial class ClientIntegrationTests : LoggedTest, IClassFixture<ClientIntegrationTestFixture>
{
private static readonly string? s_openAIKey = Environment.GetEnvironmentVariable("AI:OpenAI:ApiKey");

Expand Down Expand Up @@ -261,7 +263,7 @@ public async Task SubscribeResource_Stdio()
[
new(NotificationMethods.ResourceUpdatedNotification, (notification, cancellationToken) =>
{
var notificationParams = JsonSerializer.Deserialize<ResourceUpdatedNotificationParams>(notification.Params);
var notificationParams = JsonSerializer.Deserialize<ResourceUpdatedNotificationParams>(notification.Params, McpJsonUtilities.DefaultOptions);
tcs.TrySetResult(true);
return Task.CompletedTask;
})
Expand Down Expand Up @@ -291,7 +293,7 @@ public async Task UnsubscribeResource_Stdio()
[
new(NotificationMethods.ResourceUpdatedNotification, (notification, cancellationToken) =>
{
var notificationParams = JsonSerializer.Deserialize<ResourceUpdatedNotificationParams>(notification.Params);
var notificationParams = JsonSerializer.Deserialize<ResourceUpdatedNotificationParams>(notification.Params, McpJsonUtilities.DefaultOptions);
receivedNotification.TrySetResult(true);
return Task.CompletedTask;
})
Expand Down Expand Up @@ -442,13 +444,18 @@ public async Task Notifications_Stdio(string clientId)

// Verify we can send notifications without errors
await client.SendNotificationAsync(NotificationMethods.RootsUpdatedNotification, cancellationToken: TestContext.Current.CancellationToken);
await client.SendNotificationAsync("test/notification", new { test = true }, cancellationToken: TestContext.Current.CancellationToken);
await client.SendNotificationAsync("test/notification", new TestNotification { Test = true }, cancellationToken: TestContext.Current.CancellationToken, serializerOptions: JsonContext3.Default.Options);

// assert
// no response to check, if no exception is thrown, it's a success
Assert.True(true);
}

class TestNotification
{
public required bool Test { get; set; }
}

[Fact]
public async Task CallTool_Stdio_MemoryServer()
{
Expand Down Expand Up @@ -557,7 +564,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId)
[
new(NotificationMethods.LoggingMessageNotification, (notification, cancellationToken) =>
{
var loggingMessageNotificationParameters = JsonSerializer.Deserialize<LoggingMessageNotificationParams>(notification.Params);
var loggingMessageNotificationParameters = JsonSerializer.Deserialize<LoggingMessageNotificationParams>(notification.Params, McpJsonUtilities.DefaultOptions);
if (loggingMessageNotificationParameters is not null)
{
receivedNotification.TrySetResult(true);
Expand All @@ -574,4 +581,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId)
// assert
await receivedNotification.Task;
}

[JsonSerializable(typeof(TestNotification))]
partial class JsonContext3 : JsonSerializerContext;
}
Loading