Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
/// <summary>Adds <see cref="McpServerPrompt"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <typeparam name="TPromptType">The prompt type.</typeparam>
/// <param name="builder">The builder instance.</param>
/// <param name="serializerOptions">The serializer options governing prompt parameter marshalling.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <remarks>
Expand All @@ -154,7 +155,8 @@ where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
DynamicallyAccessedMemberTypes.PublicMethods |
DynamicallyAccessedMemberTypes.NonPublicMethods |
DynamicallyAccessedMemberTypes.PublicConstructors)] TPromptType>(
this IMcpServerBuilder builder)
this IMcpServerBuilder builder,
JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);

Expand All @@ -163,8 +165,8 @@ where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
if (promptMethod.GetCustomAttribute<McpServerPromptAttribute>() is not null)
{
builder.Services.AddSingleton((Func<IServiceProvider, McpServerPrompt>)(promptMethod.IsStatic ?
services => McpServerPrompt.Create(promptMethod, options: new() { Services = services }) :
services => McpServerPrompt.Create(promptMethod, typeof(TPromptType), new() { Services = services })));
services => McpServerPrompt.Create(promptMethod, options: new() { Services = services, SerializerOptions = serializerOptions }) :
services => McpServerPrompt.Create(promptMethod, typeof(TPromptType), new() { Services = services, SerializerOptions = serializerOptions })));
}
}

Expand All @@ -174,6 +176,7 @@ where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
/// <summary>Adds <see cref="McpServerPrompt"/> instances to the service collection backing <paramref name="builder"/>.</summary>
/// <param name="builder">The builder instance.</param>
/// <param name="promptTypes">Types with marked methods to add as prompts to the server.</param>
/// <param name="serializerOptions">The serializer options governing prompt parameter marshalling.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="promptTypes"/> is <see langword="null"/>.</exception>
Expand All @@ -183,7 +186,7 @@ where t.GetCustomAttribute<McpServerToolTypeAttribute>() is not null
/// instance for each. For instance methods, an instance will be constructed for each invocation of the prompt.
/// </remarks>
[RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)]
public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, params IEnumerable<Type> promptTypes)
public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, IEnumerable<Type> promptTypes, JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);
Throw.IfNull(promptTypes);
Expand All @@ -197,8 +200,8 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, para
if (promptMethod.GetCustomAttribute<McpServerPromptAttribute>() is not null)
{
builder.Services.AddSingleton((Func<IServiceProvider, McpServerPrompt>)(promptMethod.IsStatic ?
services => McpServerPrompt.Create(promptMethod, options: new() { Services = services }) :
services => McpServerPrompt.Create(promptMethod, promptType, new() { Services = services })));
services => McpServerPrompt.Create(promptMethod, options: new() { Services = services, SerializerOptions = serializerOptions }) :
services => McpServerPrompt.Create(promptMethod, promptType, new() { Services = services, SerializerOptions = serializerOptions })));
}
}
}
Expand All @@ -211,6 +214,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, para
/// Adds types marked with the <see cref="McpServerPromptTypeAttribute"/> attribute from the given assembly as prompts to the server.
/// </summary>
/// <param name="builder">The builder instance.</param>
/// <param name="serializerOptions">The serializer options governing prompt parameter marshalling.</param>
/// <param name="promptAssembly">The assembly to load the types from. If <see langword="null"/>, the calling assembly will be used.</param>
/// <returns>The builder provided in <paramref name="builder"/>.</returns>
/// <exception cref="ArgumentNullException"><paramref name="builder"/> is <see langword="null"/>.</exception>
Expand All @@ -235,7 +239,7 @@ public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, para
/// </para>
/// </remarks>
[RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)]
public static IMcpServerBuilder WithPromptsFromAssembly(this IMcpServerBuilder builder, Assembly? promptAssembly = null)
public static IMcpServerBuilder WithPromptsFromAssembly(this IMcpServerBuilder builder, Assembly? promptAssembly = null, JsonSerializerOptions? serializerOptions = null)
{
Throw.IfNull(builder);

Expand All @@ -244,7 +248,8 @@ public static IMcpServerBuilder WithPromptsFromAssembly(this IMcpServerBuilder b
return builder.WithPrompts(
from t in promptAssembly.GetTypes()
where t.GetCustomAttribute<McpServerPromptTypeAttribute>() is not null
select t);
select t,
serializerOptions);
}
#endregion

Expand Down
8 changes: 7 additions & 1 deletion src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Text.Json;
Expand Down Expand Up @@ -66,6 +67,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
Name = options?.Name ?? method.GetCustomAttribute<McpServerPromptAttribute>()?.Name,
Description = options?.Description,
MarshalResult = static (result, _, cancellationToken) => new ValueTask<object?>(result),
SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions,
ConfigureParameterBinding = pi =>
{
if (pi.ParameterType == typeof(RequestContext<GetPromptRequestParams>))
Expand Down Expand Up @@ -136,6 +138,10 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
Throw.IfNull(function);

List<PromptArgument> args = [];
HashSet<string>? requiredProps = function.JsonSchema.TryGetProperty("required", out JsonElement required)
? new(required.EnumerateArray().Select(p => p.GetString()!), StringComparer.Ordinal)
: null;

if (function.JsonSchema.TryGetProperty("properties", out JsonElement properties))
{
foreach (var param in properties.EnumerateObject())
Expand All @@ -144,7 +150,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
{
Name = param.Name,
Description = param.Value.TryGetProperty("description", out JsonElement description) ? description.GetString() : null,
Required = param.Value.TryGetProperty("required", out JsonElement required) && required.GetBoolean(),
Required = requiredProps?.Contains(param.Name) ?? false,
});
}
}
Expand Down
13 changes: 12 additions & 1 deletion src/ModelContextProtocol/Server/McpServerPromptCreateOptions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using ModelContextProtocol.Utils.Json;
using System.ComponentModel;
using System.Text.Json;

namespace ModelContextProtocol.Server;

Expand Down Expand Up @@ -45,14 +47,23 @@ public sealed class McpServerPromptCreateOptions
/// </remarks>
public string? Description { get; set; }

/// <summary>
/// Gets or sets the JSON serializer options to use when marshalling data to/from JSON.
/// </summary>
/// <remarks>
/// Defaults to <see cref="McpJsonUtilities.DefaultOptions"/> if left unspecified.
/// </remarks>
public JsonSerializerOptions? SerializerOptions { get; set; }

/// <summary>
/// Creates a shallow clone of the current <see cref="McpServerPromptCreateOptions"/> instance.
/// </summary>
internal McpServerPromptCreateOptions Clone() =>
new McpServerPromptCreateOptions()
new McpServerPromptCreateOptions
{
Services = Services,
Name = Name,
Description = Description,
SerializerOptions = SerializerOptions,
};
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;

Expand All @@ -21,7 +22,8 @@ namespace System.Text.Json.Serialization;
/// 9.x support for custom enum member naming. It will be replaced by the built-in functionality
/// once .NET 9 is fully adopted.
/// </remarks>
internal sealed class CustomizableJsonStringEnumConverter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum> :
[EditorBrowsable(EditorBrowsableState.Never)]
public sealed class CustomizableJsonStringEnumConverter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum> :
JsonStringEnumConverter<TEnum> where TEnum : struct, Enum
{
#if !NET9_0_OR_GREATER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using System.ComponentModel;
using System.Text.Json.Serialization;
using System.Threading.Channels;

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously

namespace ModelContextProtocol.Tests.Configuration;

public class McpServerBuilderExtensionsPromptsTests : ClientServerTestBase
public partial class McpServerBuilderExtensionsPromptsTests : ClientServerTestBase
{
public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper)
: base(testOutputHelper)
Expand Down Expand Up @@ -237,7 +238,7 @@ public void Register_Prompts_From_Multiple_Sources()
ServiceCollection sc = new();
sc.AddMcpServer()
.WithPrompts<SimplePrompts>()
.WithPrompts<MorePrompts>();
.WithPrompts<MorePrompts>(JsonContext4.Default.Options);
IServiceProvider services = sc.BuildServiceProvider();

Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages));
Expand Down Expand Up @@ -270,7 +271,7 @@ public string ReturnsString([Description("The first parameter")] string message)
public sealed class MorePrompts
{
[McpServerPrompt]
public static PromptMessage AnotherPrompt() =>
public static PromptMessage AnotherPrompt(ObjectWithId id) =>
new PromptMessage
{
Role = Role.User,
Expand All @@ -282,4 +283,8 @@ public class ObjectWithId
{
public string Id { get; set; } = Guid.NewGuid().ToString("N");
}

[JsonSerializable(typeof(ObjectWithId))]
[JsonSerializable(typeof(PromptMessage))]
partial class JsonContext4 : JsonSerializerContext;
}