Skip to content

Commit ff71a52

Browse files
authored
feat: stateless mode of streamable http client (#233)
1 parent 209be7b commit ff71a52

File tree

2 files changed

+117
-60
lines changed

2 files changed

+117
-60
lines changed

crates/rmcp/src/transport/common/client_side_sse.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,28 @@ impl SseRetryPolicy for ExponentialBackoff {
7676
}
7777
}
7878

79+
#[derive(Debug, Clone, Copy, Default)]
80+
pub struct NeverRetry;
81+
82+
impl SseRetryPolicy for NeverRetry {
83+
fn retry(&self, _current_times: usize) -> Option<Duration> {
84+
None
85+
}
86+
}
87+
88+
#[derive(Debug, Default)]
89+
pub struct NeverReconnect<E> {
90+
error: Option<E>,
91+
}
92+
93+
impl<E: std::error::Error + Send> SseStreamReconnect for NeverReconnect<E> {
94+
type Error = E;
95+
type Future = futures::future::Ready<Result<BoxedSseResponse, Self::Error>>;
96+
fn retry_connection(&mut self, _last_event_id: Option<&str>) -> Self::Future {
97+
futures::future::ready(Err(self.error.take().expect("should not be called again")))
98+
}
99+
}
100+
79101
pub(crate) trait SseStreamReconnect {
80102
type Error: std::error::Error;
81103
type Future: Future<Output = Result<BoxedSseResponse, Self::Error>> + Send;
@@ -111,6 +133,20 @@ impl<R: SseStreamReconnect> SseAutoReconnectStream<R> {
111133
}
112134
}
113135

136+
impl<E: std::error::Error + Send> SseAutoReconnectStream<NeverReconnect<E>> {
137+
pub fn never_reconnect(stream: BoxedSseResponse, error_when_reconnect: E) -> Self {
138+
Self {
139+
retry_policy: Arc::new(NeverRetry),
140+
last_event_id: None,
141+
server_retry_interval: None,
142+
connector: NeverReconnect {
143+
error: Some(error_when_reconnect),
144+
},
145+
state: SseAutoReconnectStreamState::Connected { stream },
146+
}
147+
}
148+
}
149+
114150
pin_project_lite::pin_project! {
115151
#[project = SseAutoReconnectStreamStateProj]
116152
pub enum SseAutoReconnectStreamState<F> {

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{borrow::Cow, sync::Arc, time::Duration};
22

3-
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
3+
use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
44
pub use sse_stream::Error as SseError;
55
use sse_stream::Sse;
66
use thiserror::Error;
@@ -193,8 +193,7 @@ impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
193193
client: C::default(),
194194
config: StreamableHttpClientTransportConfig {
195195
uri: url.into(),
196-
retry_config: Arc::new(ExponentialBackoff::default()),
197-
channel_buffer_capacity: 16,
196+
..Default::default()
198197
},
199198
}
200199
}
@@ -208,7 +207,9 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
208207

209208
impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
210209
async fn execute_sse_stream(
211-
sse_stream: SseAutoReconnectStream<StreamableHttpClientReconnect<C>>,
210+
sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
211+
+ Send
212+
+ 'static,
212213
sse_worker_tx: tokio::sync::mpsc::Sender<ServerJsonRpcMessage>,
213214
ct: CancellationToken,
214215
) -> Result<(), StreamableHttpError<C::Error>> {
@@ -277,16 +278,19 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
277278
.map_err(WorkerQuitReason::fatal_context(
278279
"process initialize response",
279280
))?;
280-
let Some(session_id) = session_id else {
281-
return Err(WorkerQuitReason::fatal(
282-
"missing session id in initialize response",
283-
"process initialize response",
284-
));
281+
let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
282+
Some(session_id.into())
283+
} else {
284+
if !self.config.allow_stateless {
285+
return Err(WorkerQuitReason::fatal(
286+
"missing session id in initialize response",
287+
"process initialize response",
288+
));
289+
}
290+
None
285291
};
286-
let session_id: Arc<str> = session_id.into();
287-
288292
// delete session when drop guard is dropped
289-
{
293+
if let Some(session_id) = &session_id {
290294
let ct = transport_task_ct.clone();
291295
let client = self.client.clone();
292296
let session_id = session_id.clone();
@@ -322,7 +326,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
322326
.post_message(
323327
config.uri.clone(),
324328
initialized_notification.message,
325-
Some(session_id.clone()),
329+
session_id.clone(),
326330
None,
327331
)
328332
.await
@@ -340,38 +344,40 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
340344
StreamResult(Result<(), StreamableHttpError<E>>),
341345
}
342346
let mut streams = tokio::task::JoinSet::new();
343-
match self
344-
.client
345-
.get_stream(config.uri.clone(), session_id.clone(), None, None)
346-
.await
347-
{
348-
Ok(stream) => {
349-
let sse_stream = SseAutoReconnectStream::new(
350-
stream,
351-
StreamableHttpClientReconnect {
352-
client: self.client.clone(),
353-
session_id: session_id.clone(),
354-
uri: config.uri.clone(),
355-
},
356-
self.config.retry_config.clone(),
357-
);
358-
streams.spawn(Self::execute_sse_stream(
359-
sse_stream,
360-
sse_worker_tx.clone(),
361-
transport_task_ct.child_token(),
362-
));
363-
tracing::debug!("got common stream");
364-
}
365-
Err(StreamableHttpError::SeverDoesNotSupportSse) => {
366-
tracing::debug!("server doesn't support sse, skip common stream");
367-
}
368-
Err(e) => {
369-
// fail to get common stream
370-
tracing::error!("fail to get common stream: {e}");
371-
return Err(WorkerQuitReason::fatal(
372-
"fail to get general purpose event stream",
373-
"get general purpose event stream",
374-
));
347+
if let Some(session_id) = &session_id {
348+
match self
349+
.client
350+
.get_stream(config.uri.clone(), session_id.clone(), None, None)
351+
.await
352+
{
353+
Ok(stream) => {
354+
let sse_stream = SseAutoReconnectStream::new(
355+
stream,
356+
StreamableHttpClientReconnect {
357+
client: self.client.clone(),
358+
session_id: session_id.clone(),
359+
uri: config.uri.clone(),
360+
},
361+
self.config.retry_config.clone(),
362+
);
363+
streams.spawn(Self::execute_sse_stream(
364+
sse_stream,
365+
sse_worker_tx.clone(),
366+
transport_task_ct.child_token(),
367+
));
368+
tracing::debug!("got common stream");
369+
}
370+
Err(StreamableHttpError::SeverDoesNotSupportSse) => {
371+
tracing::debug!("server doesn't support sse, skip common stream");
372+
}
373+
Err(e) => {
374+
// fail to get common stream
375+
tracing::error!("fail to get common stream: {e}");
376+
return Err(WorkerQuitReason::fatal(
377+
"fail to get general purpose event stream",
378+
"get general purpose event stream",
379+
));
380+
}
375381
}
376382
}
377383
loop {
@@ -407,7 +413,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
407413
let WorkerSendRequest { message, responder } = send_request;
408414
let response = self
409415
.client
410-
.post_message(config.uri.clone(), message, Some(session_id.clone()), None)
416+
.post_message(config.uri.clone(), message, session_id.clone(), None)
411417
.await;
412418
let send_result = match response {
413419
Err(e) => Err(e),
@@ -420,20 +426,32 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
420426
Ok(())
421427
}
422428
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
423-
let sse_stream = SseAutoReconnectStream::new(
424-
stream,
425-
StreamableHttpClientReconnect {
426-
client: self.client.clone(),
427-
session_id: session_id.clone(),
428-
uri: config.uri.clone(),
429-
},
430-
self.config.retry_config.clone(),
431-
);
432-
streams.spawn(Self::execute_sse_stream(
433-
sse_stream,
434-
sse_worker_tx.clone(),
435-
transport_task_ct.child_token(),
436-
));
429+
if let Some(session_id) = &session_id {
430+
let sse_stream = SseAutoReconnectStream::new(
431+
stream,
432+
StreamableHttpClientReconnect {
433+
client: self.client.clone(),
434+
session_id: session_id.clone(),
435+
uri: config.uri.clone(),
436+
},
437+
self.config.retry_config.clone(),
438+
);
439+
streams.spawn(Self::execute_sse_stream(
440+
sse_stream,
441+
sse_worker_tx.clone(),
442+
transport_task_ct.child_token(),
443+
));
444+
} else {
445+
let sse_stream = SseAutoReconnectStream::never_reconnect(
446+
stream,
447+
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
448+
);
449+
streams.spawn(Self::execute_sse_stream(
450+
sse_stream,
451+
sse_worker_tx.clone(),
452+
transport_task_ct.child_token(),
453+
));
454+
}
437455
tracing::trace!("got new sse stream");
438456
Ok(())
439457
}
@@ -470,6 +488,8 @@ pub struct StreamableHttpClientTransportConfig {
470488
pub uri: Arc<str>,
471489
pub retry_config: Arc<dyn SseRetryPolicy>,
472490
pub channel_buffer_capacity: usize,
491+
/// if true, the transport will not require a session to be established
492+
pub allow_stateless: bool,
473493
}
474494

475495
impl StreamableHttpClientTransportConfig {
@@ -487,6 +507,7 @@ impl Default for StreamableHttpClientTransportConfig {
487507
uri: "localhost".into(),
488508
retry_config: Arc::new(ExponentialBackoff::default()),
489509
channel_buffer_capacity: 16,
510+
allow_stateless: true,
490511
}
491512
}
492513
}

0 commit comments

Comments
 (0)