@@ -10,6 +10,7 @@ use std::error::Error;
10
10
use std:: future:: Future ;
11
11
use std:: pin:: Pin ;
12
12
use std:: sync:: Arc ;
13
+ use std:: sync:: Weak ;
13
14
14
15
use futures:: future:: FutureExt ;
15
16
use futures:: future:: Shared ;
@@ -21,6 +22,7 @@ use hyperactor_mesh::Mesh;
21
22
use hyperactor_mesh:: RootActorMesh ;
22
23
use hyperactor_mesh:: actor_mesh:: ActorMesh ;
23
24
use hyperactor_mesh:: actor_mesh:: ActorSupervisionEvents ;
25
+ use hyperactor_mesh:: dashmap:: DashMap ;
24
26
use hyperactor_mesh:: reference:: ActorMeshRef ;
25
27
use hyperactor_mesh:: sel;
26
28
use hyperactor_mesh:: shared_cell:: SharedCell ;
@@ -168,9 +170,8 @@ pub(crate) struct PythonActorMeshImpl {
168
170
inner : SharedCell < RootActorMesh < ' static , PythonActor > > ,
169
171
client : PyMailbox ,
170
172
_keepalive : Keepalive ,
171
- unhealthy_event : Arc < std:: sync:: Mutex < Unhealthy < ActorSupervisionEvent > > > ,
172
- user_monitor_sender : tokio:: sync:: broadcast:: Sender < Option < ActorSupervisionEvent > > ,
173
173
monitor : tokio:: task:: JoinHandle < ( ) > ,
174
+ health_state : Arc < RootHealthState > ,
174
175
}
175
176
176
177
impl PythonActorMeshImpl {
@@ -184,41 +185,46 @@ impl PythonActorMeshImpl {
184
185
) -> Self {
185
186
let ( user_monitor_sender, _) =
186
187
tokio:: sync:: broadcast:: channel :: < Option < ActorSupervisionEvent > > ( 1 ) ;
187
- let unhealthy_event = Arc :: new ( std :: sync :: Mutex :: new ( Unhealthy :: SoFarSoGood ) ) ;
188
- let monitor = tokio :: spawn ( PythonActorMeshImpl :: actor_mesh_monitor (
189
- events ,
190
- user_monitor_sender . clone ( ) ,
191
- Arc :: clone ( & unhealthy_event ) ,
192
- ) ) ;
188
+ let health_state = Arc :: new ( RootHealthState {
189
+ user_monitor_sender ,
190
+ unhealthy_event : std :: sync :: Mutex :: new ( Unhealthy :: SoFarSoGood ) ,
191
+ crashed_ranks : DashMap :: new ( ) ,
192
+ } ) ;
193
+ let monitor = tokio :: spawn ( Self :: actor_mesh_monitor ( events , health_state . clone ( ) ) ) ;
193
194
PythonActorMeshImpl {
194
195
inner,
195
196
client,
196
197
_keepalive : keepalive,
197
- unhealthy_event,
198
- user_monitor_sender,
199
198
monitor,
199
+ health_state,
200
200
}
201
201
}
202
202
/// Monitor of the actor mesh. It processes supervision errors for the mesh, and keeps mesh
203
203
/// health state up to date.
204
204
async fn actor_mesh_monitor (
205
205
mut events : ActorSupervisionEvents ,
206
- user_sender : tokio:: sync:: broadcast:: Sender < Option < ActorSupervisionEvent > > ,
207
- unhealthy_event : Arc < std:: sync:: Mutex < Unhealthy < ActorSupervisionEvent > > > ,
206
+ health_state : Arc < RootHealthState > ,
208
207
) {
209
208
loop {
210
209
let event = events. next ( ) . await ;
211
210
tracing:: debug!( "actor_mesh_monitor received supervision event: {event:?}" ) ;
212
- let mut inner_unhealthy_event = unhealthy_event. lock ( ) . unwrap ( ) ;
213
- match & event {
214
- None => * inner_unhealthy_event = Unhealthy :: StreamClosed ,
215
- Some ( event) => * inner_unhealthy_event = Unhealthy :: Crashed ( event. clone ( ) ) ,
211
+ {
212
+ let mut inner_unhealthy_event = health_state. unhealthy_event . lock ( ) . unwrap ( ) ;
213
+ match & event {
214
+ None => * inner_unhealthy_event = Unhealthy :: StreamClosed ,
215
+ Some ( event) => {
216
+ health_state
217
+ . crashed_ranks
218
+ . insert ( event. actor_id . rank ( ) , event. clone ( ) ) ;
219
+ * inner_unhealthy_event = Unhealthy :: Crashed ( event. clone ( ) )
220
+ }
221
+ }
216
222
}
217
223
218
224
// Ignore the sender error when there is no receiver,
219
225
// which happens when there is no active requests to this
220
226
// mesh.
221
- let ret = user_sender . send ( event. clone ( ) ) ;
227
+ let ret = health_state . user_monitor_sender . send ( event. clone ( ) ) ;
222
228
tracing:: debug!( "actor_mesh_monitor user_sender send: {ret:?}" ) ;
223
229
224
230
if event. is_none ( ) {
@@ -236,13 +242,18 @@ impl PythonActorMeshImpl {
236
242
237
243
fn bind ( & self ) -> PyResult < PythonActorMeshRef > {
238
244
let mesh = self . try_inner ( ) ?;
239
- Ok ( PythonActorMeshRef { inner : mesh. bind ( ) } )
245
+ let root_health_state = Some ( Arc :: downgrade ( & self . health_state ) ) ;
246
+ Ok ( PythonActorMeshRef {
247
+ inner : mesh. bind ( ) ,
248
+ root_health_state,
249
+ } )
240
250
}
241
251
}
242
252
243
253
impl ActorMeshProtocol for PythonActorMeshImpl {
244
254
fn cast ( & self , message : PythonMessage , selection : Selection , mailbox : Mailbox ) -> PyResult < ( ) > {
245
255
let unhealthy_event = self
256
+ . health_state
246
257
. unhealthy_event
247
258
. lock ( )
248
259
. expect ( "failed to acquire unhealthy_event lock" ) ;
@@ -268,7 +279,7 @@ impl ActorMeshProtocol for PythonActorMeshImpl {
268
279
Ok ( ( ) )
269
280
}
270
281
fn supervision_event ( & self ) -> PyResult < Option < PyShared > > {
271
- let mut receiver = self . user_monitor_sender . subscribe ( ) ;
282
+ let mut receiver = self . health_state . user_monitor_sender . subscribe ( ) ;
272
283
PyPythonTask :: new ( async move {
273
284
let event = receiver. recv ( ) . await ;
274
285
let event = match event {
@@ -313,6 +324,7 @@ impl ActorMeshProtocol for PythonActorMeshImpl {
313
324
impl PythonActorMeshImpl {
314
325
fn get_supervision_event ( & self ) -> PyResult < Option < PyActorSupervisionEvent > > {
315
326
let unhealthy_event = self
327
+ . health_state
316
328
. unhealthy_event
317
329
. lock ( )
318
330
. expect ( "failed to acquire unhealthy_event lock" ) ;
@@ -351,13 +363,62 @@ impl PythonActorMeshImpl {
351
363
}
352
364
}
353
365
366
+ #[ derive( Debug ) ]
367
+ struct RootHealthState {
368
+ user_monitor_sender : tokio:: sync:: broadcast:: Sender < Option < ActorSupervisionEvent > > ,
369
+ unhealthy_event : std:: sync:: Mutex < Unhealthy < ActorSupervisionEvent > > ,
370
+ crashed_ranks : DashMap < usize , ActorSupervisionEvent > ,
371
+ }
372
+
354
373
#[ derive( Debug , Serialize , Deserialize ) ]
355
374
struct PythonActorMeshRef {
356
375
inner : ActorMeshRef < PythonActor > ,
376
+ #[ serde( skip) ]
377
+ // If the reference has been serialized and sent over the wire
378
+ // we no longer have access to the underlying mesh's state
379
+ root_health_state : Option < Weak < RootHealthState > > ,
357
380
}
358
381
359
382
impl ActorMeshProtocol for PythonActorMeshRef {
360
383
fn cast ( & self , message : PythonMessage , selection : Selection , client : Mailbox ) -> PyResult < ( ) > {
384
+ if let Some ( root_health_state) = & self . root_health_state {
385
+ // MeshRef has not been serialized and sent over the wire so we can actually validate
386
+ // if the underlying mesh still exists
387
+ if let Some ( root_health_state) = root_health_state. upgrade ( ) {
388
+ // iterate through all crashed ranks in the root mesh and take first rank
389
+ // that is in the sliced mesh
390
+ match self . inner . shape ( ) . slice ( ) . iter ( ) . find_map ( |rank| {
391
+ root_health_state
392
+ . crashed_ranks
393
+ . get ( & rank)
394
+ . map ( |entry| entry. value ( ) . clone ( ) )
395
+ } ) {
396
+ Some ( event) => {
397
+ return Err ( SupervisionError :: new_err ( format ! (
398
+ "Actor {:?} is unhealthy with reason: {}" ,
399
+ event. actor_id, event. actor_status
400
+ ) ) ) ;
401
+ }
402
+ None => {
403
+ if matches ! (
404
+ & * root_health_state
405
+ . unhealthy_event
406
+ . lock( )
407
+ . unwrap_or_else( |e| e. into_inner( ) ) ,
408
+ Unhealthy :: StreamClosed
409
+ ) {
410
+ return Err ( SupervisionError :: new_err (
411
+ "actor mesh is stopped due to proc mesh shutdown" . to_string ( ) ,
412
+ ) ) ;
413
+ }
414
+ }
415
+ }
416
+ } else {
417
+ return Err ( SupervisionError :: new_err (
418
+ "actor mesh is stopped due to proc mesh shutdown" . to_string ( ) ,
419
+ ) ) ;
420
+ }
421
+ }
361
422
self . inner
362
423
. cast ( & client, selection, message. clone ( ) )
363
424
. map_err ( |err| PyException :: new_err ( err. to_string ( ) ) ) ?;
@@ -369,9 +430,36 @@ impl ActorMeshProtocol for PythonActorMeshRef {
369
430
. inner
370
431
. new_with_shape ( shape. get_inner ( ) . clone ( ) )
371
432
. map_err ( |e| PyErr :: new :: < PyValueError , _ > ( e. to_string ( ) ) ) ?;
372
- Ok ( Box :: new ( Self { inner : sliced } ) )
433
+ Ok ( Box :: new ( Self {
434
+ inner : sliced,
435
+ root_health_state : self . root_health_state . clone ( ) ,
436
+ } ) )
373
437
}
374
438
439
+ fn supervision_event ( & self ) -> PyResult < Option < PyShared > > {
440
+ match self . root_health_state . as_ref ( ) . and_then ( |x| x. upgrade ( ) ) {
441
+ Some ( root_health_state) => {
442
+ let mut receiver = root_health_state. user_monitor_sender . subscribe ( ) ;
443
+ let slice = self . inner . shape ( ) . slice ( ) . clone ( ) ;
444
+ PyPythonTask :: new ( async move {
445
+ while let Ok ( Some ( event) ) = receiver. recv ( ) . await {
446
+ if slice. iter ( ) . any ( |rank| rank == event. actor_id . rank ( ) ) {
447
+ return Ok ( PyErr :: new :: < SupervisionError , _ > ( format ! (
448
+ "Actor {:?} exited because of the following reason: {}" ,
449
+ event. actor_id, event. actor_status
450
+ ) ) ) ;
451
+ }
452
+ }
453
+ Ok ( PyErr :: new :: < SupervisionError , _ > ( format ! (
454
+ "Actor {:?} exited because of the following reason: actor mesh is stopped due to proc mesh shutdown" ,
455
+ id!( default [ 0 ] . actor[ 0 ] )
456
+ ) ) )
457
+ } )
458
+ . map ( |mut x| x. spawn ( ) . map ( Some ) ) ?
459
+ }
460
+ None => Ok ( None ) ,
461
+ }
462
+ }
375
463
fn __reduce__ < ' py > ( & self , py : Python < ' py > ) -> PyResult < ( Bound < ' py , PyAny > , Bound < ' py , PyAny > ) > {
376
464
let bytes =
377
465
bincode:: serialize ( self ) . map_err ( |e| PyErr :: new :: < PyValueError , _ > ( e. to_string ( ) ) ) ?;
@@ -504,7 +592,7 @@ impl ActorMeshProtocol for AsyncActorMesh {
504
592
let mesh = self . mesh . clone ( ) ;
505
593
Ok ( Box :: new ( AsyncActorMesh :: new (
506
594
self . queue . clone ( ) ,
507
- false ,
595
+ self . supervised ,
508
596
async { Ok ( mesh. await ?. new_with_shape ( shape) ?) } ,
509
597
) ) )
510
598
}
0 commit comments