Skip to content

Commit a069f64

Browse files
committed
Add UseMcpClient
1 parent 9f6fc36 commit a069f64

File tree

9 files changed

+634
-23
lines changed

9 files changed

+634
-23
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
using System.Collections.Concurrent;
2+
using System.Runtime.CompilerServices;
3+
using Microsoft.Extensions.AI;
4+
using Microsoft.Extensions.Logging;
5+
using Microsoft.Extensions.Logging.Abstractions;
6+
using ModelContextProtocol.Client;
7+
#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
8+
9+
namespace ModelContextProtocol;
10+
11+
/// <summary>
12+
/// Extension methods for adding MCP client support to chat clients.
13+
/// </summary>
14+
public static class McpChatClientBuilderExtensions
15+
{
16+
/// <summary>
17+
/// Adds a chat client to the chat client pipeline that creates an <see cref="McpClient"/> for each <see cref="HostedMcpServerTool"/>
18+
/// in <see cref="ChatOptions.Tools"/> and augments it with the tools from MCP servers as <see cref="AIFunction"/> instances.
19+
/// </summary>
20+
/// <param name="builder">The <see cref="ChatClientBuilder"/> to configure.</param>
21+
/// <param name="httpClient">The <see cref="HttpClient"/> to use, or <see langword="null"/> to create a new instance.</param>
22+
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use, or <see langword="null"/> to resolve from services.</param>
23+
/// <returns>The <see cref="ChatClientBuilder"/> for method chaining.</returns>
24+
/// <remarks>
25+
/// <para>
26+
/// When a <c>HostedMcpServerTool</c> is encountered in the tools collection, the client
27+
/// connects to the MCP server, retrieves available tools, and expands them into callable AI functions.
28+
/// Connections are cached by server address to avoid redundant connections.
29+
/// </para>
30+
/// <para>
31+
/// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers.
32+
/// </para>
33+
/// </remarks>
34+
public static ChatClientBuilder UseMcpClient(
35+
this ChatClientBuilder builder,
36+
HttpClient? httpClient = null,
37+
ILoggerFactory? loggerFactory = null)
38+
{
39+
return builder.Use((innerClient, services) =>
40+
{
41+
loggerFactory ??= (ILoggerFactory)services.GetService(typeof(ILoggerFactory))!;
42+
var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory);
43+
return chatClient;
44+
});
45+
}
46+
47+
private class McpChatClient : DelegatingChatClient
48+
{
49+
private readonly ILoggerFactory? _loggerFactory;
50+
private readonly ILogger _logger;
51+
private readonly HttpClient _httpClient;
52+
private readonly bool _ownsHttpClient;
53+
private ConcurrentDictionary<string, Task<McpClient>>? _mcpClientTasks = null;
54+
55+
/// <summary>
56+
/// Initializes a new instance of the <see cref="McpChatClient"/> class.
57+
/// </summary>
58+
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
59+
/// <param name="httpClient">An optional <see cref="HttpClient"/> to use when connecting to MCP servers. If not provided, a new instance will be created.</param>
60+
/// <param name="loggerFactory">An <see cref="ILoggerFactory"/> to use for logging information about function invocation.</param>
61+
public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null)
62+
: base(innerClient)
63+
{
64+
_loggerFactory = loggerFactory;
65+
_logger = (ILogger?)loggerFactory?.CreateLogger<McpChatClient>() ?? NullLogger.Instance;
66+
_httpClient = httpClient ?? new HttpClient();
67+
_ownsHttpClient = httpClient is null;
68+
}
69+
70+
/// <inheritdoc/>
71+
public override async Task<ChatResponse> GetResponseAsync(
72+
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
73+
{
74+
if (options?.Tools is { Count: > 0 })
75+
{
76+
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false);
77+
options = options.Clone();
78+
options.Tools = downstreamTools;
79+
}
80+
81+
return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
82+
}
83+
84+
/// <inheritdoc/>
85+
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
86+
{
87+
if (options?.Tools is { Count: > 0 })
88+
{
89+
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false);
90+
options = options.Clone();
91+
options.Tools = downstreamTools;
92+
}
93+
94+
await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
95+
{
96+
yield return update;
97+
}
98+
}
99+
100+
private async Task<List<AITool>?> BuildDownstreamAIToolsAsync(IList<AITool>? inputTools, CancellationToken cancellationToken)
101+
{
102+
List<AITool>? downstreamTools = null;
103+
foreach (var tool in inputTools ?? [])
104+
{
105+
if (tool is not HostedMcpServerTool mcpTool)
106+
{
107+
// For other tools, we want to keep them in the list of tools.
108+
downstreamTools ??= new List<AITool>();
109+
downstreamTools.Add(tool);
110+
continue;
111+
}
112+
113+
if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) ||
114+
(parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps))
115+
{
116+
throw new InvalidOperationException(
117+
$"MCP server address must be an absolute HTTP or HTTPS URI. Invalid address: '{mcpTool.ServerAddress}'");
118+
}
119+
120+
// List all MCP functions from the specified MCP server.
121+
// This will need some caching in a real-world scenario to avoid repeated calls.
122+
var mcpClient = await CreateMcpClientAsync(parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false);
123+
var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
124+
125+
// Add the listed functions to our list of tools we'll pass to the inner client.
126+
foreach (var mcpFunction in mcpFunctions)
127+
{
128+
if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name))
129+
{
130+
_logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name);
131+
continue;
132+
}
133+
134+
downstreamTools ??= new List<AITool>();
135+
switch (mcpTool.ApprovalMode)
136+
{
137+
case HostedMcpServerToolAlwaysRequireApprovalMode alwaysRequireApproval:
138+
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
139+
break;
140+
case HostedMcpServerToolNeverRequireApprovalMode neverRequireApproval:
141+
downstreamTools.Add(mcpFunction);
142+
break;
143+
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.AlwaysRequireApprovalToolNames?.Contains(mcpFunction.Name) is true:
144+
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
145+
break;
146+
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true:
147+
downstreamTools.Add(mcpFunction);
148+
break;
149+
default:
150+
// Default to always require approval if no specific mode is set.
151+
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
152+
break;
153+
}
154+
}
155+
}
156+
157+
return downstreamTools;
158+
}
159+
160+
/// <inheritdoc/>
161+
protected override void Dispose(bool disposing)
162+
{
163+
if (disposing)
164+
{
165+
// Dispose of the HTTP client if it was created by this client.
166+
if (_ownsHttpClient)
167+
{
168+
_httpClient?.Dispose();
169+
}
170+
171+
if (_mcpClientTasks is not null)
172+
{
173+
// Dispose of all cached MCP clients.
174+
foreach (var clientTask in _mcpClientTasks.Values)
175+
{
176+
#if NETSTANDARD2_0
177+
if (clientTask.Status == TaskStatus.RanToCompletion)
178+
#else
179+
if (clientTask.IsCompletedSuccessfully)
180+
#endif
181+
{
182+
_ = clientTask.Result.DisposeAsync();
183+
}
184+
}
185+
186+
_mcpClientTasks.Clear();
187+
}
188+
}
189+
190+
base.Dispose(disposing);
191+
}
192+
193+
private Task<McpClient> CreateMcpClientAsync(Uri serverAddress, string serverName, string? authorizationToken)
194+
{
195+
if (_mcpClientTasks is null)
196+
{
197+
_mcpClientTasks = new ConcurrentDictionary<string, Task<McpClient>>(StringComparer.OrdinalIgnoreCase);
198+
}
199+
200+
// Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token.
201+
// Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently.
202+
return _mcpClientTasks.GetOrAdd(serverAddress.ToString(), _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None));
203+
}
204+
205+
private async Task<McpClient> CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken)
206+
{
207+
var serverAddressKey = serverAddress.ToString();
208+
try
209+
{
210+
var transport = new HttpClientTransport(new HttpClientTransportOptions
211+
{
212+
Endpoint = serverAddress,
213+
Name = serverName,
214+
AdditionalHeaders = authorizationToken is not null
215+
// Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available.
216+
? new Dictionary<string, string>() { { "Authorization", $"Bearer {authorizationToken}" } }
217+
: null,
218+
}, _httpClient, _loggerFactory);
219+
220+
return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false);
221+
}
222+
catch
223+
{
224+
// Remove the failed task from cache so subsequent requests can retry
225+
_mcpClientTasks?.TryRemove(serverAddressKey, out _);
226+
throw;
227+
}
228+
}
229+
}
230+
}

src/ModelContextProtocol/ModelContextProtocol.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
<ItemGroup>
2525
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" />
26+
<PackageReference Include="Microsoft.Extensions.AI" />
2627
</ItemGroup>
2728

2829
<ItemGroup>

tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
namespace ModelContextProtocol.AspNetCore.Tests;
66

7-
public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture<SseServerIntegrationTestFixture>
7+
public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture<SseServerWithXunitLoggerFixture>
88
{
9-
protected readonly SseServerIntegrationTestFixture _fixture;
9+
protected readonly SseServerWithXunitLoggerFixture _fixture;
1010

11-
public HttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
11+
public HttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
1212
: base(testOutputHelper)
1313
{
1414
_fixture = fixture;

tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,18 @@
77

88
namespace ModelContextProtocol.AspNetCore.Tests;
99

10-
public class SseServerIntegrationTestFixture : IAsyncDisposable
10+
public abstract class SseServerIntegrationTestFixture : IAsyncDisposable
1111
{
1212
private readonly KestrelInMemoryTransport _inMemoryTransport = new();
13-
1413
private readonly Task _serverTask;
1514
private readonly CancellationTokenSource _stopCts = new();
1615

17-
// XUnit's ITestOutputHelper is created per test, while this fixture is used for
18-
// multiple tests, so this dispatches the output to the current test.
19-
private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new();
20-
2116
private HttpClientTransportOptions DefaultTransportOptions { get; set; } = new()
2217
{
2318
Endpoint = new("http://localhost:5000/"),
2419
};
2520

26-
public SseServerIntegrationTestFixture()
21+
protected SseServerIntegrationTestFixture()
2722
{
2823
var socketsHttpHandler = new SocketsHttpHandler
2924
{
@@ -39,8 +34,10 @@ public SseServerIntegrationTestFixture()
3934
BaseAddress = new("http://localhost:5000/"),
4035
};
4136

42-
_serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token);
37+
_serverTask = Program.MainAsync([], CreateLoggerProvider(), _inMemoryTransport, _stopCts.Token);
4338
}
39+
40+
protected abstract ILoggerProvider CreateLoggerProvider();
4441

4542
public HttpClient HttpClient { get; }
4643

@@ -53,21 +50,17 @@ public Task<McpClient> ConnectMcpClientAsync(McpClientOptions? options, ILoggerF
5350
TestContext.Current.CancellationToken);
5451
}
5552

56-
public void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions)
53+
public virtual void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions)
5754
{
58-
_delegatingTestOutputHelper.CurrentTestOutputHelper = output;
5955
DefaultTransportOptions = clientTransportOptions;
6056
}
6157

62-
public void TestCompleted()
58+
public virtual void TestCompleted()
6359
{
64-
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
6560
}
6661

67-
public async ValueTask DisposeAsync()
62+
public virtual async ValueTask DisposeAsync()
6863
{
69-
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
70-
7164
HttpClient.Dispose();
7265
_stopCts.Cancel();
7366

@@ -82,3 +75,49 @@ public async ValueTask DisposeAsync()
8275
_stopCts.Dispose();
8376
}
8477
}
78+
79+
/// <summary>
80+
/// SSE server fixture that routes logs to xUnit test output.
81+
/// </summary>
82+
public class SseServerWithXunitLoggerFixture : SseServerIntegrationTestFixture
83+
{
84+
// XUnit's ITestOutputHelper is created per test, while this fixture is used for
85+
// multiple tests, so this dispatches the output to the current test.
86+
private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new();
87+
88+
protected override ILoggerProvider CreateLoggerProvider()
89+
=> new XunitLoggerProvider(_delegatingTestOutputHelper);
90+
91+
public override void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions)
92+
{
93+
_delegatingTestOutputHelper.CurrentTestOutputHelper = output;
94+
base.Initialize(output, clientTransportOptions);
95+
}
96+
97+
public override void TestCompleted()
98+
{
99+
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
100+
base.TestCompleted();
101+
}
102+
103+
public override async ValueTask DisposeAsync()
104+
{
105+
_delegatingTestOutputHelper.CurrentTestOutputHelper = null;
106+
await base.DisposeAsync();
107+
}
108+
}
109+
110+
/// <summary>
111+
/// Fixture for tests that need to inspect server logs using MockLoggerProvider.
112+
/// Use <see cref="SseServerWithXunitLoggerFixture"/> for tests that just need xUnit output.
113+
/// </summary>
114+
public class SseServerWithMockLoggerFixture : SseServerIntegrationTestFixture
115+
{
116+
private readonly MockLoggerProvider _mockLoggerProvider = new();
117+
118+
protected override ILoggerProvider CreateLoggerProvider()
119+
=> _mockLoggerProvider;
120+
121+
public IEnumerable<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> ServerLogs
122+
=> _mockLoggerProvider.LogMessages;
123+
}

tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace ModelContextProtocol.AspNetCore.Tests;
66

7-
public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
7+
public class SseServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
88
: HttpServerIntegrationTests(fixture, testOutputHelper)
99

1010
{

tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
namespace ModelContextProtocol.AspNetCore.Tests;
44

5-
public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
5+
public class StatelessServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
66
: StreamableHttpServerIntegrationTests(fixture, testOutputHelper)
77
{
88
protected override HttpClientTransportOptions ClientTransportOptions => new()

tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace ModelContextProtocol.AspNetCore.Tests;
55

6-
public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper)
6+
public class StreamableHttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper)
77
: HttpServerIntegrationTests(fixture, testOutputHelper)
88

99
{

0 commit comments

Comments
 (0)