@@ -10,10 +10,39 @@ use spawned_rt::{
1010 tasks:: { self as rt, mpsc, oneshot, timeout, CancellationToken , JoinHandle } ,
1111 threads,
1212} ;
13- use std:: { fmt:: Debug , future:: Future , panic:: AssertUnwindSafe , time:: Duration } ;
13+ use std:: {
14+ fmt:: Debug ,
15+ future:: Future ,
16+ panic:: AssertUnwindSafe ,
17+ sync:: { Arc , Mutex } ,
18+ time:: Duration ,
19+ } ;
1420
1521const DEFAULT_REQUEST_TIMEOUT : Duration = Duration :: from_secs ( 5 ) ;
1622
23+ /// Wrapper for different JoinHandle types based on backend.
24+ #[ derive( Debug ) ]
25+ enum ActorJoinHandle {
26+ /// Tokio task JoinHandle (for Async and Blocking backends)
27+ Task ( JoinHandle < ( ) > ) ,
28+ /// OS thread JoinHandle (for Thread backend)
29+ Thread ( threads:: JoinHandle < ( ) > ) ,
30+ }
31+
32+ impl ActorJoinHandle {
33+ /// Waits for the actor to finish.
34+ async fn join ( self ) {
35+ match self {
36+ ActorJoinHandle :: Task ( h) => {
37+ let _ = h. await ;
38+ }
39+ ActorJoinHandle :: Thread ( h) => {
40+ let _ = h. join ( ) ;
41+ }
42+ }
43+ }
44+ }
45+
1746/// Execution backend for Actor.
1847///
1948/// Determines how the Actor's async loop is executed. Choose based on
@@ -106,13 +135,16 @@ pub struct ActorRef<A: Actor + 'static> {
106135 pub tx : mpsc:: Sender < ActorInMsg < A > > ,
107136 /// Cancellation token to stop the Actor
108137 cancellation_token : CancellationToken ,
138+ /// JoinHandle for waiting on actor completion
139+ join_handle : Arc < Mutex < Option < ActorJoinHandle > > > ,
109140}
110141
111142impl < A : Actor > Clone for ActorRef < A > {
112143 fn clone ( & self ) -> Self {
113144 Self {
114145 tx : self . tx . clone ( ) ,
115146 cancellation_token : self . cancellation_token . clone ( ) ,
147+ join_handle : self . join_handle . clone ( ) ,
116148 }
117149 }
118150}
@@ -121,9 +153,11 @@ impl<A: Actor> ActorRef<A> {
121153 fn new ( actor : A ) -> Self {
122154 let ( tx, mut rx) = mpsc:: channel :: < ActorInMsg < A > > ( ) ;
123155 let cancellation_token = CancellationToken :: new ( ) ;
156+ let join_handle = Arc :: new ( Mutex :: new ( None ) ) ;
124157 let handle = ActorRef {
125158 tx,
126159 cancellation_token,
160+ join_handle : join_handle. clone ( ) ,
127161 } ;
128162 let handle_clone = handle. clone ( ) ;
129163 let inner_future = async move {
@@ -136,47 +170,62 @@ impl<A: Actor> ActorRef<A> {
136170 // Optionally warn if the Actor future blocks for too much time
137171 let inner_future = warn_on_block:: WarnOnBlocking :: new ( inner_future) ;
138172
139- // Ignore the JoinHandle for now. Maybe we'll use it in the future
140- let _join_handle = rt:: spawn ( inner_future) ;
173+ let task_handle = rt:: spawn ( inner_future) ;
174+ let mut guard = join_handle
175+ . lock ( )
176+ . unwrap_or_else ( |poisoned| poisoned. into_inner ( ) ) ;
177+ * guard = Some ( ActorJoinHandle :: Task ( task_handle) ) ;
141178
142179 handle_clone
143180 }
144181
145182 fn new_blocking ( actor : A ) -> Self {
146183 let ( tx, mut rx) = mpsc:: channel :: < ActorInMsg < A > > ( ) ;
147184 let cancellation_token = CancellationToken :: new ( ) ;
185+ let join_handle = Arc :: new ( Mutex :: new ( None ) ) ;
148186 let handle = ActorRef {
149187 tx,
150188 cancellation_token,
189+ join_handle : join_handle. clone ( ) ,
151190 } ;
152191 let handle_clone = handle. clone ( ) ;
153- // Ignore the JoinHandle for now. Maybe we'll use it in the future
154- let _join_handle = rt:: spawn_blocking ( || {
192+ let task_handle = rt:: spawn_blocking ( || {
155193 rt:: block_on ( async move {
156194 if let Err ( error) = actor. run ( & handle, & mut rx) . await {
157195 tracing:: trace!( %error, "Actor crashed" )
158196 } ;
159197 } )
160198 } ) ;
199+ let mut guard = join_handle
200+ . lock ( )
201+ . unwrap_or_else ( |poisoned| poisoned. into_inner ( ) ) ;
202+ * guard = Some ( ActorJoinHandle :: Task ( task_handle) ) ;
203+
161204 handle_clone
162205 }
163206
164207 fn new_on_thread ( actor : A ) -> Self {
165208 let ( tx, mut rx) = mpsc:: channel :: < ActorInMsg < A > > ( ) ;
166209 let cancellation_token = CancellationToken :: new ( ) ;
210+ let join_handle = Arc :: new ( Mutex :: new ( None ) ) ;
167211 let handle = ActorRef {
168212 tx,
169213 cancellation_token,
214+ join_handle : join_handle. clone ( ) ,
170215 } ;
171216 let handle_clone = handle. clone ( ) ;
172- // Ignore the JoinHandle for now. Maybe we'll use it in the future
173- let _join_handle = threads:: spawn ( || {
217+ let thread_handle = threads:: spawn ( || {
174218 threads:: block_on ( async move {
175219 if let Err ( error) = actor. run ( & handle, & mut rx) . await {
176220 tracing:: trace!( %error, "Actor crashed" )
177221 } ;
178222 } )
179223 } ) ;
224+ let mut guard = join_handle
225+ . lock ( )
226+ . unwrap_or_else ( |poisoned| poisoned. into_inner ( ) ) ;
227+ * guard = Some ( ActorJoinHandle :: Thread ( thread_handle) ) ;
228+
180229 handle_clone
181230 }
182231
@@ -220,9 +269,19 @@ impl<A: Actor> ActorRef<A> {
220269 /// Waits for the actor to stop.
221270 ///
222271 /// This method returns a future that completes when the actor has finished
223- /// processing and exited its main loop.
272+ /// processing and exited its main loop. Can only be called once; subsequent
273+ /// calls return immediately.
224274 pub async fn join ( & self ) {
225- self . cancellation_token . cancelled ( ) . await
275+ let handle = {
276+ let mut guard = self
277+ . join_handle
278+ . lock ( )
279+ . unwrap_or_else ( |poisoned| poisoned. into_inner ( ) ) ;
280+ guard. take ( )
281+ } ;
282+ if let Some ( h) = handle {
283+ h. join ( ) . await ;
284+ }
226285 }
227286}
228287
0 commit comments