Skip to content

Commit c494dbf

Browse files
committed
fix: make stdio shutdown more graceful
According to the protocol specifications Signed-off-by: jokemanfire <[email protected]>
1 parent 209dbac commit c494dbf

File tree

1 file changed

+182
-28
lines changed

1 file changed

+182
-28
lines changed

crates/rmcp/src/transport/child_process.rs

Lines changed: 182 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
use std::process::Stdio;
1+
use std::{process::Stdio, sync::Arc};
22

3+
use futures::future::Future;
34
use process_wrap::tokio::{TokioChildWrapper, TokioCommandWrap};
45
use tokio::{
5-
io::AsyncRead,
6+
io::{AsyncRead, AsyncWriteExt},
67
process::{ChildStderr, ChildStdin, ChildStdout},
8+
sync::Mutex,
79
};
810

9-
use super::{IntoTransport, Transport};
11+
use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage};
1012
use crate::service::ServiceRole;
1113

14+
const MAX_WAIT_ON_DROP_SECS: u64 = 3;
1215
/// The parts of a child process.
1316
type ChildProcessParts = (
1417
Box<dyn TokioChildWrapper>,
@@ -36,18 +39,24 @@ fn child_process(mut child: Box<dyn TokioChildWrapper>) -> std::io::Result<Child
3639

3740
pub struct TokioChildProcess {
3841
child: ChildWithCleanup,
39-
child_stdin: ChildStdin,
42+
child_stdin: Arc<Mutex<ChildStdin>>,
4043
child_stdout: ChildStdout,
4144
}
4245

4346
pub struct ChildWithCleanup {
44-
inner: Box<dyn TokioChildWrapper>,
47+
inner: Option<Box<dyn TokioChildWrapper>>,
4548
}
4649

4750
impl Drop for ChildWithCleanup {
4851
fn drop(&mut self) {
49-
if let Err(e) = self.inner.start_kill() {
50-
tracing::warn!("Failed to kill child process: {e}");
52+
// We should not use start_kill(), instead we should use kill() to avoid zombies
53+
if let Some(mut inner) = self.inner.take() {
54+
// We don't care about the result, just try to kill it
55+
tokio::spawn(async move {
56+
if let Err(e) = Box::into_pin(inner.kill()).await {
57+
tracing::warn!("Error killing child process: {}", e);
58+
}
59+
});
5160
}
5261
}
5362
}
@@ -64,7 +73,7 @@ pin_project_lite::pin_project! {
6473
impl TokioChildProcessOut {
6574
/// Get the process ID of the child process.
6675
pub fn id(&self) -> Option<u32> {
67-
self.child.inner.id()
76+
self.child.inner.as_ref()?.id()
6877
}
6978
}
7079

@@ -92,23 +101,51 @@ impl TokioChildProcess {
92101

93102
/// Get the process ID of the child process.
94103
pub fn id(&self) -> Option<u32> {
95-
self.child.inner.id()
104+
self.child.inner.as_ref()?.id()
105+
}
106+
107+
/// Gracefully shutdown the child process
108+
///
109+
/// This will first wait for the child process to exit normally with a timeout.
110+
/// If the child process doesn't exit within the timeout, it will be killed.
111+
pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> {
112+
if let Some(mut child) = self.child.inner.take() {
113+
let wait_fut = Box::into_pin(child.wait());
114+
tokio::select! {
115+
_ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => {
116+
if let Err(e) = Box::into_pin(child.kill()).await {
117+
tracing::warn!("Error killing child: {e}");
118+
return Err(e);
119+
}
120+
},
121+
res = wait_fut => {
122+
match res {
123+
Ok(status) => {
124+
tracing::info!("Child exited gracefully {}", status);
125+
}
126+
Err(e) => {
127+
tracing::warn!("Error waiting for child: {e}");
128+
return Err(e);
129+
}
130+
}
131+
}
132+
}
133+
}
134+
Ok(())
135+
}
136+
137+
/// Take ownership of the inner child process
138+
pub fn into_inner(mut self) -> Option<Box<dyn TokioChildWrapper>> {
139+
self.child.inner.take()
96140
}
97141

98142
/// Split this helper into a reader (stdout) and writer (stdin).
143+
#[deprecated(
144+
since = "0.1.0",
145+
note = "use the Transport trait implementation instead"
146+
)]
99147
pub fn split(self) -> (TokioChildProcessOut, ChildStdin) {
100-
let TokioChildProcess {
101-
child,
102-
child_stdin,
103-
child_stdout,
104-
} = self;
105-
(
106-
TokioChildProcessOut {
107-
child,
108-
child_stdout,
109-
},
110-
child_stdin,
111-
)
148+
unimplemented!("This method is deprecated, use the Transport trait implementation instead");
112149
}
113150
}
114151

@@ -157,19 +194,61 @@ impl TokioChildProcessBuilder {
157194
let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?;
158195

159196
let proc = TokioChildProcess {
160-
child: ChildWithCleanup { inner: child },
161-
child_stdin: stdin,
197+
child: ChildWithCleanup { inner: Some(child) },
198+
child_stdin: Arc::new(Mutex::new(stdin)),
162199
child_stdout: stdout,
163200
};
164201
Ok((proc, stderr_opt))
165202
}
166203
}
167204

168-
impl<R: ServiceRole> IntoTransport<R, std::io::Error, ()> for TokioChildProcess {
169-
fn into_transport(self) -> impl Transport<R, Error = std::io::Error> + 'static {
170-
IntoTransport::<R, std::io::Error, super::async_rw::TransportAdapterAsyncRW>::into_transport(
171-
self.split(),
172-
)
205+
impl<R: ServiceRole> Transport<R> for TokioChildProcess {
206+
type Error = std::io::Error;
207+
208+
fn send(
209+
&mut self,
210+
item: TxJsonRpcMessage<R>,
211+
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
212+
let json = serde_json::to_string(&item).unwrap();
213+
let child_stdin = Arc::clone(&self.child_stdin);
214+
215+
async move {
216+
let mut child_stdin = child_stdin.lock().await;
217+
let serialized = format!("{}\n", json);
218+
child_stdin.write_all(serialized.as_bytes()).await?;
219+
child_stdin.flush().await?;
220+
Ok(())
221+
}
222+
}
223+
224+
fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<R>>> + Send {
225+
use tokio::io::{AsyncBufReadExt, BufReader};
226+
227+
// Create a new BufReader for each call to receive
228+
let stdout = &mut self.child_stdout;
229+
let mut buf_reader = BufReader::new(stdout);
230+
let mut line = String::new();
231+
232+
async move {
233+
match buf_reader.read_line(&mut line).await {
234+
Ok(0) => None, // EOF
235+
Ok(_) => match serde_json::from_str::<RxJsonRpcMessage<R>>(&line) {
236+
Ok(msg) => Some(msg),
237+
Err(e) => {
238+
tracing::error!("Failed to deserialize message: {}", e);
239+
None
240+
}
241+
},
242+
Err(e) => {
243+
tracing::error!("Error reading from child process: {}", e);
244+
None
245+
}
246+
}
247+
}
248+
}
249+
250+
fn close(&mut self) -> impl Future<Output = Result<(), Self::Error>> + Send {
251+
self.graceful_shutdown()
173252
}
174253
}
175254

@@ -183,3 +262,78 @@ impl ConfigureCommandExt for tokio::process::Command {
183262
self
184263
}
185264
}
265+
266+
#[cfg(unix)]
267+
#[cfg(test)]
268+
mod tests {
269+
use tokio::process::Command;
270+
271+
use super::*;
272+
273+
#[tokio::test]
274+
async fn test_tokio_child_process_drop() {
275+
let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
276+
cmd.arg("30");
277+
}));
278+
assert!(r.is_ok());
279+
let child_process = r.unwrap();
280+
let id = child_process.id();
281+
assert!(id.is_some());
282+
let id = id.unwrap();
283+
// Drop the child process
284+
drop(child_process);
285+
// Wait a moment to allow the cleanup task to run
286+
tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
287+
// Check if the process is still running
288+
let status = Command::new("ps")
289+
.arg("-p")
290+
.arg(id.to_string())
291+
.status()
292+
.await;
293+
match status {
294+
Ok(status) => {
295+
assert!(
296+
!status.success(),
297+
"Process with PID {} is still running",
298+
id
299+
);
300+
}
301+
Err(e) => {
302+
panic!("Failed to check process status: {}", e);
303+
}
304+
}
305+
}
306+
307+
#[tokio::test]
308+
async fn test_tokio_child_process_graceful_shutdown() {
309+
let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
310+
cmd.arg("30");
311+
}));
312+
assert!(r.is_ok());
313+
let mut child_process = r.unwrap();
314+
let id = child_process.id();
315+
assert!(id.is_some());
316+
let id = id.unwrap();
317+
child_process.graceful_shutdown().await.unwrap();
318+
// Wait a moment to allow the cleanup task to run
319+
tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
320+
// Check if the process is still running
321+
let status = Command::new("ps")
322+
.arg("-p")
323+
.arg(id.to_string())
324+
.status()
325+
.await;
326+
match status {
327+
Ok(status) => {
328+
assert!(
329+
!status.success(),
330+
"Process with PID {} is still running",
331+
id
332+
);
333+
}
334+
Err(e) => {
335+
panic!("Failed to check process status: {}", e);
336+
}
337+
}
338+
}
339+
}

0 commit comments

Comments
 (0)