Skip to content

Commit 6497937

Browse files
committed
adds safe guard against premature dropping of clients from its clones
1 parent 7dfa998 commit 6497937

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

crates/mcp_client/src/client.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ pub struct Client<T: Transport> {
8181
server_name: String,
8282
transport: Arc<T>,
8383
timeout: u64,
84-
server_process_id: Pid,
84+
server_process_id: Option<Pid>,
8585
init_params: serde_json::Value,
8686
current_id: Arc<AtomicU64>,
8787
prompts: Arc<SyncRwLock<HashMap<String, Prompt>>>,
@@ -93,7 +93,9 @@ impl<T: Transport> Clone for Client<T> {
9393
server_name: self.server_name.clone(),
9494
transport: self.transport.clone(),
9595
timeout: self.timeout,
96-
server_process_id: self.server_process_id,
96+
// Note that we cannot have an id for the clone because we would kill the original
97+
// process when we drop the clone
98+
server_process_id: None,
9799
init_params: self.init_params.clone(),
98100
current_id: self.current_id.clone(),
99101
prompts: self.prompts.clone(),
@@ -133,6 +135,7 @@ impl Client<StdioTransport> {
133135
.try_into()
134136
.map_err(|_| ClientError::MissingProcessId)?,
135137
);
138+
let server_process_id = Some(server_process_id);
136139
let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?);
137140
Ok(Self {
138141
server_name,
@@ -150,8 +153,12 @@ impl<T> Drop for Client<T>
150153
where
151154
T: Transport,
152155
{
156+
// IF the servers are implemented well, they will shutdown once the pipe closes.
157+
// This drop trait is here as a fail safe to ensure we don't leave behind any orphans.
153158
fn drop(&mut self) {
154-
let _ = nix::sys::signal::kill(self.server_process_id, Signal::SIGTERM);
159+
if let Some(process_id) = self.server_process_id {
160+
let _ = nix::sys::signal::kill(process_id, Signal::SIGTERM);
161+
}
155162
}
156163
}
157164

@@ -188,7 +195,6 @@ where
188195

189196
let server_capabilities = self.request("initialize", Some(self.init_params.clone())).await?;
190197
if let Err(e) = examine_server_capabilities(&server_capabilities) {
191-
let _ = nix::sys::signal::kill(self.server_process_id, Signal::SIGTERM);
192198
return Err(ClientError::NegotiationError(format!(
193199
"Client {} has failed to negotiate server capabilities with server: {:?}",
194200
self.server_name, e

0 commit comments

Comments
 (0)