Skip to content

Commit 0a2f77d

Browse files
4t145Copilot
andauthored
feat: keep internal error in worker's quit reason (#372)
* feat: support downcast WorkerQuitReason::Fatal * feat(transport): expose internal worker error in fatal * Update crates/rmcp/src/transport/streamable_http_client.rs Co-authored-by: Copilot <[email protected]> * Update crates/rmcp/src/transport/streamable_http_client.rs Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent 28f4781 commit 0a2f77d

File tree

4 files changed

+71
-40
lines changed

4 files changed

+71
-40
lines changed

crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl StreamableHttpClient for reqwest::Client {
3636
}
3737
let response = request_builder.send().await?;
3838
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
39-
return Err(StreamableHttpError::SeverDoesNotSupportSse);
39+
return Err(StreamableHttpError::ServerDoesNotSupportSse);
4040
}
4141
let response = response.error_for_status()?;
4242
match response.headers().get(reqwest::header::CONTENT_TYPE) {

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,17 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
3333
#[error("Unexpected content type: {0:?}")]
3434
UnexpectedContentType(Option<String>),
3535
#[error("Server does not support SSE")]
36-
SeverDoesNotSupportSse,
36+
ServerDoesNotSupportSse,
3737
#[error("Server does not support delete session")]
38-
SeverDoesNotSupportDeleteSession,
38+
ServerDoesNotSupportDeleteSession,
3939
#[error("Tokio join error: {0}")]
4040
TokioJoinError(#[from] tokio::task::JoinError),
4141
#[error("Deserialize error: {0}")]
4242
Deserialize(#[from] serde_json::Error),
4343
#[error("Transport channel closed")]
4444
TransportChannelClosed,
45+
#[error("Missing session id in HTTP response")]
46+
MissingSessionIdInResponse,
4547
#[cfg(feature = "auth")]
4648
#[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
4749
#[error("Auth error: {0}")]
@@ -54,6 +56,11 @@ impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
5456
}
5557
}
5658

59+
#[derive(Debug, Clone, Error)]
60+
pub enum StreamableHttpProtocolError {
61+
#[error("Missing session id in response")]
62+
MissingSessionIdInResponse,
63+
}
5764
pub enum StreamableHttpPostResponse {
5865
Accepted,
5966
Json(ServerJsonRpcMessage, Option<String>),
@@ -261,7 +268,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
261268
async fn run(
262269
self,
263270
mut context: super::worker::WorkerContext<Self>,
264-
) -> Result<(), WorkerQuitReason> {
271+
) -> Result<(), WorkerQuitReason<Self::Error>> {
265272
let channel_buffer_capacity = self.config.channel_buffer_capacity;
266273
let (sse_worker_tx, mut sse_worker_rx) =
267274
tokio::sync::mpsc::channel::<ServerJsonRpcMessage>(channel_buffer_capacity);
@@ -278,7 +285,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
278285
.post_message(config.uri.clone(), initialize_request, None, None)
279286
.await
280287
.map_err(WorkerQuitReason::fatal_context("send initialize request"))?
281-
.expect_initialized::<Self::Error>()
288+
.expect_initialized::<C::Error>()
282289
.await
283290
.map_err(WorkerQuitReason::fatal_context(
284291
"process initialize response",
@@ -288,7 +295,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
288295
} else {
289296
if !self.config.allow_stateless {
290297
return Err(WorkerQuitReason::fatal(
291-
"missing session id in initialize response",
298+
StreamableHttpError::<C::Error>::MissingSessionIdInResponse,
292299
"process initialize response",
293300
));
294301
}
@@ -308,7 +315,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
308315
Ok(_) => {
309316
tracing::info!(session_id = session_id.as_ref(), "delete session success")
310317
}
311-
Err(StreamableHttpError::SeverDoesNotSupportDeleteSession) => {
318+
Err(StreamableHttpError::ServerDoesNotSupportDeleteSession) => {
312319
tracing::info!(
313320
session_id = session_id.as_ref(),
314321
"server doesn't support delete session"
@@ -338,7 +345,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
338345
.map_err(WorkerQuitReason::fatal_context(
339346
"send initialized notification",
340347
))?
341-
.expect_accepted::<Self::Error>()
348+
.expect_accepted::<C::Error>()
342349
.map_err(WorkerQuitReason::fatal_context(
343350
"process initialized notification response",
344351
))?;
@@ -373,14 +380,14 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
373380
));
374381
tracing::debug!("got common stream");
375382
}
376-
Err(StreamableHttpError::SeverDoesNotSupportSse) => {
383+
Err(StreamableHttpError::ServerDoesNotSupportSse) => {
377384
tracing::debug!("server doesn't support sse, skip common stream");
378385
}
379386
Err(e) => {
380387
// fail to get common stream
381388
tracing::error!("fail to get common stream: {e}");
382389
return Err(WorkerQuitReason::fatal(
383-
"fail to get general purpose event stream",
390+
e,
384391
"get general purpose event stream",
385392
));
386393
}

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,8 @@ pub enum SessionError {
296296
SessionServiceTerminated,
297297
#[error("Invalid event id")]
298298
InvalidEventId,
299-
#[error("Transport closed")]
300-
TransportClosed,
301299
#[error("IO error: {0}")]
302300
Io(#[from] std::io::Error),
303-
#[error("Tokio join error {0}")]
304-
TokioJoinError(#[from] tokio::task::JoinError),
305301
}
306302

307303
impl From<SessionError> for std::io::Error {
@@ -317,7 +313,7 @@ enum OutboundChannel {
317313
RequestWise { id: HttpRequestId, close: bool },
318314
Common,
319315
}
320-
316+
#[derive(Debug)]
321317
pub struct StreamableHttpMessageReceiver {
322318
pub http_request_id: Option<HttpRequestId>,
323319
pub inner: Receiver<ServerSseMessage>,
@@ -534,8 +530,8 @@ impl LocalSessionWorker {
534530
}
535531
}
536532
}
537-
538-
enum SessionEvent {
533+
#[derive(Debug)]
534+
pub enum SessionEvent {
539535
ClientMessage {
540536
message: ClientJsonRpcMessage,
541537
http_request_id: Option<HttpRequestId>,
@@ -695,14 +691,31 @@ impl LocalSessionHandle {
695691

696692
pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
697693

694+
#[derive(Debug, Error)]
695+
pub enum LocalSessionWorkerError {
696+
#[error("transport terminated")]
697+
TransportTerminated,
698+
#[error("unexpected message: {0:?}")]
699+
UnexpectedEvent(SessionEvent),
700+
#[error("fail to send initialize request {0}")]
701+
FailToSendInitializeRequest(SessionError),
702+
#[error("fail to handle message: {0}")]
703+
FailToHandleMessage(SessionError),
704+
#[error("keep alive timeout after {}ms", _0.as_millis())]
705+
KeepAliveTimeout(Duration),
706+
#[error("Transport closed")]
707+
TransportClosed,
708+
#[error("Tokio join error {0}")]
709+
TokioJoinError(#[from] tokio::task::JoinError),
710+
}
698711
impl Worker for LocalSessionWorker {
699-
type Error = SessionError;
712+
type Error = LocalSessionWorkerError;
700713
type Role = RoleServer;
701714
fn err_closed() -> Self::Error {
702-
SessionError::TransportClosed
715+
LocalSessionWorkerError::TransportClosed
703716
}
704717
fn err_join(e: tokio::task::JoinError) -> Self::Error {
705-
SessionError::TokioJoinError(e)
718+
LocalSessionWorkerError::TokioJoinError(e)
706719
}
707720
fn config(&self) -> crate::transport::worker::WorkerConfig {
708721
crate::transport::worker::WorkerConfig {
@@ -711,18 +724,24 @@ impl Worker for LocalSessionWorker {
711724
}
712725
}
713726
#[instrument(name = "streamable_http_session", skip_all, fields(id = self.id.as_ref()))]
714-
async fn run(mut self, mut context: WorkerContext<Self>) -> Result<(), WorkerQuitReason> {
727+
async fn run(
728+
mut self,
729+
mut context: WorkerContext<Self>,
730+
) -> Result<(), WorkerQuitReason<Self::Error>> {
715731
enum InnerEvent {
716732
FromHttpService(SessionEvent),
717733
FromHandler(WorkerSendRequest<LocalSessionWorker>),
718734
}
719735
// waiting for initialize request
720736
let evt = self.event_rx.recv().await.ok_or_else(|| {
721-
WorkerQuitReason::fatal("transport terminated", "get initialize request")
737+
WorkerQuitReason::fatal(
738+
LocalSessionWorkerError::TransportTerminated,
739+
"get initialize request",
740+
)
722741
})?;
723742
let SessionEvent::InitializeRequest { request, responder } = evt else {
724743
return Err(WorkerQuitReason::fatal(
725-
"unexpected message",
744+
LocalSessionWorkerError::UnexpectedEvent(evt),
726745
"get initialize request",
727746
));
728747
};
@@ -732,7 +751,9 @@ impl Worker for LocalSessionWorker {
732751
.send(Ok(send_initialize_response.message))
733752
.map_err(|_| {
734753
WorkerQuitReason::fatal(
735-
"failed to send initialize response to http service",
754+
LocalSessionWorkerError::FailToSendInitializeRequest(
755+
SessionError::SessionServiceTerminated,
756+
),
736757
"send initialize response",
737758
)
738759
})?;
@@ -749,7 +770,7 @@ impl Worker for LocalSessionWorker {
749770
if let Some(event) = event {
750771
InnerEvent::FromHttpService(event)
751772
} else {
752-
return Err(WorkerQuitReason::fatal("session dropped", "waiting next session event"))
773+
return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::TransportTerminated, "waiting next session event"))
753774
}
754775
},
755776
from_handler = context.recv_from_handler() => {
@@ -759,7 +780,7 @@ impl Worker for LocalSessionWorker {
759780
return Err(WorkerQuitReason::Cancelled)
760781
}
761782
_ = keep_alive_timeout => {
762-
return Err(WorkerQuitReason::fatal("keep live timeout", "poll next session event"))
783+
return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
763784
}
764785
};
765786
match event {
@@ -779,7 +800,10 @@ impl Worker for LocalSessionWorker {
779800
// no need to unregister resource
780801
}
781802
};
782-
let handle_result = self.handle_server_message(message).await;
803+
let handle_result = self
804+
.handle_server_message(message)
805+
.await
806+
.map_err(LocalSessionWorkerError::FailToHandleMessage);
783807
let _ = responder.send(handle_result).inspect_err(|error| {
784808
tracing::warn!(?error, "failed to send message to http service handler");
785809
});

crates/rmcp/src/transport/worker.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ use super::{IntoTransport, Transport};
77
use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage};
88

99
#[derive(Debug, thiserror::Error)]
10-
pub enum WorkerQuitReason {
10+
pub enum WorkerQuitReason<E> {
1111
#[error("Join error {0}")]
1212
Join(#[from] tokio::task::JoinError),
1313
#[error("Transport fatal {error}, when {context}")]
1414
Fatal {
15-
error: Cow<'static, str>,
15+
error: E,
1616
context: Cow<'static, str>,
1717
},
1818
#[error("Transport canncelled")]
@@ -23,18 +23,16 @@ pub enum WorkerQuitReason {
2323
HandlerTerminated,
2424
}
2525

26-
impl WorkerQuitReason {
27-
pub fn fatal(msg: impl Into<Cow<'static, str>>, context: impl Into<Cow<'static, str>>) -> Self {
26+
impl<E: std::error::Error + Send + 'static> WorkerQuitReason<E> {
27+
pub fn fatal(error: E, context: impl Into<Cow<'static, str>>) -> Self {
2828
Self::Fatal {
29-
error: msg.into(),
29+
error,
3030
context: context.into(),
3131
}
3232
}
33-
pub fn fatal_context<E: std::error::Error>(
34-
context: impl Into<Cow<'static, str>>,
35-
) -> impl FnOnce(E) -> Self {
33+
pub fn fatal_context(context: impl Into<Cow<'static, str>>) -> impl FnOnce(E) -> Self {
3634
|e| Self::Fatal {
37-
error: Cow::Owned(format!("{e}")),
35+
error: e,
3836
context: context.into(),
3937
}
4038
}
@@ -48,7 +46,7 @@ pub trait Worker: Sized + Send + 'static {
4846
fn run(
4947
self,
5048
context: WorkerContext<Self>,
51-
) -> impl Future<Output = Result<(), WorkerQuitReason>> + Send;
49+
) -> impl Future<Output = Result<(), WorkerQuitReason<Self::Error>>> + Send;
5250
fn config(&self) -> WorkerConfig {
5351
WorkerConfig::default()
5452
}
@@ -62,7 +60,7 @@ pub struct WorkerSendRequest<W: Worker> {
6260
pub struct WorkerTransport<W: Worker> {
6361
rx: tokio::sync::mpsc::Receiver<RxJsonRpcMessage<W::Role>>,
6462
send_service: tokio::sync::mpsc::Sender<WorkerSendRequest<W>>,
65-
join_handle: Option<tokio::task::JoinHandle<Result<(), WorkerQuitReason>>>,
63+
join_handle: Option<tokio::task::JoinHandle<Result<(), WorkerQuitReason<W::Error>>>>,
6664
_drop_guard: tokio_util::sync::DropGuard,
6765
ct: CancellationToken,
6866
}
@@ -159,14 +157,16 @@ impl<W: Worker> WorkerContext<W> {
159157
pub async fn send_to_handler(
160158
&mut self,
161159
item: RxJsonRpcMessage<W::Role>,
162-
) -> Result<(), WorkerQuitReason> {
160+
) -> Result<(), WorkerQuitReason<W::Error>> {
163161
self.to_handler_tx
164162
.send(item)
165163
.await
166164
.map_err(|_| WorkerQuitReason::HandlerTerminated)
167165
}
168166

169-
pub async fn recv_from_handler(&mut self) -> Result<WorkerSendRequest<W>, WorkerQuitReason> {
167+
pub async fn recv_from_handler(
168+
&mut self,
169+
) -> Result<WorkerSendRequest<W>, WorkerQuitReason<W::Error>> {
170170
self.from_handler_rx
171171
.recv()
172172
.await

0 commit comments

Comments
 (0)