1- using System . Collections . Concurrent ;
1+ using System . Diagnostics ;
22using System . Diagnostics . CodeAnalysis ;
33using System . Runtime . CompilerServices ;
44using Microsoft . Extensions . AI ;
55using Microsoft . Extensions . Logging ;
66using Microsoft . Extensions . Logging . Abstractions ;
7- using ModelContextProtocol . Client ;
87
9- namespace ModelContextProtocol ;
8+ namespace ModelContextProtocol . Client ;
109
1110/// <summary>
1211/// Extension methods for adding MCP client support to chat clients.
@@ -20,6 +19,7 @@ public static class McpChatClientBuilderExtensions
2019 /// <param name="builder">The <see cref="ChatClientBuilder"/> to configure.</param>
2120 /// <param name="httpClient">The <see cref="HttpClient"/> to use, or <see langword="null"/> to create a new instance.</param>
2221 /// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use, or <see langword="null"/> to resolve from services.</param>
22+ /// <param name="configureTransportOptions">An optional callback to configure the <see cref="HttpClientTransportOptions"/> for each <see cref="HostedMcpServerTool"/>.</param>
2323 /// <returns>The <see cref="ChatClientBuilder"/> for method chaining.</returns>
2424 /// <remarks>
2525 /// <para>
@@ -35,12 +35,13 @@ public static class McpChatClientBuilderExtensions
3535 public static ChatClientBuilder UseMcpClient (
3636 this ChatClientBuilder builder ,
3737 HttpClient ? httpClient = null ,
38- ILoggerFactory ? loggerFactory = null )
38+ ILoggerFactory ? loggerFactory = null ,
39+ Action < HostedMcpServerTool , HttpClientTransportOptions > ? configureTransportOptions = null )
3940 {
4041 return builder . Use ( ( innerClient , services ) =>
4142 {
4243 loggerFactory ??= ( ILoggerFactory ) services . GetService ( typeof ( ILoggerFactory ) ) ! ;
43- var chatClient = new McpChatClient ( innerClient , httpClient , loggerFactory ) ;
44+ var chatClient = new McpChatClient ( innerClient , httpClient , loggerFactory , configureTransportOptions ) ;
4445 return chatClient ;
4546 } ) ;
4647 }
@@ -52,43 +53,45 @@ private sealed class McpChatClient : DelegatingChatClient
5253 private readonly ILogger _logger ;
5354 private readonly HttpClient _httpClient ;
5455 private readonly bool _ownsHttpClient ;
55- private readonly ConcurrentDictionary < string , Task < McpClient > > _mcpClientTasks = [ ] ;
56+ private readonly McpClientTasksLruCache _lruCache ;
57+ private readonly Action < HostedMcpServerTool , HttpClientTransportOptions > ? _configureTransportOptions ;
5658
5759 /// <summary>
5860 /// Initializes a new instance of the <see cref="McpChatClient"/> class.
5961 /// </summary>
6062 /// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
6163 /// <param name="httpClient">An optional <see cref="HttpClient"/> to use when connecting to MCP servers. If not provided, a new instance will be created.</param>
6264 /// <param name="loggerFactory">An <see cref="ILoggerFactory"/> to use for logging information about function invocation.</param>
63- public McpChatClient ( IChatClient innerClient , HttpClient ? httpClient = null , ILoggerFactory ? loggerFactory = null )
65+ /// <param name="configureTransportOptions">An optional callback to configure the <see cref="HttpClientTransportOptions"/> for each <see cref="HostedMcpServerTool"/>.</param>
66+ public McpChatClient ( IChatClient innerClient , HttpClient ? httpClient = null , ILoggerFactory ? loggerFactory = null , Action < HostedMcpServerTool , HttpClientTransportOptions > ? configureTransportOptions = null )
6467 : base ( innerClient )
6568 {
6669 _loggerFactory = loggerFactory ;
6770 _logger = ( ILogger ? ) loggerFactory ? . CreateLogger < McpChatClient > ( ) ?? NullLogger . Instance ;
6871 _httpClient = httpClient ?? new HttpClient ( ) ;
6972 _ownsHttpClient = httpClient is null ;
73+ _lruCache = new McpClientTasksLruCache ( capacity : 20 ) ;
74+ _configureTransportOptions = configureTransportOptions ;
7075 }
7176
72- /// <inheritdoc/>
7377 public override async Task < ChatResponse > GetResponseAsync (
7478 IEnumerable < ChatMessage > messages , ChatOptions ? options = null , CancellationToken cancellationToken = default )
7579 {
7680 if ( options ? . Tools is { Count : > 0 } )
7781 {
78- var downstreamTools = await BuildDownstreamAIToolsAsync ( options . Tools , cancellationToken ) . ConfigureAwait ( false ) ;
82+ var downstreamTools = await BuildDownstreamAIToolsAsync ( options . Tools ) . ConfigureAwait ( false ) ;
7983 options = options . Clone ( ) ;
8084 options . Tools = downstreamTools ;
8185 }
8286
8387 return await base . GetResponseAsync ( messages , options , cancellationToken ) . ConfigureAwait ( false ) ;
8488 }
8589
86- /// <inheritdoc/>
8790 public override async IAsyncEnumerable < ChatResponseUpdate > GetStreamingResponseAsync ( IEnumerable < ChatMessage > messages , ChatOptions ? options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
8891 {
8992 if ( options ? . Tools is { Count : > 0 } )
9093 {
91- var downstreamTools = await BuildDownstreamAIToolsAsync ( options . Tools , cancellationToken ) . ConfigureAwait ( false ) ;
94+ var downstreamTools = await BuildDownstreamAIToolsAsync ( options . Tools ) . ConfigureAwait ( false ) ;
9295 options = options . Clone ( ) ;
9396 options . Tools = downstreamTools ;
9497 }
@@ -99,51 +102,52 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
99102 }
100103 }
101104
102- private async Task < List < AITool > > BuildDownstreamAIToolsAsync ( IList < AITool > inputTools , CancellationToken cancellationToken )
105+ private async Task < List < AITool > > BuildDownstreamAIToolsAsync ( IList < AITool > chatOptionsTools )
103106 {
104107 List < AITool > downstreamTools = [ ] ;
105- foreach ( var tool in inputTools )
108+ foreach ( var chatOptionsTool in chatOptionsTools )
106109 {
107- if ( tool is not HostedMcpServerTool mcpTool )
110+ if ( chatOptionsTool is not HostedMcpServerTool hostedMcpTool )
108111 {
109112 // For other tools, we want to keep them in the list of tools.
110- downstreamTools . Add ( tool ) ;
113+ downstreamTools . Add ( chatOptionsTool ) ;
111114 continue ;
112115 }
113116
114- if ( ! Uri . TryCreate ( mcpTool . ServerAddress , UriKind . Absolute , out var parsedAddress ) ||
117+ if ( ! Uri . TryCreate ( hostedMcpTool . ServerAddress , UriKind . Absolute , out var parsedAddress ) ||
115118 ( parsedAddress . Scheme != Uri . UriSchemeHttp && parsedAddress . Scheme != Uri . UriSchemeHttps ) )
116119 {
117120 throw new InvalidOperationException (
118- $ "Invalid http(s) address: '{ mcpTool . ServerAddress } '. MCP server address must be an absolute https (s) URL.") ;
121+ $ "Invalid http(s) address: '{ hostedMcpTool . ServerAddress } '. MCP server address must be an absolute http (s) URL.") ;
119122 }
120123
121- // List all MCP functions from the specified MCP server.
122- var mcpClient = await CreateMcpClientAsync ( mcpTool . ServerAddress , parsedAddress , mcpTool . ServerName , mcpTool . AuthorizationToken ) . ConfigureAwait ( false ) ;
123- var mcpFunctions = await mcpClient . ListToolsAsync ( cancellationToken : cancellationToken ) . ConfigureAwait ( false ) ;
124+ // Get MCP client and its tools from cache (both are fetched together on first access).
125+ var ( _, mcpTools ) = await GetClientAndToolsAsync ( hostedMcpTool , parsedAddress ) . ConfigureAwait ( false ) ;
124126
125127 // Add the listed functions to our list of tools we'll pass to the inner client.
126- foreach ( var mcpFunction in mcpFunctions )
128+ foreach ( var mcpTool in mcpTools )
127129 {
128- if ( mcpTool . AllowedTools is not null && ! mcpTool . AllowedTools . Contains ( mcpFunction . Name ) )
130+ if ( hostedMcpTool . AllowedTools is not null && ! hostedMcpTool . AllowedTools . Contains ( mcpTool . Name ) )
129131 {
130132 if ( _logger . IsEnabled ( LogLevel . Information ) )
131133 {
132- _logger . LogInformation ( "MCP function '{FunctionName}' is not allowed by the tool configuration." , mcpFunction . Name ) ;
134+ _logger . LogInformation ( "MCP function '{FunctionName}' is not allowed by the tool configuration." , mcpTool . Name ) ;
133135 }
134136 continue ;
135137 }
136138
137- switch ( mcpTool . ApprovalMode )
139+ var wrappedFunction = new McpRetriableAIFunction ( mcpTool , hostedMcpTool , parsedAddress , this ) ;
140+
141+ switch ( hostedMcpTool . ApprovalMode )
138142 {
139143 case HostedMcpServerToolNeverRequireApprovalMode :
140- case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode . NeverRequireApprovalToolNames ? . Contains ( mcpFunction . Name ) is true :
141- downstreamTools . Add ( mcpFunction ) ;
144+ case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode . NeverRequireApprovalToolNames ? . Contains ( mcpTool . Name ) is true :
145+ downstreamTools . Add ( wrappedFunction ) ;
142146 break ;
143147
144148 default :
145149 // Default to always require approval if no specific mode is set.
146- downstreamTools . Add ( new ApprovalRequiredAIFunction ( mcpFunction ) ) ;
150+ downstreamTools . Add ( new ApprovalRequiredAIFunction ( wrappedFunction ) ) ;
147151 break ;
148152 }
149153 }
@@ -152,74 +156,131 @@ private async Task<List<AITool>> BuildDownstreamAIToolsAsync(IList<AITool> input
152156 return downstreamTools ;
153157 }
154158
155- /// <inheritdoc/>
156159 protected override void Dispose ( bool disposing )
157160 {
158161 if ( disposing )
159162 {
160- // Dispose of the HTTP client if it was created by this client.
161163 if ( _ownsHttpClient )
162164 {
163165 _httpClient ? . Dispose ( ) ;
164166 }
165167
166- if ( _mcpClientTasks is not null )
167- {
168- // Dispose of all cached MCP clients.
169- foreach ( var clientTask in _mcpClientTasks . Values )
170- {
171- if ( clientTask . Status == TaskStatus . RanToCompletion )
172- {
173- _ = clientTask . Result . DisposeAsync ( ) ;
174- }
175- }
176-
177- _mcpClientTasks . Clear ( ) ;
178- }
168+ _lruCache . Dispose ( ) ;
179169 }
180170
181171 base . Dispose ( disposing ) ;
182172 }
183173
184- private async Task < McpClient > CreateMcpClientAsync ( string key , Uri serverAddress , string serverName , string ? authorizationToken )
174+ internal async Task < ( McpClient Client , IList < McpClientTool > Tools ) > GetClientAndToolsAsync ( HostedMcpServerTool hostedMcpTool , Uri serverAddressUri )
185175 {
186176 // Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token.
187177 // Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently.
188- #if NET
189- // Avoid closure allocation.
190- Task < McpClient > task = _mcpClientTasks . GetOrAdd ( key ,
191- static ( _ , state ) => state . self . CreateMcpClientCoreAsync ( state . serverAddress , state . serverName , state . authorizationToken , CancellationToken . None ) ,
192- ( self : this , serverAddress , serverName , authorizationToken ) ) ;
193- #else
194- Task < McpClient > task = _mcpClientTasks . GetOrAdd ( key ,
195- _ => CreateMcpClientCoreAsync ( serverAddress , serverName , authorizationToken , CancellationToken . None ) ) ;
196- #endif
178+ Task < ( McpClient , IList < McpClientTool > Tools ) > task = _lruCache . GetOrAdd (
179+ hostedMcpTool . ServerAddress ,
180+ static ( _ , state ) => state . self . CreateMcpClientAndToolsAsync ( state . hostedMcpTool , state . serverAddressUri , CancellationToken . None ) ,
181+ ( self : this , hostedMcpTool , serverAddressUri ) ) ;
197182
198183 try
199184 {
200185 return await task . ConfigureAwait ( false ) ;
201186 }
202187 catch
203188 {
204- // Remove the failed task from cache so subsequent requests can retry.
205- _mcpClientTasks . TryRemove ( key , out _ ) ;
189+ bool result = RemoveMcpClientFromCache ( hostedMcpTool . ServerAddress , out var removedTask ) ;
190+ Debug . Assert ( result && removedTask ! . Status != TaskStatus . RanToCompletion ) ;
206191 throw ;
207192 }
208193 }
209194
210- private Task < McpClient > CreateMcpClientCoreAsync ( Uri serverAddress , string serverName , string ? authorizationToken , CancellationToken cancellationToken )
195+ private async Task < ( McpClient Client , IList < McpClientTool > Tools ) > CreateMcpClientAndToolsAsync ( HostedMcpServerTool hostedMcpTool , Uri serverAddressUri , CancellationToken cancellationToken )
211196 {
212- var transport = new HttpClientTransport ( new HttpClientTransportOptions
197+ var transportOptions = new HttpClientTransportOptions
213198 {
214- Endpoint = serverAddress ,
215- Name = serverName ,
216- AdditionalHeaders = authorizationToken is not null
199+ Endpoint = serverAddressUri ,
200+ Name = hostedMcpTool . ServerName ,
201+ AdditionalHeaders = hostedMcpTool . AuthorizationToken is not null
217202 // Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available.
218- ? new Dictionary < string , string > ( ) { { "Authorization" , $ "Bearer { authorizationToken } " } }
203+ ? new Dictionary < string , string > ( ) { { "Authorization" , $ "Bearer { hostedMcpTool . AuthorizationToken } " } }
219204 : null ,
220- } , _httpClient , _loggerFactory ) ;
205+ } ;
206+
207+ _configureTransportOptions ? . Invoke ( new DummyHostedMcpServerTool ( hostedMcpTool . ServerName , serverAddressUri ) , transportOptions ) ;
208+
209+ var transport = new HttpClientTransport ( transportOptions , _httpClient , _loggerFactory ) ;
210+ var client = await McpClient . CreateAsync ( transport , cancellationToken : cancellationToken ) . ConfigureAwait ( false ) ;
211+ try
212+ {
213+ var tools = await client . ListToolsAsync ( cancellationToken : cancellationToken ) . ConfigureAwait ( false ) ;
214+ return ( client , tools ) ;
215+ }
216+ catch
217+ {
218+ try
219+ {
220+ await client . DisposeAsync ( ) . ConfigureAwait ( false ) ;
221+ }
222+ catch { } // allow the original exception to propagate
223+
224+ throw ;
225+ }
226+ }
227+
228+ internal bool RemoveMcpClientFromCache ( string key , out Task < ( McpClient Client , IList < McpClientTool > Tools ) > ? removedTask )
229+ => _lruCache . TryRemove ( key , out removedTask ) ;
230+
231+ /// <summary>
232+ /// A temporary <see cref="HostedMcpServerTool"/> instance passed to the configureTransportOptions callback.
233+ /// This prevents the callback from modifying the original tool instance.
234+ /// </summary>
235+ private sealed class DummyHostedMcpServerTool ( string serverName , Uri serverAddress )
236+ : HostedMcpServerTool ( serverName , serverAddress ) ;
237+ }
238+
239+ /// <summary>
240+ /// An AI function wrapper that retries the invocation by recreating an MCP client when an <see cref="HttpRequestException"/> occurs.
241+ /// For example, this can happen if a session is revoked or a server error occurs. The retry evicts the cached MCP client.
242+ /// </summary>
243+ [ Experimental ( "MEAI001" ) ]
244+ private sealed class McpRetriableAIFunction : DelegatingAIFunction
245+ {
246+ private readonly HostedMcpServerTool _hostedMcpTool ;
247+ private readonly Uri _serverAddressUri ;
248+ private readonly McpChatClient _chatClient ;
249+
250+ public McpRetriableAIFunction ( AIFunction innerFunction , HostedMcpServerTool hostedMcpTool , Uri serverAddressUri , McpChatClient chatClient )
251+ : base ( innerFunction )
252+ {
253+ _hostedMcpTool = hostedMcpTool ;
254+ _serverAddressUri = serverAddressUri ;
255+ _chatClient = chatClient ;
256+ }
257+
258+ protected override async ValueTask < object ? > InvokeCoreAsync ( AIFunctionArguments arguments , CancellationToken cancellationToken )
259+ {
260+ try
261+ {
262+ return await base . InvokeCoreAsync ( arguments , cancellationToken ) . ConfigureAwait ( false ) ;
263+ }
264+ catch ( HttpRequestException ) { }
265+
266+ bool result = _chatClient . RemoveMcpClientFromCache ( _hostedMcpTool . ServerAddress , out var removedTask ) ;
267+ Debug . Assert ( result && removedTask ! . Status == TaskStatus . RanToCompletion ) ;
268+ _ = removedTask ! . Result . Client . DisposeAsync ( ) . AsTask ( ) ;
269+
270+ var freshTool = await GetCurrentToolAsync ( ) . ConfigureAwait ( false ) ;
271+ return await freshTool . InvokeAsync ( arguments , cancellationToken ) . ConfigureAwait ( false ) ;
272+ }
273+
274+ private async Task < AIFunction > GetCurrentToolAsync ( )
275+ {
276+ Debug . Assert ( Uri . TryCreate ( _hostedMcpTool . ServerAddress , UriKind . Absolute , out var parsedAddress ) &&
277+ ( parsedAddress . Scheme == Uri . UriSchemeHttp || parsedAddress . Scheme == Uri . UriSchemeHttps ) ,
278+ "Server address should have been validated before construction" ) ;
221279
222- return McpClient . CreateAsync ( transport , cancellationToken : cancellationToken ) ;
280+ var ( client , tools ) = await _chatClient . GetClientAndToolsAsync ( _hostedMcpTool , _serverAddressUri ! ) . ConfigureAwait ( false ) ;
281+
282+ return tools . FirstOrDefault ( t => t . Name == Name ) ??
283+ throw new McpProtocolException ( $ "Tool '{ Name } ' no longer exists on the MCP server.", McpErrorCode . InvalidParams ) ;
223284 }
224285 }
225286}
0 commit comments