|
| 1 | +using System.Collections.Concurrent; |
| 2 | +using System.Runtime.CompilerServices; |
| 3 | +using Microsoft.Extensions.AI; |
| 4 | +using Microsoft.Extensions.Logging; |
| 5 | +using Microsoft.Extensions.Logging.Abstractions; |
| 6 | +using ModelContextProtocol.Client; |
| 7 | +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. |
| 8 | + |
| 9 | +namespace ModelContextProtocol; |
| 10 | + |
| 11 | +/// <summary> |
| 12 | +/// Extension methods for adding MCP client support to chat clients. |
| 13 | +/// </summary> |
| 14 | +public static class McpChatClientBuilderExtensions |
| 15 | +{ |
| 16 | + /// <summary> |
| 17 | + /// Adds a chat client to the chat client pipeline that creates an <see cref="McpClient"/> for each <see cref="HostedMcpServerTool"/> |
| 18 | + /// in <see cref="ChatOptions.Tools"/> and augments it with the tools from MCP servers as <see cref="AIFunction"/> instances. |
| 19 | + /// </summary> |
| 20 | + /// <param name="builder">The <see cref="ChatClientBuilder"/> to configure.</param> |
| 21 | + /// <param name="httpClient">The <see cref="HttpClient"/> to use, or <see langword="null"/> to create a new instance.</param> |
| 22 | + /// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use, or <see langword="null"/> to resolve from services.</param> |
| 23 | + /// <returns>The <see cref="ChatClientBuilder"/> for method chaining.</returns> |
| 24 | + /// <remarks> |
| 25 | + /// <para> |
| 26 | + /// When a <c>HostedMcpServerTool</c> is encountered in the tools collection, the client |
| 27 | + /// connects to the MCP server, retrieves available tools, and expands them into callable AI functions. |
| 28 | + /// Connections are cached by server address to avoid redundant connections. |
| 29 | + /// </para> |
| 30 | + /// <para> |
| 31 | + /// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers. |
| 32 | + /// </para> |
| 33 | + /// </remarks> |
| 34 | + public static ChatClientBuilder UseMcpClient( |
| 35 | + this ChatClientBuilder builder, |
| 36 | + HttpClient? httpClient = null, |
| 37 | + ILoggerFactory? loggerFactory = null) |
| 38 | + { |
| 39 | + return builder.Use((innerClient, services) => |
| 40 | + { |
| 41 | + loggerFactory ??= (ILoggerFactory)services.GetService(typeof(ILoggerFactory))!; |
| 42 | + var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory); |
| 43 | + return chatClient; |
| 44 | + }); |
| 45 | + } |
| 46 | + |
| 47 | + private class McpChatClient : DelegatingChatClient |
| 48 | + { |
| 49 | + private readonly ILoggerFactory? _loggerFactory; |
| 50 | + private readonly ILogger _logger; |
| 51 | + private readonly HttpClient _httpClient; |
| 52 | + private readonly bool _ownsHttpClient; |
| 53 | + private ConcurrentDictionary<string, Task<McpClient>>? _mcpClientTasks = null; |
| 54 | + |
| 55 | + /// <summary> |
| 56 | + /// Initializes a new instance of the <see cref="McpChatClient"/> class. |
| 57 | + /// </summary> |
| 58 | + /// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param> |
| 59 | + /// <param name="httpClient">An optional <see cref="HttpClient"/> to use when connecting to MCP servers. If not provided, a new instance will be created.</param> |
| 60 | + /// <param name="loggerFactory">An <see cref="ILoggerFactory"/> to use for logging information about function invocation.</param> |
| 61 | + public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) |
| 62 | + : base(innerClient) |
| 63 | + { |
| 64 | + _loggerFactory = loggerFactory; |
| 65 | + _logger = (ILogger?)loggerFactory?.CreateLogger<McpChatClient>() ?? NullLogger.Instance; |
| 66 | + _httpClient = httpClient ?? new HttpClient(); |
| 67 | + _ownsHttpClient = httpClient is null; |
| 68 | + } |
| 69 | + |
| 70 | + /// <inheritdoc/> |
| 71 | + public override async Task<ChatResponse> GetResponseAsync( |
| 72 | + IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default) |
| 73 | + { |
| 74 | + if (options?.Tools is { Count: > 0 }) |
| 75 | + { |
| 76 | + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false); |
| 77 | + options = options.Clone(); |
| 78 | + options.Tools = downstreamTools; |
| 79 | + } |
| 80 | + |
| 81 | + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); |
| 82 | + } |
| 83 | + |
| 84 | + /// <inheritdoc/> |
| 85 | + public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) |
| 86 | + { |
| 87 | + if (options?.Tools is { Count: > 0 }) |
| 88 | + { |
| 89 | + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false); |
| 90 | + options = options.Clone(); |
| 91 | + options.Tools = downstreamTools; |
| 92 | + } |
| 93 | + |
| 94 | + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) |
| 95 | + { |
| 96 | + yield return update; |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + private async Task<List<AITool>?> BuildDownstreamAIToolsAsync(IList<AITool>? inputTools, CancellationToken cancellationToken) |
| 101 | + { |
| 102 | + List<AITool>? downstreamTools = null; |
| 103 | + foreach (var tool in inputTools ?? []) |
| 104 | + { |
| 105 | + if (tool is not HostedMcpServerTool mcpTool) |
| 106 | + { |
| 107 | + // For other tools, we want to keep them in the list of tools. |
| 108 | + downstreamTools ??= new List<AITool>(); |
| 109 | + downstreamTools.Add(tool); |
| 110 | + continue; |
| 111 | + } |
| 112 | + |
| 113 | + if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) || |
| 114 | + (parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps)) |
| 115 | + { |
| 116 | + throw new InvalidOperationException( |
| 117 | + $"MCP server address must be an absolute HTTP or HTTPS URI. Invalid address: '{mcpTool.ServerAddress}'"); |
| 118 | + } |
| 119 | + |
| 120 | + // List all MCP functions from the specified MCP server. |
| 121 | + // This will need some caching in a real-world scenario to avoid repeated calls. |
| 122 | + var mcpClient = await CreateMcpClientAsync(parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false); |
| 123 | + var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false); |
| 124 | + |
| 125 | + // Add the listed functions to our list of tools we'll pass to the inner client. |
| 126 | + foreach (var mcpFunction in mcpFunctions) |
| 127 | + { |
| 128 | + if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name)) |
| 129 | + { |
| 130 | + _logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name); |
| 131 | + continue; |
| 132 | + } |
| 133 | + |
| 134 | + downstreamTools ??= new List<AITool>(); |
| 135 | + switch (mcpTool.ApprovalMode) |
| 136 | + { |
| 137 | + case HostedMcpServerToolAlwaysRequireApprovalMode alwaysRequireApproval: |
| 138 | + downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); |
| 139 | + break; |
| 140 | + case HostedMcpServerToolNeverRequireApprovalMode neverRequireApproval: |
| 141 | + downstreamTools.Add(mcpFunction); |
| 142 | + break; |
| 143 | + case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.AlwaysRequireApprovalToolNames?.Contains(mcpFunction.Name) is true: |
| 144 | + downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); |
| 145 | + break; |
| 146 | + case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true: |
| 147 | + downstreamTools.Add(mcpFunction); |
| 148 | + break; |
| 149 | + default: |
| 150 | + // Default to always require approval if no specific mode is set. |
| 151 | + downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); |
| 152 | + break; |
| 153 | + } |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + return downstreamTools; |
| 158 | + } |
| 159 | + |
| 160 | + /// <inheritdoc/> |
| 161 | + protected override void Dispose(bool disposing) |
| 162 | + { |
| 163 | + if (disposing) |
| 164 | + { |
| 165 | + // Dispose of the HTTP client if it was created by this client. |
| 166 | + if (_ownsHttpClient) |
| 167 | + { |
| 168 | + _httpClient?.Dispose(); |
| 169 | + } |
| 170 | + |
| 171 | + if (_mcpClientTasks is not null) |
| 172 | + { |
| 173 | + // Dispose of all cached MCP clients. |
| 174 | + foreach (var clientTask in _mcpClientTasks.Values) |
| 175 | + { |
| 176 | +#if NETSTANDARD2_0 |
| 177 | + if (clientTask.Status == TaskStatus.RanToCompletion) |
| 178 | +#else |
| 179 | + if (clientTask.IsCompletedSuccessfully) |
| 180 | +#endif |
| 181 | + { |
| 182 | + _ = clientTask.Result.DisposeAsync(); |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + _mcpClientTasks.Clear(); |
| 187 | + } |
| 188 | + } |
| 189 | + |
| 190 | + base.Dispose(disposing); |
| 191 | + } |
| 192 | + |
| 193 | + private Task<McpClient> CreateMcpClientAsync(Uri serverAddress, string serverName, string? authorizationToken) |
| 194 | + { |
| 195 | + if (_mcpClientTasks is null) |
| 196 | + { |
| 197 | + _mcpClientTasks = new ConcurrentDictionary<string, Task<McpClient>>(StringComparer.OrdinalIgnoreCase); |
| 198 | + } |
| 199 | + |
| 200 | + // Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token. |
| 201 | + // Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently. |
| 202 | + return _mcpClientTasks.GetOrAdd(serverAddress.ToString(), _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None)); |
| 203 | + } |
| 204 | + |
| 205 | + private async Task<McpClient> CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken) |
| 206 | + { |
| 207 | + var serverAddressKey = serverAddress.ToString(); |
| 208 | + try |
| 209 | + { |
| 210 | + var transport = new HttpClientTransport(new HttpClientTransportOptions |
| 211 | + { |
| 212 | + Endpoint = serverAddress, |
| 213 | + Name = serverName, |
| 214 | + AdditionalHeaders = authorizationToken is not null |
| 215 | + // Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available. |
| 216 | + ? new Dictionary<string, string>() { { "Authorization", $"Bearer {authorizationToken}" } } |
| 217 | + : null, |
| 218 | + }, _httpClient, _loggerFactory); |
| 219 | + |
| 220 | + return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false); |
| 221 | + } |
| 222 | + catch |
| 223 | + { |
| 224 | + // Remove the failed task from cache so subsequent requests can retry |
| 225 | + _mcpClientTasks?.TryRemove(serverAddressKey, out _); |
| 226 | + throw; |
| 227 | + } |
| 228 | + } |
| 229 | + } |
| 230 | +} |
0 commit comments