Skip to content

Commit 90ba895

Browse files
committed
Pass an AbortSignal to request handlers
1 parent 54d146b commit 90ba895

File tree

1 file changed

+49
-8
lines changed

1 file changed

+49
-8
lines changed

src/shared/protocol.ts

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { ZodLiteral, ZodObject, ZodType, z } from "zod";
22
import {
3+
CancelledNotificationSchema,
34
ErrorCode,
45
JSONRPCError,
56
JSONRPCNotification,
@@ -12,6 +13,7 @@ import {
1213
ProgressNotification,
1314
ProgressNotificationSchema,
1415
Request,
16+
RequestId,
1517
Result,
1618
} from "../types.js";
1719
import { Transport } from "./transport.js";
@@ -50,6 +52,16 @@ export type RequestOptions = {
5052
signal?: AbortSignal;
5153
};
5254

55+
/**
56+
* Extra data given to request handlers.
57+
*/
58+
export type RequestHandlerExtra = {
59+
/**
60+
* An abort signal used to communicate if the request was cancelled from the sender's side.
61+
*/
62+
signal: AbortSignal;
63+
};
64+
5365
/**
5466
* Implements MCP protocol framing on top of a pluggable transport, including
5567
* features like request/response linking, notifications, and progress.
@@ -61,10 +73,15 @@ export abstract class Protocol<
6173
> {
6274
private _transport?: Transport;
6375
private _requestMessageId = 0;
64-
protected _requestHandlers: Map<
76+
private _requestHandlers: Map<
6577
string,
66-
(request: JSONRPCRequest) => Promise<SendResultT>
78+
(
79+
request: JSONRPCRequest,
80+
extra: RequestHandlerExtra,
81+
) => Promise<SendResultT>
6782
> = new Map();
83+
private _requestHandlerAbortControllers: Map<RequestId, AbortController> =
84+
new Map();
6885
private _notificationHandlers: Map<
6986
string,
7087
(notification: JSONRPCNotification) => Promise<void>
@@ -100,6 +117,13 @@ export abstract class Protocol<
100117
fallbackNotificationHandler?: (notification: Notification) => Promise<void>;
101118

102119
constructor(private _options?: ProtocolOptions) {
120+
this.setNotificationHandler(CancelledNotificationSchema, (notification) => {
121+
const controller = this._requestHandlerAbortControllers.get(
122+
notification.params.requestId,
123+
);
124+
controller?.abort(notification.params.reason);
125+
});
126+
103127
this.setNotificationHandler(ProgressNotificationSchema, (notification) => {
104128
this._onprogress(notification as unknown as ProgressNotification);
105129
});
@@ -195,16 +219,27 @@ export abstract class Protocol<
195219
return;
196220
}
197221

198-
handler(request)
222+
const abortController = new AbortController();
223+
this._requestHandlerAbortControllers.set(request.id, abortController);
224+
225+
handler(request, { signal: abortController.signal })
199226
.then(
200227
(result) => {
201-
this._transport?.send({
228+
if (abortController.signal.aborted) {
229+
return;
230+
}
231+
232+
return this._transport?.send({
202233
result,
203234
jsonrpc: "2.0",
204235
id: request.id,
205236
});
206237
},
207238
(error) => {
239+
if (abortController.signal.aborted) {
240+
return;
241+
}
242+
208243
return this._transport?.send({
209244
jsonrpc: "2.0",
210245
id: request.id,
@@ -219,7 +254,10 @@ export abstract class Protocol<
219254
)
220255
.catch((error) =>
221256
this._onerror(new Error(`Failed to send response: ${error}`)),
222-
);
257+
)
258+
.finally(() => {
259+
this._requestHandlerAbortControllers.delete(request.id);
260+
});
223261
}
224262

225263
private _onprogress(notification: ProgressNotification): void {
@@ -403,12 +441,15 @@ export abstract class Protocol<
403441
}>,
404442
>(
405443
requestSchema: T,
406-
handler: (request: z.infer<T>) => SendResultT | Promise<SendResultT>,
444+
handler: (
445+
request: z.infer<T>,
446+
extra: RequestHandlerExtra,
447+
) => SendResultT | Promise<SendResultT>,
407448
): void {
408449
const method = requestSchema.shape.method.value;
409450
this.assertRequestHandlerCapability(method);
410-
this._requestHandlers.set(method, (request) =>
411-
Promise.resolve(handler(requestSchema.parse(request))),
451+
this._requestHandlers.set(method, (request, extra) =>
452+
Promise.resolve(handler(requestSchema.parse(request), extra)),
412453
);
413454
}
414455

0 commit comments

Comments
 (0)