Skip to content

Commit 298e76e

Browse files
committed
Add LRU cache
Add more tests Add Retry logic Add configureTransportOptions
1 parent 1b3de48 commit 298e76e

File tree

3 files changed

+1012
-62
lines changed

3 files changed

+1012
-62
lines changed
Lines changed: 123 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
using System.Collections.Concurrent;
1+
using System.Diagnostics;
22
using System.Diagnostics.CodeAnalysis;
33
using System.Runtime.CompilerServices;
44
using Microsoft.Extensions.AI;
55
using Microsoft.Extensions.Logging;
66
using Microsoft.Extensions.Logging.Abstractions;
7-
using ModelContextProtocol.Client;
87

9-
namespace ModelContextProtocol;
8+
namespace ModelContextProtocol.Client;
109

1110
/// <summary>
1211
/// Extension methods for adding MCP client support to chat clients.
@@ -20,6 +19,7 @@ public static class McpChatClientBuilderExtensions
2019
/// <param name="builder">The <see cref="ChatClientBuilder"/> to configure.</param>
2120
/// <param name="httpClient">The <see cref="HttpClient"/> to use, or <see langword="null"/> to create a new instance.</param>
2221
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use, or <see langword="null"/> to resolve from services.</param>
22+
/// <param name="configureTransportOptions">An optional callback to configure the <see cref="HttpClientTransportOptions"/> for each <see cref="HostedMcpServerTool"/>.</param>
2323
/// <returns>The <see cref="ChatClientBuilder"/> for method chaining.</returns>
2424
/// <remarks>
2525
/// <para>
@@ -35,12 +35,13 @@ public static class McpChatClientBuilderExtensions
3535
public static ChatClientBuilder UseMcpClient(
3636
this ChatClientBuilder builder,
3737
HttpClient? httpClient = null,
38-
ILoggerFactory? loggerFactory = null)
38+
ILoggerFactory? loggerFactory = null,
39+
Action<HostedMcpServerTool, HttpClientTransportOptions>? configureTransportOptions = null)
3940
{
4041
return builder.Use((innerClient, services) =>
4142
{
4243
loggerFactory ??= (ILoggerFactory)services.GetService(typeof(ILoggerFactory))!;
43-
var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory);
44+
var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory, configureTransportOptions);
4445
return chatClient;
4546
});
4647
}
@@ -52,43 +53,45 @@ private sealed class McpChatClient : DelegatingChatClient
5253
private readonly ILogger _logger;
5354
private readonly HttpClient _httpClient;
5455
private readonly bool _ownsHttpClient;
55-
private readonly ConcurrentDictionary<string, Task<McpClient>> _mcpClientTasks = [];
56+
private readonly McpClientTasksLruCache _lruCache;
57+
private readonly Action<HostedMcpServerTool, HttpClientTransportOptions>? _configureTransportOptions;
5658

5759
/// <summary>
5860
/// Initializes a new instance of the <see cref="McpChatClient"/> class.
5961
/// </summary>
6062
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
6163
/// <param name="httpClient">An optional <see cref="HttpClient"/> to use when connecting to MCP servers. If not provided, a new instance will be created.</param>
6264
/// <param name="loggerFactory">An <see cref="ILoggerFactory"/> to use for logging information about function invocation.</param>
63-
public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null)
65+
/// <param name="configureTransportOptions">An optional callback to configure the <see cref="HttpClientTransportOptions"/> for each <see cref="HostedMcpServerTool"/>.</param>
66+
public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null, Action<HostedMcpServerTool, HttpClientTransportOptions>? configureTransportOptions = null)
6467
: base(innerClient)
6568
{
6669
_loggerFactory = loggerFactory;
6770
_logger = (ILogger?)loggerFactory?.CreateLogger<McpChatClient>() ?? NullLogger.Instance;
6871
_httpClient = httpClient ?? new HttpClient();
6972
_ownsHttpClient = httpClient is null;
73+
_lruCache = new McpClientTasksLruCache(capacity: 20);
74+
_configureTransportOptions = configureTransportOptions;
7075
}
7176

72-
/// <inheritdoc/>
7377
public override async Task<ChatResponse> GetResponseAsync(
7478
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
7579
{
7680
if (options?.Tools is { Count: > 0 })
7781
{
78-
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false);
82+
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools).ConfigureAwait(false);
7983
options = options.Clone();
8084
options.Tools = downstreamTools;
8185
}
8286

8387
return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
8488
}
8589

86-
/// <inheritdoc/>
8790
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
8891
{
8992
if (options?.Tools is { Count: > 0 })
9093
{
91-
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false);
94+
var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools).ConfigureAwait(false);
9295
options = options.Clone();
9396
options.Tools = downstreamTools;
9497
}
@@ -99,51 +102,52 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
99102
}
100103
}
101104

102-
private async Task<List<AITool>> BuildDownstreamAIToolsAsync(IList<AITool> inputTools, CancellationToken cancellationToken)
105+
private async Task<List<AITool>> BuildDownstreamAIToolsAsync(IList<AITool> chatOptionsTools)
103106
{
104107
List<AITool> downstreamTools = [];
105-
foreach (var tool in inputTools)
108+
foreach (var chatOptionsTool in chatOptionsTools)
106109
{
107-
if (tool is not HostedMcpServerTool mcpTool)
110+
if (chatOptionsTool is not HostedMcpServerTool hostedMcpTool)
108111
{
109112
// For other tools, we want to keep them in the list of tools.
110-
downstreamTools.Add(tool);
113+
downstreamTools.Add(chatOptionsTool);
111114
continue;
112115
}
113116

114-
if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) ||
117+
if (!Uri.TryCreate(hostedMcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) ||
115118
(parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps))
116119
{
117120
throw new InvalidOperationException(
118-
$"Invalid http(s) address: '{mcpTool.ServerAddress}'. MCP server address must be an absolute https(s) URL.");
121+
$"Invalid http(s) address: '{hostedMcpTool.ServerAddress}'. MCP server address must be an absolute http(s) URL.");
119122
}
120123

121-
// List all MCP functions from the specified MCP server.
122-
var mcpClient = await CreateMcpClientAsync(mcpTool.ServerAddress, parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false);
123-
var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
124+
// Get MCP client and its tools from cache (both are fetched together on first access).
125+
var (_, mcpTools) = await GetClientAndToolsAsync(hostedMcpTool, parsedAddress).ConfigureAwait(false);
124126

125127
// Add the listed functions to our list of tools we'll pass to the inner client.
126-
foreach (var mcpFunction in mcpFunctions)
128+
foreach (var mcpTool in mcpTools)
127129
{
128-
if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name))
130+
if (hostedMcpTool.AllowedTools is not null && !hostedMcpTool.AllowedTools.Contains(mcpTool.Name))
129131
{
130132
if (_logger.IsEnabled(LogLevel.Information))
131133
{
132-
_logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name);
134+
_logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpTool.Name);
133135
}
134136
continue;
135137
}
136138

137-
switch (mcpTool.ApprovalMode)
139+
var wrappedFunction = new McpRetriableAIFunction(mcpTool, hostedMcpTool, parsedAddress, this);
140+
141+
switch (hostedMcpTool.ApprovalMode)
138142
{
139143
case HostedMcpServerToolNeverRequireApprovalMode:
140-
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true:
141-
downstreamTools.Add(mcpFunction);
144+
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpTool.Name) is true:
145+
downstreamTools.Add(wrappedFunction);
142146
break;
143147

144148
default:
145149
// Default to always require approval if no specific mode is set.
146-
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
150+
downstreamTools.Add(new ApprovalRequiredAIFunction(wrappedFunction));
147151
break;
148152
}
149153
}
@@ -152,74 +156,131 @@ private async Task<List<AITool>> BuildDownstreamAIToolsAsync(IList<AITool> input
152156
return downstreamTools;
153157
}
154158

155-
/// <inheritdoc/>
156159
protected override void Dispose(bool disposing)
157160
{
158161
if (disposing)
159162
{
160-
// Dispose of the HTTP client if it was created by this client.
161163
if (_ownsHttpClient)
162164
{
163165
_httpClient?.Dispose();
164166
}
165167

166-
if (_mcpClientTasks is not null)
167-
{
168-
// Dispose of all cached MCP clients.
169-
foreach (var clientTask in _mcpClientTasks.Values)
170-
{
171-
if (clientTask.Status == TaskStatus.RanToCompletion)
172-
{
173-
_ = clientTask.Result.DisposeAsync();
174-
}
175-
}
176-
177-
_mcpClientTasks.Clear();
178-
}
168+
_lruCache.Dispose();
179169
}
180170

181171
base.Dispose(disposing);
182172
}
183173

184-
private async Task<McpClient> CreateMcpClientAsync(string key, Uri serverAddress, string serverName, string? authorizationToken)
174+
internal async Task<(McpClient Client, IList<McpClientTool> Tools)> GetClientAndToolsAsync(HostedMcpServerTool hostedMcpTool, Uri serverAddressUri)
185175
{
186176
// Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token.
187177
// Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently.
188-
#if NET
189-
// Avoid closure allocation.
190-
Task<McpClient> task = _mcpClientTasks.GetOrAdd(key,
191-
static (_, state) => state.self.CreateMcpClientCoreAsync(state.serverAddress, state.serverName, state.authorizationToken, CancellationToken.None),
192-
(self: this, serverAddress, serverName, authorizationToken));
193-
#else
194-
Task<McpClient> task = _mcpClientTasks.GetOrAdd(key,
195-
_ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None));
196-
#endif
178+
Task<(McpClient, IList<McpClientTool> Tools)> task = _lruCache.GetOrAdd(
179+
hostedMcpTool.ServerAddress,
180+
static (_, state) => state.self.CreateMcpClientAndToolsAsync(state.hostedMcpTool, state.serverAddressUri, CancellationToken.None),
181+
(self: this, hostedMcpTool, serverAddressUri));
197182

198183
try
199184
{
200185
return await task.ConfigureAwait(false);
201186
}
202187
catch
203188
{
204-
// Remove the failed task from cache so subsequent requests can retry.
205-
_mcpClientTasks.TryRemove(key, out _);
189+
bool result = RemoveMcpClientFromCache(hostedMcpTool.ServerAddress, out var removedTask);
190+
Debug.Assert(result && removedTask!.Status != TaskStatus.RanToCompletion);
206191
throw;
207192
}
208193
}
209194

210-
private Task<McpClient> CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken)
195+
private async Task<(McpClient Client, IList<McpClientTool> Tools)> CreateMcpClientAndToolsAsync(HostedMcpServerTool hostedMcpTool, Uri serverAddressUri, CancellationToken cancellationToken)
211196
{
212-
var transport = new HttpClientTransport(new HttpClientTransportOptions
197+
var transportOptions = new HttpClientTransportOptions
213198
{
214-
Endpoint = serverAddress,
215-
Name = serverName,
216-
AdditionalHeaders = authorizationToken is not null
199+
Endpoint = serverAddressUri,
200+
Name = hostedMcpTool.ServerName,
201+
AdditionalHeaders = hostedMcpTool.AuthorizationToken is not null
217202
// Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available.
218-
? new Dictionary<string, string>() { { "Authorization", $"Bearer {authorizationToken}" } }
203+
? new Dictionary<string, string>() { { "Authorization", $"Bearer {hostedMcpTool.AuthorizationToken}" } }
219204
: null,
220-
}, _httpClient, _loggerFactory);
205+
};
206+
207+
_configureTransportOptions?.Invoke(new DummyHostedMcpServerTool(hostedMcpTool.ServerName, serverAddressUri), transportOptions);
208+
209+
var transport = new HttpClientTransport(transportOptions, _httpClient, _loggerFactory);
210+
var client = await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false);
211+
try
212+
{
213+
var tools = await client.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
214+
return (client, tools);
215+
}
216+
catch
217+
{
218+
try
219+
{
220+
await client.DisposeAsync().ConfigureAwait(false);
221+
}
222+
catch { } // allow the original exception to propagate
223+
224+
throw;
225+
}
226+
}
227+
228+
internal bool RemoveMcpClientFromCache(string key, out Task<(McpClient Client, IList<McpClientTool> Tools)>? removedTask)
229+
=> _lruCache.TryRemove(key, out removedTask);
230+
231+
/// <summary>
232+
/// A temporary <see cref="HostedMcpServerTool"/> instance passed to the configureTransportOptions callback.
233+
/// This prevents the callback from modifying the original tool instance.
234+
/// </summary>
235+
private sealed class DummyHostedMcpServerTool(string serverName, Uri serverAddress)
236+
: HostedMcpServerTool(serverName, serverAddress);
237+
}
238+
239+
/// <summary>
240+
/// An AI function wrapper that retries the invocation by recreating an MCP client when an <see cref="HttpRequestException"/> occurs.
241+
/// For example, this can happen if a session is revoked or a server error occurs. The retry evicts the cached MCP client.
242+
/// </summary>
243+
[Experimental("MEAI001")]
244+
private sealed class McpRetriableAIFunction : DelegatingAIFunction
245+
{
246+
private readonly HostedMcpServerTool _hostedMcpTool;
247+
private readonly Uri _serverAddressUri;
248+
private readonly McpChatClient _chatClient;
249+
250+
public McpRetriableAIFunction(AIFunction innerFunction, HostedMcpServerTool hostedMcpTool, Uri serverAddressUri, McpChatClient chatClient)
251+
: base(innerFunction)
252+
{
253+
_hostedMcpTool = hostedMcpTool;
254+
_serverAddressUri = serverAddressUri;
255+
_chatClient = chatClient;
256+
}
257+
258+
protected override async ValueTask<object?> InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken)
259+
{
260+
try
261+
{
262+
return await base.InvokeCoreAsync(arguments, cancellationToken).ConfigureAwait(false);
263+
}
264+
catch (HttpRequestException) { }
265+
266+
bool result = _chatClient.RemoveMcpClientFromCache(_hostedMcpTool.ServerAddress, out var removedTask);
267+
Debug.Assert(result && removedTask!.Status == TaskStatus.RanToCompletion);
268+
_ = removedTask!.Result.Client.DisposeAsync().AsTask();
269+
270+
var freshTool = await GetCurrentToolAsync().ConfigureAwait(false);
271+
return await freshTool.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false);
272+
}
273+
274+
private async Task<AIFunction> GetCurrentToolAsync()
275+
{
276+
Debug.Assert(Uri.TryCreate(_hostedMcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) &&
277+
(parsedAddress.Scheme == Uri.UriSchemeHttp || parsedAddress.Scheme == Uri.UriSchemeHttps),
278+
"Server address should have been validated before construction");
221279

222-
return McpClient.CreateAsync(transport, cancellationToken: cancellationToken);
280+
var (client, tools) = await _chatClient.GetClientAndToolsAsync(_hostedMcpTool, _serverAddressUri!).ConfigureAwait(false);
281+
282+
return tools.FirstOrDefault(t => t.Name == Name) ??
283+
throw new McpProtocolException($"Tool '{Name}' no longer exists on the MCP server.", McpErrorCode.InvalidParams);
223284
}
224285
}
225286
}

0 commit comments

Comments
 (0)