Skip to content

Commit 3a8e98c

Browse files
authored
fix: timers wake immediately on actor shutdown in threads module (#142)
* feat: add unified ctrl_c() signal handling to spawned-rt * Formatted several files * fix: use JoinHandle for Actor::join() instead of polling * fix: timers wake immediately on actor shutdown in threads module * Removed message about reverting state that no longer holds * fix: allow multiple actors to subscribe to ctrl_c signal * test: update signal_test to use two actors * fix: use spawn_blocking for Thread backend join and fix flaky stream test * fix: handle poisoned mutex in ctrl_c subscriber registration * fix: race condition in CancellationToken::on_cancel * Fixed typo
1 parent 3cdbcda commit 3a8e98c

File tree

7 files changed

+183
-40
lines changed

7 files changed

+183
-40
lines changed

concurrency/src/tasks/actor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,10 +456,10 @@ where
456456
U: Future + Send + 'static,
457457
<U as Future>::Output: Send,
458458
{
459-
let cancelation_token = handle.cancellation_token();
459+
let cancellation_token = handle.cancellation_token();
460460
let mut handle_clone = handle.clone();
461461
let join_handle = rt::spawn(async move {
462-
let is_cancelled = pin!(cancelation_token.cancelled());
462+
let is_cancelled = pin!(cancellation_token.cancelled());
463463
let signal = pin!(future);
464464
match future::select(is_cancelled, signal).await {
465465
future::Either::Left(_) => tracing::debug!("Actor stopped"),

concurrency/src/tasks/stream.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ where
1313
T: Actor,
1414
S: Send + Stream<Item = T::Message> + 'static,
1515
{
16-
let cancelation_token = handle.cancellation_token();
16+
let cancellation_token = handle.cancellation_token();
1717
let join_handle = spawned_rt::tasks::spawn(async move {
1818
let mut pinned_stream = core::pin::pin!(stream);
19-
let is_cancelled = core::pin::pin!(cancelation_token.cancelled());
19+
let is_cancelled = core::pin::pin!(cancellation_token.cancelled());
2020
let listener_loop = core::pin::pin!(async {
2121
loop {
2222
match pinned_stream.next().await {

concurrency/src/threads/actor.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
//! Actor trait and structs to create an abstraction similar to Erlang gen_server.
22
//! See examples/name_server for a usage example.
3-
use spawned_rt::threads::{self as rt, mpsc, oneshot, oneshot::RecvTimeoutError, CancellationToken};
3+
use spawned_rt::threads::{
4+
self as rt, mpsc, oneshot, oneshot::RecvTimeoutError, CancellationToken,
5+
};
46
use std::{
57
fmt::Debug,
68
panic::{catch_unwind, AssertUnwindSafe},
@@ -164,7 +166,7 @@ pub trait Actor: Send + Sized {
164166
handle: &ActorRef<Self>,
165167
rx: &mut mpsc::Receiver<ActorInMsg<Self>>,
166168
) -> Result<(), ActorError> {
167-
let mut cancellation_token = handle.cancellation_token.clone();
169+
let cancellation_token = handle.cancellation_token.clone();
168170

169171
let res = match self.init(handle) {
170172
Ok(InitResult::Success(new_state)) => {
@@ -235,7 +237,7 @@ pub trait Actor: Send + Sized {
235237
}
236238
},
237239
Err(error) => {
238-
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
240+
tracing::error!("Error in callback: '{error:?}'");
239241
(true, Err(ActorError::Callback))
240242
}
241243
};
@@ -256,7 +258,7 @@ pub trait Actor: Send + Sized {
256258
}
257259
},
258260
Err(error) => {
259-
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
261+
tracing::error!("Error in callback: '{error:?}'");
260262
true
261263
}
262264
}
@@ -301,7 +303,7 @@ where
301303
T: Actor,
302304
F: FnOnce() + Send + 'static,
303305
{
304-
let mut cancellation_token = handle.cancellation_token();
306+
let cancellation_token = handle.cancellation_token();
305307
let mut handle_clone = handle.clone();
306308
rt::spawn(move || {
307309
f();

concurrency/src/threads/stream.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ where
1212
<I as IntoIterator>::IntoIter: std::marker::Send + 'static,
1313
{
1414
let mut iter = stream.into_iter();
15-
let mut cancelation_token = handle.cancellation_token();
15+
let cancellation_token = handle.cancellation_token();
1616
let join_handle = spawned_rt::threads::spawn(move || loop {
1717
match iter.next() {
1818
Some(msg) => match handle.send(msg) {
@@ -27,7 +27,7 @@ where
2727
break;
2828
}
2929
}
30-
if cancelation_token.is_cancelled() {
30+
if cancellation_token.is_cancelled() {
3131
tracing::trace!("Actor stopped");
3232
break;
3333
}

concurrency/src/threads/time.rs

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::sync::mpsc::{self, RecvTimeoutError};
12
use std::time::Duration;
23

34
use spawned_rt::threads::{self as rt, CancellationToken, JoinHandle};
@@ -9,29 +10,58 @@ pub struct TimerHandle {
910
pub cancellation_token: CancellationToken,
1011
}
1112

12-
// Sends a message after a given period to the specified Actor. The task terminates
13-
// once the send has completed
13+
/// Sends a message after a given period to the specified Actor.
14+
///
15+
/// The timer respects both its own cancellation token and the Actor's
16+
/// cancellation token. If either is cancelled, the timer wakes up immediately
17+
/// and exits without sending the message.
1418
pub fn send_after<T>(period: Duration, mut handle: ActorRef<T>, message: T::Message) -> TimerHandle
1519
where
1620
T: Actor + 'static,
1721
{
1822
let cancellation_token = CancellationToken::new();
19-
let mut cloned_token = cancellation_token.clone();
20-
let mut actor_cancellation_token = handle.cancellation_token();
23+
let timer_token = cancellation_token.clone();
24+
let actor_token = handle.cancellation_token();
25+
26+
// Channel to wake the timer thread on cancellation
27+
let (wake_tx, wake_rx) = mpsc::channel::<()>();
28+
29+
// Register wake-up on timer cancellation
30+
let wake_tx1 = wake_tx.clone();
31+
timer_token.on_cancel(Box::new(move || {
32+
let _ = wake_tx1.send(());
33+
}));
34+
35+
// Register wake-up on actor cancellation
36+
actor_token.on_cancel(Box::new(move || {
37+
let _ = wake_tx.send(());
38+
}));
39+
2140
let join_handle = rt::spawn(move || {
22-
rt::sleep(period);
23-
// Only send if neither the timer nor the actor has been cancelled
24-
if !cloned_token.is_cancelled() && !actor_cancellation_token.is_cancelled() {
25-
let _ = handle.send(message);
26-
};
41+
match wake_rx.recv_timeout(period) {
42+
Err(RecvTimeoutError::Timeout) => {
43+
// Timer expired - send if still valid
44+
if !timer_token.is_cancelled() && !actor_token.is_cancelled() {
45+
let _ = handle.send(message);
46+
}
47+
}
48+
Ok(()) | Err(RecvTimeoutError::Disconnected) => {
49+
// Woken early by cancellation - exit without sending
50+
}
51+
}
2752
});
53+
2854
TimerHandle {
2955
join_handle,
3056
cancellation_token,
3157
}
3258
}
3359

34-
// Sends a message to the specified Actor repeatedly after `Time` milliseconds.
60+
/// Sends a message to the specified Actor repeatedly at the given interval.
61+
///
62+
/// The timer respects both its own cancellation token and the Actor's
63+
/// cancellation token. If either is cancelled, the timer wakes up immediately
64+
/// and exits.
3565
pub fn send_interval<T>(
3666
period: Duration,
3767
mut handle: ActorRef<T>,
@@ -41,17 +71,34 @@ where
4171
T: Actor + 'static,
4272
{
4373
let cancellation_token = CancellationToken::new();
44-
let mut cloned_token = cancellation_token.clone();
45-
let mut actor_cancellation_token = handle.cancellation_token();
46-
let join_handle = rt::spawn(move || loop {
47-
rt::sleep(period);
48-
// Stop if either the timer or the actor has been cancelled
49-
if cloned_token.is_cancelled() || actor_cancellation_token.is_cancelled() {
50-
break;
51-
} else {
74+
let timer_token = cancellation_token.clone();
75+
let actor_token = handle.cancellation_token();
76+
77+
// Channel to wake the timer thread on cancellation
78+
let (wake_tx, wake_rx) = mpsc::channel::<()>();
79+
80+
// Register wake-up on timer cancellation
81+
let wake_tx1 = wake_tx.clone();
82+
timer_token.on_cancel(Box::new(move || {
83+
let _ = wake_tx1.send(());
84+
}));
85+
86+
// Register wake-up on actor cancellation
87+
actor_token.on_cancel(Box::new(move || {
88+
let _ = wake_tx.send(());
89+
}));
90+
91+
let join_handle = rt::spawn(move || {
92+
while let Err(RecvTimeoutError::Timeout) = wake_rx.recv_timeout(period) {
93+
// Timer expired - send if still valid
94+
if timer_token.is_cancelled() || actor_token.is_cancelled() {
95+
break;
96+
}
5297
let _ = handle.send(message.clone());
53-
};
98+
}
99+
// If we exit the loop via Ok(()) or Disconnected, cancellation occurred
54100
});
101+
55102
TimerHandle {
56103
join_handle,
57104
cancellation_token,

concurrency/src/threads/timer_tests.rs

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ impl Actor for Repeater {
8686
self.count += 1;
8787
}
8888
RepeaterCastMessage::StopTimer => {
89-
if let Some(mut ct) = self.cancellation_token.clone() {
89+
if let Some(ct) = self.cancellation_token.clone() {
9090
ct.cancel()
9191
};
9292
}
@@ -132,6 +132,7 @@ enum DelayedCastMessage {
132132
#[derive(Clone)]
133133
enum DelayedCallMessage {
134134
GetCount,
135+
Stop,
135136
}
136137

137138
#[derive(PartialEq, Debug)]
@@ -156,6 +157,10 @@ impl Delayed {
156157
pub fn get_count(server: &mut DelayedHandle) -> Result<DelayedOutMessage, ()> {
157158
server.request(DelayedCallMessage::GetCount).map_err(|_| ())
158159
}
160+
161+
pub fn stop(server: &mut DelayedHandle) -> Result<DelayedOutMessage, ()> {
162+
server.request(DelayedCallMessage::Stop).map_err(|_| ())
163+
}
159164
}
160165

161166
impl Actor for Delayed {
@@ -166,11 +171,17 @@ impl Actor for Delayed {
166171

167172
fn handle_request(
168173
&mut self,
169-
_message: Self::Request,
174+
message: Self::Request,
170175
_handle: &DelayedHandle,
171176
) -> RequestResponse<Self> {
172-
let count = self.count;
173-
RequestResponse::Reply(DelayedOutMessage::Count(count))
177+
match message {
178+
DelayedCallMessage::GetCount => {
179+
RequestResponse::Reply(DelayedOutMessage::Count(self.count))
180+
}
181+
DelayedCallMessage::Stop => {
182+
RequestResponse::Stop(DelayedOutMessage::Count(self.count))
183+
}
184+
}
174185
}
175186

176187
fn handle_message(
@@ -209,7 +220,7 @@ pub fn test_send_after_and_cancellation() {
209220
assert_eq!(DelayedOutMessage::Count(1), count);
210221

211222
// New timer
212-
let mut timer = send_after(
223+
let timer = send_after(
213224
Duration::from_millis(100),
214225
repeater.clone(),
215226
DelayedCastMessage::Inc,
@@ -227,3 +238,41 @@ pub fn test_send_after_and_cancellation() {
227238
// As timer was cancelled, count should remain at 1
228239
assert_eq!(DelayedOutMessage::Count(1), count2);
229240
}
241+
242+
#[test]
243+
pub fn test_send_after_actor_shutdown() {
244+
// Start a Delayed
245+
let mut actor = Delayed::new(0).start();
246+
247+
// Set a just once timed message
248+
let _ = send_after(
249+
Duration::from_millis(100),
250+
actor.clone(),
251+
DelayedCastMessage::Inc,
252+
);
253+
254+
// Wait for 200 milliseconds
255+
rt::sleep(Duration::from_millis(200));
256+
257+
// Check count
258+
let count = Delayed::get_count(&mut actor).unwrap();
259+
260+
// Only one message (no repetition)
261+
assert_eq!(DelayedOutMessage::Count(1), count);
262+
263+
// New timer with long delay
264+
let _ = send_after(
265+
Duration::from_millis(100),
266+
actor.clone(),
267+
DelayedCastMessage::Inc,
268+
);
269+
270+
// Stop the Actor before timeout - this should wake up the timer immediately
271+
let count2 = Delayed::stop(&mut actor).unwrap();
272+
273+
// Wait another 200 milliseconds
274+
rt::sleep(Duration::from_millis(200));
275+
276+
// As actor was stopped, count should remain at 1 (timer didn't fire)
277+
assert_eq!(DelayedOutMessage::Count(1), count2);
278+
}

rt/src/threads/mod.rs

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,69 @@ where
3737
spawn(f)
3838
}
3939

40-
#[derive(Clone, Debug, Default)]
40+
type CancelCallback = Box<dyn FnOnce() + Send>;
41+
42+
/// A token that can be used to signal cancellation.
43+
///
44+
/// Supports registering callbacks via `on_cancel()` that fire when
45+
/// the token is cancelled, enabling efficient waiting patterns.
46+
#[derive(Clone, Default)]
4147
pub struct CancellationToken {
4248
is_cancelled: Arc<AtomicBool>,
49+
callbacks: Arc<Mutex<Vec<CancelCallback>>>,
50+
}
51+
52+
impl std::fmt::Debug for CancellationToken {
53+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54+
f.debug_struct("CancellationToken")
55+
.field("is_cancelled", &self.is_cancelled())
56+
.finish()
57+
}
4358
}
4459

4560
impl CancellationToken {
4661
pub fn new() -> Self {
4762
CancellationToken {
4863
is_cancelled: Arc::new(false.into()),
64+
callbacks: Arc::new(Mutex::new(Vec::new())),
4965
}
5066
}
5167

52-
pub fn is_cancelled(&mut self) -> bool {
53-
self.is_cancelled.fetch_and(false, Ordering::SeqCst)
68+
pub fn is_cancelled(&self) -> bool {
69+
self.is_cancelled.load(Ordering::SeqCst)
5470
}
5571

56-
pub fn cancel(&mut self) {
57-
self.is_cancelled.fetch_or(true, Ordering::SeqCst);
72+
pub fn cancel(&self) {
73+
self.is_cancelled.store(true, Ordering::SeqCst);
74+
// Fire all registered callbacks
75+
let callbacks: Vec<_> = self
76+
.callbacks
77+
.lock()
78+
.unwrap_or_else(|e| e.into_inner())
79+
.drain(..)
80+
.collect();
81+
for cb in callbacks {
82+
cb();
83+
}
84+
}
85+
86+
/// Register a callback to be invoked when this token is cancelled.
87+
/// If already cancelled, the callback fires immediately.
88+
///
89+
/// This method is thread-safe: the callback is guaranteed to fire exactly
90+
/// once, either immediately (if already cancelled) or when `cancel()` is called.
91+
pub fn on_cancel(&self, callback: CancelCallback) {
92+
// Hold the lock while checking is_cancelled to avoid a race with cancel().
93+
// cancel() sets the flag BEFORE acquiring the lock, so if we see
94+
// is_cancelled=false while holding the lock, cancel() hasn't drained
95+
// callbacks yet and will drain ours after we release the lock.
96+
let mut callbacks = self.callbacks.lock().unwrap_or_else(|e| e.into_inner());
97+
if self.is_cancelled() {
98+
drop(callbacks);
99+
callback();
100+
} else {
101+
callbacks.push(callback);
102+
}
58103
}
59104
}
60105

0 commit comments

Comments
 (0)