1
1
import { ZodLiteral , ZodObject , ZodType , z } from "zod" ;
2
2
import {
3
+ CancelledNotificationSchema ,
3
4
ErrorCode ,
4
5
JSONRPCError ,
5
6
JSONRPCNotification ,
@@ -12,6 +13,7 @@ import {
12
13
ProgressNotification ,
13
14
ProgressNotificationSchema ,
14
15
Request ,
16
+ RequestId ,
15
17
Result ,
16
18
} from "../types.js" ;
17
19
import { Transport } from "./transport.js" ;
@@ -50,6 +52,16 @@ export type RequestOptions = {
50
52
signal ?: AbortSignal ;
51
53
} ;
52
54
55
+ /**
56
+ * Extra data given to request handlers.
57
+ */
58
+ export type RequestHandlerExtra = {
59
+ /**
60
+ * An abort signal used to communicate if the request was cancelled from the sender's side.
61
+ */
62
+ signal : AbortSignal ;
63
+ } ;
64
+
53
65
/**
54
66
* Implements MCP protocol framing on top of a pluggable transport, including
55
67
* features like request/response linking, notifications, and progress.
@@ -61,10 +73,15 @@ export abstract class Protocol<
61
73
> {
62
74
private _transport ?: Transport ;
63
75
private _requestMessageId = 0 ;
64
- protected _requestHandlers : Map <
76
+ private _requestHandlers : Map <
65
77
string ,
66
- ( request : JSONRPCRequest ) => Promise < SendResultT >
78
+ (
79
+ request : JSONRPCRequest ,
80
+ extra : RequestHandlerExtra ,
81
+ ) => Promise < SendResultT >
67
82
> = new Map ( ) ;
83
+ private _requestHandlerAbortControllers : Map < RequestId , AbortController > =
84
+ new Map ( ) ;
68
85
private _notificationHandlers : Map <
69
86
string ,
70
87
( notification : JSONRPCNotification ) => Promise < void >
@@ -100,6 +117,13 @@ export abstract class Protocol<
100
117
fallbackNotificationHandler ?: ( notification : Notification ) => Promise < void > ;
101
118
102
119
constructor ( private _options ?: ProtocolOptions ) {
120
+ this . setNotificationHandler ( CancelledNotificationSchema , ( notification ) => {
121
+ const controller = this . _requestHandlerAbortControllers . get (
122
+ notification . params . requestId ,
123
+ ) ;
124
+ controller ?. abort ( notification . params . reason ) ;
125
+ } ) ;
126
+
103
127
this . setNotificationHandler ( ProgressNotificationSchema , ( notification ) => {
104
128
this . _onprogress ( notification as unknown as ProgressNotification ) ;
105
129
} ) ;
@@ -195,16 +219,27 @@ export abstract class Protocol<
195
219
return ;
196
220
}
197
221
198
- handler ( request )
222
+ const abortController = new AbortController ( ) ;
223
+ this . _requestHandlerAbortControllers . set ( request . id , abortController ) ;
224
+
225
+ handler ( request , { signal : abortController . signal } )
199
226
. then (
200
227
( result ) => {
201
- this . _transport ?. send ( {
228
+ if ( abortController . signal . aborted ) {
229
+ return ;
230
+ }
231
+
232
+ return this . _transport ?. send ( {
202
233
result,
203
234
jsonrpc : "2.0" ,
204
235
id : request . id ,
205
236
} ) ;
206
237
} ,
207
238
( error ) => {
239
+ if ( abortController . signal . aborted ) {
240
+ return ;
241
+ }
242
+
208
243
return this . _transport ?. send ( {
209
244
jsonrpc : "2.0" ,
210
245
id : request . id ,
@@ -219,7 +254,10 @@ export abstract class Protocol<
219
254
)
220
255
. catch ( ( error ) =>
221
256
this . _onerror ( new Error ( `Failed to send response: ${ error } ` ) ) ,
222
- ) ;
257
+ )
258
+ . finally ( ( ) => {
259
+ this . _requestHandlerAbortControllers . delete ( request . id ) ;
260
+ } ) ;
223
261
}
224
262
225
263
private _onprogress ( notification : ProgressNotification ) : void {
@@ -403,12 +441,15 @@ export abstract class Protocol<
403
441
} > ,
404
442
> (
405
443
requestSchema : T ,
406
- handler : ( request : z . infer < T > ) => SendResultT | Promise < SendResultT > ,
444
+ handler : (
445
+ request : z . infer < T > ,
446
+ extra : RequestHandlerExtra ,
447
+ ) => SendResultT | Promise < SendResultT > ,
407
448
) : void {
408
449
const method = requestSchema . shape . method . value ;
409
450
this . assertRequestHandlerCapability ( method ) ;
410
- this . _requestHandlers . set ( method , ( request ) =>
411
- Promise . resolve ( handler ( requestSchema . parse ( request ) ) ) ,
451
+ this . _requestHandlers . set ( method , ( request , extra ) =>
452
+ Promise . resolve ( handler ( requestSchema . parse ( request ) , extra ) ) ,
412
453
) ;
413
454
}
414
455
0 commit comments