Skip to content

Commit cbfd5e4

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Mesh slice supervision (#824)
Summary: Suppose we have a ProcMesh with 4 gpus. On this mesh we spawn a ActorMesh A, and an ActorMesh B. We create a slice of ActorMesh A SliceA_1 containing only gpu 4 and a slice SliceA_2 containing only gpu 1. If the Actor A on gpu 4 crashes we should have the following health states: - ActorMesh A is unhealthy (contains A gpu=4) - SliceA_1 is unhealthy (contains A gpu=4) - SliceA_2 is healthy (does not contain A gpu=4) - ActorMesh B is healthy (contains gpu=4 but not Actor A) Implementation: 1. An `Arc<DashMap<usize, ActorSupervisionEvent>>` is created in order to track all `Actor` crashes. This is necessary because the `UnhealthyState` only tracks the latest event. This `Arc` will be called `crashed_actors` 2. `crashed_actors` is passed into the monitor loop and updated when an `Actor` crashes 3. Before casting to a `PythonActorMeshRef`, we will check `crashed_actors` and return a `SupervisionError` if containing the first rank it finds int `crashed_actors` 4. When it comes to monitoring supervision events through the `PortListener`, the `PythonTask` will loop and skip over any `ActorSupervisionEvents` that do not affect ranks outside of the mesh Reviewed By: pzhan9 Differential Revision: D79821712
1 parent 6103677 commit cbfd5e4

File tree

2 files changed

+167
-21
lines changed

2 files changed

+167
-21
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 109 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::error::Error;
1010
use std::future::Future;
1111
use std::pin::Pin;
1212
use std::sync::Arc;
13+
use std::sync::Weak;
1314

1415
use futures::future::FutureExt;
1516
use futures::future::Shared;
@@ -21,6 +22,7 @@ use hyperactor_mesh::Mesh;
2122
use hyperactor_mesh::RootActorMesh;
2223
use hyperactor_mesh::actor_mesh::ActorMesh;
2324
use hyperactor_mesh::actor_mesh::ActorSupervisionEvents;
25+
use hyperactor_mesh::dashmap::DashMap;
2426
use hyperactor_mesh::reference::ActorMeshRef;
2527
use hyperactor_mesh::sel;
2628
use hyperactor_mesh::shared_cell::SharedCell;
@@ -168,9 +170,8 @@ pub(crate) struct PythonActorMeshImpl {
168170
inner: SharedCell<RootActorMesh<'static, PythonActor>>,
169171
client: PyMailbox,
170172
_keepalive: Keepalive,
171-
unhealthy_event: Arc<std::sync::Mutex<Unhealthy<ActorSupervisionEvent>>>,
172-
user_monitor_sender: tokio::sync::broadcast::Sender<Option<ActorSupervisionEvent>>,
173173
monitor: tokio::task::JoinHandle<()>,
174+
health_state: Arc<RootHealthState>,
174175
}
175176

176177
impl PythonActorMeshImpl {
@@ -184,41 +185,46 @@ impl PythonActorMeshImpl {
184185
) -> Self {
185186
let (user_monitor_sender, _) =
186187
tokio::sync::broadcast::channel::<Option<ActorSupervisionEvent>>(1);
187-
let unhealthy_event = Arc::new(std::sync::Mutex::new(Unhealthy::SoFarSoGood));
188-
let monitor = tokio::spawn(PythonActorMeshImpl::actor_mesh_monitor(
189-
events,
190-
user_monitor_sender.clone(),
191-
Arc::clone(&unhealthy_event),
192-
));
188+
let health_state = Arc::new(RootHealthState {
189+
user_monitor_sender,
190+
unhealthy_event: std::sync::Mutex::new(Unhealthy::SoFarSoGood),
191+
crashed_ranks: DashMap::new(),
192+
});
193+
let monitor = tokio::spawn(Self::actor_mesh_monitor(events, health_state.clone()));
193194
PythonActorMeshImpl {
194195
inner,
195196
client,
196197
_keepalive: keepalive,
197-
unhealthy_event,
198-
user_monitor_sender,
199198
monitor,
199+
health_state,
200200
}
201201
}
202202
/// Monitor of the actor mesh. It processes supervision errors for the mesh, and keeps mesh
203203
/// health state up to date.
204204
async fn actor_mesh_monitor(
205205
mut events: ActorSupervisionEvents,
206-
user_sender: tokio::sync::broadcast::Sender<Option<ActorSupervisionEvent>>,
207-
unhealthy_event: Arc<std::sync::Mutex<Unhealthy<ActorSupervisionEvent>>>,
206+
health_state: Arc<RootHealthState>,
208207
) {
209208
loop {
210209
let event = events.next().await;
211210
tracing::debug!("actor_mesh_monitor received supervision event: {event:?}");
212-
let mut inner_unhealthy_event = unhealthy_event.lock().unwrap();
213-
match &event {
214-
None => *inner_unhealthy_event = Unhealthy::StreamClosed,
215-
Some(event) => *inner_unhealthy_event = Unhealthy::Crashed(event.clone()),
211+
{
212+
let mut inner_unhealthy_event = health_state.unhealthy_event.lock().unwrap();
213+
match &event {
214+
None => *inner_unhealthy_event = Unhealthy::StreamClosed,
215+
Some(event) => {
216+
health_state
217+
.crashed_ranks
218+
.insert(event.actor_id.rank(), event.clone());
219+
*inner_unhealthy_event = Unhealthy::Crashed(event.clone())
220+
}
221+
}
216222
}
217223

218224
// Ignore the sender error when there is no receiver,
219225
// which happens when there is no active requests to this
220226
// mesh.
221-
let ret = user_sender.send(event.clone());
227+
let ret = health_state.user_monitor_sender.send(event.clone());
222228
tracing::debug!("actor_mesh_monitor user_sender send: {ret:?}");
223229

224230
if event.is_none() {
@@ -236,13 +242,18 @@ impl PythonActorMeshImpl {
236242

237243
fn bind(&self) -> PyResult<PythonActorMeshRef> {
238244
let mesh = self.try_inner()?;
239-
Ok(PythonActorMeshRef { inner: mesh.bind() })
245+
let root_health_state = Some(Arc::downgrade(&self.health_state));
246+
Ok(PythonActorMeshRef {
247+
inner: mesh.bind(),
248+
root_health_state,
249+
})
240250
}
241251
}
242252

243253
impl ActorMeshProtocol for PythonActorMeshImpl {
244254
fn cast(&self, message: PythonMessage, selection: Selection, mailbox: Mailbox) -> PyResult<()> {
245255
let unhealthy_event = self
256+
.health_state
246257
.unhealthy_event
247258
.lock()
248259
.expect("failed to acquire unhealthy_event lock");
@@ -268,7 +279,7 @@ impl ActorMeshProtocol for PythonActorMeshImpl {
268279
Ok(())
269280
}
270281
fn supervision_event(&self) -> PyResult<Option<PyShared>> {
271-
let mut receiver = self.user_monitor_sender.subscribe();
282+
let mut receiver = self.health_state.user_monitor_sender.subscribe();
272283
PyPythonTask::new(async move {
273284
let event = receiver.recv().await;
274285
let event = match event {
@@ -313,6 +324,7 @@ impl ActorMeshProtocol for PythonActorMeshImpl {
313324
impl PythonActorMeshImpl {
314325
fn get_supervision_event(&self) -> PyResult<Option<PyActorSupervisionEvent>> {
315326
let unhealthy_event = self
327+
.health_state
316328
.unhealthy_event
317329
.lock()
318330
.expect("failed to acquire unhealthy_event lock");
@@ -351,13 +363,62 @@ impl PythonActorMeshImpl {
351363
}
352364
}
353365

366+
#[derive(Debug)]
367+
struct RootHealthState {
368+
user_monitor_sender: tokio::sync::broadcast::Sender<Option<ActorSupervisionEvent>>,
369+
unhealthy_event: std::sync::Mutex<Unhealthy<ActorSupervisionEvent>>,
370+
crashed_ranks: DashMap<usize, ActorSupervisionEvent>,
371+
}
372+
354373
#[derive(Debug, Serialize, Deserialize)]
355374
struct PythonActorMeshRef {
356375
inner: ActorMeshRef<PythonActor>,
376+
#[serde(skip)]
377+
// If the reference has been serialized and sent over the wire
378+
// we no longer have access to the underlying mesh's state
379+
root_health_state: Option<Weak<RootHealthState>>,
357380
}
358381

359382
impl ActorMeshProtocol for PythonActorMeshRef {
360383
fn cast(&self, message: PythonMessage, selection: Selection, client: Mailbox) -> PyResult<()> {
384+
if let Some(root_health_state) = &self.root_health_state {
385+
// MeshRef has not been serialized and sent over the wire so we can actually validate
386+
// if the underlying mesh still exists
387+
if let Some(root_health_state) = root_health_state.upgrade() {
388+
// iterate through all crashed ranks in the root mesh and take first rank
389+
// that is in the sliced mesh
390+
match self.inner.shape().slice().iter().find_map(|rank| {
391+
root_health_state
392+
.crashed_ranks
393+
.get(&rank)
394+
.map(|entry| entry.value().clone())
395+
}) {
396+
Some(event) => {
397+
return Err(SupervisionError::new_err(format!(
398+
"Actor {:?} is unhealthy with reason: {}",
399+
event.actor_id, event.actor_status
400+
)));
401+
}
402+
None => {
403+
if matches!(
404+
&*root_health_state
405+
.unhealthy_event
406+
.lock()
407+
.unwrap_or_else(|e| e.into_inner()),
408+
Unhealthy::StreamClosed
409+
) {
410+
return Err(SupervisionError::new_err(
411+
"actor mesh is stopped due to proc mesh shutdown".to_string(),
412+
));
413+
}
414+
}
415+
}
416+
} else {
417+
return Err(SupervisionError::new_err(
418+
"actor mesh is stopped due to proc mesh shutdown".to_string(),
419+
));
420+
}
421+
}
361422
self.inner
362423
.cast(&client, selection, message.clone())
363424
.map_err(|err| PyException::new_err(err.to_string()))?;
@@ -369,9 +430,36 @@ impl ActorMeshProtocol for PythonActorMeshRef {
369430
.inner
370431
.new_with_shape(shape.get_inner().clone())
371432
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
372-
Ok(Box::new(Self { inner: sliced }))
433+
Ok(Box::new(Self {
434+
inner: sliced,
435+
root_health_state: self.root_health_state.clone(),
436+
}))
373437
}
374438

439+
fn supervision_event(&self) -> PyResult<Option<PyShared>> {
440+
match self.root_health_state.as_ref().and_then(|x| x.upgrade()) {
441+
Some(root_health_state) => {
442+
let mut receiver = root_health_state.user_monitor_sender.subscribe();
443+
let slice = self.inner.shape().slice().clone();
444+
PyPythonTask::new(async move {
445+
while let Ok(Some(event)) = receiver.recv().await {
446+
if slice.iter().any(|rank| rank == event.actor_id.rank()) {
447+
return Ok(PyErr::new::<SupervisionError, _>(format!(
448+
"Actor {:?} exited because of the following reason: {}",
449+
event.actor_id, event.actor_status
450+
)));
451+
}
452+
}
453+
Ok(PyErr::new::<SupervisionError, _>(format!(
454+
"Actor {:?} exited because of the following reason: actor mesh is stopped due to proc mesh shutdown",
455+
id!(default[0].actor[0])
456+
)))
457+
})
458+
.map(|mut x| x.spawn().map(Some))?
459+
}
460+
None => Ok(None),
461+
}
462+
}
375463
fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
376464
let bytes =
377465
bincode::serialize(self).map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
@@ -504,7 +592,7 @@ impl ActorMeshProtocol for AsyncActorMesh {
504592
let mesh = self.mesh.clone();
505593
Ok(Box::new(AsyncActorMesh::new(
506594
self.queue.clone(),
507-
false,
595+
self.supervised,
508596
async { Ok(mesh.await?.new_with_shape(shape)?) },
509597
)))
510598
}

python/tests/test_actor_error.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,61 @@ async def test_supervision_with_sending_error():
685685
await actor_mesh.check.call()
686686
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason"):
687687
await actor_mesh.check_with_payload.call(payload="a")
688+
689+
690+
async def test_slice_supervision() -> None:
691+
pm = await proc_mesh(gpus=4)
692+
healthy_mesh = await pm.spawn("healthy", HealthyActor)
693+
error_mesh = await pm.spawn("error", ErrorActor)
694+
slice_1 = error_mesh.slice(gpus=slice(2, 4))
695+
slice_2 = error_mesh.slice(gpus=2)
696+
slice_3 = error_mesh.slice(gpus=3)
697+
698+
# Trigger supervision error on gpus=3
699+
with pytest.raises(SupervisionError, match="did not handle supervision event"):
700+
await slice_3.fail_with_supervision_error.call()
701+
702+
# Mesh containing all gpus is unhealthy
703+
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"):
704+
await error_mesh.check.call()
705+
706+
# Slice containing only gpus=3 is unhealthy
707+
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"):
708+
await slice_3.check.call()
709+
710+
# Slice containing gpus=3 is unhealthy
711+
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"):
712+
await slice_1.check.call()
713+
714+
# Slice not containing gpus=3 is healthy
715+
check = await slice_2.check.call()
716+
for _, item in check.items():
717+
assert item == "this is a healthy check"
718+
719+
# Other actor mesh on the same proc mesh is healthy
720+
check = await healthy_mesh.check.call()
721+
for _, item in check.items():
722+
assert item == "this is a healthy check"
723+
724+
725+
async def test_mesh_slices_inherit_parent_errors() -> None:
726+
pm = await proc_mesh(gpus=4)
727+
error_mesh = await pm.spawn("error", ErrorActor)
728+
slice_1 = error_mesh.slice(gpus=slice(2, 4))
729+
730+
# Trigger supervision error on gpus=2, 3, 4
731+
with pytest.raises(SupervisionError):
732+
await slice_1.fail_with_supervision_error.call()
733+
734+
# Newly created slice containing gpu=3 is unhealthy
735+
slice_2 = error_mesh.slice(gpus=3)
736+
with pytest.raises(SupervisionError):
737+
await slice_2.check.call()
738+
739+
# Newly created slice containing gpu=1 is healthy
740+
slice_3 = error_mesh.slice(gpus=1)
741+
check = await slice_3.check.call()
742+
for _, item in check.items():
743+
assert item == "this is a healthy check"
744+
745+
await pm.stop()

0 commit comments

Comments
 (0)