Skip to content

fix: make stdio shutdown more graceful #364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 182 additions & 28 deletions crates/rmcp/src/transport/child_process.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::process::Stdio;
use std::{process::Stdio, sync::Arc};

use futures::future::Future;
use process_wrap::tokio::{TokioChildWrapper, TokioCommandWrap};
use tokio::{
io::AsyncRead,
io::{AsyncRead, AsyncWriteExt},
process::{ChildStderr, ChildStdin, ChildStdout},
sync::Mutex,
};

use super::{IntoTransport, Transport};
use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage};
use crate::service::ServiceRole;

const MAX_WAIT_ON_DROP_SECS: u64 = 3;
/// The parts of a child process.
type ChildProcessParts = (
Box<dyn TokioChildWrapper>,
Expand Down Expand Up @@ -36,18 +39,24 @@ fn child_process(mut child: Box<dyn TokioChildWrapper>) -> std::io::Result<Child

pub struct TokioChildProcess {
child: ChildWithCleanup,
child_stdin: ChildStdin,
child_stdin: Arc<Mutex<ChildStdin>>,
child_stdout: ChildStdout,
}

pub struct ChildWithCleanup {
inner: Box<dyn TokioChildWrapper>,
inner: Option<Box<dyn TokioChildWrapper>>,
}

impl Drop for ChildWithCleanup {
fn drop(&mut self) {
if let Err(e) = self.inner.start_kill() {
tracing::warn!("Failed to kill child process: {e}");
// We should not use start_kill(), instead we should use kill() to avoid zombies
if let Some(mut inner) = self.inner.take() {
// We don't care about the result, just try to kill it
tokio::spawn(async move {
if let Err(e) = Box::into_pin(inner.kill()).await {
tracing::warn!("Error killing child process: {}", e);
}
});
}
}
}
Expand All @@ -64,7 +73,7 @@ pin_project_lite::pin_project! {
impl TokioChildProcessOut {
/// Get the process ID of the child process.
pub fn id(&self) -> Option<u32> {
self.child.inner.id()
self.child.inner.as_ref()?.id()
}
}

Expand Down Expand Up @@ -92,23 +101,51 @@ impl TokioChildProcess {

/// Get the process ID of the child process.
pub fn id(&self) -> Option<u32> {
self.child.inner.id()
self.child.inner.as_ref()?.id()
}

/// Gracefully shutdown the child process
///
/// This will first wait for the child process to exit normally with a timeout.
/// If the child process doesn't exit within the timeout, it will be killed.
pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> {
if let Some(mut child) = self.child.inner.take() {
let wait_fut = Box::into_pin(child.wait());
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => {
if let Err(e) = Box::into_pin(child.kill()).await {
tracing::warn!("Error killing child: {e}");
return Err(e);
}
},
res = wait_fut => {
match res {
Ok(status) => {
tracing::info!("Child exited gracefully {}", status);
}
Err(e) => {
tracing::warn!("Error waiting for child: {e}");
return Err(e);
}
}
}
}
}
Ok(())
}

/// Take ownership of the inner child process
pub fn into_inner(mut self) -> Option<Box<dyn TokioChildWrapper>> {
self.child.inner.take()
}

/// Split this helper into a reader (stdout) and writer (stdin).
#[deprecated(
since = "0.5.0",
note = "use the Transport trait implementation instead"
)]
pub fn split(self) -> (TokioChildProcessOut, ChildStdin) {
let TokioChildProcess {
child,
child_stdin,
child_stdout,
} = self;
(
TokioChildProcessOut {
child,
child_stdout,
},
child_stdin,
)
unimplemented!("This method is deprecated, use the Transport trait implementation instead");
}
}

Expand Down Expand Up @@ -157,19 +194,61 @@ impl TokioChildProcessBuilder {
let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?;

let proc = TokioChildProcess {
child: ChildWithCleanup { inner: child },
child_stdin: stdin,
child: ChildWithCleanup { inner: Some(child) },
child_stdin: Arc::new(Mutex::new(stdin)),
child_stdout: stdout,
};
Ok((proc, stderr_opt))
}
}

impl<R: ServiceRole> IntoTransport<R, std::io::Error, ()> for TokioChildProcess {
fn into_transport(self) -> impl Transport<R, Error = std::io::Error> + 'static {
IntoTransport::<R, std::io::Error, super::async_rw::TransportAdapterAsyncRW>::into_transport(
self.split(),
)
impl<R: ServiceRole> Transport<R> for TokioChildProcess {
type Error = std::io::Error;

fn send(
&mut self,
item: TxJsonRpcMessage<R>,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
let json = serde_json::to_string(&item).unwrap();
let child_stdin = Arc::clone(&self.child_stdin);

async move {
let mut child_stdin = child_stdin.lock().await;
let serialized = format!("{}\n", json);
child_stdin.write_all(serialized.as_bytes()).await?;
child_stdin.flush().await?;
Ok(())
}
}

fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<R>>> + Send {
use tokio::io::{AsyncBufReadExt, BufReader};

// Create a new BufReader for each call to receive
let stdout = &mut self.child_stdout;
let mut buf_reader = BufReader::new(stdout);
let mut line = String::new();

async move {
match buf_reader.read_line(&mut line).await {
Ok(0) => None, // EOF
Ok(_) => match serde_json::from_str::<RxJsonRpcMessage<R>>(&line) {
Ok(msg) => Some(msg),
Err(e) => {
tracing::error!("Failed to deserialize message: {}", e);
None
}
},
Err(e) => {
tracing::error!("Error reading from child process: {}", e);
None
}
}
}
}

fn close(&mut self) -> impl Future<Output = Result<(), Self::Error>> + Send {
self.graceful_shutdown()
}
}

Expand All @@ -183,3 +262,78 @@ impl ConfigureCommandExt for tokio::process::Command {
self
}
}

#[cfg(unix)]
#[cfg(test)]
mod tests {
use tokio::process::Command;

use super::*;

#[tokio::test]
async fn test_tokio_child_process_drop() {
let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
cmd.arg("30");
}));
assert!(r.is_ok());
let child_process = r.unwrap();
let id = child_process.id();
assert!(id.is_some());
let id = id.unwrap();
// Drop the child process
drop(child_process);
// Wait a moment to allow the cleanup task to run
tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
// Check if the process is still running
let status = Command::new("ps")
.arg("-p")
.arg(id.to_string())
.status()
.await;
match status {
Ok(status) => {
assert!(
!status.success(),
"Process with PID {} is still running",
id
);
}
Err(e) => {
panic!("Failed to check process status: {}", e);
}
}
}

#[tokio::test]
async fn test_tokio_child_process_graceful_shutdown() {
let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
cmd.arg("30");
}));
assert!(r.is_ok());
let mut child_process = r.unwrap();
let id = child_process.id();
assert!(id.is_some());
let id = id.unwrap();
child_process.graceful_shutdown().await.unwrap();
// Wait a moment to allow the cleanup task to run
tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
// Check if the process is still running
let status = Command::new("ps")
.arg("-p")
.arg(id.to_string())
.status()
.await;
match status {
Ok(status) => {
assert!(
!status.success(),
"Process with PID {} is still running",
id
);
}
Err(e) => {
panic!("Failed to check process status: {}", e);
}
}
}
}