diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 75cdc8fa..e6991f38 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -8,7 +8,7 @@ use crate::{ ArgumentInfo, CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult, - CompletionContext, CompletionInfo, GetPromptRequest, GetPromptRequestParam, + CompletionContext, CompletionInfo, ErrorData, GetPromptRequest, GetPromptRequestParam, GetPromptResult, InitializeRequest, InitializedNotification, JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, @@ -44,6 +44,9 @@ pub enum ClientInitializeError { context: Cow<'static, str>, }, + #[error("JSON-RPC error: {0}")] + JsonRpcError(ErrorData), + #[error("Cancelled")] Cancelled, } @@ -92,6 +95,10 @@ where ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => { break Ok((result, id)); } + // Handle JSON-RPC error responses + ServerJsonRpcMessage::Error(error) => { + break Err(ClientInitializeError::JsonRpcError(error.error)); + } // Server could send logging messages before handshake ServerJsonRpcMessage::Notification(mut notification) => { let ServerNotification::LoggingMessageNotification(logging) = diff --git a/crates/rmcp/tests/test_client_initialization.rs b/crates/rmcp/tests/test_client_initialization.rs new file mode 100644 index 00000000..0080a620 --- /dev/null +++ b/crates/rmcp/tests/test_client_initialization.rs @@ -0,0 +1,50 @@ +// cargo test --features "server client" --package rmcp test_client_initialization +mod common; + +use common::handlers::TestClientHandler; +use rmcp::{ + ServiceExt, + model::{ + ErrorCode, ErrorData, JsonRpcError, JsonRpcVersion2_0, RequestId, ServerJsonRpcMessage, + }, + transport::{IntoTransport, Transport}, +}; +use std::borrow::Cow; + +#[tokio::test] +async fn test_client_init_handles_jsonrpc_error() { + let (server_transport, client_transport) = tokio::io::duplex(1024); + let mut server = IntoTransport::::into_transport(server_transport); + + let client_handle = tokio::spawn(async move { + TestClientHandler::new(true, true) + .serve(client_transport) + .await + }); + + tokio::spawn(async move { + let _init_request = server.receive().await; + + let error_msg = ServerJsonRpcMessage::Error(JsonRpcError { + jsonrpc: JsonRpcVersion2_0, + id: RequestId::Number(1), + error: ErrorData { + code: ErrorCode(-32600), + message: Cow::Borrowed("Invalid Request"), + data: None, + }, + }); + let _: Result<(), _> = server.send(error_msg).await; + }); + + let result = client_handle.await.unwrap(); + + assert!(result.is_err()); + match result { + Err(rmcp::service::ClientInitializeError::JsonRpcError(error_data)) => { + assert_eq!(error_data.code, ErrorCode(-32600)); + assert_eq!(error_data.message, "Invalid Request"); + } + _ => panic!("Expected ClientInitializeError::JsonRpcError"), + } +}