Skip to content

Commit 90d58bf

Browse files
Add stop behaviour to GenServer (#22)
* First attempt for stream_listener * add test to stream listener * add util fn to convert receiver to stream * add bounded channel * add broadcast listener * fix `spawn_broadcast_listener` * unify spawn_listener & remove oneline functions * doc update * add impl of sync spawn listener * rename spawn_listener to spawn_listener_from_iter, and port spawn_listener * add bound channel to threads concurrency * merge duplicated code * add cancel token with 'flaky' test * unflaky the test * add cancellation to task impl of spawn_listener * docs & clippy * use futures select inside spawn listener * use genserver cancel token on stream * add cancelation token to timer * add cancellation token to gen server tasks impl * remove bounded channels from tasks impl * remove sync channels from threads impl * deprecate spawn_listener for threads impl * fix imports * remove impl for threads due to reprecation * revert more lines * rename `stop` to `teardown` * refactor teardown logic * remove commented code, reword tracing msg * improve default teardown function * mandatory token cancel * remove unused variable * bump crate version * update lock --------- Co-authored-by: Esteban Dimitroff Hodi <[email protected]>
1 parent 6322041 commit 90d58bf

File tree

7 files changed

+148
-38
lines changed

7 files changed

+148
-38
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ tracing = { version = "0.1.41", features = ["log"] }
2222
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
2323

2424
[workspace.package]
25-
version = "0.1.3"
25+
version = "0.1.4"
2626
license = "MIT"
2727
edition = "2021"

concurrency/src/tasks/gen_server.rs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11
//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
22
//! See examples/name_server for a usage example.
33
use futures::future::FutureExt as _;
4-
use spawned_rt::tasks::{self as rt, mpsc, oneshot};
4+
use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken};
55
use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe};
66

77
use crate::error::GenServerError;
88

9-
#[derive(Debug)]
109
pub struct GenServerHandle<G: GenServer + 'static> {
1110
pub tx: mpsc::Sender<GenServerInMsg<G>>,
11+
/// Cancellation token to stop the GenServer
12+
cancellation_token: CancellationToken,
1213
}
1314

1415
impl<G: GenServer> Clone for GenServerHandle<G> {
1516
fn clone(&self) -> Self {
1617
Self {
1718
tx: self.tx.clone(),
19+
cancellation_token: self.cancellation_token.clone(),
1820
}
1921
}
2022
}
2123

2224
impl<G: GenServer> GenServerHandle<G> {
2325
pub(crate) fn new(initial_state: G::State) -> Self {
2426
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
25-
let handle = GenServerHandle { tx };
27+
let cancellation_token = CancellationToken::new();
28+
let handle = GenServerHandle {
29+
tx,
30+
cancellation_token,
31+
};
2632
let mut gen_server: G = GenServer::new();
2733
let handle_clone = handle.clone();
2834
// Ignore the JoinHandle for now. Maybe we'll use it in the future
@@ -40,7 +46,11 @@ impl<G: GenServer> GenServerHandle<G> {
4046

4147
pub(crate) fn new_blocking(initial_state: G::State) -> Self {
4248
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
43-
let handle = GenServerHandle { tx };
49+
let cancellation_token = CancellationToken::new();
50+
let handle = GenServerHandle {
51+
tx,
52+
cancellation_token,
53+
};
4454
let mut gen_server: G = GenServer::new();
4555
let handle_clone = handle.clone();
4656
// Ignore the JoinHandle for now. Maybe we'll use it in the future
@@ -79,6 +89,10 @@ impl<G: GenServer> GenServerHandle<G> {
7989
.send(GenServerInMsg::Cast { message })
8090
.map_err(|_error| GenServerError::Server)
8191
}
92+
93+
pub fn cancellation_token(&self) -> CancellationToken {
94+
self.cancellation_token.clone()
95+
}
8296
}
8397

8498
pub enum GenServerInMsg<G: GenServer> {
@@ -168,12 +182,16 @@ where
168182
async {
169183
loop {
170184
let (new_state, cont) = self.receive(handle, rx, state).await?;
185+
state = new_state;
171186
if !cont {
172187
break;
173188
}
174-
state = new_state;
175189
}
176190
tracing::trace!("Stopping GenServer");
191+
handle.cancellation_token().cancel();
192+
if let Err(err) = self.teardown(handle, state).await {
193+
tracing::error!("Error during teardown: {err:?}");
194+
}
177195
Ok(())
178196
}
179197
}
@@ -269,6 +287,17 @@ where
269287
) -> impl Future<Output = CastResponse<Self>> + Send {
270288
async { CastResponse::Unused }
271289
}
290+
291+
/// Teardown function. It's called after the stop message is received.
292+
/// It can be overrided on implementations in case final steps are required,
293+
/// like closing streams, stopping timers, etc.
294+
fn teardown(
295+
&mut self,
296+
_handle: &GenServerHandle<Self>,
297+
_state: Self::State,
298+
) -> impl Future<Output = Result<(), Self::Error>> + Send {
299+
async { Ok(()) }
300+
}
272301
}
273302

274303
#[cfg(test)]

concurrency/src/tasks/stream.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::tasks::{GenServer, GenServerHandle};
22
use futures::{future::select, Stream, StreamExt};
3-
use spawned_rt::tasks::{CancellationToken, JoinHandle};
3+
use spawned_rt::tasks::JoinHandle;
44

55
/// Spawns a listener that listens to a stream and sends messages to a GenServer.
66
///
@@ -12,19 +12,18 @@ pub fn spawn_listener<T, F, S, I, E>(
1212
mut handle: GenServerHandle<T>,
1313
message_builder: F,
1414
mut stream: S,
15-
) -> (JoinHandle<()>, CancellationToken)
15+
) -> JoinHandle<()>
1616
where
1717
T: GenServer + 'static,
1818
F: Fn(I) -> T::CastMsg + Send + 'static + std::marker::Sync,
1919
I: Send,
2020
E: std::fmt::Debug + Send,
2121
S: Unpin + Send + Stream<Item = Result<I, E>> + 'static,
2222
{
23-
let cancelation_token = CancellationToken::new();
24-
let cloned_token = cancelation_token.clone();
23+
let cancelation_token = handle.cancellation_token();
2524
let join_handle = spawned_rt::tasks::spawn(async move {
2625
let result = select(
27-
Box::pin(cloned_token.cancelled()),
26+
Box::pin(cancelation_token.cancelled()),
2827
Box::pin(async {
2928
loop {
3029
match stream.next().await {
@@ -49,9 +48,9 @@ where
4948
)
5049
.await;
5150
match result {
52-
futures::future::Either::Left(_) => tracing::trace!("Listener cancelled"),
51+
futures::future::Either::Left(_) => tracing::trace!("GenServer stopped"),
5352
futures::future::Either::Right(_) => (), // Stream finished or errored out
5453
}
5554
});
56-
(join_handle, cancelation_token)
55+
join_handle
5756
}

concurrency/src/tasks/stream_tests.rs

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@ use std::time::Duration;
33
use spawned_rt::tasks::{self as rt, BroadcastStream, ReceiverStream};
44

55
use crate::tasks::{
6-
stream::spawn_listener, CallResponse, CastResponse, GenServer, GenServerHandle,
6+
send_after, stream::spawn_listener, CallResponse, CastResponse, GenServer, GenServerHandle,
77
};
88

99
type SummatoryHandle = GenServerHandle<Summatory>;
1010

1111
struct Summatory;
1212

1313
type SummatoryState = u16;
14+
type SummatoryOutMessage = SummatoryState;
1415

1516
#[derive(Clone)]
16-
struct UpdateSumatory {
17-
added_value: u16,
17+
enum SummatoryCastMessage {
18+
Add(u16),
19+
Stop,
1820
}
1921

2022
impl Summatory {
@@ -25,8 +27,8 @@ impl Summatory {
2527

2628
impl GenServer for Summatory {
2729
type CallMsg = (); // We only handle one type of call, so there is no need for a specific message type.
28-
type CastMsg = UpdateSumatory;
29-
type OutMsg = SummatoryState;
30+
type CastMsg = SummatoryCastMessage;
31+
type OutMsg = SummatoryOutMessage;
3032
type State = SummatoryState;
3133
type Error = ();
3234

@@ -40,8 +42,13 @@ impl GenServer for Summatory {
4042
_handle: &GenServerHandle<Self>,
4143
state: Self::State,
4244
) -> CastResponse<Self> {
43-
let new_state = state + message.added_value;
44-
CastResponse::NoReply(new_state)
45+
match message {
46+
SummatoryCastMessage::Add(val) => {
47+
let new_state = state + val;
48+
CastResponse::NoReply(new_state)
49+
}
50+
SummatoryCastMessage::Stop => CastResponse::Stop,
51+
}
4552
}
4653

4754
async fn handle_call(
@@ -56,11 +63,9 @@ impl GenServer for Summatory {
5663
}
5764

5865
// In this example, the stream sends u8 values, which are converted to the type
59-
// supported by the GenServer (UpdateSumatory / u16).
60-
fn message_builder(value: u8) -> UpdateSumatory {
61-
UpdateSumatory {
62-
added_value: value as u16,
63-
}
66+
// supported by the GenServer (SummatoryCastMessage / u16).
67+
fn message_builder(value: u8) -> SummatoryCastMessage {
68+
SummatoryCastMessage::Add(value.into())
6469
}
6570

6671
#[test]
@@ -153,22 +158,34 @@ pub fn test_stream_cancellation() {
153158
}
154159
});
155160

156-
let (_handle, cancellation_token) = spawn_listener(
161+
let listener_handle = spawn_listener(
157162
summatory_handle.clone(),
158163
message_builder,
159164
ReceiverStream::new(rx),
160165
);
161166

162-
// Wait for 1 second so the whole stream is processed
163-
rt::sleep(Duration::from_millis(RUNNING_TIME)).await;
167+
// Start a timer to stop the stream after a certain time
168+
let summatory_handle_clone = summatory_handle.clone();
169+
let _ = send_after(
170+
Duration::from_millis(RUNNING_TIME + 10),
171+
summatory_handle_clone,
172+
SummatoryCastMessage::Stop,
173+
);
164174

165-
cancellation_token.cancel();
175+
// Just before the stream is cancelled we retrieve the current value.
176+
rt::sleep(Duration::from_millis(RUNNING_TIME)).await;
177+
let val = Summatory::get_value(&mut summatory_handle).await.unwrap();
166178

167179
// The reasoning for this assertion is that each message takes a quarter of the total time
168180
// to be processed, so having a stream of 5 messages, the last one won't be processed.
169181
// We could safely assume that it will get to process 4 messages, but in case of any extenal
170182
// slowdown, it could process less.
171-
let val = Summatory::get_value(&mut summatory_handle).await.unwrap();
172-
assert!(val > 0 && val < 15);
183+
assert!((1..=10).contains(&val));
184+
185+
assert!(listener_handle.await.is_ok());
186+
187+
// Finnally, we check that the server is stopped, by getting an error when trying to call it.
188+
rt::sleep(Duration::from_millis(10)).await;
189+
assert!(Summatory::get_value(&mut summatory_handle).await.is_err());
173190
})
174191
}

concurrency/src/tasks/time.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,16 @@ where
2222
{
2323
let cancellation_token = CancellationToken::new();
2424
let cloned_token = cancellation_token.clone();
25+
let gen_server_cancellation_token = handle.cancellation_token();
2526
let join_handle = rt::spawn(async move {
26-
let _ = select(
27+
// Timer action is ignored if it was either cancelled or the associated GenServer is no longer running.
28+
let cancel_conditions = select(
2729
Box::pin(cloned_token.cancelled()),
30+
Box::pin(gen_server_cancellation_token.cancelled()),
31+
);
32+
33+
let _ = select(
34+
cancel_conditions,
2835
Box::pin(async {
2936
rt::sleep(period).await;
3037
let _ = handle.cast(message.clone()).await;
@@ -49,10 +56,17 @@ where
4956
{
5057
let cancellation_token = CancellationToken::new();
5158
let cloned_token = cancellation_token.clone();
59+
let gen_server_cancellation_token = handle.cancellation_token();
5260
let join_handle = rt::spawn(async move {
5361
loop {
54-
let result = select(
62+
// Timer action is ignored if it was either cancelled or the associated GenServer is no longer running.
63+
let cancel_conditions = select(
5564
Box::pin(cloned_token.cancelled()),
65+
Box::pin(gen_server_cancellation_token.cancelled()),
66+
);
67+
68+
let result = select(
69+
Box::pin(cancel_conditions),
5670
Box::pin(async {
5771
rt::sleep(period).await;
5872
let _ = handle.cast(message.clone()).await;

concurrency/src/tasks/timer_tests.rs

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ enum DelayedCastMessage {
149149
#[derive(Clone)]
150150
enum DelayedCallMessage {
151151
GetCount,
152+
Stop,
152153
}
153154

154155
#[derive(PartialEq, Debug)]
@@ -165,6 +166,10 @@ impl Delayed {
165166
.await
166167
.map_err(|_| ())
167168
}
169+
170+
pub async fn stop(server: &mut DelayedHandle) -> Result<DelayedOutMessage, ()> {
171+
server.call(DelayedCallMessage::Stop).await.map_err(|_| ())
172+
}
168173
}
169174

170175
impl GenServer for Delayed {
@@ -180,12 +185,17 @@ impl GenServer for Delayed {
180185

181186
async fn handle_call(
182187
&mut self,
183-
_message: Self::CallMsg,
188+
message: Self::CallMsg,
184189
_handle: &DelayedHandle,
185190
state: Self::State,
186191
) -> CallResponse<Self> {
187-
let count = state.count;
188-
CallResponse::Reply(state, DelayedOutMessage::Count(count))
192+
match message {
193+
DelayedCallMessage::GetCount => {
194+
let count = state.count;
195+
CallResponse::Reply(state, DelayedOutMessage::Count(count))
196+
}
197+
DelayedCallMessage::Stop => CallResponse::Stop(DelayedOutMessage::Count(state.count)),
198+
}
189199
}
190200

191201
async fn handle_cast(
@@ -246,3 +256,44 @@ pub fn test_send_after_and_cancellation() {
246256
assert_eq!(DelayedOutMessage::Count(1), count2);
247257
});
248258
}
259+
260+
#[test]
261+
pub fn test_send_after_gen_server_teardown() {
262+
let runtime = rt::Runtime::new().unwrap();
263+
runtime.block_on(async move {
264+
// Start a Delayed
265+
let mut repeater = Delayed::start(DelayedState { count: 0 });
266+
267+
// Set a just once timed message
268+
let _ = send_after(
269+
Duration::from_millis(100),
270+
repeater.clone(),
271+
DelayedCastMessage::Inc,
272+
);
273+
274+
// Wait for 200 milliseconds
275+
rt::sleep(Duration::from_millis(200)).await;
276+
277+
// Check count
278+
let count = Delayed::get_count(&mut repeater).await.unwrap();
279+
280+
// Only one message (no repetition)
281+
assert_eq!(DelayedOutMessage::Count(1), count);
282+
283+
// New timer
284+
let _ = send_after(
285+
Duration::from_millis(100),
286+
repeater.clone(),
287+
DelayedCastMessage::Inc,
288+
);
289+
290+
// Stop the GenServer before timeout
291+
let count2 = Delayed::stop(&mut repeater).await.unwrap();
292+
293+
// Wait another 200 milliseconds
294+
rt::sleep(Duration::from_millis(200)).await;
295+
296+
// As timer was cancelled, count should remain at 1
297+
assert_eq!(DelayedOutMessage::Count(1), count2);
298+
});
299+
}

0 commit comments

Comments
 (0)