Skip to content

Commit 7bb3277

Browse files
dulinrileymeta-codesync[bot]
authored andcommitted
Add controller actor for meshes (#1941)
Summary: Pull Request resolved: #1941 Fixes #1848 We want to support an important ownership axiom of Actors, that every actor which is alive has a single alive owner. If that owner is stopped (either voluntarily or via a crash), we want the child actors to stop as well. To this end, whenever we spawn a normal actor on a ProcMesh, we also spawn a "controller" actor on the owner. This controller's only job at the moment is to ensure during cleanup it calls "stop" on the actor it owns. This should mean as long as your actor stops on its own, all of its children will recursively stop. This mechanism doesn't handle an unclean exit on the owner, in which case we wouldn't be able to run the cleanup hook. That will need to be fixed separately. It is an important case for when the client might just exit for some reason, or get Ctrl-C'd. Had to enhance the `hyperactor::export` macro to handle generics. "hyperactor::remote" still doesn't, but luckily we don't even need to do a remote spawn of this actor. Reviewed By: mariusae Differential Revision: D86917010 fbshipit-source-id: 473a8de925a55bf994a1da9aaee8f36ba44423ae
1 parent 2fc4680 commit 7bb3277

File tree

6 files changed

+223
-16
lines changed

6 files changed

+223
-16
lines changed

hyperactor_macros/src/lib.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,7 @@ impl Parse for ExportAttr {
15501550
pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
15511551
let input: DeriveInput = parse_macro_input!(item as DeriveInput);
15521552
let data_type_name = &input.ident;
1553+
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
15531554

15541555
let ExportAttr { spawn, handlers } = parse_macro_input!(attr as ExportAttr);
15551556
let tys = HandlerSpec::add_indexed(handlers);
@@ -1560,7 +1561,7 @@ pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
15601561

15611562
for ty in &tys {
15621563
handles.push(quote! {
1563-
impl hyperactor::actor::RemoteHandles<#ty> for #data_type_name {}
1564+
impl #impl_generics hyperactor::actor::RemoteHandles<#ty> for #data_type_name #ty_generics #where_clause {}
15641565
});
15651566
bindings.push(quote! {
15661567
ports.bind::<#ty>();
@@ -1573,24 +1574,24 @@ pub fn export(attr: TokenStream, item: TokenStream) -> TokenStream {
15731574
let mut expanded = quote! {
15741575
#input
15751576

1576-
impl hyperactor::actor::Referable for #data_type_name {}
1577+
impl #impl_generics hyperactor::actor::Referable for #data_type_name #ty_generics #where_clause {}
15771578

15781579
#(#handles)*
15791580

15801581
#(#type_registrations)*
15811582

15821583
// Always export the `Signal` type.
1583-
impl hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name {}
1584+
impl #impl_generics hyperactor::actor::RemoteHandles<hyperactor::actor::Signal> for #data_type_name #ty_generics #where_clause {}
15841585

1585-
impl hyperactor::actor::Binds<#data_type_name> for #data_type_name {
1586+
impl #impl_generics hyperactor::actor::Binds<#data_type_name #ty_generics> for #data_type_name #ty_generics #where_clause {
15861587
fn bind(ports: &hyperactor::proc::Ports<Self>) {
15871588
#(#bindings)*
15881589
}
15891590
}
15901591

15911592
// TODO: just use Named derive directly here.
1592-
impl hyperactor::data::Named for #data_type_name {
1593-
fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name)) }
1593+
impl #impl_generics hyperactor::data::Named for #data_type_name #ty_generics #where_clause {
1594+
fn typename() -> &'static str { concat!(std::module_path!(), "::", stringify!(#data_type_name #ty_generics)) }
15941595
}
15951596
};
15961597

hyperactor_mesh/src/v1.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
1313
pub mod actor_mesh;
1414
pub mod host_mesh;
15+
pub mod mesh_controller;
1516
pub mod proc_mesh;
1617
pub mod testactor;
1718
pub mod testing;
@@ -143,6 +144,9 @@ pub enum Error {
143144
#[error("error spawning actor: {0}")]
144145
SingletonActorSpawnError(anyhow::Error),
145146

147+
#[error("error spawning controller actor for mesh {0}: {1}")]
148+
ControllerActorSpawnError(Name, anyhow::Error),
149+
146150
#[error("error: {0} does not exist")]
147151
NotExist(Name),
148152

hyperactor_mesh/src/v1/host_mesh.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use hyperactor::Actor;
10+
use hyperactor::ActorHandle;
911
use hyperactor::accum::ReducerOpts;
1012
use hyperactor::channel::ChannelTransport;
1113
use hyperactor::clock::Clock;
@@ -62,6 +64,8 @@ pub use crate::v1::host_mesh::mesh_agent::HostMeshAgent;
6264
use crate::v1::host_mesh::mesh_agent::HostMeshAgentProcMeshTrampoline;
6365
use crate::v1::host_mesh::mesh_agent::ProcState;
6466
use crate::v1::host_mesh::mesh_agent::ShutdownHostClient;
67+
use crate::v1::mesh_controller::HostMeshController;
68+
use crate::v1::mesh_controller::ProcMeshController;
6569
use crate::v1::proc_mesh::ProcRef;
6670

6771
declare_attrs! {
@@ -126,7 +130,10 @@ impl HostRef {
126130
/// This call returns `Ok(()))` only after the agent has finished
127131
/// the termination pass and released the host, so the host is no
128132
/// longer reachable when this returns.
129-
async fn shutdown(&self, cx: &impl hyperactor::context::Actor) -> anyhow::Result<()> {
133+
pub(crate) async fn shutdown(
134+
&self,
135+
cx: &impl hyperactor::context::Actor,
136+
) -> anyhow::Result<()> {
130137
let agent = self.mesh_agent();
131138
let terminate_timeout =
132139
hyperactor::config::global::get(crate::bootstrap::MESH_TERMINATE_TIMEOUT);
@@ -426,6 +433,14 @@ impl HostMesh {
426433
},
427434
current_ref: HostMeshRef::new(name, extent.into(), hosts).unwrap(),
428435
};
436+
437+
// Spawn a unique mesh controller for each proc mesh, so the type of the
438+
// mesh can be preserved.
439+
let _controller: ActorHandle<HostMeshController> =
440+
HostMeshController::spawn(cx, mesh.deref().clone())
441+
.await
442+
.map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?;
443+
429444
tracing::info!(name = "HostMeshStatus", status = "Allocate::Created");
430445
Ok(mesh)
431446
}
@@ -889,6 +904,14 @@ impl HostMeshRef {
889904
let mesh =
890905
ProcMesh::create_owned_unchecked(cx, mesh_name, extent, self.clone(), procs).await;
891906
tracing::info!(name = "ProcMeshStatus", status = "Spawn::Created",);
907+
if let Ok(ref mesh) = mesh {
908+
// Spawn a unique mesh controller for each proc mesh, so the type of the
909+
// mesh can be preserved.
910+
let _controller: ActorHandle<ProcMeshController> =
911+
ProcMeshController::spawn(cx, mesh.deref().clone())
912+
.await
913+
.map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?;
914+
}
892915
mesh
893916
}
894917

@@ -969,8 +992,8 @@ impl HostMeshRef {
969992
.await
970993
{
971994
Ok(statuses) => {
972-
let failed = statuses.values().any(|s| s.is_failure());
973-
if failed {
995+
let all_stopped = statuses.values().all(|s| s.is_terminating());
996+
if !all_stopped {
974997
tracing::error!(
975998
name = "ProcMeshStatus",
976999
mesh_name = %proc_mesh_name,
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::fmt::Debug;
10+
11+
use async_trait::async_trait;
12+
use hyperactor::Actor;
13+
use hyperactor::Instance;
14+
use hyperactor::ProcId;
15+
use hyperactor::actor::ActorError;
16+
use hyperactor::actor::Referable;
17+
use ndslice::ViewExt;
18+
use ndslice::view::Ranked;
19+
20+
use crate::v1::actor_mesh::ActorMeshRef;
21+
use crate::v1::host_mesh::HostMeshRef;
22+
use crate::v1::proc_mesh::ProcMeshRef;
23+
24+
#[hyperactor::export(spawn = false)]
25+
pub(crate) struct ActorMeshController<A>
26+
where
27+
A: Referable,
28+
{
29+
mesh: ActorMeshRef<A>,
30+
}
31+
32+
impl<A: Referable> Debug for ActorMeshController<A> {
33+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34+
f.debug_struct("MeshController")
35+
.field("mesh", &self.mesh)
36+
.finish()
37+
}
38+
}
39+
40+
#[async_trait]
41+
impl<A: Referable> Actor for ActorMeshController<A> {
42+
type Params = ActorMeshRef<A>;
43+
async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
44+
Ok(Self { mesh: params })
45+
}
46+
47+
async fn cleanup(
48+
&mut self,
49+
this: &Instance<Self>,
50+
_err: Option<&ActorError>,
51+
) -> Result<(), anyhow::Error> {
52+
// Cannot use "ActorMesh::stop" as it's only defined on ActorMesh, not ActorMeshRef.
53+
self.mesh
54+
.proc_mesh()
55+
.stop_actor_by_name(this, self.mesh.name().clone())
56+
.await?;
57+
Ok(())
58+
}
59+
}
60+
61+
#[derive(Debug)]
62+
#[hyperactor::export(spawn = true)]
63+
pub(crate) struct ProcMeshController {
64+
mesh: ProcMeshRef,
65+
}
66+
67+
#[async_trait]
68+
impl Actor for ProcMeshController {
69+
type Params = ProcMeshRef;
70+
async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
71+
Ok(Self { mesh: params })
72+
}
73+
74+
async fn cleanup(
75+
&mut self,
76+
this: &Instance<Self>,
77+
_err: Option<&ActorError>,
78+
) -> Result<(), anyhow::Error> {
79+
// Cannot use "ProcMesh::stop" as it's only defined on ProcMesh, not ProcMeshRef.
80+
let names = self.mesh.proc_ids().collect::<Vec<ProcId>>();
81+
let region = self.mesh.region().clone();
82+
if let Some(hosts) = self.mesh.hosts() {
83+
hosts
84+
.stop_proc_mesh(this, self.mesh.name(), names, region)
85+
.await
86+
} else {
87+
Ok(())
88+
}
89+
}
90+
}
91+
92+
#[derive(Debug)]
93+
#[hyperactor::export(spawn = true)]
94+
pub(crate) struct HostMeshController {
95+
mesh: HostMeshRef,
96+
}
97+
98+
#[async_trait]
99+
impl Actor for HostMeshController {
100+
type Params = HostMeshRef;
101+
async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
102+
Ok(Self { mesh: params })
103+
}
104+
105+
async fn cleanup(
106+
&mut self,
107+
this: &Instance<Self>,
108+
_err: Option<&ActorError>,
109+
) -> Result<(), anyhow::Error> {
110+
// Cannot use "HostMesh::shutdown" as it's only defined on HostMesh, not HostMeshRef.
111+
for host in self.mesh.values() {
112+
if let Err(e) = host.shutdown(this).await {
113+
tracing::warn!(host = %host, error = %e, "host shutdown failed");
114+
}
115+
}
116+
Ok(())
117+
}
118+
}

hyperactor_mesh/src/v1/proc_mesh.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use std::sync::atomic::Ordering;
1919
use std::time::Duration;
2020

2121
use hyperactor::Actor;
22+
use hyperactor::ActorHandle;
2223
use hyperactor::ActorId;
2324
use hyperactor::ActorRef;
2425
use hyperactor::Named;
@@ -75,6 +76,7 @@ use crate::v1::Name;
7576
use crate::v1::ValueMesh;
7677
use crate::v1::host_mesh::mesh_agent::ProcState;
7778
use crate::v1::host_mesh::mesh_to_rankedvalues_with_default;
79+
use crate::v1::mesh_controller::ActorMeshController;
7880

7981
declare_attrs! {
8082
/// The maximum idle time between updates while spawning actor
@@ -686,6 +688,12 @@ impl ProcMeshRef {
686688
&self.name
687689
}
688690

691+
/// Returns the HostMeshRef that this ProcMeshRef might be backed by.
692+
/// Returns None if this ProcMeshRef is backed by an Alloc instead of a host mesh.
693+
pub fn hosts(&self) -> Option<&HostMeshRef> {
694+
self.host_mesh.as_ref()
695+
}
696+
689697
/// The current statuses of procs in this mesh.
690698
pub async fn status(&self, cx: &impl context::Actor) -> v1::Result<ValueMesh<bool>> {
691699
let vm: ValueMesh<_> = self.map_into(|proc_ref| {
@@ -809,7 +817,7 @@ impl ProcMeshRef {
809817
}
810818

811819
/// Returns an iterator over the proc ids in this mesh.
812-
fn proc_ids(&self) -> impl Iterator<Item = ProcId> {
820+
pub(crate) fn proc_ids(&self) -> impl Iterator<Item = ProcId> {
813821
self.ranks.iter().map(|proc_ref| proc_ref.proc_id.clone())
814822
}
815823

@@ -942,7 +950,7 @@ impl ProcMeshRef {
942950
// overlays are applied, it emits a new StatusMesh snapshot.
943951
// `wait()` loops on it, deciding when the stream is
944952
// "complete" (no more NotExist) or times out.
945-
match GetRankStatus::wait(
953+
let mesh = match GetRankStatus::wait(
946954
rx,
947955
self.ranks.len(),
948956
config::global::get(ACTOR_SPAWN_MAX_IDLE),
@@ -979,7 +987,14 @@ impl ProcMeshRef {
979987
);
980988
Err(Error::ActorSpawnError { statuses: legacy })
981989
}
982-
}
990+
}?;
991+
// Spawn a unique mesh manager for each actor mesh, so the type of the
992+
// mesh can be preserved.
993+
let _controller: ActorHandle<ActorMeshController<A>> =
994+
ActorMeshController::<A>::spawn(cx, mesh.deref().clone())
995+
.await
996+
.map_err(|e| Error::ControllerActorSpawnError(mesh.name().clone(), e))?;
997+
Ok(mesh)
983998
}
984999

9851000
/// Send stop actors message to all mesh agents for a specific mesh name
@@ -1034,10 +1049,11 @@ impl ProcMeshRef {
10341049
.await
10351050
{
10361051
Ok(statuses) => {
1037-
let has_failed = statuses
1038-
.values()
1039-
.any(|s| matches!(s, Status::Failed(_) | Status::Timeout(_)));
1040-
if !has_failed {
1052+
// Check that all actors are in some terminal state.
1053+
// Failed is ok, because one of these actors may have failed earlier
1054+
// and we're trying to stop the others.
1055+
let all_stopped = statuses.values().all(|s| s.is_terminating());
1056+
if all_stopped {
10411057
Ok(())
10421058
} else {
10431059
let legacy = mesh_to_rankedvalues_with_default(

python/tests/test_python_actors.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
current_rank,
6666
current_size,
6767
endpoint,
68+
ProcMesh,
6869
)
6970
from monarch.tools.config import defaults
7071
from typing_extensions import assert_type
@@ -2007,3 +2008,47 @@ def test_cleanup_async():
20072008
cleanup.check.call_one().get()
20082009
cast(ActorMesh[ActorWithCleanup], cleanup).stop().get()
20092010
assert counter.value.call_one().get() == 1
2011+
2012+
2013+
class WrapperActor(Actor):
2014+
"""Just spawns an actor and does nothing with it."""
2015+
2016+
def __init__(self, procs: ProcMesh, counter: Counter) -> None:
2017+
# Ensure both the proc mesh and actor mesh owned by this actor are
2018+
# cleaned up, and that there's no race between the two.
2019+
self.procs = this_host().spawn_procs()
2020+
# Also test that when spawning actors on a foreign proc mesh, the actors
2021+
# are cleaned up even though the client is still alive.
2022+
self.mesh_on_passed_procs = procs.spawn(
2023+
"inner_passed", ActorWithCleanup, counter
2024+
)
2025+
self.mesh_on_owned_procs = self.procs.spawn(
2026+
"inner_owned", ActorWithCleanup, counter
2027+
)
2028+
2029+
@endpoint
2030+
def check(self) -> None:
2031+
self.mesh_on_passed_procs.check.call().get()
2032+
self.mesh_on_owned_procs.check.call().get()
2033+
2034+
# Note that there is no __cleanup__ defined, the inner mesh should be
2035+
# automatically stopped without needing to define one.
2036+
2037+
2038+
def test_recursive_stop():
2039+
"""Tests that if A owns B, and A is stopped, B is also stopped. Cleanup
2040+
actors are used because we can observe a side effect of them stopping"""
2041+
procs = this_host().spawn_procs(per_host={"gpus": 1})
2042+
counter = procs.spawn("counter", Counter, 0)
2043+
wrapper = procs.spawn("wrapper", WrapperActor, procs, counter)
2044+
# Call an endpoint to ensure it is constructed.
2045+
wrapper.check.call_one().get()
2046+
# Calling stop on WrapperActor should also stop its owned ActorWithCleanup.
2047+
cast(ActorMesh[WrapperActor], wrapper).stop().get()
2048+
# The incr messages in the cleanups have no sequencing guarantee with when
2049+
# stop is complete, nor with any further messages sent from this client.
2050+
# So we need to make sure there is time for both messages to be processed.
2051+
time.sleep(10)
2052+
# Two increments to the counter: one for the actors on the owned proc mesh,
2053+
# and one for the actors on the passed-in proc mesh.
2054+
assert counter.value.call_one().get() == 2

0 commit comments

Comments
 (0)