|
1 | 1 | //! GenServer trait and structs to create an abstraction similar to Erlang gen_server. |
2 | 2 | //! See examples/name_server for a usage example. |
3 | 3 | use futures::future::FutureExt as _; |
4 | | -use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken}; |
5 | | -use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe}; |
| 4 | +use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken}; |
| 5 | +use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration}; |
6 | 6 |
|
7 | 7 | use crate::error::GenServerError; |
8 | 8 |
|
| 9 | +const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5); |
| 10 | + |
9 | 11 | #[derive(Debug)] |
10 | 12 | pub struct GenServerHandle<G: GenServer + 'static> { |
11 | 13 | pub tx: mpsc::Sender<GenServerInMsg<G>>, |
@@ -74,14 +76,24 @@ impl<G: GenServer> GenServerHandle<G> { |
74 | 76 | } |
75 | 77 |
|
76 | 78 | pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> { |
| 79 | + self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await |
| 80 | + } |
| 81 | + |
| 82 | + pub async fn call_with_timeout( |
| 83 | + &mut self, |
| 84 | + message: G::CallMsg, |
| 85 | + duration: Duration, |
| 86 | + ) -> Result<G::OutMsg, GenServerError> { |
77 | 87 | let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>(); |
78 | 88 | self.tx.send(GenServerInMsg::Call { |
79 | 89 | sender: oneshot_tx, |
80 | 90 | message, |
81 | 91 | })?; |
82 | | - match oneshot_rx.await { |
83 | | - Ok(result) => result, |
84 | | - Err(_) => Err(GenServerError::Server), |
| 92 | + |
| 93 | + match timeout(duration, oneshot_rx).await { |
| 94 | + Ok(Ok(result)) => result, |
| 95 | + Ok(Err(_)) => Err(GenServerError::Server), |
| 96 | + Err(_) => Err(GenServerError::CallTimeout), |
85 | 97 | } |
86 | 98 | } |
87 | 99 |
|
@@ -434,4 +446,61 @@ mod tests { |
434 | 446 | goodboy.call(InMessage::Stop).await.unwrap(); |
435 | 447 | }); |
436 | 448 | } |
| 449 | + |
| 450 | + const TIMEOUT_DURATION: Duration = Duration::from_millis(100); |
| 451 | + |
| 452 | + #[derive(Default)] |
| 453 | + struct SomeTask; |
| 454 | + |
| 455 | + #[derive(Clone)] |
| 456 | + enum SomeTaskCallMsg { |
| 457 | + SlowOperation, |
| 458 | + FastOperation, |
| 459 | + } |
| 460 | + |
| 461 | + impl GenServer for SomeTask { |
| 462 | + type CallMsg = SomeTaskCallMsg; |
| 463 | + type CastMsg = (); |
| 464 | + type OutMsg = (); |
| 465 | + type State = (); |
| 466 | + type Error = (); |
| 467 | + |
| 468 | + async fn handle_call( |
| 469 | + &mut self, |
| 470 | + message: Self::CallMsg, |
| 471 | + _handle: &GenServerHandle<Self>, |
| 472 | + _state: Self::State, |
| 473 | + ) -> CallResponse<Self> { |
| 474 | + match message { |
| 475 | + SomeTaskCallMsg::SlowOperation => { |
| 476 | + // Simulate a slow operation that will not resolve in time |
| 477 | + rt::sleep(TIMEOUT_DURATION * 2).await; |
| 478 | + CallResponse::Reply((), ()) |
| 479 | + } |
| 480 | + SomeTaskCallMsg::FastOperation => { |
| 481 | + // Simulate a fast operation that resolves in time |
| 482 | + rt::sleep(TIMEOUT_DURATION / 2).await; |
| 483 | + CallResponse::Reply((), ()) |
| 484 | + } |
| 485 | + } |
| 486 | + } |
| 487 | + } |
| 488 | + |
| 489 | + #[test] |
| 490 | + pub fn unresolving_task_times_out() { |
| 491 | + let runtime = rt::Runtime::new().unwrap(); |
| 492 | + runtime.block_on(async move { |
| 493 | + let mut unresolving_task = SomeTask::start(()); |
| 494 | + |
| 495 | + let result = unresolving_task |
| 496 | + .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION) |
| 497 | + .await; |
| 498 | + assert!(matches!(result, Ok(()))); |
| 499 | + |
| 500 | + let result = unresolving_task |
| 501 | + .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION) |
| 502 | + .await; |
| 503 | + assert!(matches!(result, Err(GenServerError::CallTimeout))); |
| 504 | + }); |
| 505 | + } |
437 | 506 | } |
0 commit comments