diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 5610a6293..0d7eb9418 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -436,3 +436,58 @@ test("should typecheck", () => { }, }); }); + +test("should handle client cancelling a request", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + // Set up server to delay responding to listResources + server.setRequestHandler( + ListResourcesRequestSchema, + async (request, extra) => { + await new Promise((resolve) => setTimeout(resolve, 1000)); + return { + resources: [], + }; + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: {}, + }, + ); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // Set up abort controller + const controller = new AbortController(); + + // Issue request but cancel it immediately + const listResourcesPromise = client.listResources(undefined, { + signal: controller.signal, + }); + controller.abort("Cancelled by test"); + + // Request should be rejected + await expect(listResourcesPromise).rejects.toBe("Cancelled by test"); +}); diff --git a/src/client/index.ts b/src/client/index.ts index e0df322b4..2c08bcc09 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,7 +1,7 @@ import { - ProgressCallback, Protocol, ProtocolOptions, + RequestOptions, } from "../shared/protocol.js"; import { Transport } from "../shared/transport.js"; import { @@ -244,6 +244,10 @@ export class Client< // No specific capability required for initialized break; + case "notifications/cancelled": + // Cancellation notifications are always allowed + break; + case "notifications/progress": // Progress notifications are always allowed break; @@ -278,14 +282,11 @@ export class Client< return this.request({ method: "ping" }, EmptyResultSchema); } - async complete( - params: CompleteRequest["params"], - onprogress?: ProgressCallback, - ) { + async complete(params: CompleteRequest["params"], options?: RequestOptions) { return this.request( { method: "completion/complete", params }, CompleteResultSchema, - onprogress, + options, ); } @@ -298,56 +299,56 @@ export class Client< async getPrompt( params: GetPromptRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "prompts/get", params }, GetPromptResultSchema, - onprogress, + options, ); } async listPrompts( params?: ListPromptsRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "prompts/list", params }, ListPromptsResultSchema, - onprogress, + options, ); } async listResources( params?: ListResourcesRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "resources/list", params }, ListResourcesResultSchema, - onprogress, + options, ); } async listResourceTemplates( params?: ListResourceTemplatesRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "resources/templates/list", params }, ListResourceTemplatesResultSchema, - onprogress, + options, ); } async readResource( params: ReadResourceRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "resources/read", params }, ReadResourceResultSchema, - onprogress, + options, ); } @@ -370,23 +371,23 @@ export class Client< resultSchema: | typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "tools/call", params }, resultSchema, - onprogress, + options, ); } async listTools( params?: ListToolsRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "tools/list", params }, ListToolsResultSchema, - onprogress, + options, ); } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index d30c670bc..0697cc5cb 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -407,3 +407,71 @@ test("should typecheck", () => { }, ); }); + +test("should handle server cancelling a request", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + // Set up client to delay responding to createMessage + client.setRequestHandler( + CreateMessageRequestSchema, + async (_request, extra) => { + await new Promise((resolve) => setTimeout(resolve, 1000)); + return { + model: "test", + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }; + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // Set up abort controller + const controller = new AbortController(); + + // Issue request but cancel it immediately + const createMessagePromise = server.createMessage( + { + messages: [], + maxTokens: 10, + }, + { + signal: controller.signal, + }, + ); + controller.abort("Cancelled by test"); + + // Request should be rejected + await expect(createMessagePromise).rejects.toBe("Cancelled by test"); +}); diff --git a/src/server/index.ts b/src/server/index.ts index ecb525b5c..d15ad3c0d 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,7 +1,7 @@ import { - ProgressCallback, Protocol, ProtocolOptions, + RequestOptions, } from "../shared/protocol.js"; import { ClientCapabilities, @@ -157,6 +157,10 @@ export class Server< } break; + case "notifications/cancelled": + // Cancellation notifications are always allowed + break; + case "notifications/progress": // Progress notifications are always allowed break; @@ -257,23 +261,23 @@ export class Server< async createMessage( params: CreateMessageRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "sampling/createMessage", params }, CreateMessageResultSchema, - onprogress, + options, ); } async listRoots( params?: ListRootsRequest["params"], - onprogress?: ProgressCallback, + options?: RequestOptions, ) { return this.request( { method: "roots/list", params }, ListRootsResultSchema, - onprogress, + options, ); } diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 85610a9d6..8103695d1 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,5 +1,6 @@ import { ZodLiteral, ZodObject, ZodType, z } from "zod"; import { + CancelledNotificationSchema, ErrorCode, JSONRPCError, JSONRPCNotification, @@ -12,6 +13,7 @@ import { ProgressNotification, ProgressNotificationSchema, Request, + RequestId, Result, } from "../types.js"; import { Transport } from "./transport.js"; @@ -35,6 +37,33 @@ export type ProtocolOptions = { enforceStrictCapabilities?: boolean; }; +/** + * Options that can be given per request. + */ +export type RequestOptions = { + /** + * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. + */ + onprogress?: ProgressCallback; + + /** + * Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request(). + * + * Use abortAfterTimeout() to easily implement timeouts using this signal. + */ + signal?: AbortSignal; +}; + +/** + * Extra data given to request handlers. + */ +export type RequestHandlerExtra = { + /** + * An abort signal used to communicate if the request was cancelled from the sender's side. + */ + signal: AbortSignal; +}; + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. @@ -46,10 +75,15 @@ export abstract class Protocol< > { private _transport?: Transport; private _requestMessageId = 0; - protected _requestHandlers: Map< + private _requestHandlers: Map< string, - (request: JSONRPCRequest) => Promise + ( + request: JSONRPCRequest, + extra: RequestHandlerExtra, + ) => Promise > = new Map(); + private _requestHandlerAbortControllers: Map = + new Map(); private _notificationHandlers: Map< string, (notification: JSONRPCNotification) => Promise @@ -85,6 +119,13 @@ export abstract class Protocol< fallbackNotificationHandler?: (notification: Notification) => Promise; constructor(private _options?: ProtocolOptions) { + this.setNotificationHandler(CancelledNotificationSchema, (notification) => { + const controller = this._requestHandlerAbortControllers.get( + notification.params.requestId, + ); + controller?.abort(notification.params.reason); + }); + this.setNotificationHandler(ProgressNotificationSchema, (notification) => { this._onprogress(notification as unknown as ProgressNotification); }); @@ -151,11 +192,14 @@ export abstract class Protocol< return; } - handler(notification).catch((error) => - this._onerror( - new Error(`Uncaught error in notification handler: ${error}`), - ), - ); + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. + Promise.resolve() + .then(() => handler(notification)) + .catch((error) => + this._onerror( + new Error(`Uncaught error in notification handler: ${error}`), + ), + ); } private _onrequest(request: JSONRPCRequest): void { @@ -180,16 +224,29 @@ export abstract class Protocol< return; } - handler(request) + const abortController = new AbortController(); + this._requestHandlerAbortControllers.set(request.id, abortController); + + // Starting with Promise.resolve() puts any synchronous errors into the monad as well. + Promise.resolve() + .then(() => handler(request, { signal: abortController.signal })) .then( (result) => { - this._transport?.send({ + if (abortController.signal.aborted) { + return; + } + + return this._transport?.send({ result, jsonrpc: "2.0", id: request.id, }); }, (error) => { + if (abortController.signal.aborted) { + return; + } + return this._transport?.send({ jsonrpc: "2.0", id: request.id, @@ -204,7 +261,10 @@ export abstract class Protocol< ) .catch((error) => this._onerror(new Error(`Failed to send response: ${error}`)), - ); + ) + .finally(() => { + this._requestHandlerAbortControllers.delete(request.id); + }); } private _onprogress(notification: ProgressNotification): void { @@ -285,14 +345,14 @@ export abstract class Protocol< protected abstract assertRequestHandlerCapability(method: string): void; /** - * Sends a request and wait for a response, with optional progress notifications in the meantime (if supported by the server). + * Sends a request and wait for a response. * * Do not use this method to emit notifications! Use notification() instead. */ request>( request: SendRequestT, resultSchema: T, - onprogress?: ProgressCallback, + options?: RequestOptions, ): Promise> { return new Promise((resolve, reject) => { if (!this._transport) { @@ -304,6 +364,8 @@ export abstract class Protocol< this.assertCapabilityForMethod(request.method); } + options?.signal?.throwIfAborted(); + const messageId = this._requestMessageId++; const jsonrpcRequest: JSONRPCRequest = { ...request, @@ -311,8 +373,8 @@ export abstract class Protocol< id: messageId, }; - if (onprogress) { - this._progressHandlers.set(messageId, onprogress); + if (options?.onprogress) { + this._progressHandlers.set(messageId, options.onprogress); jsonrpcRequest.params = { ...request.params, _meta: { progressToken: messageId }, @@ -320,6 +382,10 @@ export abstract class Protocol< } this._responseHandlers.set(messageId, (response) => { + if (options?.signal?.aborted) { + return; + } + if (response instanceof Error) { return reject(response); } @@ -332,6 +398,23 @@ export abstract class Protocol< } }); + options?.signal?.addEventListener("abort", () => { + const reason = options?.signal?.reason; + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); + + this._transport?.send({ + jsonrpc: "2.0", + method: "cancelled", + params: { + requestId: messageId, + reason: String(reason), + }, + }); + + reject(reason); + }); + this._transport.send(jsonrpcRequest).catch(reject); }); } @@ -365,12 +448,15 @@ export abstract class Protocol< }>, >( requestSchema: T, - handler: (request: z.infer) => SendResultT | Promise, + handler: ( + request: z.infer, + extra: RequestHandlerExtra, + ) => SendResultT | Promise, ): void { const method = requestSchema.shape.method.value; this.assertRequestHandlerCapability(method); - this._requestHandlers.set(method, (request) => - Promise.resolve(handler(requestSchema.parse(request))), + this._requestHandlers.set(method, (request, extra) => + Promise.resolve(handler(requestSchema.parse(request), extra)), ); } diff --git a/src/types.ts b/src/types.ts index 0d55a75bb..5b9fb664e 100644 --- a/src/types.ts +++ b/src/types.ts @@ -39,18 +39,18 @@ export const RequestSchema = z.object({ params: z.optional(BaseRequestParamsSchema), }); +const BaseNotificationParamsSchema = z + .object({ + /** + * This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. + */ + _meta: z.optional(z.object({}).passthrough()), + }) + .passthrough(); + export const NotificationSchema = z.object({ method: z.string(), - params: z.optional( - z - .object({ - /** - * This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. - */ - _meta: z.optional(z.object({}).passthrough()), - }) - .passthrough(), - ), + params: z.optional(BaseNotificationParamsSchema), }); export const ResultSchema = z @@ -151,6 +151,33 @@ export const JSONRPCMessageSchema = z.union([ */ export const EmptyResultSchema = ResultSchema.strict(); +/* Cancellation */ +/** + * This notification can be sent by either side to indicate that it is cancelling a previously-issued request. + * + * The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished. + * + * This notification indicates that the result will be unused, so any associated processing SHOULD cease. + * + * A client MUST NOT attempt to cancel its `initialize` request. + */ +export const CancelledNotificationSchema = NotificationSchema.extend({ + method: z.literal("notifications/cancelled"), + params: BaseNotificationParamsSchema.extend({ + /** + * The ID of the request to cancel. + * + * This MUST correspond to the ID of a request previously issued in the same direction. + */ + requestId: RequestIdSchema, + + /** + * An optional string describing the reason for the cancellation. This MAY be logged or presented to the user. + */ + reason: z.string().optional(), + }), +}); + /* Initialization */ /** * Describes the name and version of an MCP implementation. @@ -312,7 +339,7 @@ export const ProgressSchema = z */ export const ProgressNotificationSchema = NotificationSchema.extend({ method: z.literal("notifications/progress"), - params: ProgressSchema.extend({ + params: BaseNotificationParamsSchema.merge(ProgressSchema).extend({ /** * The progress token which was given in the initial request, used to associate this notification with the request that is proceeding. */ @@ -522,14 +549,12 @@ export const UnsubscribeRequestSchema = RequestSchema.extend({ */ export const ResourceUpdatedNotificationSchema = NotificationSchema.extend({ method: z.literal("notifications/resources/updated"), - params: z - .object({ - /** - * The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. - */ - uri: z.string(), - }) - .passthrough(), + params: BaseNotificationParamsSchema.extend({ + /** + * The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + */ + uri: z.string(), + }), }); /* Prompts */ @@ -786,22 +811,20 @@ export const SetLevelRequestSchema = RequestSchema.extend({ */ export const LoggingMessageNotificationSchema = NotificationSchema.extend({ method: z.literal("notifications/message"), - params: z - .object({ - /** - * The severity of this log message. - */ - level: LoggingLevelSchema, - /** - * An optional name of the logger issuing this message. - */ - logger: z.optional(z.string()), - /** - * The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. - */ - data: z.unknown(), - }) - .passthrough(), + params: BaseNotificationParamsSchema.extend({ + /** + * The severity of this log message. + */ + level: LoggingLevelSchema, + /** + * An optional name of the logger issuing this message. + */ + logger: z.optional(z.string()), + /** + * The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. + */ + data: z.unknown(), + }), }); /* Sampling */ @@ -1034,6 +1057,7 @@ export const ClientRequestSchema = z.union([ ]); export const ClientNotificationSchema = z.union([ + CancelledNotificationSchema, ProgressNotificationSchema, InitializedNotificationSchema, RootsListChangedNotificationSchema, @@ -1053,6 +1077,7 @@ export const ServerRequestSchema = z.union([ ]); export const ServerNotificationSchema = z.union([ + CancelledNotificationSchema, ProgressNotificationSchema, LoggingMessageNotificationSchema, ResourceUpdatedNotificationSchema, @@ -1100,6 +1125,9 @@ export type JSONRPCMessage = z.infer; /* Empty result */ export type EmptyResult = z.infer; +/* Cancellation */ +export type CancelledNotification = z.infer; + /* Initialization */ export type Implementation = z.infer; export type ClientCapabilities = z.infer; diff --git a/src/utils.test.ts b/src/utils.test.ts new file mode 100644 index 000000000..e4aa4e5fc --- /dev/null +++ b/src/utils.test.ts @@ -0,0 +1,15 @@ +import { abortAfterTimeout } from "./utils.js"; + +describe("abortAfterTimeout", () => { + it("should abort after timeout", () => { + const signal = abortAfterTimeout(0); + expect(signal.aborted).toBe(false); + + return new Promise((resolve) => { + setTimeout(() => { + expect(signal.aborted).toBe(true); + resolve(true); + }, 0); + }); + }); +}); diff --git a/src/utils.ts b/src/utils.ts new file mode 100644 index 000000000..11672ecd2 --- /dev/null +++ b/src/utils.ts @@ -0,0 +1,11 @@ +/** + * Returns an AbortSignal that will enter aborted state after `timeoutMs` milliseconds. + */ +export function abortAfterTimeout(timeoutMs: number): AbortSignal { + const controller = new AbortController(); + setTimeout(() => { + controller.abort(); + }, timeoutMs); + + return controller.signal; +}