Skip to content

Commit bafd9e7

Browse files
committed
Add ways to associate related requests and notifications
1 parent e9caa5a commit bafd9e7

File tree

4 files changed

+77
-33
lines changed

4 files changed

+77
-33
lines changed

src/server/mcp.test.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ describe("ResourceTemplate", () => {
8585
const abortController = new AbortController();
8686
const result = await template.listCallback?.({
8787
signal: abortController.signal,
88+
sendRequest: () => { throw new Error("Not implemented") },
89+
sendNotification: () => { throw new Error("Not implemented") }
8890
});
8991
expect(result?.resources).toHaveLength(1);
9092
expect(list).toHaveBeenCalled();
@@ -318,7 +320,7 @@ describe("tool()", () => {
318320

319321
// This should succeed
320322
mcpServer.tool("tool1", () => ({ content: [] }));
321-
323+
322324
// This should also succeed and not throw about request handlers
323325
mcpServer.tool("tool2", () => ({ content: [] }));
324326
});
@@ -815,7 +817,7 @@ describe("resource()", () => {
815817
},
816818
],
817819
}));
818-
820+
819821
// This should also succeed and not throw about request handlers
820822
mcpServer.resource("resource2", "test://resource2", async () => ({
821823
contents: [
@@ -1321,7 +1323,7 @@ describe("prompt()", () => {
13211323
},
13221324
],
13231325
}));
1324-
1326+
13251327
// This should also succeed and not throw about request handlers
13261328
mcpServer.prompt("prompt2", async () => ({
13271329
messages: [

src/server/mcp.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import {
3737
PromptArgument,
3838
GetPromptResult,
3939
ReadResourceResult,
40+
ServerRequest,
41+
ServerNotification,
4042
} from "../types.js";
4143
import { Completable, CompletableDef } from "./completable.js";
4244
import { UriTemplate, Variables } from "../shared/uriTemplate.js";
@@ -694,9 +696,9 @@ export type ToolCallback<Args extends undefined | ZodRawShape = undefined> =
694696
Args extends ZodRawShape
695697
? (
696698
args: z.objectOutputType<Args, ZodTypeAny>,
697-
extra: RequestHandlerExtra,
699+
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
698700
) => CallToolResult | Promise<CallToolResult>
699-
: (extra: RequestHandlerExtra) => CallToolResult | Promise<CallToolResult>;
701+
: (extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => CallToolResult | Promise<CallToolResult>;
700702

701703
type RegisteredTool = {
702704
description?: string;
@@ -717,15 +719,15 @@ export type ResourceMetadata = Omit<Resource, "uri" | "name">;
717719
* Callback to list all resources matching a given template.
718720
*/
719721
export type ListResourcesCallback = (
720-
extra: RequestHandlerExtra,
722+
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
721723
) => ListResourcesResult | Promise<ListResourcesResult>;
722724

723725
/**
724726
* Callback to read a resource at a given URI.
725727
*/
726728
export type ReadResourceCallback = (
727729
uri: URL,
728-
extra: RequestHandlerExtra,
730+
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
729731
) => ReadResourceResult | Promise<ReadResourceResult>;
730732

731733
type RegisteredResource = {
@@ -740,7 +742,7 @@ type RegisteredResource = {
740742
export type ReadResourceTemplateCallback = (
741743
uri: URL,
742744
variables: Variables,
743-
extra: RequestHandlerExtra,
745+
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
744746
) => ReadResourceResult | Promise<ReadResourceResult>;
745747

746748
type RegisteredResourceTemplate = {
@@ -760,9 +762,9 @@ export type PromptCallback<
760762
> = Args extends PromptArgsRawShape
761763
? (
762764
args: z.objectOutputType<Args, ZodTypeAny>,
763-
extra: RequestHandlerExtra,
765+
extra: RequestHandlerExtra<ServerRequest, ServerNotification>,
764766
) => GetPromptResult | Promise<GetPromptResult>
765-
: (extra: RequestHandlerExtra) => GetPromptResult | Promise<GetPromptResult>;
767+
: (extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => GetPromptResult | Promise<GetPromptResult>;
766768

767769
type RegisteredPrompt = {
768770
description?: string;

src/shared/protocol.ts

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,52 @@ export type RequestOptions = {
7878
* If not specified, there is no maximum total timeout.
7979
*/
8080
maxTotalTimeout?: number;
81+
82+
/**
83+
* May be used to indicate to the transport which incoming request to associate this outgoing request with.
84+
*/
85+
relatedRequestId?: RequestId;
8186
};
8287

8388
/**
84-
* Extra data given to request handlers.
89+
* Options that can be given per notification.
8590
*/
86-
export type RequestHandlerExtra = {
91+
export type NotificationOptions = {
8792
/**
88-
* An abort signal used to communicate if the request was cancelled from the sender's side.
93+
* May be used to indicate to the transport which incoming request to associate this outgoing notification with.
8994
*/
90-
signal: AbortSignal;
95+
relatedRequestId?: RequestId;
96+
}
9197

92-
/**
93-
* The session ID from the transport, if available.
94-
*/
95-
sessionId?: string;
96-
};
98+
/**
99+
* Extra data given to request handlers.
100+
*/
101+
export type RequestHandlerExtra<SendRequestT extends Request,
102+
SendNotificationT extends Notification> = {
103+
/**
104+
* An abort signal used to communicate if the request was cancelled from the sender's side.
105+
*/
106+
signal: AbortSignal;
107+
108+
/**
109+
* The session ID from the transport, if available.
110+
*/
111+
sessionId?: string;
112+
113+
/**
114+
* Sends a notification that relates to the current request being handled.
115+
*
116+
* This is used by certain transports to correctly associate related messages.
117+
*/
118+
sendNotification: (notification: SendNotificationT) => Promise<void>;
119+
120+
/**
121+
* Sends a request that relates to the current request being handled.
122+
*
123+
* This is used by certain transports to correctly associate related messages.
124+
*/
125+
sendRequest: <U extends ZodType<object>>(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise<z.infer<U>>;
126+
};
97127

98128
/**
99129
* Information about a request's timeout state
@@ -122,7 +152,7 @@ export abstract class Protocol<
122152
string,
123153
(
124154
request: JSONRPCRequest,
125-
extra: RequestHandlerExtra,
155+
extra: RequestHandlerExtra<SendRequestT, SendNotificationT>,
126156
) => Promise<SendResultT>
127157
> = new Map();
128158
private _requestHandlerAbortControllers: Map<RequestId, AbortController> =
@@ -316,9 +346,14 @@ export abstract class Protocol<
316346
this._requestHandlerAbortControllers.set(request.id, abortController);
317347

318348
// Create extra object with both abort signal and sessionId from transport
319-
const extra: RequestHandlerExtra = {
349+
const extra: RequestHandlerExtra<SendRequestT, SendNotificationT> = {
320350
signal: abortController.signal,
321351
sessionId: this._transport?.sessionId,
352+
sendNotification:
353+
(notification) =>
354+
this.notification(notification, { relatedRequestId: request.id }),
355+
sendRequest: (r, resultSchema, options?) =>
356+
this.request(r, resultSchema, { ...options, relatedRequestId: request.id })
322357
};
323358

324359
// Starting with Promise.resolve() puts any synchronous errors into the monad as well.
@@ -364,7 +399,7 @@ export abstract class Protocol<
364399
private _onprogress(notification: ProgressNotification): void {
365400
const { progressToken, ...params } = notification.params;
366401
const messageId = Number(progressToken);
367-
402+
368403
const handler = this._progressHandlers.get(messageId);
369404
if (!handler) {
370405
this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`));
@@ -373,7 +408,7 @@ export abstract class Protocol<
373408

374409
const responseHandler = this._responseHandlers.get(messageId);
375410
const timeoutInfo = this._timeoutInfo.get(messageId);
376-
411+
377412
if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) {
378413
try {
379414
this._resetTimeout(messageId);
@@ -460,6 +495,8 @@ export abstract class Protocol<
460495
resultSchema: T,
461496
options?: RequestOptions,
462497
): Promise<z.infer<T>> {
498+
const { relatedRequestId } = options ?? {};
499+
463500
return new Promise((resolve, reject) => {
464501
if (!this._transport) {
465502
reject(new Error("Not connected"));
@@ -500,7 +537,7 @@ export abstract class Protocol<
500537
requestId: messageId,
501538
reason: String(reason),
502539
},
503-
})
540+
}, { relatedRequestId })
504541
.catch((error) =>
505542
this._onerror(new Error(`Failed to send cancellation: ${error}`)),
506543
);
@@ -538,7 +575,7 @@ export abstract class Protocol<
538575

539576
this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false);
540577

541-
this._transport.send(jsonrpcRequest).catch((error) => {
578+
this._transport.send(jsonrpcRequest, { relatedRequestId }).catch((error) => {
542579
this._cleanupTimeout(messageId);
543580
reject(error);
544581
});
@@ -548,7 +585,7 @@ export abstract class Protocol<
548585
/**
549586
* Emits a notification, which is a one-way message that does not expect a response.
550587
*/
551-
async notification(notification: SendNotificationT): Promise<void> {
588+
async notification(notification: SendNotificationT, options?: NotificationOptions): Promise<void> {
552589
if (!this._transport) {
553590
throw new Error("Not connected");
554591
}
@@ -560,7 +597,7 @@ export abstract class Protocol<
560597
jsonrpc: "2.0",
561598
};
562599

563-
await this._transport.send(jsonrpcNotification);
600+
await this._transport.send(jsonrpcNotification, options);
564601
}
565602

566603
/**
@@ -576,14 +613,15 @@ export abstract class Protocol<
576613
requestSchema: T,
577614
handler: (
578615
request: z.infer<T>,
579-
extra: RequestHandlerExtra,
616+
extra: RequestHandlerExtra<SendRequestT, SendNotificationT>,
580617
) => SendResultT | Promise<SendResultT>,
581618
): void {
582619
const method = requestSchema.shape.method.value;
583620
this.assertRequestHandlerCapability(method);
584-
this._requestHandlers.set(method, (request, extra) =>
585-
Promise.resolve(handler(requestSchema.parse(request), extra)),
586-
);
621+
622+
this._requestHandlers.set(method, (request, extra) => {
623+
return Promise.resolve(handler(requestSchema.parse(request), extra));
624+
});
587625
}
588626

589627
/**

src/shared/transport.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { JSONRPCMessage } from "../types.js";
1+
import { JSONRPCMessage, RequestId } from "../types.js";
22

33
/**
44
* Describes the minimal contract for a MCP transport that a client or server can communicate over.
@@ -15,8 +15,10 @@ export interface Transport {
1515

1616
/**
1717
* Sends a JSON-RPC message (request or response).
18+
*
19+
* If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with.
1820
*/
19-
send(message: JSONRPCMessage): Promise<void>;
21+
send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise<void>;
2022

2123
/**
2224
* Closes the connection.

0 commit comments

Comments
 (0)