Skip to content

Commit 42fbfe2

Browse files
committed
Update Protocol type inference
1 parent 3ea255f commit 42fbfe2

File tree

3 files changed

+63
-64
lines changed

3 files changed

+63
-64
lines changed

src/client/index.ts

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,9 @@ import {
55
ClientRequest,
66
ClientResult,
77
Implementation,
8-
InitializeResult,
8+
InitializeResultSchema,
99
PROTOCOL_VERSION,
1010
ServerCapabilities,
11-
ServerNotification,
12-
ServerRequest,
13-
ServerResult,
1411
} from "../types.js";
1512

1613
/**
@@ -19,9 +16,6 @@ import {
1916
* The client will automatically begin the initialization flow with the server when connect() is called.
2017
*/
2118
export class Client extends Protocol<
22-
ServerRequest,
23-
ServerNotification,
24-
ServerResult,
2519
ClientRequest,
2620
ClientNotification,
2721
ClientResult
@@ -39,14 +33,17 @@ export class Client extends Protocol<
3933
override async connect(transport: Transport): Promise<void> {
4034
await super.connect(transport);
4135

42-
const result = (await this.request({
43-
method: "initialize",
44-
params: {
45-
protocolVersion: 1,
46-
capabilities: {},
47-
clientInfo: this._clientInfo,
36+
const result = await this.request(
37+
{
38+
method: "initialize",
39+
params: {
40+
protocolVersion: 1,
41+
capabilities: {},
42+
clientInfo: this._clientInfo,
43+
},
4844
},
49-
})) as InitializeResult;
45+
InitializeResultSchema,
46+
);
5047

5148
if (result === undefined) {
5249
throw new Error(`Server sent invalid initialize result: ${result}`);

src/server/index.ts

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import { Protocol } from "../shared/protocol.js";
22
import {
33
ClientCapabilities,
4-
ClientNotification,
5-
ClientRequest,
6-
ClientResult,
74
Implementation,
5+
InitializedNotificationSchema,
86
InitializeRequest,
7+
InitializeRequestSchema,
98
InitializeResult,
109
PROTOCOL_VERSION,
1110
ServerNotification,
@@ -19,9 +18,6 @@ import {
1918
* This server will automatically respond to the initialization flow as initiated from the client.
2019
*/
2120
export class Server extends Protocol<
22-
ClientRequest,
23-
ClientNotification,
24-
ClientResult,
2521
ServerRequest,
2622
ServerNotification,
2723
ServerResult
@@ -40,10 +36,10 @@ export class Server extends Protocol<
4036
constructor(private _serverInfo: Implementation) {
4137
super();
4238

43-
this.setRequestHandler("initialize", (request) =>
44-
this._oninitialize(request as InitializeRequest),
39+
this.setRequestHandler(InitializeRequestSchema, (request) =>
40+
this._oninitialize(request),
4541
);
46-
this.setNotificationHandler("notification/initialized", () =>
42+
this.setNotificationHandler(InitializedNotificationSchema, () =>
4743
this.oninitialized?.(),
4844
);
4945
}

src/shared/protocol.ts

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { AnyZodObject, ZodLiteral, ZodObject, z } from "zod";
12
import {
23
ErrorCode,
34
JSONRPCError,
@@ -25,9 +26,6 @@ export type ProgressCallback = (progress: Progress) => void;
2526
* features like request/response linking, notifications, and progress.
2627
*/
2728
export class Protocol<
28-
ReceiveRequestT extends Request,
29-
ReceiveNotificationT extends Notification,
30-
ReceiveResultT extends Result,
3129
SendRequestT extends Request,
3230
SendNotificationT extends Notification,
3331
SendResultT extends Result,
@@ -36,15 +34,15 @@ export class Protocol<
3634
private _requestMessageId = 0;
3735
private _requestHandlers: Map<
3836
string,
39-
(request: ReceiveRequestT) => Promise<SendResultT>
37+
(request: JSONRPCRequest) => Promise<SendResultT>
4038
> = new Map();
4139
private _notificationHandlers: Map<
4240
string,
43-
(notification: ReceiveNotificationT) => Promise<void>
41+
(notification: JSONRPCNotification) => Promise<void>
4442
> = new Map();
4543
private _responseHandlers: Map<
4644
number,
47-
(response: ReceiveResultT | Error) => void
45+
(response: JSONRPCResponse | Error) => void
4846
> = new Map();
4947
private _progressHandlers: Map<number, ProgressCallback> = new Map();
5048

@@ -65,25 +63,20 @@ export class Protocol<
6563
/**
6664
* A handler to invoke for any request types that do not have their own handler installed.
6765
*/
68-
fallbackRequestHandler?: (request: ReceiveRequestT) => Promise<SendResultT>;
66+
fallbackRequestHandler?: (request: Request) => Promise<SendResultT>;
6967

7068
/**
7169
* A handler to invoke for any notification types that do not have their own handler installed.
7270
*/
73-
fallbackNotificationHandler?: (
74-
notification: ReceiveNotificationT,
75-
) => Promise<void>;
71+
fallbackNotificationHandler?: (notification: Notification) => Promise<void>;
7672

7773
constructor() {
78-
this.setNotificationHandler(
79-
ProgressNotificationSchema.shape.method.value,
80-
(notification) => {
81-
this._onprogress(notification as unknown as ProgressNotification);
82-
},
83-
);
74+
this.setNotificationHandler(ProgressNotificationSchema, (notification) => {
75+
this._onprogress(notification as unknown as ProgressNotification);
76+
});
8477

8578
this.setRequestHandler(
86-
PingRequestSchema.shape.method.value,
79+
PingRequestSchema,
8780
// Automatic pong by default.
8881
(_request) => ({}) as SendResultT,
8982
);
@@ -106,11 +99,11 @@ export class Protocol<
10699

107100
this._transport.onmessage = (message) => {
108101
if (!("method" in message)) {
109-
this._onresponse(message as JSONRPCResponse | JSONRPCError);
102+
this._onresponse(message);
110103
} else if ("id" in message) {
111-
this._onrequest(message as JSONRPCRequest);
104+
this._onrequest(message);
112105
} else {
113-
this._onnotification(message as JSONRPCNotification);
106+
this._onnotification(message);
114107
}
115108
};
116109
}
@@ -142,7 +135,7 @@ export class Protocol<
142135
return;
143136
}
144137

145-
handler(notification as unknown as ReceiveNotificationT).catch((error) =>
138+
handler(notification).catch((error) =>
146139
this._onerror(
147140
new Error(`Uncaught error in notification handler: ${error}`),
148141
),
@@ -171,7 +164,7 @@ export class Protocol<
171164
return;
172165
}
173166

174-
handler(request as unknown as ReceiveRequestT)
167+
handler(request)
175168
.then(
176169
(result) => {
177170
this._transport?.send({
@@ -228,7 +221,7 @@ export class Protocol<
228221
this._responseHandlers.delete(Number(messageId));
229222
this._progressHandlers.delete(Number(messageId));
230223
if ("result" in response) {
231-
handler(response.result as ReceiveResultT);
224+
handler(response);
232225
} else {
233226
const error = new McpError(
234227
response.error.code,
@@ -255,11 +248,11 @@ export class Protocol<
255248
*
256249
* Do not use this method to emit notifications! Use notification() instead.
257250
*/
258-
// TODO: This could infer a better response type based on the method
259-
request(
251+
request<T extends AnyZodObject>(
260252
request: SendRequestT,
253+
resultSchema: T,
261254
onprogress?: ProgressCallback,
262-
): Promise<ReceiveResultT> {
255+
): Promise<z.infer<T>> {
263256
return new Promise((resolve, reject) => {
264257
if (!this._transport) {
265258
reject(new Error("Not connected"));
@@ -283,9 +276,14 @@ export class Protocol<
283276

284277
this._responseHandlers.set(messageId, (response) => {
285278
if (response instanceof Error) {
286-
reject(response);
287-
} else {
288-
resolve(response);
279+
return reject(response);
280+
}
281+
282+
try {
283+
const result = resultSchema.parse(response.result);
284+
resolve(result);
285+
} catch (error) {
286+
reject(error);
289287
}
290288
});
291289

@@ -314,13 +312,16 @@ export class Protocol<
314312
*
315313
* Note that this will replace any previous request handler for the same method.
316314
*/
317-
// TODO: This could infer a better request type based on the method.
318-
setRequestHandler(
319-
method: string,
320-
handler: (request: ReceiveRequestT) => SendResultT | Promise<SendResultT>,
315+
setRequestHandler<
316+
T extends ZodObject<{
317+
method: ZodLiteral<string>;
318+
}>,
319+
>(
320+
requestSchema: T,
321+
handler: (request: z.infer<T>) => SendResultT | Promise<SendResultT>,
321322
): void {
322-
this._requestHandlers.set(method, (request) =>
323-
Promise.resolve(handler(request)),
323+
this._requestHandlers.set(requestSchema.shape.method.value, (request) =>
324+
Promise.resolve(handler(requestSchema.parse(request))),
324325
);
325326
}
326327

@@ -336,13 +337,18 @@ export class Protocol<
336337
*
337338
* Note that this will replace any previous notification handler for the same method.
338339
*/
339-
// TODO: This could infer a better notification type based on the method.
340-
setNotificationHandler<T extends ReceiveNotificationT>(
341-
method: string,
342-
handler: (notification: T) => void | Promise<void>,
340+
setNotificationHandler<
341+
T extends ZodObject<{
342+
method: ZodLiteral<string>;
343+
}>,
344+
>(
345+
notificationSchema: T,
346+
handler: (notification: z.infer<T>) => void | Promise<void>,
343347
): void {
344-
this._notificationHandlers.set(method, (notification) =>
345-
Promise.resolve(handler(notification as T)),
348+
this._notificationHandlers.set(
349+
notificationSchema.shape.method.value,
350+
(notification) =>
351+
Promise.resolve(handler(notificationSchema.parse(notification))),
346352
);
347353
}
348354

0 commit comments

Comments
 (0)