Skip to content

Commit 823ba05

Browse files
committed
fix: use JoinHandle for Actor::join() instead of polling
1 parent c9e1cca commit 823ba05

File tree

2 files changed

+88
-16
lines changed

2 files changed

+88
-16
lines changed

concurrency/src/tasks/actor.rs

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,39 @@ use spawned_rt::{
1010
tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken, JoinHandle},
1111
threads,
1212
};
13-
use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
13+
use std::{
14+
fmt::Debug,
15+
future::Future,
16+
panic::AssertUnwindSafe,
17+
sync::{Arc, Mutex},
18+
time::Duration,
19+
};
1420

1521
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
1622

23+
/// Wrapper for different JoinHandle types based on backend.
24+
#[derive(Debug)]
25+
enum ActorJoinHandle {
26+
/// Tokio task JoinHandle (for Async and Blocking backends)
27+
Task(JoinHandle<()>),
28+
/// OS thread JoinHandle (for Thread backend)
29+
Thread(threads::JoinHandle<()>),
30+
}
31+
32+
impl ActorJoinHandle {
33+
/// Waits for the actor to finish.
34+
async fn join(self) {
35+
match self {
36+
ActorJoinHandle::Task(h) => {
37+
let _ = h.await;
38+
}
39+
ActorJoinHandle::Thread(h) => {
40+
let _ = h.join();
41+
}
42+
}
43+
}
44+
}
45+
1746
/// Execution backend for Actor.
1847
///
1948
/// Determines how the Actor's async loop is executed. Choose based on
@@ -106,13 +135,16 @@ pub struct ActorRef<A: Actor + 'static> {
106135
pub tx: mpsc::Sender<ActorInMsg<A>>,
107136
/// Cancellation token to stop the Actor
108137
cancellation_token: CancellationToken,
138+
/// JoinHandle for waiting on actor completion
139+
join_handle: Arc<Mutex<Option<ActorJoinHandle>>>,
109140
}
110141

111142
impl<A: Actor> Clone for ActorRef<A> {
112143
fn clone(&self) -> Self {
113144
Self {
114145
tx: self.tx.clone(),
115146
cancellation_token: self.cancellation_token.clone(),
147+
join_handle: self.join_handle.clone(),
116148
}
117149
}
118150
}
@@ -121,9 +153,11 @@ impl<A: Actor> ActorRef<A> {
121153
fn new(actor: A) -> Self {
122154
let (tx, mut rx) = mpsc::channel::<ActorInMsg<A>>();
123155
let cancellation_token = CancellationToken::new();
156+
let join_handle = Arc::new(Mutex::new(None));
124157
let handle = ActorRef {
125158
tx,
126159
cancellation_token,
160+
join_handle: join_handle.clone(),
127161
};
128162
let handle_clone = handle.clone();
129163
let inner_future = async move {
@@ -136,47 +170,62 @@ impl<A: Actor> ActorRef<A> {
136170
// Optionally warn if the Actor future blocks for too much time
137171
let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
138172

139-
// Ignore the JoinHandle for now. Maybe we'll use it in the future
140-
let _join_handle = rt::spawn(inner_future);
173+
let task_handle = rt::spawn(inner_future);
174+
let mut guard = join_handle
175+
.lock()
176+
.unwrap_or_else(|poisoned| poisoned.into_inner());
177+
*guard = Some(ActorJoinHandle::Task(task_handle));
141178

142179
handle_clone
143180
}
144181

145182
fn new_blocking(actor: A) -> Self {
146183
let (tx, mut rx) = mpsc::channel::<ActorInMsg<A>>();
147184
let cancellation_token = CancellationToken::new();
185+
let join_handle = Arc::new(Mutex::new(None));
148186
let handle = ActorRef {
149187
tx,
150188
cancellation_token,
189+
join_handle: join_handle.clone(),
151190
};
152191
let handle_clone = handle.clone();
153-
// Ignore the JoinHandle for now. Maybe we'll use it in the future
154-
let _join_handle = rt::spawn_blocking(|| {
192+
let task_handle = rt::spawn_blocking(|| {
155193
rt::block_on(async move {
156194
if let Err(error) = actor.run(&handle, &mut rx).await {
157195
tracing::trace!(%error, "Actor crashed")
158196
};
159197
})
160198
});
199+
let mut guard = join_handle
200+
.lock()
201+
.unwrap_or_else(|poisoned| poisoned.into_inner());
202+
*guard = Some(ActorJoinHandle::Task(task_handle));
203+
161204
handle_clone
162205
}
163206

164207
fn new_on_thread(actor: A) -> Self {
165208
let (tx, mut rx) = mpsc::channel::<ActorInMsg<A>>();
166209
let cancellation_token = CancellationToken::new();
210+
let join_handle = Arc::new(Mutex::new(None));
167211
let handle = ActorRef {
168212
tx,
169213
cancellation_token,
214+
join_handle: join_handle.clone(),
170215
};
171216
let handle_clone = handle.clone();
172-
// Ignore the JoinHandle for now. Maybe we'll use it in the future
173-
let _join_handle = threads::spawn(|| {
217+
let thread_handle = threads::spawn(|| {
174218
threads::block_on(async move {
175219
if let Err(error) = actor.run(&handle, &mut rx).await {
176220
tracing::trace!(%error, "Actor crashed")
177221
};
178222
})
179223
});
224+
let mut guard = join_handle
225+
.lock()
226+
.unwrap_or_else(|poisoned| poisoned.into_inner());
227+
*guard = Some(ActorJoinHandle::Thread(thread_handle));
228+
180229
handle_clone
181230
}
182231

@@ -220,9 +269,19 @@ impl<A: Actor> ActorRef<A> {
220269
/// Waits for the actor to stop.
221270
///
222271
/// This method returns a future that completes when the actor has finished
223-
/// processing and exited its main loop.
272+
/// processing and exited its main loop. Can only be called once; subsequent
273+
/// calls return immediately.
224274
pub async fn join(&self) {
225-
self.cancellation_token.cancelled().await
275+
let handle = {
276+
let mut guard = self
277+
.join_handle
278+
.lock()
279+
.unwrap_or_else(|poisoned| poisoned.into_inner());
280+
guard.take()
281+
};
282+
if let Some(h) = handle {
283+
h.join().await;
284+
}
226285
}
227286
}
228287

concurrency/src/threads/actor.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
//! Actor trait and structs to create an abstraction similar to Erlang gen_server.
22
//! See examples/name_server for a usage example.
33
use spawned_rt::threads::{
4-
self as rt, mpsc, oneshot, oneshot::RecvTimeoutError, CancellationToken,
4+
self as rt, mpsc, oneshot, oneshot::RecvTimeoutError, CancellationToken, JoinHandle,
55
};
66
use std::{
77
fmt::Debug,
88
panic::{catch_unwind, AssertUnwindSafe},
9+
sync::{Arc, Mutex},
910
time::Duration,
1011
};
1112

@@ -17,13 +18,15 @@ const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
1718
pub struct ActorRef<A: Actor + 'static> {
1819
pub tx: mpsc::Sender<ActorInMsg<A>>,
1920
cancellation_token: CancellationToken,
21+
join_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
2022
}
2123

2224
impl<A: Actor> Clone for ActorRef<A> {
2325
fn clone(&self) -> Self {
2426
Self {
2527
tx: self.tx.clone(),
2628
cancellation_token: self.cancellation_token.clone(),
29+
join_handle: self.join_handle.clone(),
2730
}
2831
}
2932
}
@@ -32,17 +35,22 @@ impl<A: Actor> ActorRef<A> {
3235
pub(crate) fn new(actor: A) -> Self {
3336
let (tx, mut rx) = mpsc::channel::<ActorInMsg<A>>();
3437
let cancellation_token = CancellationToken::new();
38+
let join_handle = Arc::new(Mutex::new(None));
3539
let handle = ActorRef {
3640
tx,
3741
cancellation_token,
42+
join_handle: join_handle.clone(),
3843
};
3944
let handle_clone = handle.clone();
40-
// Ignore the JoinHandle for now. Maybe we'll use it in the future
41-
let _join_handle = rt::spawn(move || {
45+
let thread_handle = rt::spawn(move || {
4246
if actor.run(&handle, &mut rx).is_err() {
4347
tracing::trace!("Actor crashed")
4448
};
4549
});
50+
let mut guard = join_handle
51+
.lock()
52+
.unwrap_or_else(|poisoned| poisoned.into_inner());
53+
*guard = Some(thread_handle);
4654
handle_clone
4755
}
4856

@@ -84,11 +92,16 @@ impl<A: Actor> ActorRef<A> {
8492
/// Blocks until the actor has stopped.
8593
///
8694
/// This method blocks the current thread until the actor has finished
87-
/// processing and exited its main loop.
95+
/// processing and exited its main loop. Can only be called once; subsequent
96+
/// calls return immediately.
8897
pub fn join(&self) {
89-
let mut token = self.cancellation_token.clone();
90-
while !token.is_cancelled() {
91-
std::thread::sleep(std::time::Duration::from_millis(10));
98+
// Recover from poisoned lock if another thread panicked while holding it
99+
let mut guard = self
100+
.join_handle
101+
.lock()
102+
.unwrap_or_else(|poisoned| poisoned.into_inner());
103+
if let Some(h) = guard.take() {
104+
let _ = h.join();
92105
}
93106
}
94107
}

0 commit comments

Comments
 (0)