Skip to content

Commit c306dcd

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Allow proc_mesh.stop() after actor_mesh.stop()
Summary: Currently calling proc_mesh.stop() will fail if an actor_mesh spawned on this proc_mesh has been stopped. This is because the python-binding wrapper for ProcMesh (TrackedProcMesh) holds onto a map of SharedCells containing each RootActorMesh spawned on it. When an PythonActorMesh is stopped, the RootActorMesh is dropped, but does not update the state of the TrackedProcMesh to remove this. When a ProcMesh is stopped, it will attempt to discard every SharedCell containing it's RootActorMeshes, but will return an error if any RootActorMesh has been dropped (The error we are seeing) We can make it such that stopping the PythonActorMesh will update the state of the TrackedProcMesh, but this would involve plumbing through references of TrackedProcMesh and related state to PythonActorMesh A simpler solution, it to simply allow the TrackedProcMesh to discard the errors related to attempting to discard a SharedCell where the RootActorMesh has already been dropped Differential Revision: D80269385
1 parent ba2a023 commit c306dcd

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

hyperactor_mesh/src/shared_cell.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use std::sync::atomic::Ordering;
1414

1515
use async_trait::async_trait;
1616
use dashmap::DashMap;
17+
use futures::future::join_all;
1718
use futures::future::try_join_all;
1819
use preempt_rwlock::OwnedPreemptibleRwLockReadGuard;
1920
use preempt_rwlock::PreemptibleRwLock;
@@ -219,6 +220,16 @@ impl SharedCellPool {
219220
.await?;
220221
Ok(())
221222
}
223+
224+
/// Run `take` on all cells in the pool and immediately drop them or produce an error if the cell has already been taken
225+
pub async fn discard_or_error_all(self) -> Vec<Result<(), EmptyCellError>> {
226+
join_all(
227+
self.map
228+
.iter()
229+
.map(|r| async move { r.value().discard().await }),
230+
)
231+
.await
232+
}
222233
}
223234

224235
/// Trait to facilitate storing `SharedCell`s of different types in a single pool.

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ impl PyProcMesh {
229229
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
230230

231231
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
232-
children.discard_all().await?;
232+
// Discarding actor meshes that have been individually stopped will result in an expected error
233+
// which we can safely ignore
234+
children.discard_or_error_all().await;
233235

234236
// Finally, take ownership of the inner proc mesh, which will allowing dropping it.
235237
let _proc_mesh = proc_mesh.take().await?;

python/tests/test_python_actors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,15 @@ async def test_actor_mesh_stop(self) -> None:
10091009
await am_2.print.call("hello 3")
10101010
await am_2.log.call("hello 4")
10111011

1012+
await pm.stop()
1013+
1014+
async def test_proc_mesh_stop_after_actor_mesh_stop(self) -> None:
1015+
pm = proc_mesh(gpus=2)
1016+
am = await pm.spawn("printer", Printer)
1017+
1018+
await cast(ActorMesh, am).stop()
1019+
await pm.stop()
1020+
10121021

10131022
class PortedActor(Actor):
10141023
@endpoint(explicit_response_port=True)

0 commit comments

Comments
 (0)