1
+ import { AnyZodObject , ZodLiteral , ZodObject , z } from "zod" ;
1
2
import {
2
3
ErrorCode ,
3
4
JSONRPCError ,
@@ -25,9 +26,6 @@ export type ProgressCallback = (progress: Progress) => void;
25
26
* features like request/response linking, notifications, and progress.
26
27
*/
27
28
export class Protocol <
28
- ReceiveRequestT extends Request ,
29
- ReceiveNotificationT extends Notification ,
30
- ReceiveResultT extends Result ,
31
29
SendRequestT extends Request ,
32
30
SendNotificationT extends Notification ,
33
31
SendResultT extends Result ,
@@ -36,15 +34,15 @@ export class Protocol<
36
34
private _requestMessageId = 0 ;
37
35
private _requestHandlers : Map <
38
36
string ,
39
- ( request : ReceiveRequestT ) => Promise < SendResultT >
37
+ ( request : JSONRPCRequest ) => Promise < SendResultT >
40
38
> = new Map ( ) ;
41
39
private _notificationHandlers : Map <
42
40
string ,
43
- ( notification : ReceiveNotificationT ) => Promise < void >
41
+ ( notification : JSONRPCNotification ) => Promise < void >
44
42
> = new Map ( ) ;
45
43
private _responseHandlers : Map <
46
44
number ,
47
- ( response : ReceiveResultT | Error ) => void
45
+ ( response : JSONRPCResponse | Error ) => void
48
46
> = new Map ( ) ;
49
47
private _progressHandlers : Map < number , ProgressCallback > = new Map ( ) ;
50
48
@@ -65,25 +63,20 @@ export class Protocol<
65
63
/**
66
64
* A handler to invoke for any request types that do not have their own handler installed.
67
65
*/
68
- fallbackRequestHandler ?: ( request : ReceiveRequestT ) => Promise < SendResultT > ;
66
+ fallbackRequestHandler ?: ( request : Request ) => Promise < SendResultT > ;
69
67
70
68
/**
71
69
* A handler to invoke for any notification types that do not have their own handler installed.
72
70
*/
73
- fallbackNotificationHandler ?: (
74
- notification : ReceiveNotificationT ,
75
- ) => Promise < void > ;
71
+ fallbackNotificationHandler ?: ( notification : Notification ) => Promise < void > ;
76
72
77
73
constructor ( ) {
78
- this . setNotificationHandler (
79
- ProgressNotificationSchema . shape . method . value ,
80
- ( notification ) => {
81
- this . _onprogress ( notification as unknown as ProgressNotification ) ;
82
- } ,
83
- ) ;
74
+ this . setNotificationHandler ( ProgressNotificationSchema , ( notification ) => {
75
+ this . _onprogress ( notification as unknown as ProgressNotification ) ;
76
+ } ) ;
84
77
85
78
this . setRequestHandler (
86
- PingRequestSchema . shape . method . value ,
79
+ PingRequestSchema ,
87
80
// Automatic pong by default.
88
81
( _request ) => ( { } ) as SendResultT ,
89
82
) ;
@@ -106,11 +99,11 @@ export class Protocol<
106
99
107
100
this . _transport . onmessage = ( message ) => {
108
101
if ( ! ( "method" in message ) ) {
109
- this . _onresponse ( message as JSONRPCResponse | JSONRPCError ) ;
102
+ this . _onresponse ( message ) ;
110
103
} else if ( "id" in message ) {
111
- this . _onrequest ( message as JSONRPCRequest ) ;
104
+ this . _onrequest ( message ) ;
112
105
} else {
113
- this . _onnotification ( message as JSONRPCNotification ) ;
106
+ this . _onnotification ( message ) ;
114
107
}
115
108
} ;
116
109
}
@@ -142,7 +135,7 @@ export class Protocol<
142
135
return ;
143
136
}
144
137
145
- handler ( notification as unknown as ReceiveNotificationT ) . catch ( ( error ) =>
138
+ handler ( notification ) . catch ( ( error ) =>
146
139
this . _onerror (
147
140
new Error ( `Uncaught error in notification handler: ${ error } ` ) ,
148
141
) ,
@@ -171,7 +164,7 @@ export class Protocol<
171
164
return ;
172
165
}
173
166
174
- handler ( request as unknown as ReceiveRequestT )
167
+ handler ( request )
175
168
. then (
176
169
( result ) => {
177
170
this . _transport ?. send ( {
@@ -228,7 +221,7 @@ export class Protocol<
228
221
this . _responseHandlers . delete ( Number ( messageId ) ) ;
229
222
this . _progressHandlers . delete ( Number ( messageId ) ) ;
230
223
if ( "result" in response ) {
231
- handler ( response . result as ReceiveResultT ) ;
224
+ handler ( response ) ;
232
225
} else {
233
226
const error = new McpError (
234
227
response . error . code ,
@@ -255,11 +248,11 @@ export class Protocol<
255
248
*
256
249
* Do not use this method to emit notifications! Use notification() instead.
257
250
*/
258
- // TODO: This could infer a better response type based on the method
259
- request (
251
+ request < T extends AnyZodObject > (
260
252
request : SendRequestT ,
253
+ resultSchema : T ,
261
254
onprogress ?: ProgressCallback ,
262
- ) : Promise < ReceiveResultT > {
255
+ ) : Promise < z . infer < T > > {
263
256
return new Promise ( ( resolve , reject ) => {
264
257
if ( ! this . _transport ) {
265
258
reject ( new Error ( "Not connected" ) ) ;
@@ -283,9 +276,14 @@ export class Protocol<
283
276
284
277
this . _responseHandlers . set ( messageId , ( response ) => {
285
278
if ( response instanceof Error ) {
286
- reject ( response ) ;
287
- } else {
288
- resolve ( response ) ;
279
+ return reject ( response ) ;
280
+ }
281
+
282
+ try {
283
+ const result = resultSchema . parse ( response . result ) ;
284
+ resolve ( result ) ;
285
+ } catch ( error ) {
286
+ reject ( error ) ;
289
287
}
290
288
} ) ;
291
289
@@ -314,13 +312,16 @@ export class Protocol<
314
312
*
315
313
* Note that this will replace any previous request handler for the same method.
316
314
*/
317
- // TODO: This could infer a better request type based on the method.
318
- setRequestHandler (
319
- method : string ,
320
- handler : ( request : ReceiveRequestT ) => SendResultT | Promise < SendResultT > ,
315
+ setRequestHandler <
316
+ T extends ZodObject < {
317
+ method : ZodLiteral < string > ;
318
+ } > ,
319
+ > (
320
+ requestSchema : T ,
321
+ handler : ( request : z . infer < T > ) => SendResultT | Promise < SendResultT > ,
321
322
) : void {
322
- this . _requestHandlers . set ( method , ( request ) =>
323
- Promise . resolve ( handler ( request ) ) ,
323
+ this . _requestHandlers . set ( requestSchema . shape . method . value , ( request ) =>
324
+ Promise . resolve ( handler ( requestSchema . parse ( request ) ) ) ,
324
325
) ;
325
326
}
326
327
@@ -336,13 +337,18 @@ export class Protocol<
336
337
*
337
338
* Note that this will replace any previous notification handler for the same method.
338
339
*/
339
- // TODO: This could infer a better notification type based on the method.
340
- setNotificationHandler < T extends ReceiveNotificationT > (
341
- method : string ,
342
- handler : ( notification : T ) => void | Promise < void > ,
340
+ setNotificationHandler <
341
+ T extends ZodObject < {
342
+ method : ZodLiteral < string > ;
343
+ } > ,
344
+ > (
345
+ notificationSchema : T ,
346
+ handler : ( notification : z . infer < T > ) => void | Promise < void > ,
343
347
) : void {
344
- this . _notificationHandlers . set ( method , ( notification ) =>
345
- Promise . resolve ( handler ( notification as T ) ) ,
348
+ this . _notificationHandlers . set (
349
+ notificationSchema . shape . method . value ,
350
+ ( notification ) =>
351
+ Promise . resolve ( handler ( notificationSchema . parse ( notification ) ) ) ,
346
352
) ;
347
353
}
348
354
0 commit comments