diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index fd3aa1d5..c82cd177 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -36,7 +36,7 @@ impl StreamableHttpClient for reqwest::Client { } let response = request_builder.send().await?; if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { - return Err(StreamableHttpError::SeverDoesNotSupportSse); + return Err(StreamableHttpError::ServerDoesNotSupportSse); } let response = response.error_for_status()?; match response.headers().get(reqwest::header::CONTENT_TYPE) { diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 20159a01..22a6b0d6 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -33,15 +33,17 @@ pub enum StreamableHttpError { #[error("Unexpected content type: {0:?}")] UnexpectedContentType(Option), #[error("Server does not support SSE")] - SeverDoesNotSupportSse, + ServerDoesNotSupportSse, #[error("Server does not support delete session")] - SeverDoesNotSupportDeleteSession, + ServerDoesNotSupportDeleteSession, #[error("Tokio join error: {0}")] TokioJoinError(#[from] tokio::task::JoinError), #[error("Deserialize error: {0}")] Deserialize(#[from] serde_json::Error), #[error("Transport channel closed")] TransportChannelClosed, + #[error("Missing session id in HTTP response")] + MissingSessionIdInResponse, #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] #[error("Auth error: {0}")] @@ -54,6 +56,11 @@ impl From for StreamableHttpError { } } +#[derive(Debug, Clone, Error)] +pub enum StreamableHttpProtocolError { + #[error("Missing session id in response")] + MissingSessionIdInResponse, +} pub enum StreamableHttpPostResponse { Accepted, Json(ServerJsonRpcMessage, Option), @@ -261,7 +268,7 @@ impl Worker for StreamableHttpClientWorker { async fn run( self, mut context: super::worker::WorkerContext, - ) -> Result<(), WorkerQuitReason> { + ) -> Result<(), WorkerQuitReason> { let channel_buffer_capacity = self.config.channel_buffer_capacity; let (sse_worker_tx, mut sse_worker_rx) = tokio::sync::mpsc::channel::(channel_buffer_capacity); @@ -278,7 +285,7 @@ impl Worker for StreamableHttpClientWorker { .post_message(config.uri.clone(), initialize_request, None, None) .await .map_err(WorkerQuitReason::fatal_context("send initialize request"))? - .expect_initialized::() + .expect_initialized::() .await .map_err(WorkerQuitReason::fatal_context( "process initialize response", @@ -288,7 +295,7 @@ impl Worker for StreamableHttpClientWorker { } else { if !self.config.allow_stateless { return Err(WorkerQuitReason::fatal( - "missing session id in initialize response", + StreamableHttpError::::MissingSessionIdInResponse, "process initialize response", )); } @@ -308,7 +315,7 @@ impl Worker for StreamableHttpClientWorker { Ok(_) => { tracing::info!(session_id = session_id.as_ref(), "delete session success") } - Err(StreamableHttpError::SeverDoesNotSupportDeleteSession) => { + Err(StreamableHttpError::ServerDoesNotSupportDeleteSession) => { tracing::info!( session_id = session_id.as_ref(), "server doesn't support delete session" @@ -338,7 +345,7 @@ impl Worker for StreamableHttpClientWorker { .map_err(WorkerQuitReason::fatal_context( "send initialized notification", ))? - .expect_accepted::() + .expect_accepted::() .map_err(WorkerQuitReason::fatal_context( "process initialized notification response", ))?; @@ -373,14 +380,14 @@ impl Worker for StreamableHttpClientWorker { )); tracing::debug!("got common stream"); } - Err(StreamableHttpError::SeverDoesNotSupportSse) => { + Err(StreamableHttpError::ServerDoesNotSupportSse) => { tracing::debug!("server doesn't support sse, skip common stream"); } Err(e) => { // fail to get common stream tracing::error!("fail to get common stream: {e}"); return Err(WorkerQuitReason::fatal( - "fail to get general purpose event stream", + e, "get general purpose event stream", )); } diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index c1c4f893..5458c404 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -296,12 +296,8 @@ pub enum SessionError { SessionServiceTerminated, #[error("Invalid event id")] InvalidEventId, - #[error("Transport closed")] - TransportClosed, #[error("IO error: {0}")] Io(#[from] std::io::Error), - #[error("Tokio join error {0}")] - TokioJoinError(#[from] tokio::task::JoinError), } impl From for std::io::Error { @@ -317,7 +313,7 @@ enum OutboundChannel { RequestWise { id: HttpRequestId, close: bool }, Common, } - +#[derive(Debug)] pub struct StreamableHttpMessageReceiver { pub http_request_id: Option, pub inner: Receiver, @@ -534,8 +530,8 @@ impl LocalSessionWorker { } } } - -enum SessionEvent { +#[derive(Debug)] +pub enum SessionEvent { ClientMessage { message: ClientJsonRpcMessage, http_request_id: Option, @@ -695,14 +691,31 @@ impl LocalSessionHandle { pub type SessionTransport = WorkerTransport; +#[derive(Debug, Error)] +pub enum LocalSessionWorkerError { + #[error("transport terminated")] + TransportTerminated, + #[error("unexpected message: {0:?}")] + UnexpectedEvent(SessionEvent), + #[error("fail to send initialize request {0}")] + FailToSendInitializeRequest(SessionError), + #[error("fail to handle message: {0}")] + FailToHandleMessage(SessionError), + #[error("keep alive timeout after {}ms", _0.as_millis())] + KeepAliveTimeout(Duration), + #[error("Transport closed")] + TransportClosed, + #[error("Tokio join error {0}")] + TokioJoinError(#[from] tokio::task::JoinError), +} impl Worker for LocalSessionWorker { - type Error = SessionError; + type Error = LocalSessionWorkerError; type Role = RoleServer; fn err_closed() -> Self::Error { - SessionError::TransportClosed + LocalSessionWorkerError::TransportClosed } fn err_join(e: tokio::task::JoinError) -> Self::Error { - SessionError::TokioJoinError(e) + LocalSessionWorkerError::TokioJoinError(e) } fn config(&self) -> crate::transport::worker::WorkerConfig { crate::transport::worker::WorkerConfig { @@ -711,18 +724,24 @@ impl Worker for LocalSessionWorker { } } #[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))] - async fn run(mut self, mut context: WorkerContext) -> Result<(), WorkerQuitReason> { + async fn run( + mut self, + mut context: WorkerContext, + ) -> Result<(), WorkerQuitReason> { enum InnerEvent { FromHttpService(SessionEvent), FromHandler(WorkerSendRequest), } // waiting for initialize request let evt = self.event_rx.recv().await.ok_or_else(|| { - WorkerQuitReason::fatal("transport terminated", "get initialize request") + WorkerQuitReason::fatal( + LocalSessionWorkerError::TransportTerminated, + "get initialize request", + ) })?; let SessionEvent::InitializeRequest { request, responder } = evt else { return Err(WorkerQuitReason::fatal( - "unexpected message", + LocalSessionWorkerError::UnexpectedEvent(evt), "get initialize request", )); }; @@ -732,7 +751,9 @@ impl Worker for LocalSessionWorker { .send(Ok(send_initialize_response.message)) .map_err(|_| { WorkerQuitReason::fatal( - "failed to send initialize response to http service", + LocalSessionWorkerError::FailToSendInitializeRequest( + SessionError::SessionServiceTerminated, + ), "send initialize response", ) })?; @@ -749,7 +770,7 @@ impl Worker for LocalSessionWorker { if let Some(event) = event { InnerEvent::FromHttpService(event) } else { - return Err(WorkerQuitReason::fatal("session dropped", "waiting next session event")) + return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event")) } }, from_handler = context.recv_from_handler() => { @@ -759,7 +780,7 @@ impl Worker for LocalSessionWorker { return Err(WorkerQuitReason::Cancelled) } _ = keep_alive_timeout => { - return Err(WorkerQuitReason::fatal("keep live timeout", "poll next session event")) + return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event")) } }; match event { @@ -779,7 +800,10 @@ impl Worker for LocalSessionWorker { // no need to unregister resource } }; - let handle_result = self.handle_server_message(message).await; + let handle_result = self + .handle_server_message(message) + .await + .map_err(LocalSessionWorkerError::FailToHandleMessage); let _ = responder.send(handle_result).inspect_err(|error| { tracing::warn!(?error, "failed to send message to http service handler"); }); diff --git a/crates/rmcp/src/transport/worker.rs b/crates/rmcp/src/transport/worker.rs index 5ae9098e..eaabc506 100644 --- a/crates/rmcp/src/transport/worker.rs +++ b/crates/rmcp/src/transport/worker.rs @@ -7,12 +7,12 @@ use super::{IntoTransport, Transport}; use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; #[derive(Debug, thiserror::Error)] -pub enum WorkerQuitReason { +pub enum WorkerQuitReason { #[error("Join error {0}")] Join(#[from] tokio::task::JoinError), #[error("Transport fatal {error}, when {context}")] Fatal { - error: Cow<'static, str>, + error: E, context: Cow<'static, str>, }, #[error("Transport canncelled")] @@ -23,18 +23,16 @@ pub enum WorkerQuitReason { HandlerTerminated, } -impl WorkerQuitReason { - pub fn fatal(msg: impl Into>, context: impl Into>) -> Self { +impl WorkerQuitReason { + pub fn fatal(error: E, context: impl Into>) -> Self { Self::Fatal { - error: msg.into(), + error, context: context.into(), } } - pub fn fatal_context( - context: impl Into>, - ) -> impl FnOnce(E) -> Self { + pub fn fatal_context(context: impl Into>) -> impl FnOnce(E) -> Self { |e| Self::Fatal { - error: Cow::Owned(format!("{e}")), + error: e, context: context.into(), } } @@ -48,7 +46,7 @@ pub trait Worker: Sized + Send + 'static { fn run( self, context: WorkerContext, - ) -> impl Future> + Send; + ) -> impl Future>> + Send; fn config(&self) -> WorkerConfig { WorkerConfig::default() } @@ -62,7 +60,7 @@ pub struct WorkerSendRequest { pub struct WorkerTransport { rx: tokio::sync::mpsc::Receiver>, send_service: tokio::sync::mpsc::Sender>, - join_handle: Option>>, + join_handle: Option>>>, _drop_guard: tokio_util::sync::DropGuard, ct: CancellationToken, } @@ -159,14 +157,16 @@ impl WorkerContext { pub async fn send_to_handler( &mut self, item: RxJsonRpcMessage, - ) -> Result<(), WorkerQuitReason> { + ) -> Result<(), WorkerQuitReason> { self.to_handler_tx .send(item) .await .map_err(|_| WorkerQuitReason::HandlerTerminated) } - pub async fn recv_from_handler(&mut self) -> Result, WorkerQuitReason> { + pub async fn recv_from_handler( + &mut self, + ) -> Result, WorkerQuitReason> { self.from_handler_rx .recv() .await