@@ -78,22 +78,52 @@ export type RequestOptions = {
78
78
* If not specified, there is no maximum total timeout.
79
79
*/
80
80
maxTotalTimeout ?: number ;
81
+
82
+ /**
83
+ * May be used to indicate to the transport which incoming request to associate this outgoing request with.
84
+ */
85
+ relatedRequestId ?: RequestId ;
81
86
} ;
82
87
83
88
/**
84
- * Extra data given to request handlers .
89
+ * Options that can be given per notification .
85
90
*/
86
- export type RequestHandlerExtra = {
91
+ export type NotificationOptions = {
87
92
/**
88
- * An abort signal used to communicate if the request was cancelled from the sender's side .
93
+ * May be used to indicate to the transport which incoming request to associate this outgoing notification with .
89
94
*/
90
- signal : AbortSignal ;
95
+ relatedRequestId ?: RequestId ;
96
+ }
91
97
92
- /**
93
- * The session ID from the transport, if available.
94
- */
95
- sessionId ?: string ;
96
- } ;
98
+ /**
99
+ * Extra data given to request handlers.
100
+ */
101
+ export type RequestHandlerExtra < SendRequestT extends Request ,
102
+ SendNotificationT extends Notification > = {
103
+ /**
104
+ * An abort signal used to communicate if the request was cancelled from the sender's side.
105
+ */
106
+ signal : AbortSignal ;
107
+
108
+ /**
109
+ * The session ID from the transport, if available.
110
+ */
111
+ sessionId ?: string ;
112
+
113
+ /**
114
+ * Sends a notification that relates to the current request being handled.
115
+ *
116
+ * This is used by certain transports to correctly associate related messages.
117
+ */
118
+ sendNotification : ( notification : SendNotificationT ) => Promise < void > ;
119
+
120
+ /**
121
+ * Sends a request that relates to the current request being handled.
122
+ *
123
+ * This is used by certain transports to correctly associate related messages.
124
+ */
125
+ sendRequest : < U extends ZodType < object > > ( request : SendRequestT , resultSchema : U , options ?: RequestOptions ) => Promise < z . infer < U > > ;
126
+ } ;
97
127
98
128
/**
99
129
* Information about a request's timeout state
@@ -122,7 +152,7 @@ export abstract class Protocol<
122
152
string ,
123
153
(
124
154
request : JSONRPCRequest ,
125
- extra : RequestHandlerExtra ,
155
+ extra : RequestHandlerExtra < SendRequestT , SendNotificationT > ,
126
156
) => Promise < SendResultT >
127
157
> = new Map ( ) ;
128
158
private _requestHandlerAbortControllers : Map < RequestId , AbortController > =
@@ -316,9 +346,14 @@ export abstract class Protocol<
316
346
this . _requestHandlerAbortControllers . set ( request . id , abortController ) ;
317
347
318
348
// Create extra object with both abort signal and sessionId from transport
319
- const extra : RequestHandlerExtra = {
349
+ const extra : RequestHandlerExtra < SendRequestT , SendNotificationT > = {
320
350
signal : abortController . signal ,
321
351
sessionId : this . _transport ?. sessionId ,
352
+ sendNotification :
353
+ ( notification ) =>
354
+ this . notification ( notification , { relatedRequestId : request . id } ) ,
355
+ sendRequest : ( r , resultSchema , options ?) =>
356
+ this . request ( r , resultSchema , { ...options , relatedRequestId : request . id } )
322
357
} ;
323
358
324
359
// Starting with Promise.resolve() puts any synchronous errors into the monad as well.
@@ -364,7 +399,7 @@ export abstract class Protocol<
364
399
private _onprogress ( notification : ProgressNotification ) : void {
365
400
const { progressToken, ...params } = notification . params ;
366
401
const messageId = Number ( progressToken ) ;
367
-
402
+
368
403
const handler = this . _progressHandlers . get ( messageId ) ;
369
404
if ( ! handler ) {
370
405
this . _onerror ( new Error ( `Received a progress notification for an unknown token: ${ JSON . stringify ( notification ) } ` ) ) ;
@@ -373,7 +408,7 @@ export abstract class Protocol<
373
408
374
409
const responseHandler = this . _responseHandlers . get ( messageId ) ;
375
410
const timeoutInfo = this . _timeoutInfo . get ( messageId ) ;
376
-
411
+
377
412
if ( timeoutInfo && responseHandler && timeoutInfo . resetTimeoutOnProgress ) {
378
413
try {
379
414
this . _resetTimeout ( messageId ) ;
@@ -460,6 +495,8 @@ export abstract class Protocol<
460
495
resultSchema : T ,
461
496
options ?: RequestOptions ,
462
497
) : Promise < z . infer < T > > {
498
+ const { relatedRequestId } = options ?? { } ;
499
+
463
500
return new Promise ( ( resolve , reject ) => {
464
501
if ( ! this . _transport ) {
465
502
reject ( new Error ( "Not connected" ) ) ;
@@ -500,7 +537,7 @@ export abstract class Protocol<
500
537
requestId : messageId ,
501
538
reason : String ( reason ) ,
502
539
} ,
503
- } )
540
+ } , { relatedRequestId } )
504
541
. catch ( ( error ) =>
505
542
this . _onerror ( new Error ( `Failed to send cancellation: ${ error } ` ) ) ,
506
543
) ;
@@ -538,7 +575,7 @@ export abstract class Protocol<
538
575
539
576
this . _setupTimeout ( messageId , timeout , options ?. maxTotalTimeout , timeoutHandler , options ?. resetTimeoutOnProgress ?? false ) ;
540
577
541
- this . _transport . send ( jsonrpcRequest ) . catch ( ( error ) => {
578
+ this . _transport . send ( jsonrpcRequest , { relatedRequestId } ) . catch ( ( error ) => {
542
579
this . _cleanupTimeout ( messageId ) ;
543
580
reject ( error ) ;
544
581
} ) ;
@@ -548,7 +585,7 @@ export abstract class Protocol<
548
585
/**
549
586
* Emits a notification, which is a one-way message that does not expect a response.
550
587
*/
551
- async notification ( notification : SendNotificationT ) : Promise < void > {
588
+ async notification ( notification : SendNotificationT , options ?: NotificationOptions ) : Promise < void > {
552
589
if ( ! this . _transport ) {
553
590
throw new Error ( "Not connected" ) ;
554
591
}
@@ -560,7 +597,7 @@ export abstract class Protocol<
560
597
jsonrpc : "2.0" ,
561
598
} ;
562
599
563
- await this . _transport . send ( jsonrpcNotification ) ;
600
+ await this . _transport . send ( jsonrpcNotification , options ) ;
564
601
}
565
602
566
603
/**
@@ -576,14 +613,15 @@ export abstract class Protocol<
576
613
requestSchema : T ,
577
614
handler : (
578
615
request : z . infer < T > ,
579
- extra : RequestHandlerExtra ,
616
+ extra : RequestHandlerExtra < SendRequestT , SendNotificationT > ,
580
617
) => SendResultT | Promise < SendResultT > ,
581
618
) : void {
582
619
const method = requestSchema . shape . method . value ;
583
620
this . assertRequestHandlerCapability ( method ) ;
584
- this . _requestHandlers . set ( method , ( request , extra ) =>
585
- Promise . resolve ( handler ( requestSchema . parse ( request ) , extra ) ) ,
586
- ) ;
621
+
622
+ this . _requestHandlers . set ( method , ( request , extra ) => {
623
+ return Promise . resolve ( handler ( requestSchema . parse ( request ) , extra ) ) ;
624
+ } ) ;
587
625
}
588
626
589
627
/**
0 commit comments