Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
MethodInfo method, McpServerPromptCreateOptions? options) =>
new()
{
Name = options?.Name ?? method.GetCustomAttribute<McpServerPromptAttribute>()?.Name,
Name = options?.Name ?? method.GetCustomAttribute<McpServerPromptAttribute>()?.Name ?? AIFunctionMcpServerTool.DeriveName(method),
Description = options?.Description,
MarshalResult = static (result, _, cancellationToken) => new ValueTask<object?>(result),
SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
MethodInfo method, McpServerResourceCreateOptions? options) =>
new()
{
Name = options?.Name ?? method.GetCustomAttribute<McpServerResourceAttribute>()?.Name,
Name = options?.Name ?? method.GetCustomAttribute<McpServerResourceAttribute>()?.Name ?? AIFunctionMcpServerTool.DeriveName(method),
Description = options?.Description,
MarshalResult = static (result, _, cancellationToken) => new ValueTask<object?>(result),
SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions,
Expand Down
60 changes: 59 additions & 1 deletion src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.RegularExpressions;

namespace ModelContextProtocol.Server;

Expand Down Expand Up @@ -74,7 +75,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions(
MethodInfo method, McpServerToolCreateOptions? options) =>
new()
{
Name = options?.Name ?? method.GetCustomAttribute<McpServerToolAttribute>()?.Name,
Name = options?.Name ?? method.GetCustomAttribute<McpServerToolAttribute>()?.Name ?? DeriveName(method),
Description = options?.Description,
MarshalResult = static (result, _, cancellationToken) => new ValueTask<object?>(result),
SerializerOptions = options?.SerializerOptions ?? McpJsonUtilities.DefaultOptions,
Expand Down Expand Up @@ -293,6 +294,63 @@ public override async ValueTask<CallToolResult> InvokeAsync(
};
}

/// <summary>Creates a name to use based on the supplied method and naming policy.</summary>
internal static string DeriveName(MethodInfo method, JsonNamingPolicy? policy = null)
{
string name = method.Name;

// Remove any "Async" suffix if the method is an async method and if the method name isn't just "Async".
const string AsyncSuffix = "Async";
if (IsAsyncMethod(method) &&
name.EndsWith(AsyncSuffix, StringComparison.Ordinal) &&
name.Length > AsyncSuffix.Length)
{
name = name.Substring(0, name.Length - AsyncSuffix.Length);
}

// Replace anything other than ASCII letters or digits with underscores, trim off any leading or trailing underscores.
name = NonAsciiLetterDigitsRegex().Replace(name, "_").Trim('_');

// If after all our transformations the name is empty, just use the original method name.
if (name.Length == 0)
{
name = method.Name;
}

// Case the name based on the provided naming policy.
return (policy ?? JsonNamingPolicy.SnakeCaseLower).ConvertName(name) ?? name;

static bool IsAsyncMethod(MethodInfo method)
{
Type t = method.ReturnType;

if (t == typeof(Task) || t == typeof(ValueTask))
{
return true;
}

if (t.IsGenericType)
{
t = t.GetGenericTypeDefinition();
if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>))
{
return true;
}
}

return false;
}
}

/// <summary>Regex that flags runs of characters other than ASCII digits or letters.</summary>
#if NET
[GeneratedRegex("[^0-9A-Za-z]+")]
private static partial Regex NonAsciiLetterDigitsRegex();
#else
private static Regex NonAsciiLetterDigitsRegex() => _nonAsciiLetterDigits;
private static readonly Regex _nonAsciiLetterDigits = new("[^0-9A-Za-z]+", RegexOptions.Compiled);
#endif

private static JsonElement? CreateOutputSchema(AIFunction function, McpServerToolCreateOptions? toolCreateOptions, out bool structuredOutputRequiresWrapping)
{
structuredOutputRequiresWrapping = false;
Expand Down
2 changes: 1 addition & 1 deletion tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ IHttpContextAccessor is not currently supported with non-stateless Streamable HT
await using var mcpClient = await ConnectAsync();

var response = await mcpClient.CallToolAsync(
"EchoWithUserName",
"echo_with_user_name",
new Dictionary<string, object?>() { ["message"] = "Hello world!" },
cancellationToken: TestContext.Current.CancellationToken);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ public async Task AddMcpServer_CanBeCalled_MultipleTimes()
var tools = await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);

Assert.Equal(2, tools.Count);
Assert.Contains(tools, tools => tools.Name == "Echo");
Assert.Contains(tools, tools => tools.Name == "echo");
Assert.Contains(tools, tools => tools.Name == "sampleLLM");

var echoResponse = await mcpClient.CallToolAsync(
"Echo",
"echo",
new Dictionary<string, object?>
{
["message"] = "from client!"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public async Task Can_List_And_Call_Registered_Prompts()
var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken);
Assert.Equal(6, prompts.Count);

var prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsChatMessages));
var prompt = prompts.First(t => t.Name == "returns_chat_messages");
Assert.Equal("Returns chat messages", prompt.Description);

var result = await prompt.GetAsync(new Dictionary<string, object?>() { ["message"] = "hello" }, cancellationToken: TestContext.Current.CancellationToken);
Expand Down Expand Up @@ -171,7 +171,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle()
Assert.NotNull(prompts);
Assert.NotEmpty(prompts);

McpClientPrompt prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsString));
McpClientPrompt prompt = prompts.First(t => t.Name == "returns_string");

Assert.Equal("This is a title", prompt.Title);
}
Expand Down Expand Up @@ -204,7 +204,7 @@ public async Task Throws_Exception_Missing_Parameter()
await using IMcpClient client = await CreateMcpClientForServer();

var e = await Assert.ThrowsAsync<McpException>(async () => await client.GetPromptAsync(
nameof(SimplePrompts.ReturnsChatMessages),
"returns_chat_messages",
cancellationToken: TestContext.Current.CancellationToken));

Assert.Equal(McpErrorCode.InternalError, e.ErrorCode);
Expand Down Expand Up @@ -242,7 +242,7 @@ public void Register_Prompts_From_Current_Assembly()
sc.AddMcpServer().WithPromptsFromAssembly();
IServiceProvider services = sc.BuildServiceProvider();

Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages));
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == "returns_chat_messages");
}

[Fact]
Expand All @@ -255,10 +255,10 @@ public void Register_Prompts_From_Multiple_Sources()
.WithPrompts([McpServerPrompt.Create(() => "42", new() { Name = "Returns42" })]);
IServiceProvider services = sc.BuildServiceProvider();

Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages));
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ThrowsException));
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsString));
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == nameof(MorePrompts.AnotherPrompt));
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == "returns_chat_messages");
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == "throws_exception");
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == "returns_string");
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == "another_prompt");
Assert.Contains(services.GetServices<McpServerPrompt>(), t => t.ProtocolPrompt.Name == "Returns42");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public async Task Can_List_And_Call_Registered_Resources()
var resources = await client.ListResourcesAsync(TestContext.Current.CancellationToken);
Assert.Equal(5, resources.Count);

var resource = resources.First(t => t.Name == nameof(SimpleResources.SomeNeatDirectResource));
var resource = resources.First(t => t.Name == "some_neat_direct_resource");
Assert.Equal("Some neat direct resource", resource.Description);

var result = await resource.ReadAsync(cancellationToken: TestContext.Current.CancellationToken);
Expand All @@ -146,7 +146,7 @@ public async Task Can_List_And_Call_Registered_ResourceTemplates()
var resources = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken);
Assert.Equal(3, resources.Count);

var resource = resources.First(t => t.Name == nameof(SimpleResources.SomeNeatTemplatedResource));
var resource = resources.First(t => t.Name == "some_neat_templated_resource");
Assert.Equal("Some neat resource with parameters", resource.Description);

var result = await resource.ReadAsync(new Dictionary<string, object?>() { ["name"] = "hello" }, cancellationToken: TestContext.Current.CancellationToken);
Expand Down Expand Up @@ -204,13 +204,13 @@ public async Task TitleAttributeProperty_PropagatedToTitle()
var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotNull(resources);
Assert.NotEmpty(resources);
McpClientResource resource = resources.First(t => t.Name == nameof(SimpleResources.SomeNeatDirectResource));
McpClientResource resource = resources.First(t => t.Name == "some_neat_direct_resource");
Assert.Equal("This is a title", resource.Title);

var resourceTemplates = await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotNull(resourceTemplates);
Assert.NotEmpty(resourceTemplates);
McpClientResourceTemplate resourceTemplate = resourceTemplates.First(t => t.Name == nameof(SimpleResources.SomeNeatTemplatedResource));
McpClientResourceTemplate resourceTemplate = resourceTemplates.First(t => t.Name == "some_neat_templated_resource");
Assert.Equal("This is another title", resourceTemplate.Title);
}

Expand Down Expand Up @@ -268,8 +268,8 @@ public void Register_Resources_From_Current_Assembly()
sc.AddMcpServer().WithResourcesFromAssembly();
IServiceProvider services = sc.BuildServiceProvider();

Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResource?.Uri == $"resource://mcp/{nameof(SimpleResources.SomeNeatDirectResource)}");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/{nameof(SimpleResources.SomeNeatTemplatedResource)}{{?name}}");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResource?.Uri == $"resource://mcp/some_neat_direct_resource");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/some_neat_templated_resource{{?name}}");
}

[Fact]
Expand All @@ -282,9 +282,9 @@ public void Register_Resources_From_Multiple_Sources()
.WithResources([McpServerResource.Create(() => "42", new() { UriTemplate = "myResources:///returns42/{something}" })]);
IServiceProvider services = sc.BuildServiceProvider();

Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResource?.Uri == $"resource://mcp/{nameof(SimpleResources.SomeNeatDirectResource)}");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/{nameof(SimpleResources.SomeNeatTemplatedResource)}{{?name}}");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/{nameof(MoreResources.AnotherNeatDirectResource)}");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResource?.Uri == $"resource://mcp/some_neat_direct_resource");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/some_neat_templated_resource{{?name}}");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResourceTemplate?.UriTemplate == $"resource://mcp/another_neat_direct_resource");
Assert.Contains(services.GetServices<McpServerResource>(), t => t.ProtocolResourceTemplate.UriTemplate == "myResources:///returns42/{something}");
}

Expand Down
Loading