Skip to content

Commit 910d3b3

Browse files
authored
test(context): test context request handling and refactor for reusable client-server tests (#97)
1 parent 36834f3 commit 910d3b3

File tree

7 files changed

+902
-140
lines changed

7 files changed

+902
-140
lines changed

crates/rmcp/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ tracing-subscriber = { version = "0.3", features = [
8080
"std",
8181
"fmt",
8282
] }
83+
async-trait = "0.1"
8384
[[test]]
8485
name = "test_tool_macros"
8586
required-features = ["server"]
@@ -105,3 +106,8 @@ name = "test_logging"
105106
required-features = ["server", "client"]
106107
path = "tests/test_logging.rs"
107108

109+
[[test]]
110+
name = "test_message_protocol"
111+
required-features = ["client"]
112+
path = "tests/test_message_protocol.rs"
113+

crates/rmcp/src/handler/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
8484
McpError::method_not_found::<CreateMessageRequestMethod>(),
8585
))
8686
}
87+
8788
fn list_roots(
8889
&self,
8990
context: RequestContext<RoleClient>,

crates/rmcp/src/model.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,16 @@ pub struct SamplingMessage {
713713
pub content: Content,
714714
}
715715

716+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
717+
pub enum ContextInclusion {
718+
#[serde(rename = "allServers")]
719+
AllServers,
720+
#[serde(rename = "none")]
721+
None,
722+
#[serde(rename = "thisServer")]
723+
ThisServer,
724+
}
725+
716726
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
717727
#[serde(rename_all = "camelCase")]
718728
pub struct CreateMessageRequestParam {
@@ -722,7 +732,7 @@ pub struct CreateMessageRequestParam {
722732
#[serde(skip_serializing_if = "Option::is_none")]
723733
pub system_prompt: Option<String>,
724734
#[serde(skip_serializing_if = "Option::is_none")]
725-
pub include_context: Option<String>,
735+
pub include_context: Option<ContextInclusion>,
726736
#[serde(skip_serializing_if = "Option::is_none")]
727737
pub temperature: Option<f32>,
728738
pub max_tokens: u32,

crates/rmcp/tests/common/handlers.rs

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
use std::{
2+
future::Future,
3+
sync::{Arc, Mutex},
4+
};
5+
6+
use rmcp::{
7+
ClientHandler, Error as McpError, RoleClient, RoleServer, ServerHandler,
8+
model::*,
9+
service::{Peer, RequestContext},
10+
};
11+
use serde_json::json;
12+
use tokio::sync::Notify;
13+
14+
#[derive(Clone)]
15+
pub struct TestClientHandler {
16+
pub peer: Option<Peer<RoleClient>>,
17+
pub honor_this_server: bool,
18+
pub honor_all_servers: bool,
19+
pub receive_signal: Arc<Notify>,
20+
pub received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
21+
}
22+
23+
impl TestClientHandler {
24+
#[allow(dead_code)]
25+
pub fn new(honor_this_server: bool, honor_all_servers: bool) -> Self {
26+
Self {
27+
peer: None,
28+
honor_this_server,
29+
honor_all_servers,
30+
receive_signal: Arc::new(Notify::new()),
31+
received_messages: Arc::new(Mutex::new(Vec::new())),
32+
}
33+
}
34+
35+
#[allow(dead_code)]
36+
pub fn with_notification(
37+
honor_this_server: bool,
38+
honor_all_servers: bool,
39+
receive_signal: Arc<Notify>,
40+
received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
41+
) -> Self {
42+
Self {
43+
peer: None,
44+
honor_this_server,
45+
honor_all_servers,
46+
receive_signal,
47+
received_messages,
48+
}
49+
}
50+
}
51+
52+
impl ClientHandler for TestClientHandler {
53+
fn get_peer(&self) -> Option<Peer<RoleClient>> {
54+
self.peer.clone()
55+
}
56+
57+
fn set_peer(&mut self, peer: Peer<RoleClient>) {
58+
self.peer = Some(peer);
59+
}
60+
61+
async fn create_message(
62+
&self,
63+
params: CreateMessageRequestParam,
64+
_context: RequestContext<RoleClient>,
65+
) -> Result<CreateMessageResult, McpError> {
66+
// First validate that there's at least one User message
67+
if !params.messages.iter().any(|msg| msg.role == Role::User) {
68+
return Err(McpError::invalid_request(
69+
"Message sequence must contain at least one user message",
70+
Some(json!({"messages": params.messages})),
71+
));
72+
}
73+
74+
// Create response based on context inclusion
75+
let response = match params.include_context {
76+
Some(ContextInclusion::ThisServer) if self.honor_this_server => {
77+
"Test response with context: test context"
78+
}
79+
Some(ContextInclusion::AllServers) if self.honor_all_servers => {
80+
"Test response with context: test context"
81+
}
82+
_ => "Test response without context",
83+
};
84+
85+
Ok(CreateMessageResult {
86+
message: SamplingMessage {
87+
role: Role::Assistant,
88+
content: Content::text(response.to_string()),
89+
},
90+
model: "test-model".to_string(),
91+
stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()),
92+
})
93+
}
94+
95+
fn on_logging_message(
96+
&self,
97+
params: LoggingMessageNotificationParam,
98+
) -> impl Future<Output = ()> + Send + '_ {
99+
let receive_signal = self.receive_signal.clone();
100+
let received_messages = self.received_messages.clone();
101+
102+
async move {
103+
println!("Client: Received log message: {:?}", params);
104+
let mut messages = received_messages.lock().unwrap();
105+
messages.push(params);
106+
receive_signal.notify_one();
107+
}
108+
}
109+
}
110+
111+
pub struct TestServer {}
112+
113+
impl TestServer {
114+
#[allow(dead_code)]
115+
pub fn new() -> Self {
116+
Self {}
117+
}
118+
}
119+
120+
impl ServerHandler for TestServer {
121+
fn get_info(&self) -> ServerInfo {
122+
ServerInfo {
123+
capabilities: ServerCapabilities::builder().enable_logging().build(),
124+
..Default::default()
125+
}
126+
}
127+
128+
fn set_level(
129+
&self,
130+
request: SetLevelRequestParam,
131+
context: RequestContext<RoleServer>,
132+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
133+
let peer = context.peer;
134+
async move {
135+
let (data, logger) = match request.level {
136+
LoggingLevel::Error => (
137+
serde_json::json!({
138+
"message": "Failed to process request",
139+
"error_code": "E1001",
140+
"error_details": "Connection timeout",
141+
"timestamp": chrono::Utc::now().to_rfc3339(),
142+
}),
143+
Some("error_handler".to_string()),
144+
),
145+
LoggingLevel::Debug => (
146+
serde_json::json!({
147+
"message": "Processing request",
148+
"function": "handle_request",
149+
"line": 42,
150+
"context": {
151+
"request_id": "req-123",
152+
"user_id": "user-456"
153+
},
154+
"timestamp": chrono::Utc::now().to_rfc3339(),
155+
}),
156+
Some("debug_logger".to_string()),
157+
),
158+
LoggingLevel::Info => (
159+
serde_json::json!({
160+
"message": "System status update",
161+
"status": "healthy",
162+
"metrics": {
163+
"requests_per_second": 150,
164+
"average_latency_ms": 45,
165+
"error_rate": 0.01
166+
},
167+
"timestamp": chrono::Utc::now().to_rfc3339(),
168+
}),
169+
Some("monitoring".to_string()),
170+
),
171+
_ => (
172+
serde_json::json!({
173+
"message": format!("Message at level {:?}", request.level),
174+
"timestamp": chrono::Utc::now().to_rfc3339(),
175+
}),
176+
None,
177+
),
178+
};
179+
180+
if let Err(e) = peer
181+
.notify_logging_message(LoggingMessageNotificationParam {
182+
level: request.level,
183+
data,
184+
logger,
185+
})
186+
.await
187+
{
188+
panic!("Failed to send notification: {}", e);
189+
}
190+
Ok(())
191+
}
192+
}
193+
}

crates/rmcp/tests/common/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
pub mod calculator;
2+
pub mod handlers;

0 commit comments

Comments
 (0)