diff --git a/concurrency/src/error.rs b/concurrency/src/error.rs index 9faa0c4..c1a37db 100644 --- a/concurrency/src/error.rs +++ b/concurrency/src/error.rs @@ -10,6 +10,8 @@ pub enum GenServerError { CallMsgUnused, #[error("Unsupported Cast Messages on this GenServer")] CastMsgUnused, + #[error("Call to GenServer timed out")] + CallTimeout, } impl From> for GenServerError { diff --git a/concurrency/src/tasks/gen_server.rs b/concurrency/src/tasks/gen_server.rs index f04bb79..fd03f00 100644 --- a/concurrency/src/tasks/gen_server.rs +++ b/concurrency/src/tasks/gen_server.rs @@ -1,11 +1,13 @@ //! GenServer trait and structs to create an abstraction similar to Erlang gen_server. //! See examples/name_server for a usage example. use futures::future::FutureExt as _; -use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken}; -use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe}; +use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken}; +use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration}; use crate::error::GenServerError; +const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5); + #[derive(Debug)] pub struct GenServerHandle { pub tx: mpsc::Sender>, @@ -74,14 +76,24 @@ impl GenServerHandle { } pub async fn call(&mut self, message: G::CallMsg) -> Result { + self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await + } + + pub async fn call_with_timeout( + &mut self, + message: G::CallMsg, + duration: Duration, + ) -> Result { let (oneshot_tx, oneshot_rx) = oneshot::channel::>(); self.tx.send(GenServerInMsg::Call { sender: oneshot_tx, message, })?; - match oneshot_rx.await { - Ok(result) => result, - Err(_) => Err(GenServerError::Server), + + match timeout(duration, oneshot_rx).await { + Ok(Ok(result)) => result, + Ok(Err(_)) => Err(GenServerError::Server), + Err(_) => Err(GenServerError::CallTimeout), } } @@ -434,4 +446,61 @@ mod tests { goodboy.call(InMessage::Stop).await.unwrap(); }); } + + const TIMEOUT_DURATION: Duration = Duration::from_millis(100); + + #[derive(Default)] + struct SomeTask; + + #[derive(Clone)] + enum SomeTaskCallMsg { + SlowOperation, + FastOperation, + } + + impl GenServer for SomeTask { + type CallMsg = SomeTaskCallMsg; + type CastMsg = (); + type OutMsg = (); + type State = (); + type Error = (); + + async fn handle_call( + &mut self, + message: Self::CallMsg, + _handle: &GenServerHandle, + _state: Self::State, + ) -> CallResponse { + match message { + SomeTaskCallMsg::SlowOperation => { + // Simulate a slow operation that will not resolve in time + rt::sleep(TIMEOUT_DURATION * 2).await; + CallResponse::Reply((), ()) + } + SomeTaskCallMsg::FastOperation => { + // Simulate a fast operation that resolves in time + rt::sleep(TIMEOUT_DURATION / 2).await; + CallResponse::Reply((), ()) + } + } + } + } + + #[test] + pub fn unresolving_task_times_out() { + let runtime = rt::Runtime::new().unwrap(); + runtime.block_on(async move { + let mut unresolving_task = SomeTask::start(()); + + let result = unresolving_task + .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION) + .await; + assert!(matches!(result, Ok(()))); + + let result = unresolving_task + .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION) + .await; + assert!(matches!(result, Err(GenServerError::CallTimeout))); + }); + } } diff --git a/rt/src/tasks/mod.rs b/rt/src/tasks/mod.rs index 10de5fd..5291f69 100644 --- a/rt/src/tasks/mod.rs +++ b/rt/src/tasks/mod.rs @@ -16,6 +16,7 @@ use crate::tracing::init_tracing; pub use crate::tasks::tokio::mpsc; pub use crate::tasks::tokio::oneshot; pub use crate::tasks::tokio::sleep; +pub use crate::tasks::tokio::timeout; pub use crate::tasks::tokio::CancellationToken; pub use crate::tasks::tokio::{spawn, spawn_blocking, JoinHandle, Runtime}; pub use crate::tasks::tokio::{BroadcastStream, ReceiverStream}; diff --git a/rt/src/tasks/tokio/mod.rs b/rt/src/tasks/tokio/mod.rs index 6abf60d..eac39e0 100644 --- a/rt/src/tasks/tokio/mod.rs +++ b/rt/src/tasks/tokio/mod.rs @@ -5,7 +5,7 @@ pub mod oneshot; pub use tokio::{ runtime::Runtime, task::{spawn, spawn_blocking, JoinHandle}, - time::sleep, + time::{sleep, timeout}, }; pub use tokio_stream::wrappers::{BroadcastStream, UnboundedReceiverStream as ReceiverStream}; pub use tokio_util::sync::CancellationToken;