diff --git a/hyperactor_mesh/src/shared_cell.rs b/hyperactor_mesh/src/shared_cell.rs index 6cdabf9b1..fc2b7f809 100644 --- a/hyperactor_mesh/src/shared_cell.rs +++ b/hyperactor_mesh/src/shared_cell.rs @@ -14,6 +14,7 @@ use std::sync::atomic::Ordering; use async_trait::async_trait; use dashmap::DashMap; +use futures::future::join_all; use futures::future::try_join_all; use preempt_rwlock::OwnedPreemptibleRwLockReadGuard; use preempt_rwlock::PreemptibleRwLock; @@ -219,6 +220,16 @@ impl SharedCellPool { .await?; Ok(()) } + + /// Run `take` on all cells in the pool and immediately drop them or produce an error if the cell has already been taken + pub async fn discard_or_error_all(self) -> Vec> { + join_all( + self.map + .iter() + .map(|r| async move { r.value().discard().await }), + ) + .await + } } /// Trait to facilitate storing `SharedCell`s of different types in a single pool. diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 1e2837f77..3a0a22907 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -233,7 +233,9 @@ impl PyProcMesh { let (proc_mesh, children) = tracked_proc_mesh.into_inner(); // Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused". - children.discard_all().await?; + // Discarding actor meshes that have been individually stopped will result in an expected error + // which we can safely ignore + children.discard_or_error_all().await; // Finally, take ownership of the inner proc mesh, which will allowing dropping it. let _proc_mesh = proc_mesh.take().await?; diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 58f4f16dc..269409a99 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -1012,6 +1012,15 @@ async def test_actor_mesh_stop(self) -> None: await am_2.print.call("hello 3") await am_2.log.call("hello 4") + await pm.stop() + + async def test_proc_mesh_stop_after_actor_mesh_stop(self) -> None: + pm = proc_mesh(gpus=2) + am = await pm.spawn("printer", Printer) + + await cast(ActorMesh, am).stop() + await pm.stop() + class PortedActor(Actor): @endpoint(explicit_response_port=True)