Skip to content

Commit 8239b40

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Allow proc_mesh.stop() after actor_mesh.stop() (#877)
Summary: Pull Request resolved: #877 Currently calling proc_mesh.stop() will fail if an actor_mesh spawned on this proc_mesh has been stopped. This is because the python-binding wrapper for ProcMesh (TrackedProcMesh) holds onto a map of SharedCells containing each RootActorMesh spawned on it. When an PythonActorMesh is stopped, the RootActorMesh is dropped, but does not update the state of the TrackedProcMesh to remove this. When a ProcMesh is stopped, it will attempt to discard every SharedCell containing it's RootActorMeshes, but will return an error if any RootActorMesh has been dropped (The error we are seeing) We can make it such that stopping the PythonActorMesh will update the state of the TrackedProcMesh, but this would involve plumbing through references of TrackedProcMesh and related state to PythonActorMesh A simpler solution, it to simply allow the TrackedProcMesh to discard the errors related to attempting to discard a SharedCell where the RootActorMesh has already been dropped Reviewed By: ahmadsharif1 Differential Revision: D80269385
1 parent cce99fb commit 8239b40

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

hyperactor_mesh/src/shared_cell.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::sync::atomic::Ordering;
1414

1515
use async_trait::async_trait;
1616
use dashmap::DashMap;
17+
use futures::future::join_all;
1718
use futures::future::try_join_all;
1819
use preempt_rwlock::OwnedPreemptibleRwLockReadGuard;
1920
use preempt_rwlock::PreemptibleRwLock;
@@ -219,6 +220,16 @@ impl SharedCellPool {
219220
.await?;
220221
Ok(())
221222
}
223+
224+
/// Run `take` on all cells in the pool and immediately drop them or produce an error if the cell has already been taken
225+
pub async fn discard_or_error_all(self) -> Vec<Result<(), EmptyCellError>> {
226+
join_all(
227+
self.map
228+
.iter()
229+
.map(|r| async move { r.value().discard().await }),
230+
)
231+
.await
232+
}
222233
}
223234

224235
/// Trait to facilitate storing `SharedCell`s of different types in a single pool.

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,9 @@ impl PyProcMesh {
233233
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
234234

235235
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
236-
children.discard_all().await?;
236+
// Discarding actor meshes that have been individually stopped will result in an expected error
237+
// which we can safely ignore
238+
children.discard_or_error_all().await;
237239

238240
// Finally, take ownership of the inner proc mesh, which will allowing dropping it.
239241
let _proc_mesh = proc_mesh.take().await?;

python/tests/test_python_actors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,15 @@ async def test_actor_mesh_stop(self) -> None:
10121012
await am_2.print.call("hello 3")
10131013
await am_2.log.call("hello 4")
10141014

1015+
await pm.stop()
1016+
1017+
async def test_proc_mesh_stop_after_actor_mesh_stop(self) -> None:
1018+
pm = proc_mesh(gpus=2)
1019+
am = await pm.spawn("printer", Printer)
1020+
1021+
await cast(ActorMesh, am).stop()
1022+
await pm.stop()
1023+
10151024

10161025
class PortedActor(Actor):
10171026
@endpoint(explicit_response_port=True)

0 commit comments

Comments
 (0)