Skip to content

Commit 9190457

Browse files
committed
feat: support downcast WorkerQuitReason::Fatal
1 parent 209dbac commit 9190457

File tree

4 files changed

+47
-20
lines changed

4 files changed

+47
-20
lines changed

crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl StreamableHttpClient for reqwest::Client {
3636
}
3737
let response = request_builder.send().await?;
3838
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
39-
return Err(StreamableHttpError::SeverDoesNotSupportSse);
39+
return Err(StreamableHttpError::ServerDoesNotSupportSse);
4040
}
4141
let response = response.error_for_status()?;
4242
match response.headers().get(reqwest::header::CONTENT_TYPE) {

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,17 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
3333
#[error("Unexpected content type: {0:?}")]
3434
UnexpectedContentType(Option<String>),
3535
#[error("Server does not support SSE")]
36-
SeverDoesNotSupportSse,
36+
ServerDoesNotSupportSse,
3737
#[error("Server does not support delete session")]
38-
SeverDoesNotSupportDeleteSession,
38+
ServerDoesNotSupportDeleteSession,
3939
#[error("Tokio join error: {0}")]
4040
TokioJoinError(#[from] tokio::task::JoinError),
4141
#[error("Deserialize error: {0}")]
4242
Deserialize(#[from] serde_json::Error),
4343
#[error("Transport channel closed")]
4444
TransportChannelClosed,
45+
#[error("Missing session id in response")]
46+
MissingSessionIdInResponse,
4547
#[cfg(feature = "auth")]
4648
#[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
4749
#[error("Auth error: {0}")]
@@ -54,6 +56,12 @@ impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
5456
}
5557
}
5658

59+
#[derive(Debug, Clone, Error)]
60+
pub enum StreamableHttpProtocolError {
61+
#[error("Missing session id in response")]
62+
MissingSessionIdInResponse,
63+
}
64+
5765
pub enum StreamableHttpPostResponse {
5866
Accepted,
5967
Json(ServerJsonRpcMessage, Option<String>),
@@ -288,7 +296,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
288296
} else {
289297
if !self.config.allow_stateless {
290298
return Err(WorkerQuitReason::fatal(
291-
"missing session id in initialize response",
299+
StreamableHttpError::<C::Error>::MissingSessionIdInResponse,
292300
"process initialize response",
293301
));
294302
}
@@ -308,7 +316,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
308316
Ok(_) => {
309317
tracing::info!(session_id = session_id.as_ref(), "delete session success")
310318
}
311-
Err(StreamableHttpError::SeverDoesNotSupportDeleteSession) => {
319+
Err(StreamableHttpError::ServerDoesNotSupportDeleteSession) => {
312320
tracing::info!(
313321
session_id = session_id.as_ref(),
314322
"server doesn't support delete session"
@@ -373,14 +381,14 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
373381
));
374382
tracing::debug!("got common stream");
375383
}
376-
Err(StreamableHttpError::SeverDoesNotSupportSse) => {
384+
Err(StreamableHttpError::ServerDoesNotSupportSse) => {
377385
tracing::debug!("server doesn't support sse, skip common stream");
378386
}
379387
Err(e) => {
380388
// fail to get common stream
381389
tracing::error!("fail to get common stream: {e}");
382390
return Err(WorkerQuitReason::fatal(
383-
"fail to get general purpose event stream",
391+
e,
384392
"get general purpose event stream",
385393
));
386394
}

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ enum OutboundChannel {
317317
RequestWise { id: HttpRequestId, close: bool },
318318
Common,
319319
}
320-
320+
#[derive(Debug)]
321321
pub struct StreamableHttpMessageReceiver {
322322
pub http_request_id: Option<HttpRequestId>,
323323
pub inner: Receiver<ServerSseMessage>,
@@ -534,8 +534,8 @@ impl LocalSessionWorker {
534534
}
535535
}
536536
}
537-
538-
enum SessionEvent {
537+
#[derive(Debug)]
538+
pub enum SessionEvent {
539539
ClientMessage {
540540
message: ClientJsonRpcMessage,
541541
http_request_id: Option<HttpRequestId>,
@@ -695,6 +695,17 @@ impl LocalSessionHandle {
695695

696696
pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
697697

698+
#[derive(Debug, Error)]
699+
pub enum LocalSessionError {
700+
#[error("transport terminated")]
701+
TransportTerminated,
702+
#[error("unexpected message: {0:?}")]
703+
UnexpectedEvent(SessionEvent),
704+
#[error("fail to send initialize request {0}")]
705+
FailToSendInitializeRequest(SessionError),
706+
#[error("keep alive timeout")]
707+
KeepAliveTimeout,
708+
}
698709
impl Worker for LocalSessionWorker {
699710
type Error = SessionError;
700711
type Role = RoleServer;
@@ -718,11 +729,14 @@ impl Worker for LocalSessionWorker {
718729
}
719730
// waiting for initialize request
720731
let evt = self.event_rx.recv().await.ok_or_else(|| {
721-
WorkerQuitReason::fatal("transport terminated", "get initialize request")
732+
WorkerQuitReason::fatal(
733+
LocalSessionError::TransportTerminated,
734+
"get initialize request",
735+
)
722736
})?;
723737
let SessionEvent::InitializeRequest { request, responder } = evt else {
724738
return Err(WorkerQuitReason::fatal(
725-
"unexpected message",
739+
LocalSessionError::UnexpectedEvent(evt),
726740
"get initialize request",
727741
));
728742
};
@@ -732,7 +746,9 @@ impl Worker for LocalSessionWorker {
732746
.send(Ok(send_initialize_response.message))
733747
.map_err(|_| {
734748
WorkerQuitReason::fatal(
735-
"failed to send initialize response to http service",
749+
LocalSessionError::FailToSendInitializeRequest(
750+
SessionError::SessionServiceTerminated,
751+
),
736752
"send initialize response",
737753
)
738754
})?;
@@ -749,7 +765,7 @@ impl Worker for LocalSessionWorker {
749765
if let Some(event) = event {
750766
InnerEvent::FromHttpService(event)
751767
} else {
752-
return Err(WorkerQuitReason::fatal("session dropped", "waiting next session event"))
768+
return Err(WorkerQuitReason::fatal(LocalSessionError::TransportTerminated, "waiting next session event"))
753769
}
754770
},
755771
from_handler = context.recv_from_handler() => {
@@ -759,7 +775,7 @@ impl Worker for LocalSessionWorker {
759775
return Err(WorkerQuitReason::Cancelled)
760776
}
761777
_ = keep_alive_timeout => {
762-
return Err(WorkerQuitReason::fatal("keep live timeout", "poll next session event"))
778+
return Err(WorkerQuitReason::fatal(LocalSessionError::KeepAliveTimeout, "poll next session event"))
763779
}
764780
};
765781
match event {

crates/rmcp/src/transport/worker.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub enum WorkerQuitReason {
1212
Join(#[from] tokio::task::JoinError),
1313
#[error("Transport fatal {error}, when {context}")]
1414
Fatal {
15-
error: Cow<'static, str>,
15+
error: Box<dyn std::error::Error + Send>,
1616
context: Cow<'static, str>,
1717
},
1818
#[error("Transport canncelled")]
@@ -24,17 +24,20 @@ pub enum WorkerQuitReason {
2424
}
2525

2626
impl WorkerQuitReason {
27-
pub fn fatal(msg: impl Into<Cow<'static, str>>, context: impl Into<Cow<'static, str>>) -> Self {
27+
pub fn fatal(
28+
error: impl std::error::Error + Send + 'static,
29+
context: impl Into<Cow<'static, str>>,
30+
) -> Self {
2831
Self::Fatal {
29-
error: msg.into(),
32+
error: Box::new(error),
3033
context: context.into(),
3134
}
3235
}
33-
pub fn fatal_context<E: std::error::Error>(
36+
pub fn fatal_context<E: std::error::Error + Send + 'static>(
3437
context: impl Into<Cow<'static, str>>,
3538
) -> impl FnOnce(E) -> Self {
3639
|e| Self::Fatal {
37-
error: Cow::Owned(format!("{e}")),
40+
error: Box::new(e),
3841
context: context.into(),
3942
}
4043
}

0 commit comments

Comments
 (0)