Skip to content

Commit 257e263

Browse files
authored
fix: test for mcp client after move (#1731)
1 parent a0ea994 commit 257e263

File tree

5 files changed

+366
-1
lines changed

5 files changed

+366
-1
lines changed

crates/chat-cli/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ workspace = true
1414
default = []
1515
wayland = ["arboard/wayland-data-control"]
1616

17+
[[bin]]
18+
name = "test_mcp_server"
19+
path = "test_mcp_server/test_server.rs"
20+
test = true
21+
doc = false
22+
1723
[dependencies]
1824
amzn-codewhisperer-client = { path = "../amzn-codewhisperer-client" }
1925
amzn-codewhisperer-streaming-client = { path = "../amzn-codewhisperer-streaming-client" }

crates/chat-cli/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
//! This lib.rs is only here for testing purposes.
2+
//! `test_mcp_server/test_server.rs` is declared as a separate binary and would need a way to
3+
//! reference types defined inside of this crate, hence the export.
4+
pub mod mcp_client;
5+
6+
pub use mcp_client::*;

crates/chat-cli/src/mcp_client/client.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,6 @@ mod tests {
540540
PathBuf::from(workspace_root)
541541
}
542542

543-
#[ignore = "TODO: support test binary"]
544543
#[tokio::test(flavor = "multi_thread")]
545544
async fn test_client_stdio() {
546545
std::process::Command::new("cargo")

crates/chat-cli/src/mcp_client/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ mod transport;
77

88
pub use client::*;
99
pub use facilitator_types::*;
10+
#[allow(unused_imports)]
11+
pub use server::*;
1012
pub use transport::*;
1113

1214
/// Error codes as defined in the MCP protocol.
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
//! This is a bin used solely for testing the client
2+
use std::collections::HashMap;
3+
use std::str::FromStr;
4+
use std::sync::atomic::{
5+
AtomicU8,
6+
Ordering,
7+
};
8+
9+
use chat_cli::{
10+
self,
11+
JsonRpcRequest,
12+
JsonRpcResponse,
13+
JsonRpcStdioTransport,
14+
PreServerRequestHandler,
15+
Response,
16+
Server,
17+
ServerError,
18+
ServerRequestHandler,
19+
};
20+
use tokio::sync::Mutex;
21+
22+
#[derive(Default)]
23+
struct Handler {
24+
pending_request: Option<Box<dyn Fn(u64) -> Option<JsonRpcRequest> + Send + Sync>>,
25+
#[allow(clippy::type_complexity)]
26+
send_request: Option<Box<dyn Fn(&str, Option<serde_json::Value>) -> Result<(), ServerError> + Send + Sync>>,
27+
storage: Mutex<HashMap<String, serde_json::Value>>,
28+
tool_spec: Mutex<HashMap<String, Response>>,
29+
tool_spec_key_list: Mutex<Vec<String>>,
30+
prompts: Mutex<HashMap<String, Response>>,
31+
prompt_key_list: Mutex<Vec<String>>,
32+
prompt_list_call_no: AtomicU8,
33+
}
34+
35+
impl PreServerRequestHandler for Handler {
36+
fn register_pending_request_callback(
37+
&mut self,
38+
cb: impl Fn(u64) -> Option<JsonRpcRequest> + Send + Sync + 'static,
39+
) {
40+
self.pending_request = Some(Box::new(cb));
41+
}
42+
43+
fn register_send_request_callback(
44+
&mut self,
45+
cb: impl Fn(&str, Option<serde_json::Value>) -> Result<(), ServerError> + Send + Sync + 'static,
46+
) {
47+
self.send_request = Some(Box::new(cb));
48+
}
49+
}
50+
51+
#[async_trait::async_trait]
52+
impl ServerRequestHandler for Handler {
53+
async fn handle_initialize(&self, params: Option<serde_json::Value>) -> Result<Response, ServerError> {
54+
let mut storage = self.storage.lock().await;
55+
if let Some(params) = params {
56+
storage.insert("client_cap".to_owned(), params);
57+
}
58+
let capabilities = serde_json::json!({
59+
"protocolVersion": "2024-11-05",
60+
"capabilities": {
61+
"logging": {},
62+
"prompts": {
63+
"listChanged": true
64+
},
65+
"resources": {
66+
"subscribe": true,
67+
"listChanged": true
68+
},
69+
"tools": {
70+
"listChanged": true
71+
}
72+
},
73+
"serverInfo": {
74+
"name": "TestServer",
75+
"version": "1.0.0"
76+
}
77+
});
78+
Ok(Some(capabilities))
79+
}
80+
81+
async fn handle_incoming(&self, method: &str, params: Option<serde_json::Value>) -> Result<Response, ServerError> {
82+
match method {
83+
"notifications/initialized" => {
84+
{
85+
let mut storage = self.storage.lock().await;
86+
storage.insert(
87+
"init_ack_sent".to_owned(),
88+
serde_json::Value::from_str("true").expect("Failed to convert string to value"),
89+
);
90+
}
91+
Ok(None)
92+
},
93+
"verify_init_params_sent" => {
94+
let client_capabilities = {
95+
let storage = self.storage.lock().await;
96+
storage.get("client_cap").cloned()
97+
};
98+
Ok(client_capabilities)
99+
},
100+
"verify_init_ack_sent" => {
101+
let result = {
102+
let storage = self.storage.lock().await;
103+
storage.get("init_ack_sent").cloned()
104+
};
105+
Ok(result)
106+
},
107+
"store_mock_tool_spec" => {
108+
let Some(params) = params else {
109+
eprintln!("Params missing from store mock tool spec");
110+
return Ok(None);
111+
};
112+
// expecting a mock_specs: { key: String, value: serde_json::Value }[];
113+
let Ok(mock_specs) = serde_json::from_value::<Vec<serde_json::Value>>(params) else {
114+
eprintln!("Failed to convert to mock specs from value");
115+
return Ok(None);
116+
};
117+
let self_tool_specs = self.tool_spec.lock().await;
118+
let mut self_tool_spec_key_list = self.tool_spec_key_list.lock().await;
119+
let _ = mock_specs.iter().fold(self_tool_specs, |mut acc, spec| {
120+
let Some(key) = spec.get("key").cloned() else {
121+
return acc;
122+
};
123+
let Ok(key) = serde_json::from_value::<String>(key) else {
124+
eprintln!("Failed to convert serde value to string for key");
125+
return acc;
126+
};
127+
self_tool_spec_key_list.push(key.clone());
128+
acc.insert(key, spec.get("value").cloned());
129+
acc
130+
});
131+
Ok(None)
132+
},
133+
"tools/list" => {
134+
if let Some(params) = params {
135+
if let Some(cursor) = params.get("cursor").cloned() {
136+
let Ok(cursor) = serde_json::from_value::<String>(cursor) else {
137+
eprintln!("Failed to convert cursor to string: {:#?}", params);
138+
return Ok(None);
139+
};
140+
let self_tool_spec_key_list = self.tool_spec_key_list.lock().await;
141+
let self_tool_spec = self.tool_spec.lock().await;
142+
let (next_cursor, spec) = {
143+
'blk: {
144+
for (i, item) in self_tool_spec_key_list.iter().enumerate() {
145+
if item == &cursor {
146+
break 'blk (
147+
self_tool_spec_key_list.get(i + 1).cloned(),
148+
self_tool_spec.get(&cursor).cloned().unwrap(),
149+
);
150+
}
151+
}
152+
(None, None)
153+
}
154+
};
155+
if let Some(next_cursor) = next_cursor {
156+
return Ok(Some(serde_json::json!({
157+
"tools": [spec.unwrap()],
158+
"nextCursor": next_cursor,
159+
})));
160+
} else {
161+
return Ok(Some(serde_json::json!({
162+
"tools": [spec.unwrap()],
163+
})));
164+
}
165+
} else {
166+
eprintln!("Params exist but cursor is missing");
167+
return Ok(None);
168+
}
169+
} else {
170+
let first_key = self
171+
.tool_spec_key_list
172+
.lock()
173+
.await
174+
.first()
175+
.expect("First key missing from tool specs")
176+
.clone();
177+
let first_value = self
178+
.tool_spec
179+
.lock()
180+
.await
181+
.get(&first_key)
182+
.expect("First value missing from tool specs")
183+
.clone();
184+
let second_key = self
185+
.tool_spec_key_list
186+
.lock()
187+
.await
188+
.get(1)
189+
.expect("Second key missing from tool specs")
190+
.clone();
191+
return Ok(Some(serde_json::json!({
192+
"tools": [first_value],
193+
"nextCursor": second_key
194+
})));
195+
};
196+
},
197+
"get_env_vars" => {
198+
let kv = std::env::vars().fold(HashMap::<String, String>::new(), |mut acc, (k, v)| {
199+
acc.insert(k, v);
200+
acc
201+
});
202+
Ok(Some(serde_json::json!(kv)))
203+
},
204+
// This is a test path relevant only to sampling
205+
"trigger_server_request" => {
206+
let Some(ref send_request) = self.send_request else {
207+
return Err(ServerError::MissingMethod);
208+
};
209+
let params = Some(serde_json::json!({
210+
"messages": [
211+
{
212+
"role": "user",
213+
"content": {
214+
"type": "text",
215+
"text": "What is the capital of France?"
216+
}
217+
}
218+
],
219+
"modelPreferences": {
220+
"hints": [
221+
{
222+
"name": "claude-3-sonnet"
223+
}
224+
],
225+
"intelligencePriority": 0.8,
226+
"speedPriority": 0.5
227+
},
228+
"systemPrompt": "You are a helpful assistant.",
229+
"maxTokens": 100
230+
}));
231+
send_request("sampling/createMessage", params)?;
232+
Ok(None)
233+
},
234+
"store_mock_prompts" => {
235+
let Some(params) = params else {
236+
eprintln!("Params missing from store mock prompts");
237+
return Ok(None);
238+
};
239+
// expecting a mock_prompts: { key: String, value: serde_json::Value }[];
240+
let Ok(mock_prompts) = serde_json::from_value::<Vec<serde_json::Value>>(params) else {
241+
eprintln!("Failed to convert to mock specs from value");
242+
return Ok(None);
243+
};
244+
let self_prompts = self.prompts.lock().await;
245+
let mut self_prompt_key_list = self.prompt_key_list.lock().await;
246+
let _ = mock_prompts.iter().fold(self_prompts, |mut acc, spec| {
247+
let Some(key) = spec.get("key").cloned() else {
248+
return acc;
249+
};
250+
let Ok(key) = serde_json::from_value::<String>(key) else {
251+
eprintln!("Failed to convert serde value to string for key");
252+
return acc;
253+
};
254+
self_prompt_key_list.push(key.clone());
255+
acc.insert(key, spec.get("value").cloned());
256+
acc
257+
});
258+
Ok(None)
259+
},
260+
"prompts/list" => {
261+
self.prompt_list_call_no.fetch_add(1, Ordering::Relaxed);
262+
if let Some(params) = params {
263+
if let Some(cursor) = params.get("cursor").cloned() {
264+
let Ok(cursor) = serde_json::from_value::<String>(cursor) else {
265+
eprintln!("Failed to convert cursor to string: {:#?}", params);
266+
return Ok(None);
267+
};
268+
let self_prompt_key_list = self.prompt_key_list.lock().await;
269+
let self_prompts = self.prompts.lock().await;
270+
let (next_cursor, spec) = {
271+
'blk: {
272+
for (i, item) in self_prompt_key_list.iter().enumerate() {
273+
if item == &cursor {
274+
break 'blk (
275+
self_prompt_key_list.get(i + 1).cloned(),
276+
self_prompts.get(&cursor).cloned().unwrap(),
277+
);
278+
}
279+
}
280+
(None, None)
281+
}
282+
};
283+
if let Some(next_cursor) = next_cursor {
284+
return Ok(Some(serde_json::json!({
285+
"prompts": [spec.unwrap()],
286+
"nextCursor": next_cursor,
287+
})));
288+
} else {
289+
return Ok(Some(serde_json::json!({
290+
"prompts": [spec.unwrap()],
291+
})));
292+
}
293+
} else {
294+
eprintln!("Params exist but cursor is missing");
295+
return Ok(None);
296+
}
297+
} else {
298+
let first_key = self
299+
.prompt_key_list
300+
.lock()
301+
.await
302+
.first()
303+
.expect("First key missing from prompts")
304+
.clone();
305+
let first_value = self
306+
.prompts
307+
.lock()
308+
.await
309+
.get(&first_key)
310+
.expect("First value missing from prompts")
311+
.clone();
312+
let second_key = self
313+
.prompt_key_list
314+
.lock()
315+
.await
316+
.get(1)
317+
.expect("Second key missing from prompts")
318+
.clone();
319+
return Ok(Some(serde_json::json!({
320+
"prompts": [first_value],
321+
"nextCursor": second_key
322+
})));
323+
};
324+
},
325+
"get_prompt_list_call_no" => Ok(Some(
326+
serde_json::to_value::<u8>(self.prompt_list_call_no.load(Ordering::Relaxed))
327+
.expect("Failed to convert list call no to u8"),
328+
)),
329+
_ => Err(ServerError::MissingMethod),
330+
}
331+
}
332+
333+
// This is a test path relevant only to sampling
334+
async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError> {
335+
let JsonRpcResponse { id, .. } = resp;
336+
let _pending = self.pending_request.as_ref().and_then(|f| f(id));
337+
Ok(())
338+
}
339+
340+
async fn handle_shutdown(&self) -> Result<(), ServerError> {
341+
Ok(())
342+
}
343+
}
344+
345+
#[tokio::main]
346+
async fn main() {
347+
let handler = Handler::default();
348+
let stdin = tokio::io::stdin();
349+
let stdout = tokio::io::stdout();
350+
let test_server = Server::<JsonRpcStdioTransport, _>::new(handler, stdin, stdout).expect("Failed to create server");
351+
let _ = test_server.init().expect("Test server failed to init").await;
352+
}

0 commit comments

Comments
 (0)