diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index 2e7c034f..21435959 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -1,14 +1,17 @@ -use std::process::Stdio; +use std::{process::Stdio, sync::Arc}; +use futures::future::Future; use process_wrap::tokio::{TokioChildWrapper, TokioCommandWrap}; use tokio::{ - io::AsyncRead, + io::{AsyncRead, AsyncWriteExt}, process::{ChildStderr, ChildStdin, ChildStdout}, + sync::Mutex, }; -use super::{IntoTransport, Transport}; +use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage}; use crate::service::ServiceRole; +const MAX_WAIT_ON_DROP_SECS: u64 = 3; /// The parts of a child process. type ChildProcessParts = ( Box, @@ -36,18 +39,24 @@ fn child_process(mut child: Box) -> std::io::Result>, child_stdout: ChildStdout, } pub struct ChildWithCleanup { - inner: Box, + inner: Option>, } impl Drop for ChildWithCleanup { fn drop(&mut self) { - if let Err(e) = self.inner.start_kill() { - tracing::warn!("Failed to kill child process: {e}"); + // We should not use start_kill(), instead we should use kill() to avoid zombies + if let Some(mut inner) = self.inner.take() { + // We don't care about the result, just try to kill it + tokio::spawn(async move { + if let Err(e) = Box::into_pin(inner.kill()).await { + tracing::warn!("Error killing child process: {}", e); + } + }); } } } @@ -64,7 +73,7 @@ pin_project_lite::pin_project! { impl TokioChildProcessOut { /// Get the process ID of the child process. pub fn id(&self) -> Option { - self.child.inner.id() + self.child.inner.as_ref()?.id() } } @@ -92,23 +101,51 @@ impl TokioChildProcess { /// Get the process ID of the child process. pub fn id(&self) -> Option { - self.child.inner.id() + self.child.inner.as_ref()?.id() + } + + /// Gracefully shutdown the child process + /// + /// This will first wait for the child process to exit normally with a timeout. + /// If the child process doesn't exit within the timeout, it will be killed. + pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { + if let Some(mut child) = self.child.inner.take() { + let wait_fut = Box::into_pin(child.wait()); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => { + if let Err(e) = Box::into_pin(child.kill()).await { + tracing::warn!("Error killing child: {e}"); + return Err(e); + } + }, + res = wait_fut => { + match res { + Ok(status) => { + tracing::info!("Child exited gracefully {}", status); + } + Err(e) => { + tracing::warn!("Error waiting for child: {e}"); + return Err(e); + } + } + } + } + } + Ok(()) + } + + /// Take ownership of the inner child process + pub fn into_inner(mut self) -> Option> { + self.child.inner.take() } /// Split this helper into a reader (stdout) and writer (stdin). + #[deprecated( + since = "0.5.0", + note = "use the Transport trait implementation instead" + )] pub fn split(self) -> (TokioChildProcessOut, ChildStdin) { - let TokioChildProcess { - child, - child_stdin, - child_stdout, - } = self; - ( - TokioChildProcessOut { - child, - child_stdout, - }, - child_stdin, - ) + unimplemented!("This method is deprecated, use the Transport trait implementation instead"); } } @@ -157,19 +194,61 @@ impl TokioChildProcessBuilder { let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?; let proc = TokioChildProcess { - child: ChildWithCleanup { inner: child }, - child_stdin: stdin, + child: ChildWithCleanup { inner: Some(child) }, + child_stdin: Arc::new(Mutex::new(stdin)), child_stdout: stdout, }; Ok((proc, stderr_opt)) } } -impl IntoTransport for TokioChildProcess { - fn into_transport(self) -> impl Transport + 'static { - IntoTransport::::into_transport( - self.split(), - ) +impl Transport for TokioChildProcess { + type Error = std::io::Error; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + let json = serde_json::to_string(&item).unwrap(); + let child_stdin = Arc::clone(&self.child_stdin); + + async move { + let mut child_stdin = child_stdin.lock().await; + let serialized = format!("{}\n", json); + child_stdin.write_all(serialized.as_bytes()).await?; + child_stdin.flush().await?; + Ok(()) + } + } + + fn receive(&mut self) -> impl Future>> + Send { + use tokio::io::{AsyncBufReadExt, BufReader}; + + // Create a new BufReader for each call to receive + let stdout = &mut self.child_stdout; + let mut buf_reader = BufReader::new(stdout); + let mut line = String::new(); + + async move { + match buf_reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => match serde_json::from_str::>(&line) { + Ok(msg) => Some(msg), + Err(e) => { + tracing::error!("Failed to deserialize message: {}", e); + None + } + }, + Err(e) => { + tracing::error!("Error reading from child process: {}", e); + None + } + } + } + } + + fn close(&mut self) -> impl Future> + Send { + self.graceful_shutdown() } } @@ -183,3 +262,78 @@ impl ConfigureCommandExt for tokio::process::Command { self } } + +#[cfg(unix)] +#[cfg(test)] +mod tests { + use tokio::process::Command; + + use super::*; + + #[tokio::test] + async fn test_tokio_child_process_drop() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + // Drop the child process + drop(child_process); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } + + #[tokio::test] + async fn test_tokio_child_process_graceful_shutdown() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let mut child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + child_process.graceful_shutdown().await.unwrap(); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } +}