Skip to content

Commit 1b3de48

Browse files
committed
Address feedback
1 parent a069f64 commit 1b3de48

File tree

2 files changed

+50
-87
lines changed

2 files changed

+50
-87
lines changed

src/ModelContextProtocol/McpChatClientBuilderExtensions.cs

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using System.Collections.Concurrent;
2+
using System.Diagnostics.CodeAnalysis;
23
using System.Runtime.CompilerServices;
34
using Microsoft.Extensions.AI;
45
using Microsoft.Extensions.Logging;
56
using Microsoft.Extensions.Logging.Abstractions;
67
using ModelContextProtocol.Client;
7-
#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
88

99
namespace ModelContextProtocol;
1010

@@ -31,6 +31,7 @@ public static class McpChatClientBuilderExtensions
3131
/// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers.
3232
/// </para>
3333
/// </remarks>
34+
[Experimental("MEAI001")]
3435
public static ChatClientBuilder UseMcpClient(
3536
this ChatClientBuilder builder,
3637
HttpClient? httpClient = null,
@@ -44,13 +45,14 @@ public static ChatClientBuilder UseMcpClient(
4445
});
4546
}
4647

47-
private class McpChatClient : DelegatingChatClient
48+
[Experimental("MEAI001")]
49+
private sealed class McpChatClient : DelegatingChatClient
4850
{
4951
private readonly ILoggerFactory? _loggerFactory;
5052
private readonly ILogger _logger;
5153
private readonly HttpClient _httpClient;
5254
private readonly bool _ownsHttpClient;
53-
private ConcurrentDictionary<string, Task<McpClient>>? _mcpClientTasks = null;
55+
private readonly ConcurrentDictionary<string, Task<McpClient>> _mcpClientTasks = [];
5456

5557
/// <summary>
5658
/// Initializes a new instance of the <see cref="McpChatClient"/> class.
@@ -97,55 +99,48 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
9799
}
98100
}
99101

100-
private async Task<List<AITool>?> BuildDownstreamAIToolsAsync(IList<AITool>? inputTools, CancellationToken cancellationToken)
102+
private async Task<List<AITool>> BuildDownstreamAIToolsAsync(IList<AITool> inputTools, CancellationToken cancellationToken)
101103
{
102-
List<AITool>? downstreamTools = null;
103-
foreach (var tool in inputTools ?? [])
104+
List<AITool> downstreamTools = [];
105+
foreach (var tool in inputTools)
104106
{
105107
if (tool is not HostedMcpServerTool mcpTool)
106108
{
107109
// For other tools, we want to keep them in the list of tools.
108-
downstreamTools ??= new List<AITool>();
109110
downstreamTools.Add(tool);
110111
continue;
111112
}
112113

113114
if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) ||
114-
(parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps))
115+
(parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps))
115116
{
116-
throw new InvalidOperationException(
117-
$"MCP server address must be an absolute HTTP or HTTPS URI. Invalid address: '{mcpTool.ServerAddress}'");
117+
throw new InvalidOperationException(
118+
$"Invalid http(s) address: '{mcpTool.ServerAddress}'. MCP server address must be an absolute https(s) URL.");
118119
}
119120

120121
// List all MCP functions from the specified MCP server.
121-
// This will need some caching in a real-world scenario to avoid repeated calls.
122-
var mcpClient = await CreateMcpClientAsync(parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false);
122+
var mcpClient = await CreateMcpClientAsync(mcpTool.ServerAddress, parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false);
123123
var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
124124

125125
// Add the listed functions to our list of tools we'll pass to the inner client.
126126
foreach (var mcpFunction in mcpFunctions)
127127
{
128128
if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name))
129129
{
130-
_logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name);
130+
if (_logger.IsEnabled(LogLevel.Information))
131+
{
132+
_logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name);
133+
}
131134
continue;
132135
}
133136

134-
downstreamTools ??= new List<AITool>();
135137
switch (mcpTool.ApprovalMode)
136138
{
137-
case HostedMcpServerToolAlwaysRequireApprovalMode alwaysRequireApproval:
138-
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
139-
break;
140-
case HostedMcpServerToolNeverRequireApprovalMode neverRequireApproval:
141-
downstreamTools.Add(mcpFunction);
142-
break;
143-
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.AlwaysRequireApprovalToolNames?.Contains(mcpFunction.Name) is true:
144-
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
145-
break;
139+
case HostedMcpServerToolNeverRequireApprovalMode:
146140
case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true:
147141
downstreamTools.Add(mcpFunction);
148142
break;
143+
149144
default:
150145
// Default to always require approval if no specific mode is set.
151146
downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction));
@@ -173,11 +168,7 @@ protected override void Dispose(bool disposing)
173168
// Dispose of all cached MCP clients.
174169
foreach (var clientTask in _mcpClientTasks.Values)
175170
{
176-
#if NETSTANDARD2_0
177171
if (clientTask.Status == TaskStatus.RanToCompletion)
178-
#else
179-
if (clientTask.IsCompletedSuccessfully)
180-
#endif
181172
{
182173
_ = clientTask.Result.DisposeAsync();
183174
}
@@ -190,41 +181,45 @@ protected override void Dispose(bool disposing)
190181
base.Dispose(disposing);
191182
}
192183

193-
private Task<McpClient> CreateMcpClientAsync(Uri serverAddress, string serverName, string? authorizationToken)
184+
private async Task<McpClient> CreateMcpClientAsync(string key, Uri serverAddress, string serverName, string? authorizationToken)
194185
{
195-
if (_mcpClientTasks is null)
196-
{
197-
_mcpClientTasks = new ConcurrentDictionary<string, Task<McpClient>>(StringComparer.OrdinalIgnoreCase);
198-
}
199-
200186
// Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token.
201187
// Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently.
202-
return _mcpClientTasks.GetOrAdd(serverAddress.ToString(), _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None));
203-
}
188+
#if NET
189+
// Avoid closure allocation.
190+
Task<McpClient> task = _mcpClientTasks.GetOrAdd(key,
191+
static (_, state) => state.self.CreateMcpClientCoreAsync(state.serverAddress, state.serverName, state.authorizationToken, CancellationToken.None),
192+
(self: this, serverAddress, serverName, authorizationToken));
193+
#else
194+
Task<McpClient> task = _mcpClientTasks.GetOrAdd(key,
195+
_ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None));
196+
#endif
204197

205-
private async Task<McpClient> CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken)
206-
{
207-
var serverAddressKey = serverAddress.ToString();
208198
try
209199
{
210-
var transport = new HttpClientTransport(new HttpClientTransportOptions
211-
{
212-
Endpoint = serverAddress,
213-
Name = serverName,
214-
AdditionalHeaders = authorizationToken is not null
215-
// Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available.
216-
? new Dictionary<string, string>() { { "Authorization", $"Bearer {authorizationToken}" } }
217-
: null,
218-
}, _httpClient, _loggerFactory);
219-
220-
return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false);
200+
return await task.ConfigureAwait(false);
221201
}
222202
catch
223203
{
224-
// Remove the failed task from cache so subsequent requests can retry
225-
_mcpClientTasks?.TryRemove(serverAddressKey, out _);
204+
// Remove the failed task from cache so subsequent requests can retry.
205+
_mcpClientTasks.TryRemove(key, out _);
226206
throw;
227207
}
228208
}
209+
210+
private Task<McpClient> CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken)
211+
{
212+
var transport = new HttpClientTransport(new HttpClientTransportOptions
213+
{
214+
Endpoint = serverAddress,
215+
Name = serverName,
216+
AdditionalHeaders = authorizationToken is not null
217+
// Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available.
218+
? new Dictionary<string, string>() { { "Authorization", $"Bearer {authorizationToken}" } }
219+
: null,
220+
}, _httpClient, _loggerFactory);
221+
222+
return McpClient.CreateAsync(transport, cancellationToken: cancellationToken);
223+
}
229224
}
230225
}

tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -219,30 +219,6 @@ public async Task UseMcpClient_ApprovalsWorkCorrectly(
219219
Assert.Equivalent(expectedNormalAIFunctions, toolsNotRequiringApproval);
220220
}
221221

222-
[Theory]
223-
[InlineData(false)]
224-
[InlineData(true)]
225-
public async Task UseMcpClient_SupportsConnectorIdAsServer(bool streaming)
226-
{
227-
// Arrange
228-
IChatClient sut = CreateTestChatClient(out var callbackState);
229-
const string connectorId = "test-connector-123";
230-
var mcpTool = new HostedMcpServerTool(connectorId, _transportOptions.Endpoint);
231-
var options = new ChatOptions { Tools = [mcpTool] };
232-
233-
// Act
234-
await GetResponseAsync(sut, options, streaming);
235-
236-
// Assert
237-
Assert.NotNull(callbackState.CapturedOptions);
238-
Assert.NotNull(callbackState.CapturedOptions.Tools);
239-
var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList();
240-
Assert.Equal(3, toolNames.Count);
241-
Assert.Contains("echo", toolNames);
242-
Assert.Contains("echoSessionId", toolNames);
243-
Assert.Contains("sampleLLM", toolNames);
244-
}
245-
246222
[Theory]
247223
[InlineData(false)]
248224
[InlineData(true)]
@@ -280,18 +256,10 @@ public async Task UseMcpClient_AllowedTools_FiltersCorrectly(bool streaming, str
280256

281257
// Assert
282258
Assert.NotNull(callbackState.CapturedOptions);
283-
if (expectedTools.Length == 0)
284-
{
285-
// When all MCP tools are filtered out and no other tools exist, the Tools collection should be null
286-
Assert.Null(callbackState.CapturedOptions.Tools);
287-
}
288-
else
289-
{
290-
Assert.NotNull(callbackState.CapturedOptions.Tools);
291-
var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList();
292-
Assert.Equal(expectedTools.Length, toolNames.Count);
293-
Assert.Equivalent(expectedTools, toolNames);
294-
}
259+
Assert.NotNull(callbackState.CapturedOptions.Tools);
260+
var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList();
261+
Assert.Equal(expectedTools.Length, toolNames.Count);
262+
Assert.Equivalent(expectedTools, toolNames);
295263
}
296264

297265
[Theory]

0 commit comments

Comments
 (0)