@@ -22,6 +22,12 @@ namespace ModelContextProtocol.Shared;
2222/// </summary>
2323internal abstract class McpJsonRpcEndpoint : IAsyncDisposable
2424{
25+ /// <summary>
26+ /// In-flight request handling, indexed by request ID. The value provides a <see cref="CancellationTokenSource"/>
27+ /// that can be used to request cancellation of the in-flight handler.
28+ /// </summary>
29+ private static readonly ConcurrentDictionary < RequestId , CancellationTokenSource > s_handlingRequests = new ( ) ;
30+
2531 private readonly string _id = Guid . NewGuid ( ) . ToString ( "N" ) ;
2632 private readonly ITransport _transport ;
2733 private readonly ConcurrentDictionary < RequestId , TaskCompletionSource < IJsonRpcMessage > > _pendingRequests ;
@@ -78,25 +84,69 @@ internal async Task ProcessMessagesAsync(CancellationToken cancellationToken)
7884 {
7985 _logger . TransportMessageRead ( EndpointName , message . GetType ( ) . Name ) ;
8086
81- // Fire and forget the message handling task to avoid blocking the transport
82- // If awaiting the task, the transport will not be able to read more messages,
83- // which could lead to a deadlock if the handler sends a message back
8487 _ = ProcessMessageAsync ( ) ;
8588 async Task ProcessMessageAsync ( )
8689 {
90+ IJsonRpcMessageWithId ? messageWithId = message as IJsonRpcMessageWithId ;
91+ CancellationTokenSource ? combinedCts = null ;
92+ try
93+ {
94+ // Register before we yield, so that the tracking is guaranteed to be there
95+ // when subsequent messages arrive, even if the asynchronous processing happens
96+ // out of order.
97+ if ( messageWithId is not null )
98+ {
99+ combinedCts = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
100+ s_handlingRequests [ messageWithId . Id ] = combinedCts ;
101+ }
102+
103+ // Fire and forget the message handling to avoid blocking the transport
104+ // If awaiting the task, the transport will not be able to read more messages,
105+ // which could lead to a deadlock if the handler sends a message back
87106#if NET
88- await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
107+ await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
89108#else
90- await default ( ForceYielding ) ;
109+ await default ( ForceYielding ) ;
91110#endif
92- try
93- {
94- await HandleMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
111+
112+ // Handle the message.
113+ await HandleMessageAsync ( message , combinedCts ? . Token ?? cancellationToken ) . ConfigureAwait ( false ) ;
95114 }
96115 catch ( Exception ex )
97116 {
98- var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
99- _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
117+ // Only send responses for request errors that aren't user-initiated cancellation.
118+ bool isUserCancellation =
119+ ex is OperationCanceledException &&
120+ ! cancellationToken . IsCancellationRequested &&
121+ combinedCts ? . IsCancellationRequested is true ;
122+
123+ if ( ! isUserCancellation && message is JsonRpcRequest request )
124+ {
125+ _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
126+ await _transport . SendMessageAsync ( new JsonRpcError
127+ {
128+ Id = request . Id ,
129+ JsonRpc = "2.0" ,
130+ Error = new JsonRpcErrorDetail
131+ {
132+ Code = ErrorCodes . InternalError ,
133+ Message = ex . Message
134+ }
135+ } , cancellationToken ) . ConfigureAwait ( false ) ;
136+ }
137+ else if ( ex is not OperationCanceledException )
138+ {
139+ var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
140+ _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
141+ }
142+ }
143+ finally
144+ {
145+ if ( messageWithId is not null )
146+ {
147+ s_handlingRequests . TryRemove ( messageWithId . Id , out _ ) ;
148+ combinedCts ! . Dispose ( ) ;
149+ }
100150 }
101151 }
102152 }
@@ -136,6 +186,24 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
136186
137187 private async Task HandleNotification ( JsonRpcNotification notification )
138188 {
189+ // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
190+ if ( notification . Method == NotificationMethods . CancelledNotification )
191+ {
192+ try
193+ {
194+ if ( GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
195+ s_handlingRequests . TryGetValue ( cn . RequestId , out var cts ) )
196+ {
197+ await cts . CancelAsync ( ) . ConfigureAwait ( false ) ;
198+ }
199+ }
200+ catch
201+ {
202+ // "Invalid cancellation notifications SHOULD be ignored"
203+ }
204+ }
205+
206+ // Handle user-defined notifications.
139207 if ( _notificationHandlers . TryGetValue ( notification . Method , out var handlers ) )
140208 {
141209 foreach ( var notificationHandler in handlers )
@@ -170,33 +238,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
170238 {
171239 if ( _requestHandlers . TryGetValue ( request . Method , out var handler ) )
172240 {
173- try
174- {
175- _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
176- var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
177- _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
178- await _transport . SendMessageAsync ( new JsonRpcResponse
179- {
180- Id = request . Id ,
181- JsonRpc = "2.0" ,
182- Result = result
183- } , cancellationToken ) . ConfigureAwait ( false ) ;
184- }
185- catch ( Exception ex )
241+ _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
242+ var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
243+ _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
244+ await _transport . SendMessageAsync ( new JsonRpcResponse
186245 {
187- _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
188- // Send error response
189- await _transport . SendMessageAsync ( new JsonRpcError
190- {
191- Id = request . Id ,
192- JsonRpc = "2.0" ,
193- Error = new JsonRpcErrorDetail
194- {
195- Code = - 32000 , // Implementation defined error
196- Message = ex . Message
197- }
198- } , cancellationToken ) . ConfigureAwait ( false ) ;
199- }
246+ Id = request . Id ,
247+ JsonRpc = "2.0" ,
248+ Result = result
249+ } , cancellationToken ) . ConfigureAwait ( false ) ;
200250 }
201251 else
202252 {
@@ -221,8 +271,11 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
221271 throw new McpClientException ( "Transport is not connected" ) ;
222272 }
223273
224- // Set request ID
225- request . Id = new ( $ "{ _id } -{ Interlocked . Increment ( ref _nextRequestId ) } ") ;
274+ // Set request ID if it's not already set to a valid identifier.
275+ if ( request . Id . IsDefault )
276+ {
277+ request . Id = new ( $ "{ _id } -{ Interlocked . Increment ( ref _nextRequestId ) } ") ;
278+ }
226279
227280 var tcs = new TaskCompletionSource < IJsonRpcMessage > ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
228281 _pendingRequests [ request . Id ] = tcs ;
@@ -279,7 +332,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
279332 }
280333 }
281334
282- public Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
335+ public async Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
283336 {
284337 Throw . IfNull ( message ) ;
285338
@@ -294,7 +347,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
294347 _logger . SendingMessage ( EndpointName , JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ) ;
295348 }
296349
297- return _transport . SendMessageAsync ( message , cancellationToken ) ;
350+ await _transport . SendMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
351+
352+ // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
353+ // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
354+ // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
355+ if ( message is JsonRpcNotification { Method : NotificationMethods . CancelledNotification } notification &&
356+ GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
357+ _pendingRequests . TryRemove ( cn . RequestId , out var tcs ) )
358+ {
359+ tcs . TrySetCanceled ( default ) ;
360+ }
361+ }
362+
363+ private static CancelledNotification ? GetCancelledNotificationParams ( object ? notificationParams )
364+ {
365+ try
366+ {
367+ switch ( notificationParams )
368+ {
369+ case null :
370+ return null ;
371+
372+ case CancelledNotification cn :
373+ return cn ;
374+
375+ case JsonElement je :
376+ return JsonSerializer . Deserialize ( je , McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
377+
378+ default :
379+ return JsonSerializer . Deserialize (
380+ JsonSerializer . Serialize ( notificationParams , McpJsonUtilities . DefaultOptions . GetTypeInfo < object ? > ( ) ) ,
381+ McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
382+ }
383+ }
384+ catch
385+ {
386+ return null ;
387+ }
298388 }
299389
300390 /// <summary>
0 commit comments