1
1
use std:: { borrow:: Cow , sync:: Arc , time:: Duration } ;
2
2
3
- use futures:: { StreamExt , future:: BoxFuture , stream:: BoxStream } ;
3
+ use futures:: { Stream , StreamExt , future:: BoxFuture , stream:: BoxStream } ;
4
4
pub use sse_stream:: Error as SseError ;
5
5
use sse_stream:: Sse ;
6
6
use thiserror:: Error ;
@@ -193,8 +193,7 @@ impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
193
193
client : C :: default ( ) ,
194
194
config : StreamableHttpClientTransportConfig {
195
195
uri : url. into ( ) ,
196
- retry_config : Arc :: new ( ExponentialBackoff :: default ( ) ) ,
197
- channel_buffer_capacity : 16 ,
196
+ ..Default :: default ( )
198
197
} ,
199
198
}
200
199
}
@@ -208,7 +207,9 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
208
207
209
208
impl < C : StreamableHttpClient > StreamableHttpClientWorker < C > {
210
209
async fn execute_sse_stream (
211
- sse_stream : SseAutoReconnectStream < StreamableHttpClientReconnect < C > > ,
210
+ sse_stream : impl Stream < Item = Result < ServerJsonRpcMessage , StreamableHttpError < C :: Error > > >
211
+ + Send
212
+ + ' static ,
212
213
sse_worker_tx : tokio:: sync:: mpsc:: Sender < ServerJsonRpcMessage > ,
213
214
ct : CancellationToken ,
214
215
) -> Result < ( ) , StreamableHttpError < C :: Error > > {
@@ -277,16 +278,19 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
277
278
. map_err ( WorkerQuitReason :: fatal_context (
278
279
"process initialize response" ,
279
280
) ) ?;
280
- let Some ( session_id) = session_id else {
281
- return Err ( WorkerQuitReason :: fatal (
282
- "missing session id in initialize response" ,
283
- "process initialize response" ,
284
- ) ) ;
281
+ let session_id: Option < Arc < str > > = if let Some ( session_id) = session_id {
282
+ Some ( session_id. into ( ) )
283
+ } else {
284
+ if !self . config . allow_stateless {
285
+ return Err ( WorkerQuitReason :: fatal (
286
+ "missing session id in initialize response" ,
287
+ "process initialize response" ,
288
+ ) ) ;
289
+ }
290
+ None
285
291
} ;
286
- let session_id: Arc < str > = session_id. into ( ) ;
287
-
288
292
// delete session when drop guard is dropped
289
- {
293
+ if let Some ( session_id ) = & session_id {
290
294
let ct = transport_task_ct. clone ( ) ;
291
295
let client = self . client . clone ( ) ;
292
296
let session_id = session_id. clone ( ) ;
@@ -322,7 +326,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
322
326
. post_message (
323
327
config. uri . clone ( ) ,
324
328
initialized_notification. message ,
325
- Some ( session_id. clone ( ) ) ,
329
+ session_id. clone ( ) ,
326
330
None ,
327
331
)
328
332
. await
@@ -340,38 +344,40 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
340
344
StreamResult ( Result < ( ) , StreamableHttpError < E > > ) ,
341
345
}
342
346
let mut streams = tokio:: task:: JoinSet :: new ( ) ;
343
- match self
344
- . client
345
- . get_stream ( config. uri . clone ( ) , session_id. clone ( ) , None , None )
346
- . await
347
- {
348
- Ok ( stream) => {
349
- let sse_stream = SseAutoReconnectStream :: new (
350
- stream,
351
- StreamableHttpClientReconnect {
352
- client : self . client . clone ( ) ,
353
- session_id : session_id. clone ( ) ,
354
- uri : config. uri . clone ( ) ,
355
- } ,
356
- self . config . retry_config . clone ( ) ,
357
- ) ;
358
- streams. spawn ( Self :: execute_sse_stream (
359
- sse_stream,
360
- sse_worker_tx. clone ( ) ,
361
- transport_task_ct. child_token ( ) ,
362
- ) ) ;
363
- tracing:: debug!( "got common stream" ) ;
364
- }
365
- Err ( StreamableHttpError :: SeverDoesNotSupportSse ) => {
366
- tracing:: debug!( "server doesn't support sse, skip common stream" ) ;
367
- }
368
- Err ( e) => {
369
- // fail to get common stream
370
- tracing:: error!( "fail to get common stream: {e}" ) ;
371
- return Err ( WorkerQuitReason :: fatal (
372
- "fail to get general purpose event stream" ,
373
- "get general purpose event stream" ,
374
- ) ) ;
347
+ if let Some ( session_id) = & session_id {
348
+ match self
349
+ . client
350
+ . get_stream ( config. uri . clone ( ) , session_id. clone ( ) , None , None )
351
+ . await
352
+ {
353
+ Ok ( stream) => {
354
+ let sse_stream = SseAutoReconnectStream :: new (
355
+ stream,
356
+ StreamableHttpClientReconnect {
357
+ client : self . client . clone ( ) ,
358
+ session_id : session_id. clone ( ) ,
359
+ uri : config. uri . clone ( ) ,
360
+ } ,
361
+ self . config . retry_config . clone ( ) ,
362
+ ) ;
363
+ streams. spawn ( Self :: execute_sse_stream (
364
+ sse_stream,
365
+ sse_worker_tx. clone ( ) ,
366
+ transport_task_ct. child_token ( ) ,
367
+ ) ) ;
368
+ tracing:: debug!( "got common stream" ) ;
369
+ }
370
+ Err ( StreamableHttpError :: SeverDoesNotSupportSse ) => {
371
+ tracing:: debug!( "server doesn't support sse, skip common stream" ) ;
372
+ }
373
+ Err ( e) => {
374
+ // fail to get common stream
375
+ tracing:: error!( "fail to get common stream: {e}" ) ;
376
+ return Err ( WorkerQuitReason :: fatal (
377
+ "fail to get general purpose event stream" ,
378
+ "get general purpose event stream" ,
379
+ ) ) ;
380
+ }
375
381
}
376
382
}
377
383
loop {
@@ -407,7 +413,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
407
413
let WorkerSendRequest { message, responder } = send_request;
408
414
let response = self
409
415
. client
410
- . post_message ( config. uri . clone ( ) , message, Some ( session_id. clone ( ) ) , None )
416
+ . post_message ( config. uri . clone ( ) , message, session_id. clone ( ) , None )
411
417
. await ;
412
418
let send_result = match response {
413
419
Err ( e) => Err ( e) ,
@@ -420,20 +426,32 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
420
426
Ok ( ( ) )
421
427
}
422
428
Ok ( StreamableHttpPostResponse :: Sse ( stream, ..) ) => {
423
- let sse_stream = SseAutoReconnectStream :: new (
424
- stream,
425
- StreamableHttpClientReconnect {
426
- client : self . client . clone ( ) ,
427
- session_id : session_id. clone ( ) ,
428
- uri : config. uri . clone ( ) ,
429
- } ,
430
- self . config . retry_config . clone ( ) ,
431
- ) ;
432
- streams. spawn ( Self :: execute_sse_stream (
433
- sse_stream,
434
- sse_worker_tx. clone ( ) ,
435
- transport_task_ct. child_token ( ) ,
436
- ) ) ;
429
+ if let Some ( session_id) = & session_id {
430
+ let sse_stream = SseAutoReconnectStream :: new (
431
+ stream,
432
+ StreamableHttpClientReconnect {
433
+ client : self . client . clone ( ) ,
434
+ session_id : session_id. clone ( ) ,
435
+ uri : config. uri . clone ( ) ,
436
+ } ,
437
+ self . config . retry_config . clone ( ) ,
438
+ ) ;
439
+ streams. spawn ( Self :: execute_sse_stream (
440
+ sse_stream,
441
+ sse_worker_tx. clone ( ) ,
442
+ transport_task_ct. child_token ( ) ,
443
+ ) ) ;
444
+ } else {
445
+ let sse_stream = SseAutoReconnectStream :: never_reconnect (
446
+ stream,
447
+ StreamableHttpError :: < C :: Error > :: UnexpectedEndOfStream ,
448
+ ) ;
449
+ streams. spawn ( Self :: execute_sse_stream (
450
+ sse_stream,
451
+ sse_worker_tx. clone ( ) ,
452
+ transport_task_ct. child_token ( ) ,
453
+ ) ) ;
454
+ }
437
455
tracing:: trace!( "got new sse stream" ) ;
438
456
Ok ( ( ) )
439
457
}
@@ -470,6 +488,8 @@ pub struct StreamableHttpClientTransportConfig {
470
488
pub uri : Arc < str > ,
471
489
pub retry_config : Arc < dyn SseRetryPolicy > ,
472
490
pub channel_buffer_capacity : usize ,
491
+ /// if true, the transport will not require a session to be established
492
+ pub allow_stateless : bool ,
473
493
}
474
494
475
495
impl StreamableHttpClientTransportConfig {
@@ -487,6 +507,7 @@ impl Default for StreamableHttpClientTransportConfig {
487
507
uri : "localhost" . into ( ) ,
488
508
retry_config : Arc :: new ( ExponentialBackoff :: default ( ) ) ,
489
509
channel_buffer_capacity : 16 ,
510
+ allow_stateless : true ,
490
511
}
491
512
}
492
513
}
0 commit comments