Skip to content

Commit 456db48

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Mesh slice supervision (#824)
Summary: Pull Request resolved: #824 Suppose we have a ProcMesh with 4 gpus. On this mesh we spawn a ActorMesh A, and an ActorMesh B. We create a slice of ActorMesh A SliceA_1 containing only gpu 4 and a slice SliceA_2 containing only gpu 1. If the Actor A on gpu 4 crashes we should have the following health states: - ActorMesh A is unhealthy (contains A gpu=4) - SliceA_1 is unhealthy (contains A gpu=4) - SliceA_2 is healthy (does not contain A gpu=4) - ActorMesh B is healthy (contains gpu=4 but not Actor A) Implementation: 1. All supervision event streams are created when a RootActorMesh is spawned. A tx-rx pair are created and the tx is inserted into a map acting as a router. 2. The router now holds onto a Vec of senders instead of a single sender. For each Actor mesh in the router, we can call bind to create another tx-rx pair. The router will manage the tx, sending a message using every tx for a given Actor mesh name every time there is a supervision event. The rx is returned to be used by mesh slices. 3. The spawned RootActorMesh gets a copy of the Arc holding the router so that mesh slices can bind to it 4. PythonActorMeshes contain a monitor which is just a loop that listens from the next supervision event from a stream. If a ::Crashed event comes in, we will update an Arc<> keeping track of the health state. This monitor will now also take in the shape of the mesh it is monitoring, and only update the health state to ::Crashed, if the crashed Actor is within the shape. 5. When a PythonActorMesh is sliced, a PythonActorMeshRef is created. We will add a monitor to PythonActorMeshRefs. It is an Option<> because if it is ever serialized and deserialized, we can no longer monitor it 6. When we cast to an PythonActorMeshRef, we will first check the health state and return a SupervisionError if the mesh is unhealthy Differential Revision: D79821712
1 parent 6103677 commit 456db48

File tree

2 files changed

+160
-5
lines changed

2 files changed

+160
-5
lines changed

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use hyperactor_mesh::Mesh;
2121
use hyperactor_mesh::RootActorMesh;
2222
use hyperactor_mesh::actor_mesh::ActorMesh;
2323
use hyperactor_mesh::actor_mesh::ActorSupervisionEvents;
24+
use hyperactor_mesh::dashmap::DashMap;
2425
use hyperactor_mesh::reference::ActorMeshRef;
2526
use hyperactor_mesh::sel;
2627
use hyperactor_mesh::shared_cell::SharedCell;
@@ -171,6 +172,8 @@ pub(crate) struct PythonActorMeshImpl {
171172
unhealthy_event: Arc<std::sync::Mutex<Unhealthy<ActorSupervisionEvent>>>,
172173
user_monitor_sender: tokio::sync::broadcast::Sender<Option<ActorSupervisionEvent>>,
173174
monitor: tokio::task::JoinHandle<()>,
175+
/// Needed for slices to track health as ActorMeshes only track the latest supervision event
176+
crashed_ranks: Arc<DashMap<usize, ActorSupervisionEvent>>,
174177
}
175178

176179
impl PythonActorMeshImpl {
@@ -185,10 +188,12 @@ impl PythonActorMeshImpl {
185188
let (user_monitor_sender, _) =
186189
tokio::sync::broadcast::channel::<Option<ActorSupervisionEvent>>(1);
187190
let unhealthy_event = Arc::new(std::sync::Mutex::new(Unhealthy::SoFarSoGood));
188-
let monitor = tokio::spawn(PythonActorMeshImpl::actor_mesh_monitor(
191+
let crashed_ranks = Arc::new(DashMap::new());
192+
let monitor = tokio::spawn(Self::actor_mesh_monitor(
189193
events,
190194
user_monitor_sender.clone(),
191195
Arc::clone(&unhealthy_event),
196+
crashed_ranks.clone(),
192197
));
193198
PythonActorMeshImpl {
194199
inner,
@@ -197,6 +202,7 @@ impl PythonActorMeshImpl {
197202
unhealthy_event,
198203
user_monitor_sender,
199204
monitor,
205+
crashed_ranks,
200206
}
201207
}
202208
/// Monitor of the actor mesh. It processes supervision errors for the mesh, and keeps mesh
@@ -205,14 +211,19 @@ impl PythonActorMeshImpl {
205211
mut events: ActorSupervisionEvents,
206212
user_sender: tokio::sync::broadcast::Sender<Option<ActorSupervisionEvent>>,
207213
unhealthy_event: Arc<std::sync::Mutex<Unhealthy<ActorSupervisionEvent>>>,
214+
crashed_ranks: Arc<DashMap<usize, ActorSupervisionEvent>>,
208215
) {
209216
loop {
210217
let event = events.next().await;
211218
tracing::debug!("actor_mesh_monitor received supervision event: {event:?}");
212219
let mut inner_unhealthy_event = unhealthy_event.lock().unwrap();
213220
match &event {
214221
None => *inner_unhealthy_event = Unhealthy::StreamClosed,
215-
Some(event) => *inner_unhealthy_event = Unhealthy::Crashed(event.clone()),
222+
Some(event) => {
223+
// Ignore if the crashed actor is not a part of the mesh.
224+
crashed_ranks.insert(event.actor_id.rank(), event.clone());
225+
*inner_unhealthy_event = Unhealthy::Crashed(event.clone())
226+
}
216227
}
217228

218229
// Ignore the sender error when there is no receiver,
@@ -236,7 +247,15 @@ impl PythonActorMeshImpl {
236247

237248
fn bind(&self) -> PyResult<PythonActorMeshRef> {
238249
let mesh = self.try_inner()?;
239-
Ok(PythonActorMeshRef { inner: mesh.bind() })
250+
let root_health_state = Some(RootHealthState {
251+
user_monitor_sender: self.user_monitor_sender.clone(),
252+
unhealthy_event: Arc::clone(&self.unhealthy_event),
253+
crashed_ranks: Arc::clone(&self.crashed_ranks),
254+
});
255+
Ok(PythonActorMeshRef {
256+
inner: mesh.bind(),
257+
root_health_state,
258+
})
240259
}
241260
}
242261

@@ -351,13 +370,59 @@ impl PythonActorMeshImpl {
351370
}
352371
}
353372

373+
#[derive(Clone, Debug)]
374+
struct RootHealthState {
375+
user_monitor_sender: tokio::sync::broadcast::Sender<Option<ActorSupervisionEvent>>,
376+
unhealthy_event: Arc<std::sync::Mutex<Unhealthy<ActorSupervisionEvent>>>,
377+
crashed_ranks: Arc<DashMap<usize, ActorSupervisionEvent>>,
378+
}
379+
354380
#[derive(Debug, Serialize, Deserialize)]
355381
struct PythonActorMeshRef {
356382
inner: ActorMeshRef<PythonActor>,
383+
#[serde(skip)]
384+
// If the reference has been serialized and sent over the wire
385+
// we no longer have access to the underlying mesh's state
386+
root_health_state: Option<RootHealthState>,
357387
}
358388

359389
impl ActorMeshProtocol for PythonActorMeshRef {
360390
fn cast(&self, message: PythonMessage, selection: Selection, client: Mailbox) -> PyResult<()> {
391+
match &self.root_health_state {
392+
Some(RootHealthState {
393+
crashed_ranks,
394+
unhealthy_event,
395+
..
396+
}) => {
397+
// iterate through all crashed ranks in the root mesh and take first rank
398+
// that is in the sliced mesh
399+
match self
400+
.inner
401+
.shape()
402+
.slice()
403+
.iter()
404+
.find_map(|rank| crashed_ranks.get(&rank).map(|entry| entry.value().clone()))
405+
{
406+
Some(event) => {
407+
return Err(SupervisionError::new_err(format!(
408+
"Actor {:?} is unhealthy with reason: {}",
409+
event.actor_id, event.actor_status
410+
)));
411+
}
412+
None => {
413+
if matches!(
414+
&*unhealthy_event.lock().unwrap_or_else(|e| e.into_inner()),
415+
Unhealthy::StreamClosed
416+
) {
417+
return Err(SupervisionError::new_err(
418+
"actor mesh is stopped due to proc mesh shutdown".to_string(),
419+
));
420+
}
421+
}
422+
}
423+
}
424+
None => (),
425+
}
361426
self.inner
362427
.cast(&client, selection, message.clone())
363428
.map_err(|err| PyException::new_err(err.to_string()))?;
@@ -369,9 +434,41 @@ impl ActorMeshProtocol for PythonActorMeshRef {
369434
.inner
370435
.new_with_shape(shape.get_inner().clone())
371436
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
372-
Ok(Box::new(Self { inner: sliced }))
437+
Ok(Box::new(Self {
438+
inner: sliced,
439+
root_health_state: self.root_health_state.clone(),
440+
}))
373441
}
374442

443+
fn supervision_event(&self) -> PyResult<Option<PyShared>> {
444+
match &self.root_health_state {
445+
Some(RootHealthState {
446+
user_monitor_sender,
447+
..
448+
}) => {
449+
let mut receiver = user_monitor_sender.subscribe();
450+
PyPythonTask::new(async move {
451+
let event = receiver.recv().await;
452+
let event = match event {
453+
Ok(Some(event)) => PyActorSupervisionEvent::from(event.clone()),
454+
Ok(None) | Err(_) => PyActorSupervisionEvent {
455+
// Dummy actor as placeholder to indicate the whole mesh is stopped
456+
// TODO(albertli): remove this when pushing all supervision logic to rust.
457+
actor_id: id!(default[0].actor[0]).into(),
458+
actor_status: "actor mesh is stopped due to proc mesh shutdown"
459+
.to_string(),
460+
},
461+
};
462+
Ok(PyErr::new::<SupervisionError, _>(format!(
463+
"Actor {:?} exited because of the following reason: {}",
464+
event.actor_id, event.actor_status
465+
)))
466+
})
467+
.map(|mut x| x.spawn().map(Some))?
468+
}
469+
None => Ok(None),
470+
}
471+
}
375472
fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
376473
let bytes =
377474
bincode::serialize(self).map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
@@ -504,7 +601,7 @@ impl ActorMeshProtocol for AsyncActorMesh {
504601
let mesh = self.mesh.clone();
505602
Ok(Box::new(AsyncActorMesh::new(
506603
self.queue.clone(),
507-
false,
604+
self.supervised,
508605
async { Ok(mesh.await?.new_with_shape(shape)?) },
509606
)))
510607
}

python/tests/test_actor_error.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,3 +685,61 @@ async def test_supervision_with_sending_error():
685685
await actor_mesh.check.call()
686686
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason"):
687687
await actor_mesh.check_with_payload.call(payload="a")
688+
689+
690+
async def test_slice_supervision() -> None:
691+
pm = await proc_mesh(gpus=4)
692+
healthy_mesh = await pm.spawn("healthy", HealthyActor)
693+
error_mesh = await pm.spawn("error", ErrorActor)
694+
slice_1 = error_mesh.slice(gpus=slice(2, 4))
695+
slice_2 = error_mesh.slice(gpus=2)
696+
slice_3 = error_mesh.slice(gpus=3)
697+
698+
# Trigger supervision error on gpus=3
699+
with pytest.raises(SupervisionError, match="did not handle supervision event"):
700+
await slice_3.fail_with_supervision_error.call()
701+
702+
# Mesh containing all gpus is unhealthy
703+
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"):
704+
await error_mesh.check.call()
705+
706+
# Slice containing only gpus=3 is unhealthy
707+
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"):
708+
await slice_3.check.call()
709+
710+
# Slice containing gpus=3 is unhealthy
711+
with pytest.raises(SupervisionError, match="Actor .* is unhealthy with reason:"):
712+
await slice_1.check.call()
713+
714+
# Slice not containing gpus=3 is healthy
715+
check = await slice_2.check.call()
716+
for _, item in check.items():
717+
assert item == "this is a healthy check"
718+
719+
# Other actor mesh on the same proc mesh is healthy
720+
check = await healthy_mesh.check.call()
721+
for _, item in check.items():
722+
assert item == "this is a healthy check"
723+
724+
725+
async def test_mesh_slices_inherit_parent_errors() -> None:
726+
pm = await proc_mesh(gpus=4)
727+
error_mesh = await pm.spawn("error", ErrorActor)
728+
slice_1 = error_mesh.slice(gpus=slice(2, 4))
729+
730+
# Trigger supervision error on gpus=2, 3, 4
731+
with pytest.raises(SupervisionError):
732+
await slice_1.fail_with_supervision_error.call()
733+
734+
# Newly created slice containing gpu=3 is unhealthy
735+
slice_2 = error_mesh.slice(gpus=3)
736+
with pytest.raises(SupervisionError):
737+
await slice_2.check.call()
738+
739+
# Newly created slice containing gpu=1 is healthy
740+
slice_3 = error_mesh.slice(gpus=1)
741+
check = await slice_3.check.call()
742+
for _, item in check.items():
743+
assert item == "this is a healthy check"
744+
745+
await pm.stop()

0 commit comments

Comments
 (0)