Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ var response = await chatClient.GetResponseAsync(

Here is an example of how to create an MCP server and register all tools from the current application.
It includes a simple echo tool as an example (this is included in the same file here for easy of copy and paste, but it needn't be in the same file...
the employed overload of `WithTools` examines the current assembly for classes with the `McpToolType` attribute, and registers all methods with the
the employed overload of `WithTools` examines the current assembly for classes with the `McpServerToolType` attribute, and registers all methods with the
`McpTool` attribute as tools.)

```csharp
Expand Down
4 changes: 2 additions & 2 deletions src/ModelContextProtocol/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ var response = await chatClient.GetResponseAsync(

Here is an example of how to create an MCP server and register all tools from the current application.
It includes a simple echo tool as an example (this is included in the same file here for easy of copy and paste, but it needn't be in the same file...
the employed overload of `WithTools` examines the current assembly for classes with the `McpToolType` attribute, and registers all methods with the
the employed overload of `WithTools` examines the current assembly for classes with the `McpServerToolType` attribute, and registers all methods with the
`McpTool` attribute as tools.)

```csharp
Expand All @@ -101,7 +101,7 @@ builder.Services
.WithTools();
await builder.Build().RunAsync();

[McpToolType]
[McpServerToolType]
public static class EchoTool
{
[McpTool, Description("Echoes the message back to the client.")]
Expand Down
66 changes: 43 additions & 23 deletions src/ModelContextProtocol/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer
{
private readonly IServerTransport? _serverTransport;
private readonly string _serverDescription;
private readonly EventHandler? _toolsChangedDelegate;

private volatile bool _isInitializing;

/// <summary>
Expand All @@ -32,36 +34,45 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
Throw.IfNull(options);

_serverTransport = transport as IServerTransport;
ServerInstructions = options.ServerInstructions;
ServerOptions = options;
Services = serviceProvider;
_serverDescription = $"{options.ServerInfo.Name} {options.ServerInfo.Version}";
_toolsChangedDelegate = delegate
{
_ = SendMessageAsync(new JsonRpcNotification()
{
Method = NotificationMethods.ToolListChangedNotification,
});
};

AddNotificationHandler("notifications/initialized", _ =>
{
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
{
tools.Changed += _toolsChangedDelegate;
}

IsInitialized = true;
return Task.CompletedTask;
});

SetToolsHandler(ref options);
SetToolsHandler(options);

SetInitializeHandler(options);
SetCompletionHandler(options);
SetPingHandler();
SetPromptsHandler(options);
SetResourcesHandler(options);
SetSetLoggingLevelHandler(options);

ServerOptions = options;
}

public ServerCapabilities? ServerCapabilities { get; set; }

public ClientCapabilities? ClientCapabilities { get; set; }

/// <inheritdoc />
public Implementation? ClientInfo { get; set; }

/// <inheritdoc />
public string? ServerInstructions { get; set; }

/// <inheritdoc />
public McpServerOptions ServerOptions { get; }

Expand Down Expand Up @@ -111,6 +122,15 @@ public async Task StartAsync(CancellationToken cancellationToken = default)
}
}

protected override Task CleanupAsync()
{
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
{
tools.Changed -= _toolsChangedDelegate;
}
return base.CleanupAsync();
}

private void SetPingHandler()
{
SetRequestHandler<JsonNode, PingResult>("ping",
Expand All @@ -127,9 +147,9 @@ private void SetInitializeHandler(McpServerOptions options)
return Task.FromResult(new InitializeResult()
{
ProtocolVersion = options.ProtocolVersion,
Instructions = ServerInstructions,
Instructions = options.ServerInstructions,
ServerInfo = options.ServerInfo,
Capabilities = options.Capabilities ?? new ServerCapabilities(),
Capabilities = ServerCapabilities ?? new(),
});
});
}
Expand Down Expand Up @@ -198,7 +218,7 @@ private void SetPromptsHandler(McpServerOptions options)
SetRequestHandler<GetPromptRequestParams, GetPromptResult>("prompts/get", (request, ct) => getPromptHandler(new(this, request), ct));
}

private void SetToolsHandler(ref McpServerOptions options)
private void SetToolsHandler(McpServerOptions options)
{
ToolsCapability? toolsCapability = options.Capabilities?.Tools;
var listToolsHandler = toolsCapability?.ListToolsHandler;
Expand Down Expand Up @@ -261,25 +281,25 @@ private void SetToolsHandler(ref McpServerOptions options)
return tool.InvokeAsync(request, cancellationToken);
};

toolsCapability ??= new();
toolsCapability.CallToolHandler = callToolHandler;
toolsCapability.ListToolsHandler = listToolsHandler;
toolsCapability.ToolCollection = tools;
toolsCapability.ListChanged = true;

options.Capabilities ??= new();
options.Capabilities.Tools = toolsCapability;

tools.Changed += delegate
ServerCapabilities = new()
{
_ = SendMessageAsync(new JsonRpcNotification()
Experimental = options.Capabilities?.Experimental,
Logging = options.Capabilities?.Logging,
Prompts = options.Capabilities?.Prompts,
Resources = options.Capabilities?.Resources,
Tools = new()
{
Method = NotificationMethods.ToolListChangedNotification,
});
ListToolsHandler = listToolsHandler,
CallToolHandler = callToolHandler,
ToolCollection = tools,
ListChanged = true,
}
};
}
else
{
ServerCapabilities = options.Capabilities;

if (toolsCapability is null)
{
// No tools, and no tools capability was declared, so nothing to do.
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ protected void SetRequestHandler<TRequest, TResponse>(string method, Func<TReque
/// Cleans up the endpoint and releases resources.
/// </summary>
/// <returns></returns>
protected async Task CleanupAsync()
protected virtual async Task CleanupAsync()
{
if (_isDisposed)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ private async Task<IMcpClient> CreateMcpClientForServer()
{
await _server.StartAsync(TestContext.Current.CancellationToken);

var stdin = new StreamReader(_serverToClientPipe.Reader.AsStream());
var stdout = new StreamWriter(_clientToServerPipe.Writer.AsStream());
var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream());
var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream());

var serverConfig = new McpServerConfig()
{
Expand All @@ -50,7 +50,7 @@ private async Task<IMcpClient> CreateMcpClientForServer()

return await McpClientFactory.CreateAsync(
serverConfig,
createTransportFunc: (_, _) => new StreamClientTransport(stdin, stdout),
createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader),
cancellationToken: TestContext.Current.CancellationToken);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,45 @@
using Microsoft.Extensions.AI;
using System.Threading.Channels;
using ModelContextProtocol.Protocol.Messages;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Tests.Utils;
using Microsoft.Extensions.Logging;

namespace ModelContextProtocol.Tests.Configuration;

public class McpServerBuilderExtensionsToolsTests : IAsyncDisposable
public class McpServerBuilderExtensionsToolsTests : LoggedTest, IAsyncDisposable
{
private Pipe _clientToServerPipe = new();
private Pipe _serverToClientPipe = new();
private readonly Pipe _clientToServerPipe = new();
private readonly Pipe _serverToClientPipe = new();
private readonly ServiceProvider _serviceProvider;
private readonly IMcpServerBuilder _builder;
private readonly IMcpServer _server;

public McpServerBuilderExtensionsToolsTests()
public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper)
: base(testOutputHelper)
{
ServiceCollection sc = new();
sc.AddSingleton(LoggerFactory);
sc.AddSingleton<IServerTransport>(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()));
sc.AddSingleton(new ObjectWithId());
_builder = sc.AddMcpServer().WithTools<EchoTool>();
_server = sc.BuildServiceProvider().GetRequiredService<IMcpServer>();
_serviceProvider = sc.BuildServiceProvider();
_server = _serviceProvider.GetRequiredService<IMcpServer>();
}

public ValueTask DisposeAsync()
{
_clientToServerPipe.Writer.Complete();
_serverToClientPipe.Writer.Complete();
return _server.DisposeAsync();
return _serviceProvider.DisposeAsync();
}

private async Task<IMcpClient> CreateMcpClientForServer()
{
await _server.StartAsync(TestContext.Current.CancellationToken);

var stdin = new StreamReader(_serverToClientPipe.Reader.AsStream());
var stdout = new StreamWriter(_clientToServerPipe.Writer.AsStream());
var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream());
var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream());

var serverConfig = new McpServerConfig()
{
Expand All @@ -53,7 +60,7 @@ private async Task<IMcpClient> CreateMcpClientForServer()

return await McpClientFactory.CreateAsync(
serverConfig,
createTransportFunc: (_, _) => new StreamClientTransport(stdin, stdout),
createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader),
cancellationToken: TestContext.Current.CancellationToken);
}

Expand Down Expand Up @@ -86,6 +93,63 @@ public async Task Can_List_Registered_Tools()
Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description);
}


[Fact]
public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_Tools()
{
var options = _serviceProvider.GetRequiredService<IOptions<McpServerOptions>>().Value;
var loggerFactory = _serviceProvider.GetRequiredService<ILoggerFactory>();

for (int i = 0; i < 2; i++)
{
var stdinPipe = new Pipe();
var stdoutPipe = new Pipe();

try
{
var transport = new StdioServerTransport($"TestServer_{i}", stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream());
var server = McpServerFactory.Create(transport, options, loggerFactory, _serviceProvider);

await server.StartAsync(TestContext.Current.CancellationToken);

var serverStdinWriter = new StreamWriter(stdinPipe.Writer.AsStream());
var serverStdoutReader = new StreamReader(stdoutPipe.Reader.AsStream());

var serverConfig = new McpServerConfig()
{
Id = $"TestServer_{i}",
Name = $"TestServer_{i}",
TransportType = "ignored",
};

var client = await McpClientFactory.CreateAsync(
serverConfig,
createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader),
cancellationToken: TestContext.Current.CancellationToken);

var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(11, tools.Count);

McpClientTool echoTool = tools.First(t => t.Name == "Echo");
Assert.Equal("Echo", echoTool.Name);
Assert.Equal("Echoes the input back to the client.", echoTool.Description);
Assert.Equal("object", echoTool.JsonSchema.GetProperty("type").GetString());
Assert.Equal(JsonValueKind.Object, echoTool.JsonSchema.GetProperty("properties").GetProperty("message").ValueKind);
Assert.Equal("the echoes message", echoTool.JsonSchema.GetProperty("properties").GetProperty("message").GetProperty("description").GetString());
Assert.Equal(1, echoTool.JsonSchema.GetProperty("required").GetArrayLength());

McpClientTool doubleEchoTool = tools.First(t => t.Name == "double_echo");
Assert.Equal("double_echo", doubleEchoTool.Name);
Assert.Equal("Echoes the input back to the client.", doubleEchoTool.Description);
}
finally
{
stdinPipe.Writer.Complete();
stdoutPipe.Writer.Complete();
}
}
}

[Fact]
public async Task Can_Be_Notified_Of_Tool_Changes()
{
Expand Down
20 changes: 10 additions & 10 deletions tests/ModelContextProtocol.Tests/Transport/StreamClientTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ namespace ModelContextProtocol.Tests.Transport;
internal sealed class StreamClientTransport : TransportBase, IClientTransport
{
private readonly JsonSerializerOptions _jsonOptions = McpJsonUtilities.DefaultOptions;
private Task? _readTask;
private CancellationTokenSource _shutdownCts = new CancellationTokenSource();
private readonly TextReader _stdin;
private readonly TextWriter _stdout;
private readonly Task? _readTask;
private readonly CancellationTokenSource _shutdownCts = new CancellationTokenSource();
private readonly TextReader _serverStdoutReader;
private readonly TextWriter _serverStdinWriter;

public StreamClientTransport(TextReader stdin, TextWriter stdout)
public StreamClientTransport(TextWriter serverStdinWriter, TextReader serverStdoutReader)
: base(NullLoggerFactory.Instance)
{
_stdin = stdin;
_stdout = stdout;
_serverStdoutReader = serverStdoutReader;
_serverStdinWriter = serverStdinWriter;
_readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None);
SetConnected(true);
}
Expand All @@ -31,13 +31,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
messageWithId.Id.ToString() :
"(no id)";

await _stdout.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false);
await _stdout.FlushAsync(cancellationToken).ConfigureAwait(false);
await _serverStdinWriter.WriteLineAsync(JsonSerializer.Serialize(message)).ConfigureAwait(false);
await _serverStdinWriter.FlushAsync(cancellationToken).ConfigureAwait(false);
}

private async Task ReadMessagesAsync(CancellationToken cancellationToken)
{
while (await _stdin.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line)
while (await _serverStdoutReader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is string line)
{
if (!string.IsNullOrWhiteSpace(line))
{
Expand Down