diff --git a/docs/source/books/hyperactor-book/src/actors/index.md b/docs/source/books/hyperactor-book/src/actors/index.md index 3e732e0e9..1da0317fb 100644 --- a/docs/source/books/hyperactor-book/src/actors/index.md +++ b/docs/source/books/hyperactor-book/src/actors/index.md @@ -8,7 +8,7 @@ This chapter introduces the actor system in hyperactor. We'll cover: - The [`Actor`](./actor.md) trait and its lifecycle hooks - The [`Handler`](./handler.md) trait for defining message-handling behavior -- The [`RemotableActor`](./remotable_actor.md) trait for enabling remote spawning +- The [`RemoteSpawn`](./remotable_actor.md) trait for enabling remote spawning - The [`Checkpointable`](./checkpointable.md) trait for supporting actor persistence and recovery - The [`Referable`](./remote_actor.md) marker trait for remotely referencable types - The [`Binds`](./binds.md) trait for wiring exported ports to reference types diff --git a/docs/source/books/hyperactor-book/src/actors/remotable_actor.md b/docs/source/books/hyperactor-book/src/actors/remotable_actor.md index d060cfb45..5bd151393 100644 --- a/docs/source/books/hyperactor-book/src/actors/remotable_actor.md +++ b/docs/source/books/hyperactor-book/src/actors/remotable_actor.md @@ -1,26 +1,30 @@ # The `RemoteableActor` Trait ```rust -pub trait RemotableActor: Actor -where - Self::Params: RemoteMessage, -{ +pub trait RemoteSpawn: Actor + Referable + Binds { + /// The type of parameters used to instantiate the actor remotely. + type Params: RemoteMessage; + + /// Creates a new actor instance given its instantiation parameters. + async fn new(params: Self::Params) -> anyhow::Result; + fn gspawn( proc: &Proc, name: &str, serialized_params: Data, - ) -> Pin> + Send>>; + ) -> Pin> + Send>> { /* default impl. */} fn get_type_id() -> TypeId { TypeId::of::() } } ``` -The `RemotableActor` trait marks an actor type as spawnable across process boundaries. It enables hyperactor's remote spawning and registration system, allowing actors to be created from serialized parameters in a different `Proc`. +The `RemoteSpawn` trait marks an actor type as spawnable across process boundaries. It enables hyperactor's remote spawning and registration system, allowing actors to be created from serialized parameters in a different `Proc`. ## Requirements - The actor type must also implement `Actor`. -- Its `Params` type (used in `Actor::new`) must implement `RemoteMessage`, so it can be serialized and transmitted over the network. +- Its `Params` type (used in `RemoteSpawn::new`) must implement `RemoteMessage`, so it can be serialized and transmitted over the network. +- `new` creates a new instance of the actor given its parameters ## `gspawn` ```rust @@ -39,41 +43,8 @@ The method deserializes the parameters, creates the actor, and returns its `Acto This is used internally by hyperactor's remote actor registry and `spawn` services. Ordinary users generally don't call this directly. -> **Note:** This is not an `async fn` because `RemotableActor` must be object-safe. +> **Note:** This is not an `async fn` because `RemoteSpawn` must be object-safe. ## `get_type_id` Returns a stable `TypeId` for the actor type. Used to identify actor types at runtime—e.g., in registration tables or type-based routing logic. - -## Blanket Implementation - -The RemotableActor trait is automatically implemented for any actor type `A` that: -- implements `Actor` and `Referable`, -- and whose `Params` type implements `RemoteMessage`. - -This allows `A` to be remotely registered and instantiated from serialized data, typically via the runtime's registration mechanism. - -```rust -impl RemotableActor for A -where - A: Actor + Referable, - A: Binds, - A::Params: RemoteMessage, -{ - fn gspawn( - proc: &Proc, - name: &str, - serialized_params: Data, - ) -> Pin> + Send>> { - let proc = proc.clone(); - let name = name.to_string(); - Box::pin(async move { - let handle = proc - .spawn::(&name, bincode::deserialize(&serialized_params)?) - .await?; - Ok(handle.bind::().actor_id) - }) - } -} -``` -Note the `Binds` bound: this trait specifies how an actor's ports are wired determining which message types the actor can receive remotely. The resulting `ActorId` corresponds to a port-bound, remotely callable version of the actor. diff --git a/docs/source/books/hyperactor-book/src/macros/export.md b/docs/source/books/hyperactor-book/src/macros/export.md index 4b5260937..73f6045b2 100644 --- a/docs/source/books/hyperactor-book/src/macros/export.md +++ b/docs/source/books/hyperactor-book/src/macros/export.md @@ -18,7 +18,7 @@ The macro expands to include: - A `Binds` implementation that registers supported message types - Implementations of `RemoteHandles` for each type in the `handlers = [...]` list - A `Referable` marker implementation - - If `spawn = true`, a `RemotableActor` implementation and an inventory registration of the `spawn` function. + - If `spawn = true`, the actor's `RemoteSpawn` implementation is registered in the remote actor inventory. This enables the actor to be: - Spawned dynamically by name @@ -46,7 +46,7 @@ impl Named for ShoppingListActor { ``` If `spawn = true`, the macro also emits: ```rust -impl RemotableActor for ShoppingListActor {} +impl RemoteSpawn for ShoppingListActor {} ``` This enables remote spawning via the default `gspawn` provided by a blanket implementation. diff --git a/hyper/src/commands/demo.rs b/hyper/src/commands/demo.rs index 68e455a9e..fee3e947e 100644 --- a/hyper/src/commands/demo.rs +++ b/hyper/src/commands/demo.rs @@ -17,6 +17,7 @@ use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Named; use hyperactor::RefClient; +use hyperactor::RemoteSpawn; use hyperactor::channel::ChannelAddr; use hyperactor::forward; use hyperactor::id; @@ -210,7 +211,7 @@ enum DemoMessage { Error(String, #[reply] OncePortRef<()>), } -#[derive(Debug, Default, Actor)] +#[derive(Debug, Default)] #[hyperactor::export( spawn = true, handlers = [ @@ -219,6 +220,8 @@ enum DemoMessage { )] struct DemoActor; +impl Actor for DemoActor {} + #[async_trait] #[forward(DemoMessage)] impl DemoMessageHandler for DemoActor { @@ -243,7 +246,7 @@ impl DemoMessageHandler for DemoActor { async fn spawn_child(&mut self, cx: &Context) -> Result, anyhow::Error> { tracing::info!("demo: spawn child"); - Ok(Self::spawn(cx, ()).await?.bind()) + Ok(Self.spawn(cx).await?.bind()) } async fn error(&mut self, _cx: &Context, message: String) -> Result<(), anyhow::Error> { diff --git a/hyperactor/example/derive.rs b/hyperactor/example/derive.rs index e1a8d6cc6..dfa9c7d1c 100644 --- a/hyperactor/example/derive.rs +++ b/hyperactor/example/derive.rs @@ -49,17 +49,11 @@ struct GetItemCount { } // Define an actor. -#[derive(Debug, Actor, Default)] -#[hyperactor::export( - spawn = true, - handlers = [ - ShoppingList, - ClearList, - GetItemCount, - ], -)] +#[derive(Debug, Default)] struct ShoppingListActor(HashSet); +impl Actor for ShoppingListActor {} + // ShoppingListHandler is the trait generated by derive(Handler) above. // We implement the trait here for the actor, defining a handler for // each ShoppingList message. @@ -140,7 +134,7 @@ async fn main() -> Result<(), anyhow::Error> { // Spawn our actor, and get a handle for rank 0. let shopping_list_actor: hyperactor::ActorHandle = - proc.spawn("shopping", ()).await?; + proc.spawn("shopping", ShoppingListActor::default()).await?; let shopping_api: hyperactor::ActorRef = shopping_list_actor.bind(); // We join the system, so that we can send messages to actors. let (client, _) = proc.instance("client").unwrap(); diff --git a/hyperactor/example/stream.rs b/hyperactor/example/stream.rs index 11bdd39c8..27c7e3e68 100644 --- a/hyperactor/example/stream.rs +++ b/hyperactor/example/stream.rs @@ -18,16 +18,19 @@ use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use hyperactor::proc::Proc; use serde::Deserialize; use serde::Serialize; -#[derive(Debug, Actor, Default)] +#[derive(Debug, Default)] struct CounterActor { subscribers: Vec>, n: u64, } +impl Actor for CounterActor {} + #[derive(Serialize, Deserialize, Debug, Named)] struct Subscribe(PortRef); @@ -52,15 +55,14 @@ struct CountClient { counter: PortRef, } -#[async_trait] -impl Actor for CountClient { - // Where to send subscribe messages. - type Params = PortRef; - - async fn new(counter: PortRef) -> Result { - Ok(Self { counter }) +impl CountClient { + fn new(counter: PortRef) -> Self { + Self { counter } } +} +#[async_trait] +impl Actor for CountClient { async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { // Subscribe to the counter on initialization. We give it our u64 port to report // messages back to. @@ -81,13 +83,19 @@ impl Handler for CountClient { async fn main() { let proc = Proc::local(); - let counter_actor: ActorHandle = proc.spawn("counter", ()).await.unwrap(); + let counter_actor: ActorHandle = proc + .spawn("counter", CounterActor::default()) + .await + .unwrap(); for i in 0..10 { // Spawn new "countees". Every time each subscribes, the counter broadcasts // the count to everyone. let _countee_actor: ActorHandle = proc - .spawn(&format!("countee_{}", i), counter_actor.port().bind()) + .spawn( + &format!("countee_{}", i), + CountClient::new(counter_actor.port().bind()), + ) .await .unwrap(); #[allow(clippy::disallowed_methods)] diff --git a/hyperactor/src/actor.rs b/hyperactor/src/actor.rs index 745641a68..0c03753f5 100644 --- a/hyperactor/src/actor.rs +++ b/hyperactor/src/actor.rs @@ -68,12 +68,6 @@ pub mod remote; /// actor is determined by the set (and order) of messages it receives. #[async_trait] pub trait Actor: Sized + Send + Debug + 'static { - /// The type of initialization parameters accepted by this actor. - type Params: Send + 'static; - - /// Creates a new actor instance given its instantiation parameters. - async fn new(params: Self::Params) -> Result; - /// Initialize the actor, after the runtime has been fully initialized. /// Init thus provides a mechanism by which an actor can reliably and always /// receive some initial event that can be used to kick off further @@ -104,11 +98,8 @@ pub trait Actor: Sized + Send + Debug + 'static { /// Spawn a child actor, given a spawning capability (usually given by [`Instance`]). /// The spawned actor will be supervised by the parent (spawning) actor. - async fn spawn( - cx: &impl context::Actor, - params: Self::Params, - ) -> anyhow::Result> { - cx.instance().spawn(params).await + async fn spawn(self, cx: &impl context::Actor) -> anyhow::Result> { + cx.instance().spawn(self).await } /// Spawns this actor in a detached state, handling its messages @@ -117,8 +108,8 @@ pub trait Actor: Sized + Send + Debug + 'static { /// /// Actors spawned through `spawn_detached` are not attached to a supervision /// hierarchy, and not managed by a [`Proc`]. - async fn spawn_detached(params: Self::Params) -> Result, anyhow::Error> { - Proc::local().spawn("anon", params).await + async fn spawn_detached(self) -> Result, anyhow::Error> { + Proc::local().spawn("anon", self).await } /// This method is used by the runtime to spawn the actor server. It can be @@ -175,13 +166,7 @@ pub fn handle_undeliverable_message( /// An actor that does nothing. It is used to represent "client only" actors, /// returned by [`Proc::instance`]. #[async_trait] -impl Actor for () { - type Params = (); - - async fn new(params: Self::Params) -> Result { - Ok(params) - } -} +impl Actor for () {} impl Referable for () {} @@ -240,57 +225,31 @@ where /// An `Actor` that can be spawned remotely. /// -/// Blanket-implemented for actors that opt in to remote spawn by also -/// implementing `Referable` and `Binds`, with serializable -/// params: -/// -/// ```rust,ignore -/// impl RemotableActor for A -/// where -/// A: Actor + Referable + Binds, -/// A::Params: RemoteMessage, -/// {} -/// ``` -/// /// Bounds explained: +/// - `Actor`: only actors may be remotely spawned. /// - `Referable`: marks the type as eligible for typed remote /// references (`ActorRef`); required because remote spawn /// ultimately hands back an `ActorId` that higher-level APIs may /// re-type as `ActorRef`. -/// - `Binds`: lets the runtime wire this actor's message ports -/// when it is spawned (the blanket impl calls `handle.bind::()`). -/// - `A::Params: RemoteMessage`: constructor params must be -/// (de)serializable to cross a process boundary. +/// - `Binds`: lets the runtime wire this actor's message ports +/// when it is spawned (the blanket impl calls `handle.bind::()`). /// /// `gspawn` is a type-erased entry point used by the remote /// spawn/registry machinery. It takes serialized params and returns -/// the new actor’s `ActorId`; application code shouldn’t call it +/// the new actor's `ActorId`; application code shouldn't call it /// directly. -pub trait RemotableActor: Actor -where - Self::Params: RemoteMessage, -{ +#[async_trait] +pub trait RemoteSpawn: Actor + Referable + Binds { + /// The type of parameters used to instantiate the actor remotely. + type Params: RemoteMessage; + + /// Creates a new actor instance given its instantiation parameters. + async fn new(params: Self::Params) -> anyhow::Result; + /// A type-erased entry point to spawn this actor. This is /// primarily used by hyperactor's remote actor registration /// mechanism. // TODO: consider making this 'private' -- by moving it into a non-public trait as in [`cap`]. - fn gspawn( - proc: &Proc, - name: &str, - serialized_params: Data, - ) -> Pin> + Send>>; - - /// The type ID of this actor. - fn get_type_id() -> TypeId { - TypeId::of::() - } -} - -impl RemotableActor for A -where - A: Actor + Referable + Binds, - A::Params: RemoteMessage, -{ fn gspawn( proc: &Proc, name: &str, @@ -299,9 +258,9 @@ where let proc = proc.clone(); let name = name.to_string(); Box::pin(async move { - let handle = proc - .spawn::(&name, bincode::deserialize(&serialized_params)?) - .await?; + let params = bincode::deserialize(&serialized_params)?; + let actor = Self::new(params).await?; + let handle = proc.spawn(&name, actor).await?; // We return only the ActorId, not a typed ActorRef. // Callers that hold this ID can interact with the actor // only via the serialized/opaque messaging path, which @@ -313,9 +272,25 @@ where // // This will be replaced by a proper export/registry // mechanism. - Ok(handle.bind::().actor_id) + Ok(handle.bind::().actor_id) }) } + + /// The type ID of this actor. + fn get_type_id() -> TypeId { + TypeId::of::() + } +} + +/// If an actor implements Default, we use this as the +/// `RemoteSpawn` implementation, too. +#[async_trait] +impl + Default> RemoteSpawn for A { + type Params = (); + + async fn new(_params: Self::Params) -> anyhow::Result { + Ok(Default::default()) + } } #[async_trait] @@ -722,9 +697,9 @@ impl Clone for ActorHandle { /// - and can be carried in [`ActorRef`] values across process /// boundaries. /// -/// In contrast, [`RemotableActor`] is the trait that marks *actors* +/// In contrast, [`RemoteSpawn`] is the trait that marks *actors* /// that can actually be **spawned remotely**. A behavior may be a -/// `Referable` but is never a `RemotableActor`. +/// `Referable` but is never a `RemoteSpawn`. pub trait Referable: Named + Send + Sync {} /// Binds determines how an actor's ports are bound to a specific @@ -744,13 +719,16 @@ pub trait RemoteHandles: Referable {} /// # use serde::Serialize; /// # use serde::Deserialize; /// # use hyperactor::Named; +/// # use hyperactor::Actor; /// /// // First, define a behavior, based on handling a single message type `()`. /// hyperactor::behavior!(UnitBehavior, ()); /// -/// #[derive(hyperactor::Actor, Debug, Default)] +/// #[derive(Debug, Default)] /// struct MyActor; /// +/// impl Actor for MyActor {} +/// /// #[async_trait::async_trait] /// impl hyperactor::Handler<()> for MyActor { /// async fn handle( @@ -790,7 +768,6 @@ mod tests { use crate::checkpoint::CheckpointError; use crate::checkpoint::Checkpointable; use crate::test_utils::pingpong::PingPongActor; - use crate::test_utils::pingpong::PingPongActorParams; use crate::test_utils::pingpong::PingPongMessage; use crate::test_utils::proc_supervison::ProcSupervisionCoordinator; // for macros @@ -798,13 +775,7 @@ mod tests { struct EchoActor(PortRef); #[async_trait] - impl Actor for EchoActor { - type Params = PortRef; - - async fn new(params: PortRef) -> Result { - Ok(Self(params)) - } - } + impl Actor for EchoActor {} #[async_trait] impl Handler for EchoActor { @@ -820,7 +791,8 @@ mod tests { let proc = Proc::local(); let client = proc.attach("client").unwrap(); let (tx, mut rx) = client.open_port(); - let handle = proc.spawn::("echo", tx.bind()).await.unwrap(); + let actor = EchoActor(tx.bind()); + let handle = proc.spawn::("echo", actor).await.unwrap(); handle.send(123u64).unwrap(); handle.drain_and_stop().unwrap(); handle.await; @@ -834,14 +806,14 @@ mod tests { let client = proc.attach("client").unwrap(); let (undeliverable_msg_tx, _) = client.open_port(); - let ping_pong_actor_params = - PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), None); + let ping_actor = PingPongActor::new(Some(undeliverable_msg_tx.bind()), None, None); + let pong_actor = PingPongActor::new(Some(undeliverable_msg_tx.bind()), None, None); let ping_handle = proc - .spawn::("ping", ping_pong_actor_params.clone()) + .spawn::("ping", ping_actor) .await .unwrap(); let pong_handle = proc - .spawn::("pong", ping_pong_actor_params) + .spawn::("pong", pong_actor) .await .unwrap(); @@ -865,14 +837,17 @@ mod tests { ProcSupervisionCoordinator::set(&proc).await.unwrap(); let error_ttl = 66; - let ping_pong_actor_params = - PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), Some(error_ttl)); + + let ping_actor = + PingPongActor::new(Some(undeliverable_msg_tx.bind()), Some(error_ttl), None); + let pong_actor = + PingPongActor::new(Some(undeliverable_msg_tx.bind()), Some(error_ttl), None); let ping_handle = proc - .spawn::("ping", ping_pong_actor_params.clone()) + .spawn::("ping", ping_actor) .await .unwrap(); let pong_handle = proc - .spawn::("pong", ping_pong_actor_params) + .spawn::("pong", pong_actor) .await .unwrap(); @@ -898,12 +873,6 @@ mod tests { #[async_trait] impl Actor for InitActor { - type Params = (); - - async fn new(_params: ()) -> Result { - Ok(Self(false)) - } - async fn init(&mut self, _this: &Instance) -> Result<(), anyhow::Error> { self.0 = true; Ok(()) @@ -925,7 +894,8 @@ mod tests { #[tokio::test] async fn test_init() { let proc = Proc::local(); - let handle = proc.spawn::("init", ()).await.unwrap(); + let actor = InitActor(false); + let handle = proc.spawn::("init", actor).await.unwrap(); let client = proc.attach("client").unwrap(); let (port, receiver) = client.open_once_port(); @@ -944,16 +914,7 @@ mod tests { } #[async_trait] - impl Actor for CheckpointActor { - type Params = PortRef; - - async fn new(params: PortRef) -> Result { - Ok(Self { - sum: 0, - port: params, - }) - } - } + impl Actor for CheckpointActor {} #[async_trait] impl Handler for CheckpointActor { @@ -992,10 +953,8 @@ mod tests { async fn new() -> Self { let proc = Proc::local(); let values: MultiValues = Arc::new(Mutex::new((0, "".to_string()))); - let handle = proc - .spawn::("myactor", values.clone()) - .await - .unwrap(); + let actor = MultiActor(values.clone()); + let handle = proc.spawn::("myactor", actor).await.unwrap(); let (client, client_handle) = proc.instance("client").unwrap(); Self { proc, @@ -1030,13 +989,7 @@ mod tests { struct MultiActor(MultiValues); #[async_trait] - impl Actor for MultiActor { - type Params = MultiValues; - - async fn new(init: Self::Params) -> Result { - Ok(Self(init)) - } - } + impl Actor for MultiActor {} #[async_trait] impl Handler for MultiActor { @@ -1111,13 +1064,15 @@ mod tests { #[tokio::test] async fn test_actor_handle_downcast() { - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] struct NothingActor; + impl Actor for NothingActor {} + // Just test that we can round-trip the handle through a downcast. let proc = Proc::local(); - let handle = proc.spawn::("nothing", ()).await.unwrap(); + let handle = proc.spawn("nothing", NothingActor).await.unwrap(); let cell = handle.cell(); // Invalid actor doesn't succeed. diff --git a/hyperactor/src/actor/remote.rs b/hyperactor/src/actor/remote.rs index 09b7e5dcb..0a4fe7fff 100644 --- a/hyperactor/src/actor/remote.rs +++ b/hyperactor/src/actor/remote.rs @@ -42,8 +42,8 @@ macro_rules! remote { $crate::submit! { $crate::actor::remote::SpawnableActor { name: &[<$actor:snake:upper _NAME>], - gspawn: <$actor as $crate::actor::RemotableActor>::gspawn, - get_type_id: <$actor as $crate::actor::RemotableActor>::get_type_id, + gspawn: <$actor as $crate::actor::RemoteSpawn>::gspawn, + get_type_id: <$actor as $crate::actor::RemoteSpawn>::get_type_id, } } } @@ -141,13 +141,17 @@ mod tests { use crate as hyperactor; // for macros use crate::Context; use crate::Handler; + use crate::RemoteSpawn; #[derive(Debug)] #[hyperactor::export(handlers = [()])] struct MyActor; #[async_trait] - impl Actor for MyActor { + impl Actor for MyActor {} + + #[async_trait] + impl RemoteSpawn for MyActor { type Params = bool; async fn new(params: bool) -> Result { diff --git a/hyperactor/src/host.rs b/hyperactor/src/host.rs index b95c1749a..8e1c4766e 100644 --- a/hyperactor/src/host.rs +++ b/hyperactor/src/host.rs @@ -1182,13 +1182,16 @@ pub mod testing { use crate::Context; use crate::Handler; use crate::OncePortRef; + use crate::RemoteSpawn; /// Just a simple actor, available in both the bootstrap binary as well as /// hyperactor tests. - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] #[hyperactor::export(handlers = [OncePortRef])] pub struct EchoActor; + impl Actor for EchoActor {} + #[async_trait] impl Handler> for EchoActor { async fn handle( diff --git a/hyperactor/src/lib.rs b/hyperactor/src/lib.rs index 1df220283..6a67fec65 100644 --- a/hyperactor/src/lib.rs +++ b/hyperactor/src/lib.rs @@ -100,6 +100,7 @@ pub use actor::Actor; pub use actor::ActorHandle; pub use actor::Handler; pub use actor::RemoteHandles; +pub use actor::RemoteSpawn; // Re-export public dependencies of hyperactor_macros codegen. #[doc(hidden)] pub use anyhow; @@ -112,8 +113,6 @@ pub use cityhasher; #[doc(hidden)] pub use dashmap; // For intern_typename! pub use data::Named; -#[doc(hidden)] -pub use hyperactor_macros::Actor; #[doc(inline)] pub use hyperactor_macros::AttrValue; #[doc(inline)] diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index 0c36fda11..fb0f011f6 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -3104,9 +3104,11 @@ mod tests { ); } - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] struct Foo; + impl Actor for Foo {} + // Test that a message delivery failure causes the sending actor // to stop running. #[tokio::test] @@ -3123,7 +3125,7 @@ mod tests { let mut proc = Proc::new(proc_id.clone(), proc_forwarder); ProcSupervisionCoordinator::set(&proc).await.unwrap(); - let foo = proc.spawn::("foo", ()).await.unwrap(); + let foo = proc.spawn("foo", Foo).await.unwrap(); let return_handle = foo.port::>(); let message = MessageEnvelope::new( foo.actor_id().clone(), diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index 724782386..a520568c9 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -511,7 +511,7 @@ impl Proc { pub async fn spawn( &self, name: &str, - params: A::Params, + actor: A, ) -> Result, anyhow::Error> { let actor_id = self.allocate_root_id(name)?; let span = tracing::span!( @@ -525,7 +525,6 @@ impl Proc { let _guard = span.clone().entered(); Instance::new(self.clone(), actor_id.clone(), false, None) }; - let actor = A::new(params).instrument(span.clone()).await?; // Add this actor to the proc's actor ledger. We do not actively remove // inactive actors from ledger, because the actor's state can be inferred // from its weak cell. @@ -588,12 +587,11 @@ impl Proc { async fn spawn_child( &self, parent: InstanceCell, - params: A::Params, + actor: A, ) -> Result, anyhow::Error> { let actor_id = self.allocate_child_id(parent.actor_id())?; let (instance, mut actor_loop_receivers, work_rx) = Instance::new(self.clone(), actor_id, false, Some(parent.clone())); - let actor = A::new(params).await?; Ok(instance .start(actor, actor_loop_receivers.take().unwrap(), work_rx) .await) @@ -1533,13 +1531,10 @@ impl Instance { } /// Spawn on child on this instance. Currently used only by cap::CanSpawn. - pub(crate) async fn spawn( - &self, - params: C::Params, - ) -> anyhow::Result> { + pub(crate) async fn spawn(&self, actor: C) -> anyhow::Result> { self.inner .proc - .spawn_child(self.inner.cell.clone(), params) + .spawn_child(self.inner.cell.clone(), actor) .await } @@ -2194,10 +2189,12 @@ mod tests { } } - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] #[export] struct TestActor; + impl Actor for TestActor {} + #[derive(Handler, HandleClient, Debug)] enum TestActorMessage { Reply(oneshot::Sender<()>), @@ -2276,7 +2273,7 @@ mod tests { cx: &crate::Context, reply: oneshot::Sender>, ) -> Result<(), anyhow::Error> { - let handle = ::spawn(cx, ()).await?; + let handle = TestActor::default().spawn(cx).await?; reply.send(handle).unwrap(); Ok(()) } @@ -2286,7 +2283,7 @@ mod tests { #[async_timed_test(timeout_secs = 30)] async fn test_spawn_actor() { let proc = Proc::local(); - let handle = proc.spawn::("test", ()).await.unwrap(); + let handle = proc.spawn("test", TestActor::default()).await.unwrap(); // Check on the join handle. assert!(logs_contain( @@ -2335,8 +2332,14 @@ mod tests { #[async_timed_test(timeout_secs = 30)] async fn test_proc_actors_messaging() { let proc = Proc::local(); - let first = proc.spawn::("first", ()).await.unwrap(); - let second = proc.spawn::("second", ()).await.unwrap(); + let first = proc + .spawn::("first", TestActor::default()) + .await + .unwrap(); + let second = proc + .spawn::("second", TestActor::default()) + .await + .unwrap(); let (tx, rx) = oneshot::channel::<()>(); let reply_message = TestActorMessage::Reply(tx); first @@ -2345,9 +2348,12 @@ mod tests { rx.await.unwrap(); } - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] + #[export] struct LookupTestActor; + impl Actor for LookupTestActor {} + #[derive(Handler, HandleClient, Debug)] enum LookupTestMessage { ActorExists(ActorRef, #[reply] OncePortRef), @@ -2370,9 +2376,15 @@ mod tests { let proc = Proc::local(); let (client, _handle) = proc.instance("client").unwrap(); - let target_actor = proc.spawn::("target", ()).await.unwrap(); + let target_actor = proc + .spawn::("target", TestActor::default()) + .await + .unwrap(); let target_actor_ref = target_actor.bind(); - let lookup_actor = proc.spawn::("lookup", ()).await.unwrap(); + let lookup_actor = proc + .spawn::("lookup", LookupTestActor::default()) + .await + .unwrap(); assert!( lookup_actor @@ -2430,7 +2442,10 @@ mod tests { async fn test_spawn_child() { let proc = Proc::local(); - let first = proc.spawn::("first", ()).await.unwrap(); + let first = proc + .spawn::("first", TestActor::default()) + .await + .unwrap(); let second = TestActor::spawn_child(&first).await; let third = TestActor::spawn_child(&second).await; @@ -2498,7 +2513,10 @@ mod tests { async fn test_child_lifecycle() { let proc = Proc::local(); - let root = proc.spawn::("root", ()).await.unwrap(); + let root = proc + .spawn::("root", TestActor::default()) + .await + .unwrap(); let root_1 = TestActor::spawn_child(&root).await; let root_2 = TestActor::spawn_child(&root).await; let root_2_1 = TestActor::spawn_child(&root_2).await; @@ -2519,7 +2537,10 @@ mod tests { // be actor failure(s) in this test which trigger supervision. ProcSupervisionCoordinator::set(&proc).await.unwrap(); - let root = proc.spawn::("root", ()).await.unwrap(); + let root = proc + .spawn::("root", TestActor::default()) + .await + .unwrap(); let root_1 = TestActor::spawn_child(&root).await; let root_2 = TestActor::spawn_child(&root).await; let root_2_1 = TestActor::spawn_child(&root_2).await; @@ -2559,7 +2580,10 @@ mod tests { let proc = Proc::local(); // Add the 1st root. This root will remain active until the end of the test. - let root: ActorHandle = proc.spawn::("root", ()).await.unwrap(); + let root: ActorHandle = proc + .spawn::("root", TestActor::default()) + .await + .unwrap(); wait_until_idle(&root).await; { let snapshot = proc.state().ledger.snapshot(); @@ -2573,8 +2597,10 @@ mod tests { } // Add the 2nd root. - let another_root: ActorHandle = - proc.spawn::("another_root", ()).await.unwrap(); + let another_root: ActorHandle = proc + .spawn::("another_root", TestActor::default()) + .await + .unwrap(); wait_until_idle(&another_root).await; { let snapshot = proc.state().ledger.snapshot(); @@ -2782,13 +2808,7 @@ mod tests { struct TestActor(Arc); #[async_trait] - impl Actor for TestActor { - type Params = Arc; - - async fn new(param: Arc) -> Result { - Ok(Self(param)) - } - } + impl Actor for TestActor {} #[async_trait] impl Handler>> for TestActor { @@ -2816,10 +2836,8 @@ mod tests { let proc = Proc::local(); let state = Arc::new(AtomicUsize::new(0)); - let handle = proc - .spawn::("test", state.clone()) - .await - .unwrap(); + let actor = TestActor(state.clone()); + let handle = proc.spawn::("test", actor).await.unwrap(); let client = proc.attach("client").unwrap(); let (tx, rx) = client.open_once_port(); handle.send(tx).unwrap(); @@ -2843,7 +2861,10 @@ mod tests { ProcSupervisionCoordinator::set(&proc).await.unwrap(); let (client, _handle) = proc.instance("client").unwrap(); - let actor_handle = proc.spawn::("test", ()).await.unwrap(); + let actor_handle = proc + .spawn::("test", TestActor::default()) + .await + .unwrap(); actor_handle .panic(&client, "some random failure".to_string()) .await @@ -2872,12 +2893,6 @@ mod tests { #[async_trait] impl Actor for TestActor { - type Params = (Arc, bool); - - async fn new(param: (Arc, bool)) -> Result { - Ok(Self(param.0, param.1)) - } - async fn handle_supervision_event( &mut self, _this: &Instance, @@ -2920,13 +2935,13 @@ mod tests { let root_2_1_state = Arc::new(AtomicBool::new(false)); let root = proc - .spawn::("root", (root_state.clone(), false)) + .spawn::("root", TestActor(root_state.clone(), false)) .await .unwrap(); let root_1 = proc .spawn_child::( root.cell().clone(), - ( + TestActor( root_1_state.clone(), true, /* set true so children's event stops here */ ), @@ -2934,19 +2949,28 @@ mod tests { .await .unwrap(); let root_1_1 = proc - .spawn_child::(root_1.cell().clone(), (root_1_1_state.clone(), false)) + .spawn_child::( + root_1.cell().clone(), + TestActor(root_1_1_state.clone(), false), + ) .await .unwrap(); let root_1_1_1 = proc - .spawn_child::(root_1_1.cell().clone(), (root_1_1_1_state.clone(), false)) + .spawn_child::( + root_1_1.cell().clone(), + TestActor(root_1_1_1_state.clone(), false), + ) .await .unwrap(); let root_2 = proc - .spawn_child::(root.cell().clone(), (root_2_state.clone(), false)) + .spawn_child::(root.cell().clone(), TestActor(root_2_state.clone(), false)) .await .unwrap(); let root_2_1 = proc - .spawn_child::(root_2.cell().clone(), (root_2_1_state.clone(), false)) + .spawn_child::( + root_2.cell().clone(), + TestActor(root_2_1_state.clone(), false), + ) .await .unwrap(); @@ -2978,9 +3002,11 @@ mod tests { #[async_timed_test(timeout_secs = 30)] async fn test_instance() { - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] struct TestActor; + impl Actor for TestActor {} + #[async_trait] impl Handler<(String, PortRef)> for TestActor { async fn handle( @@ -2997,7 +3023,7 @@ mod tests { let (instance, handle) = proc.instance("my_test_actor").unwrap(); - let child_actor = TestActor::spawn(&instance, ()).await.unwrap(); + let child_actor = TestActor::default().spawn(&instance).await.unwrap(); let (port, mut receiver) = instance.open_port(); child_actor @@ -3028,7 +3054,10 @@ mod tests { // Intentionally not setting a proc supervison coordinator. This // should cause the process to terminate. // ProcSupervisionCoordinator::set(&proc).await.unwrap(); - let root = proc.spawn::("root", ()).await.unwrap(); + let root = proc + .spawn::("root", TestActor::default()) + .await + .unwrap(); let (client, _handle) = proc.instance("client").unwrap(); root.fail(&client, anyhow::anyhow!("some random failure")) .await @@ -3058,9 +3087,11 @@ mod tests { #[ignore = "until trace recording is turned back on"] #[test] fn test_handler_logging() { - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] struct LoggingActor; + impl Actor for LoggingActor {} + impl LoggingActor { async fn wait(handle: &ActorHandle) { let barrier = Arc::new(Barrier::new(2)); @@ -3121,7 +3152,7 @@ mod tests { } trace_and_block(async { - let handle = LoggingActor::spawn_detached(()).await.unwrap(); + let handle = LoggingActor::default().spawn_detached().await.unwrap(); handle.send("hello world".to_string()).unwrap(); handle.send("hello world again".to_string()).unwrap(); handle.send(123u64).unwrap(); diff --git a/hyperactor/src/test_utils/pingpong.rs b/hyperactor/src/test_utils/pingpong.rs index da50eace7..bce19ca37 100644 --- a/hyperactor/src/test_utils/pingpong.rs +++ b/hyperactor/src/test_utils/pingpong.rs @@ -21,6 +21,7 @@ use crate::Instance; use crate::Named; use crate::OncePortRef; use crate::PortRef; +use crate::RemoteSpawn; use crate::clock::Clock; use crate::clock::RealClock; use crate::mailbox::MessageEnvelope; @@ -34,9 +35,10 @@ use crate::mailbox::UndeliverableMessageError; #[derive(Serialize, Deserialize, Debug, Named)] pub struct PingPongMessage(pub u64, pub ActorRef, pub OncePortRef); -/// Initialization parameters for `PingPongActor`s. -#[derive(Debug, Named, Serialize, Deserialize, Clone)] -pub struct PingPongActorParams { +/// A PingPong actor that can play the PingPong game by sending messages around. +#[derive(Debug)] +#[hyperactor::export(spawn = true, handlers = [PingPongMessage])] +pub struct PingPongActor { /// A port to send undeliverable messages to. undeliverable_port_ref: Option>>, /// The TTL at which the actor will exit with error. @@ -45,40 +47,40 @@ pub struct PingPongActorParams { delay: Option, } -impl PingPongActorParams { - /// Create a new set of initialization parameters. +impl PingPongActor { + /// Create a new ping pong actor with the following parameters: + /// + /// - `undeliverable_port_ref`: A port to send undeliverable messages to. + /// - `error_ttl`: The TTL at which the actor will exit with error. + /// - `delay`: Manual delay before sending handling the message. pub fn new( undeliverable_port_ref: Option>>, error_ttl: Option, + delay: Option, ) -> Self { Self { undeliverable_port_ref, error_ttl, - delay: None, + delay, } } - - /// Set the delay - pub fn set_delay(&mut self, delay: Duration) { - self.delay = Some(delay); - } -} - -/// A PingPong actor that can play the PingPong game by sending messages around. -#[derive(Debug)] -#[hyperactor::export(handlers = [PingPongMessage])] -pub struct PingPongActor { - params: PingPongActorParams, } #[async_trait] -impl Actor for PingPongActor { - type Params = PingPongActorParams; +impl RemoteSpawn for PingPongActor { + type Params = ( + Option>>, + Option, + Option, + ); - async fn new(params: Self::Params) -> Result { - Ok(Self { params }) + async fn new((undeliverable_port_ref, error_ttl, delay): Self::Params) -> anyhow::Result { + Ok(Self::new(undeliverable_port_ref, error_ttl, delay)) } +} +#[async_trait] +impl Actor for PingPongActor { // This is an override of the default actor behavior. It is used // for testing the mechanism for returning undeliverable messages to // their senders. @@ -87,7 +89,7 @@ impl Actor for PingPongActor { cx: &Instance, undelivered: crate::mailbox::Undeliverable, ) -> Result<(), anyhow::Error> { - match &self.params.undeliverable_port_ref { + match &self.undeliverable_port_ref { Some(port) => port.send(cx, undelivered).unwrap(), None => { let Undeliverable(envelope) = undelivered; @@ -111,13 +113,13 @@ impl Handler for PingPongActor { ) -> anyhow::Result<()> { // PingPongActor sends the messages back and forth. When it's ttl = 0, it will stop. // User can set a preconfigured TTL that can cause mocked problem: such as an error. - if Some(ttl) == self.params.error_ttl { + if Some(ttl) == self.error_ttl { anyhow::bail!("PingPong handler encountered an Error"); } if ttl == 0 { done_port.send(cx, true)?; } else { - if let Some(delay) = self.params.delay { + if let Some(delay) = self.delay { RealClock.sleep(delay).await; } let next_message = PingPongMessage(ttl - 1, cx.bind(), done_port); @@ -126,5 +128,3 @@ impl Handler for PingPongActor { Ok(()) } } - -hyperactor::remote!(PingPongActor); diff --git a/hyperactor/src/test_utils/proc_supervison.rs b/hyperactor/src/test_utils/proc_supervison.rs index 6e40309c6..e63a1b883 100644 --- a/hyperactor/src/test_utils/proc_supervison.rs +++ b/hyperactor/src/test_utils/proc_supervison.rs @@ -42,8 +42,9 @@ impl ProcSupervisionCoordinator { /// proc. pub async fn set(proc: &Proc) -> Result { let state = ReportedEvent::new(); + let actor = ProcSupervisionCoordinator(state.clone()); let coordinator = proc - .spawn::("coordinator", state.clone()) + .spawn::("coordinator", actor) .await?; proc.set_supervision_coordinator(coordinator.port())?; Ok(state) @@ -69,13 +70,7 @@ impl ReportedEvent { } #[async_trait] -impl Actor for ProcSupervisionCoordinator { - type Params = ReportedEvent; - - async fn new(param: ReportedEvent) -> Result { - Ok(Self(param)) - } -} +impl Actor for ProcSupervisionCoordinator {} #[async_trait] impl Handler for ProcSupervisionCoordinator { diff --git a/hyperactor/test/host_bootstrap.rs b/hyperactor/test/host_bootstrap.rs index 85643e6d2..d15dfef78 100644 --- a/hyperactor/test/host_bootstrap.rs +++ b/hyperactor/test/host_bootstrap.rs @@ -18,7 +18,8 @@ async fn main() { let proc = ProcessProcManager::::boot_proc(|proc| async move { - proc.spawn("echo", ()).await + proc.spawn("echo", hyperactor::host::testing::EchoActor) + .await }) .await .unwrap(); diff --git a/hyperactor_macros/src/lib.rs b/hyperactor_macros/src/lib.rs index b04c4ac7a..2cef1f21b 100644 --- a/hyperactor_macros/src/lib.rs +++ b/hyperactor_macros/src/lib.rs @@ -2142,96 +2142,6 @@ pub fn derive_unbind(input: TokenStream) -> TokenStream { TokenStream::from(expand) } -/// Derives the `Actor` trait for a struct. By default, generates an implementation -/// with no params (`type Params = ()` and `async fn new(_params: ()) -> Result`). -/// This requires that the Actor implements [`Default`]. -/// -/// If the `#[actor(passthrough)]` attribute is specified, generates an implementation -/// with where the parameter type is `Self` -/// (`type Params = Self` and `async fn new(instance: Self) -> Result`). -/// -/// # Examples -/// -/// Default behavior: -/// ``` -/// #[derive(Actor, Default)] -/// struct MyActor(u64); -/// ``` -/// -/// Generates: -/// ```ignore -/// #[async_trait] -/// impl Actor for MyActor { -/// type Params = (); -/// -/// async fn new(_params: ()) -> Result { -/// Ok(Default::default()) -/// } -/// } -/// ``` -/// -/// Passthrough behavior: -/// ``` -/// #[derive(Actor, Default)] -/// #[actor(passthrough)] -/// struct MyActor(u64); -/// ``` -/// -/// Generates: -/// ```ignore -/// #[async_trait] -/// impl Actor for MyActor { -/// type Params = Self; -/// -/// async fn new(instance: Self) -> Result { -/// Ok(instance) -/// } -/// } -/// ``` -#[proc_macro_derive(Actor, attributes(actor))] -pub fn derive_actor(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let name = &input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - - let is_passthrough = input.attrs.iter().any(|attr| { - if attr.path().is_ident("actor") { - if let Ok(meta) = attr.parse_args_with( - syn::punctuated::Punctuated::::parse_terminated, - ) { - return meta.iter().any(|ident| ident == "passthrough"); - } - } - false - }); - - let expanded = if is_passthrough { - quote! { - #[hyperactor::async_trait::async_trait] - impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause { - type Params = Self; - - async fn new(instance: Self) -> Result { - Ok(instance) - } - } - } - } else { - quote! { - #[hyperactor::async_trait::async_trait] - impl #impl_generics hyperactor::Actor for #name #ty_generics #where_clause { - type Params = (); - - async fn new(_params: ()) -> Result { - Ok(Default::default()) - } - } - } - }; - - TokenStream::from(expanded) -} - // Helper function for common parsing and validation fn parse_observe_function( attr: TokenStream, diff --git a/hyperactor_macros/tests/basic.rs b/hyperactor_macros/tests/basic.rs index df25e2ce1..93351a196 100644 --- a/hyperactor_macros/tests/basic.rs +++ b/hyperactor_macros/tests/basic.rs @@ -21,6 +21,7 @@ use hyperactor::Handler; use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::RemoteSpawn; use hyperactor::forward; use hyperactor::instrument; use hyperactor::instrument_infallible; @@ -78,10 +79,12 @@ enum TestVariantForms { }, } -#[derive(Debug, Default, Actor)] +#[derive(Debug, Default)] #[hyperactor::export(handlers = [TestVariantForms])] struct TestVariantFormsActor {} +impl Actor for TestVariantFormsActor {} + #[async_trait] #[forward(TestVariantForms)] impl TestVariantFormsHandler for TestVariantFormsActor { @@ -136,17 +139,7 @@ enum GenericArgMessage { #[derive(Debug)] struct GenericArgActor {} -#[derive(Clone, Debug, Serialize, Deserialize)] -struct GenericArgParams {} - -#[async_trait] -impl Actor for GenericArgActor { - type Params = GenericArgParams; - - async fn new(_params: Self::Params) -> Result { - Ok(Self {}) - } -} +impl Actor for GenericArgActor {} #[async_trait] #[forward(GenericArgMessage)] @@ -156,24 +149,14 @@ impl GenericArgMessageHandler for GenericArgActor { } } -#[derive(Actor, Default, Debug)] +#[derive(Default, Debug)] struct DefaultActorTest { value: u64, } -static_assertions::assert_impl_all!(DefaultActorTest: Actor); +impl Actor for DefaultActorTest {} -#[derive(Actor, Default, Debug)] -#[actor(passthrough)] -struct PassthroughActorTest { - value: u64, -} - -static_assertions::assert_impl_all!(PassthroughActorTest: Actor); -static_assertions::assert_type_eq_all!( - ::Params, - PassthroughActorTest -); +static_assertions::assert_impl_all!(DefaultActorTest: Actor); // Test struct support for Handler derive #[derive(Handler, Debug, Named, Serialize, Deserialize)] @@ -199,10 +182,7 @@ mod tests { async fn test_client_macros() { let proc = Proc::local(); let (client, _) = proc.instance("client").unwrap(); - let actor_handle = proc - .spawn::("foo", ()) - .await - .unwrap(); + let actor_handle = proc.spawn("foo", TestVariantFormsActor {}).await.unwrap(); assert_eq!(actor_handle.call_struct(&client, 10).await.unwrap(), 10,); diff --git a/hyperactor_macros/tests/export.rs b/hyperactor_macros/tests/export.rs index 81d22d4f8..78db578b5 100644 --- a/hyperactor_macros/tests/export.rs +++ b/hyperactor_macros/tests/export.rs @@ -13,6 +13,7 @@ use hyperactor::Context; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use serde::Deserialize; use serde::Serialize; @@ -32,21 +33,14 @@ struct TestActor { forward_port: PortRef, } -#[derive(Debug, Clone, Named, Serialize, Deserialize)] -struct TestActorParams { - forward_port: PortRef, -} - -#[async_trait] -impl Actor for TestActor { - type Params = TestActorParams; - - async fn new(params: Self::Params) -> anyhow::Result { - let Self::Params { forward_port } = params; - Ok(Self { forward_port }) +impl TestActor { + fn new(forward_port: PortRef) -> Self { + Self { forward_port } } } +impl Actor for TestActor {} + #[derive(Debug, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)] struct TestMessage(String); @@ -113,10 +107,7 @@ mod tests { let proc = Proc::local(); let (client, _) = proc.instance("client").unwrap(); let (tx, mut rx) = client.open_port(); - let params = TestActorParams { - forward_port: tx.bind(), - }; - let actor_handle = proc.spawn::("foo", params).await.unwrap(); + let actor_handle = proc.spawn("test", TestActor::new(tx.bind())).await.unwrap(); // This will call binds actor_handle.bind::(); // Verify that the ports can be gotten successfully. @@ -194,10 +185,7 @@ mod tests { let proc = Proc::local(); let (client, _) = proc.instance("client").unwrap(); let (tx, mut rx) = client.open_port(); - let params = TestActorParams { - forward_port: tx.bind(), - }; - let actor_handle = proc.spawn::("actor", params).await.unwrap(); + let actor_handle = proc.spawn("test", TestActor::new(tx.bind())).await.unwrap(); actor_handle.send(123u64).unwrap(); actor_handle.send(TestMessage("foo".to_string())).unwrap(); diff --git a/hyperactor_mesh/benches/bench_actor.rs b/hyperactor_mesh/benches/bench_actor.rs index 86befd05f..c0b29c35f 100644 --- a/hyperactor_mesh/benches/bench_actor.rs +++ b/hyperactor_mesh/benches/bench_actor.rs @@ -16,6 +16,7 @@ use hyperactor::Context; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::clock::Clock; use serde::Deserialize; @@ -40,8 +41,10 @@ pub struct BenchActor { processing_time: Duration, } +impl Actor for BenchActor {} + #[async_trait] -impl Actor for BenchActor { +impl RemoteSpawn for BenchActor { type Params = Duration; async fn new(params: Duration) -> Result { Ok(Self { diff --git a/hyperactor_mesh/examples/dining_philosophers.rs b/hyperactor_mesh/examples/dining_philosophers.rs index 2106d6b68..c3008f2ad 100644 --- a/hyperactor_mesh/examples/dining_philosophers.rs +++ b/hyperactor_mesh/examples/dining_philosophers.rs @@ -21,6 +21,7 @@ use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::context; use hyperactor_mesh::comm::multicast::CastInfo; @@ -83,8 +84,10 @@ struct PhilosopherActorParams { size: usize, } +impl Actor for PhilosopherActor {} + #[async_trait] -impl Actor for PhilosopherActor { +impl RemoteSpawn for PhilosopherActor { type Params = PhilosopherActorParams; async fn new(params: Self::Params) -> Result { diff --git a/hyperactor_mesh/examples/sieve.rs b/hyperactor_mesh/examples/sieve.rs index fa7da5502..9cbb5f57f 100644 --- a/hyperactor_mesh/examples/sieve.rs +++ b/hyperactor_mesh/examples/sieve.rs @@ -23,6 +23,7 @@ use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; use hyperactor::Proc; +use hyperactor::RemoteSpawn; use hyperactor::channel::ChannelTransport; use hyperactor_mesh::Mesh; use hyperactor_mesh::ProcMesh; @@ -82,8 +83,13 @@ impl Handler for SieveActor { } None => { msg.prime_collector.send(cx, msg.number)?; - self.next = - Some(SieveActor::spawn(cx, SieveParams { prime: msg.number }).await?); + + self.next = Some( + SieveActor::new(SieveParams { prime: msg.number }) + .await? + .spawn(cx) + .await?, + ); } } } @@ -91,8 +97,10 @@ impl Handler for SieveActor { } } +impl Actor for SieveActor {} + #[async_trait] -impl Actor for SieveActor { +impl RemoteSpawn for SieveActor { type Params = SieveParams; /// Creates a sieve actor for `prime`. diff --git a/hyperactor_mesh/examples/test_bench.rs b/hyperactor_mesh/examples/test_bench.rs index 95dc66b5e..2798b3ff2 100644 --- a/hyperactor_mesh/examples/test_bench.rs +++ b/hyperactor_mesh/examples/test_bench.rs @@ -34,7 +34,7 @@ use ndslice::extent; use serde::Deserialize; use serde::Serialize; -#[derive(Actor, Default, Debug)] +#[derive(Default, Debug)] #[hyperactor::export( spawn = true, handlers = [ @@ -43,6 +43,8 @@ use serde::Serialize; )] struct TestActor {} +impl Actor for TestActor {} + #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)] enum TestMessage { Ping(#[binding(include)] PortRef), diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 03b9306c9..eaef8e7ba 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -628,6 +628,7 @@ pub(crate) mod test_util { use hyperactor::Handler; use hyperactor::Instance; use hyperactor::PortRef; + use hyperactor::RemoteSpawn; use ndslice::extent; use super::*; @@ -637,7 +638,7 @@ pub(crate) mod test_util { // be an entry in the spawnable actor registry in the executable // 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor // mesh test suite. - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] #[hyperactor::export( spawn = true, handlers = [ @@ -650,6 +651,8 @@ pub(crate) mod test_util { )] pub struct TestActor; + impl Actor for TestActor {} + /// Request message to retrieve the actor's rank. /// /// The `bool` in the tuple controls the outcome of the handler: @@ -764,6 +767,14 @@ pub(crate) mod test_util { #[async_trait] impl Actor for ProxyActor { + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await?); + Ok(()) + } + } + + #[async_trait] + impl RemoteSpawn for ProxyActor { type Params = (); async fn new(_params: Self::Params) -> Result { @@ -776,8 +787,7 @@ pub(crate) mod test_util { use crate::alloc::Allocator; use crate::alloc::LocalAllocator; - let mut allocator = LocalAllocator; - let alloc = allocator + let alloc = LocalAllocator .allocate(AllocSpec { extent: extent! { replica = 1 }, constraints: Default::default(), @@ -787,6 +797,7 @@ pub(crate) mod test_util { }) .await .unwrap(); + let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap()); let leaked: &'static Arc = Box::leak(Box::new(proc_mesh)); Ok(Self { @@ -794,11 +805,6 @@ pub(crate) mod test_util { actor_mesh: None, }) } - - async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { - self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await?); - Ok(()) - } } #[async_trait] @@ -847,6 +853,7 @@ mod tests { use hyperactor::ActorId; use hyperactor::PortRef; use hyperactor::ProcId; + use hyperactor::RemoteSpawn; use hyperactor::WorldId; use hyperactor::attrs::Attrs; use hyperactor::data::Encoding; @@ -937,7 +944,6 @@ mod tests { async fn test_ping_pong() { use hyperactor::test_utils::pingpong::PingPongActor; use hyperactor::test_utils::pingpong::PingPongMessage; - use hyperactor::test_utils::pingpong::PingPongActorParams; let alloc = $allocator .allocate(AllocSpec { @@ -953,9 +959,8 @@ mod tests { let mesh = ProcMesh::allocate(alloc).await.unwrap(); let (undeliverable_msg_tx, _) = mesh.client().open_port(); - let ping_pong_actor_params = PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), None); let actor_mesh: RootActorMesh = mesh - .spawn::(&instance, "ping-pong", &ping_pong_actor_params) + .spawn::(&instance, "ping-pong", &(Some(undeliverable_msg_tx.bind()), None, None)) .await .unwrap(); @@ -970,7 +975,6 @@ mod tests { #[tokio::test] async fn test_pingpong_full_mesh() { use hyperactor::test_utils::pingpong::PingPongActor; - use hyperactor::test_utils::pingpong::PingPongActorParams; use hyperactor::test_utils::pingpong::PingPongMessage; use futures::future::join_all; @@ -992,8 +996,7 @@ mod tests { let instance = $crate::v1::testing::instance().await; let proc_mesh = ProcMesh::allocate(alloc).await.unwrap(); let (undeliverable_tx, _undeliverable_rx) = proc_mesh.client().open_port(); - let params = PingPongActorParams::new(Some(undeliverable_tx.bind()), None); - let actor_mesh = proc_mesh.spawn::(&instance, "pingpong", ¶ms).await.unwrap(); + let actor_mesh = proc_mesh.spawn::(&instance, "pingpong", &(Some(undeliverable_tx.bind()), None, None)).await.unwrap(); let slice = actor_mesh.shape().slice(); let mut futures = Vec::new(); @@ -1275,7 +1278,6 @@ mod tests { hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default()); use hyperactor::test_utils::pingpong::PingPongActor; - use hyperactor::test_utils::pingpong::PingPongActorParams; use hyperactor::test_utils::pingpong::PingPongMessage; use crate::alloc::ProcStopReason; @@ -1302,12 +1304,16 @@ mod tests { let mut mesh = ProcMesh::allocate(alloc).await.unwrap(); let mut events = mesh.events().unwrap(); - let ping_pong_actor_params = PingPongActorParams::new( - Some(PortRef::attest_message_port(mesh.client().self_id())), - None, - ); let actor_mesh: RootActorMesh = mesh - .spawn::(&instance, "ping-pong", &ping_pong_actor_params) + .spawn::( + &instance, + "ping-pong", + &( + Some(PortRef::attest_message_port(mesh.client().self_id())), + None, + None, + ), + ) .await .unwrap(); @@ -1419,7 +1425,6 @@ mod tests { #[tokio::test] async fn test_stop_actor_mesh() { use hyperactor::test_utils::pingpong::PingPongActor; - use hyperactor::test_utils::pingpong::PingPongActorParams; use hyperactor::test_utils::pingpong::PingPongMessage; let config = hyperactor::config::global::lock(); @@ -1441,17 +1446,29 @@ mod tests { let instance = crate::v1::testing::instance().await; let mesh = ProcMesh::allocate(alloc).await.unwrap(); - let ping_pong_actor_params = PingPongActorParams::new( - Some(PortRef::attest_message_port(mesh.client().self_id())), - None, - ); let mesh_one: RootActorMesh = mesh - .spawn::(&instance, "mesh_one", &ping_pong_actor_params) + .spawn::( + &instance, + "mesh_one", + &( + Some(PortRef::attest_message_port(mesh.client().self_id())), + None, + None, + ), + ) .await .unwrap(); let mesh_two: RootActorMesh = mesh - .spawn::(&instance, "mesh_two", &ping_pong_actor_params) + .spawn::( + &instance, + "mesh_two", + &( + Some(PortRef::attest_message_port(mesh.client().self_id())), + None, + None, + ), + ) .await .unwrap(); @@ -1688,6 +1705,7 @@ mod tests { use hyperactor::Actor; use hyperactor::Context; use hyperactor::Handler; + use hyperactor::RemoteSpawn; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::channel::ChannelTx; @@ -1716,7 +1734,10 @@ mod tests { struct EchoActor(ChannelTx); #[async_trait] - impl Actor for EchoActor { + impl Actor for EchoActor {} + + #[async_trait] + impl RemoteSpawn for EchoActor { type Params = ChannelAddr; async fn new(params: ChannelAddr) -> Result { diff --git a/hyperactor_mesh/src/alloc.rs b/hyperactor_mesh/src/alloc.rs index 439c7b276..c597cce5d 100644 --- a/hyperactor_mesh/src/alloc.rs +++ b/hyperactor_mesh/src/alloc.rs @@ -671,6 +671,7 @@ pub mod test_utils { use hyperactor::Context; use hyperactor::Handler; use hyperactor::Named; + use hyperactor::RemoteSpawn; use libc::atexit; use tokio::sync::broadcast::Receiver; use tokio::sync::broadcast::Sender; @@ -688,7 +689,7 @@ pub mod test_utils { // be an entry in the spawnable actor registry in the executable // 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor // mesh test suite. - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] #[hyperactor::export( spawn = true, handlers = [ @@ -697,6 +698,8 @@ pub mod test_utils { )] pub struct TestActor; + impl Actor for TestActor {} + #[derive(Debug, Serialize, Deserialize, Named, Clone)] pub struct Wait; diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index a83cf6be4..679d269ff 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -2593,10 +2593,10 @@ mod test_alloc { task2_allocator_handle.await.unwrap(); } - #[tracing_test::traced_test] #[async_timed_test(timeout_secs = 60)] #[cfg(fbcode_build)] async fn test_remote_process_alloc_signal_handler() { + hyperactor_telemetry::initialize_logging_for_test(); let num_proc_meshes = 5; let hosts_per_proc_mesh = 5; diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index 605b51a6b..4056a7b5e 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -490,10 +490,12 @@ impl Bootstrap { let (host, _handle) = ok!(Host::serve(manager, addr).await); let addr = host.addr().clone(); - let host_mesh_agent = ok!(host - .system_proc() - .clone() - .spawn::("agent", HostAgentMode::Process(host)) + let system_proc = host.system_proc().clone(); + let host_mesh_agent = ok!(system_proc + .spawn::( + "agent", + HostMeshAgent::new(HostAgentMode::Process(host)), + ) .await); tracing::info!( @@ -2352,6 +2354,7 @@ mod tests { use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::ProcId; + use hyperactor::RemoteSpawn; use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; @@ -2614,14 +2617,19 @@ mod tests { // Spawn the log client and disable aggregation (immediate // print + tap push). - let log_client: ActorRef = - proc.spawn("log_client", ()).await.unwrap().bind(); + let log_client_actor = LogClientActor::new(()).await.unwrap(); + let log_client: ActorRef = proc + .spawn("log_client", log_client_actor) + .await + .unwrap() + .bind(); log_client.set_aggregate(&client, None).await.unwrap(); // Spawn the forwarder in this proc (it will serve // BOOTSTRAP_LOG_CHANNEL). + let log_forwarder_actor = LogForwardActor::new(log_client.clone()).await.unwrap(); let _log_forwarder: ActorRef = proc - .spawn("log_forwarder", log_client.clone()) + .spawn("log_forwarder", log_forwarder_actor) .await .unwrap() .bind(); diff --git a/hyperactor_mesh/src/comm.rs b/hyperactor_mesh/src/comm.rs index bbbc63f4d..18c0b15f9 100644 --- a/hyperactor_mesh/src/comm.rs +++ b/hyperactor_mesh/src/comm.rs @@ -75,7 +75,7 @@ struct ReceiveState { /// This is the comm actor used for efficient and scalable message multicasting /// and result accumulation. -#[derive(Debug)] +#[derive(Debug, Default)] #[hyperactor::export( spawn = true, handlers = [ @@ -162,16 +162,6 @@ impl CommActorMode { #[async_trait] impl Actor for CommActor { - type Params = CommActorParams; - - async fn new(_params: Self::Params) -> Result { - Ok(Self { - send_seq: HashMap::new(), - recv_state: HashMap::new(), - mode: Default::default(), - }) - } - // This is an override of the default actor behavior. async fn handle_undeliverable_message( &mut self, @@ -522,7 +512,10 @@ pub mod test_utils { } #[async_trait] - impl Actor for TestActor { + impl Actor for TestActor {} + + #[async_trait] + impl hyperactor::RemoteSpawn for TestActor { type Params = TestActorParams; async fn new(params: Self::Params) -> Result { diff --git a/hyperactor_mesh/src/connect.rs b/hyperactor_mesh/src/connect.rs index 455502229..62f3ec987 100644 --- a/hyperactor_mesh/src/connect.rs +++ b/hyperactor_mesh/src/connect.rs @@ -397,15 +397,18 @@ mod tests { use hyperactor::Actor; use hyperactor::Context; use hyperactor::Handler; + use hyperactor::RemoteSpawn; use hyperactor::proc::Proc; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use super::*; - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] struct EchoActor {} + impl Actor for EchoActor {} + #[async_trait] impl Handler for EchoActor { async fn handle( @@ -427,7 +430,7 @@ mod tests { let proc = Proc::local(); let (client, _client_handle) = proc.instance("client")?; let (connect, completer) = Connect::allocate(client.self_id().clone(), client); - let actor = proc.spawn::("actor", ()).await?; + let actor = proc.spawn("actor", EchoActor {}).await?; actor.send(connect)?; let (mut rd, mut wr) = completer.complete().await?.into_split(); let send = [3u8, 4u8, 5u8, 6u8]; diff --git a/hyperactor_mesh/src/logging.rs b/hyperactor_mesh/src/logging.rs index 2e1b817c5..e1865df35 100644 --- a/hyperactor_mesh/src/logging.rs +++ b/hyperactor_mesh/src/logging.rs @@ -998,6 +998,21 @@ pub struct LogForwardActor { #[async_trait] impl Actor for LogForwardActor { + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?; + + // Make sure we start the flush loop periodically so the log channel will not deadlock. + self.flush_tx + .lock() + .await + .send(LogMessage::Flush { sync_version: None }) + .await?; + Ok(()) + } +} + +#[async_trait] +impl hyperactor::RemoteSpawn for LogForwardActor { type Params = ActorRef; async fn new(logging_client_ref: Self::Params) -> Result { @@ -1045,18 +1060,6 @@ impl Actor for LogForwardActor { stream_to_client: true, }) } - - async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { - this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?; - - // Make sure we start the flush loop periodically so the log channel will not deadlock. - self.flush_tx - .lock() - .await - .send(LogMessage::Flush { sync_version: None }) - .await?; - Ok(()) - } } #[async_trait] @@ -1182,6 +1185,25 @@ pub struct LogClientActor { current_unflushed_procs: usize, } +impl Default for LogClientActor { + fn default() -> Self { + // Initialize aggregators + let mut aggregators = HashMap::new(); + aggregators.insert(OutputTarget::Stderr, Aggregator::new()); + aggregators.insert(OutputTarget::Stdout, Aggregator::new()); + + Self { + aggregate_window_sec: Some(DEFAULT_AGGREGATE_WINDOW_SEC), + aggregators, + last_flush_time: RealClock.system_time_now(), + next_flush_deadline: None, + current_flush_version: 0, + current_flush_port: None, + current_unflushed_procs: 0, + } + } +} + impl LogClientActor { fn print_aggregators(&mut self) { for (output_target, aggregator) in self.aggregators.iter_mut() { @@ -1222,27 +1244,7 @@ impl LogClientActor { } #[async_trait] -impl Actor for LogClientActor { - /// The aggregation window in seconds. - type Params = (); - - async fn new(_: ()) -> Result { - // Initialize aggregators - let mut aggregators = HashMap::new(); - aggregators.insert(OutputTarget::Stderr, Aggregator::new()); - aggregators.insert(OutputTarget::Stdout, Aggregator::new()); - - Ok(Self { - aggregate_window_sec: Some(DEFAULT_AGGREGATE_WINDOW_SEC), - aggregators, - last_flush_time: RealClock.system_time_now(), - next_flush_deadline: None, - current_flush_version: 0, - current_flush_port: None, - current_unflushed_procs: 0, - }) - } -} +impl Actor for LogClientActor {} impl Drop for LogClientActor { fn drop(&mut self) { @@ -1458,6 +1460,7 @@ mod tests { use std::sync::Arc; use std::sync::Mutex; + use hyperactor::RemoteSpawn; use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTx; @@ -1599,10 +1602,15 @@ mod tests { unsafe { std::env::set_var(BOOTSTRAP_LOG_CHANNEL, log_channel.to_string()); } - let log_client: ActorRef = - proc.spawn("log_client", ()).await.unwrap().bind(); + let log_client_actor = LogClientActor::new(()).await.unwrap(); + let log_client: ActorRef = proc + .spawn("log_client", log_client_actor) + .await + .unwrap() + .bind(); + let log_forwarder_actor = LogForwardActor::new(log_client.clone()).await.unwrap(); let log_forwarder: ActorRef = proc - .spawn("log_forwarder", log_client) + .spawn("log_forwarder", log_forwarder_actor) .await .unwrap() .bind(); diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 634d97dd5..b6cc4d2df 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -23,9 +23,9 @@ use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Instance; use hyperactor::RemoteMessage; +use hyperactor::RemoteSpawn; use hyperactor::WorldId; use hyperactor::actor::ActorStatus; -use hyperactor::actor::Referable; use hyperactor::actor::remote::Remote; use hyperactor::channel; use hyperactor::channel::ChannelAddr; @@ -501,7 +501,7 @@ impl ProcMesh { /// Referable`. /// - `A::Params: RemoteMessage` - params must serialize for /// cross-proc spawn. - async fn spawn_on_procs( + async fn spawn_on_procs( cx: &impl context::Actor, agents: impl IntoIterator> + '_, actor_name: &str, @@ -608,7 +608,7 @@ impl ProcMesh { /// Referable`. /// - `A::Params: RemoteMessage` — params must be serializable to /// cross proc boundaries when launching each actor. - pub async fn spawn( + pub async fn spawn( &self, cx: &impl context::Actor, actor_name: &str, @@ -958,7 +958,7 @@ impl ProcEvents { pub trait SharedSpawnable { // `Actor`: the type actually runs in the mesh; // `Referable`: so we can hand back ActorRef in RootActorMesh - async fn spawn( + async fn spawn( self, cx: &impl context::Actor, actor_name: &str, @@ -972,7 +972,7 @@ pub trait SharedSpawnable { impl + Send + Sync + 'static> SharedSpawnable for D { // `Actor`: the type actually runs in the mesh; // `Referable`: so we can hand back ActorRef in RootActorMesh - async fn spawn( + async fn spawn( self, cx: &impl context::Actor, actor_name: &str, diff --git a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs index 8ef942c78..07a527ee8 100644 --- a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs @@ -287,12 +287,6 @@ impl ProcMeshAgent { #[async_trait] impl Actor for ProcMeshAgent { - type Params = Self; - - async fn new(params: Self::Params) -> Result { - Ok(params) - } - async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { self.proc.set_supervision_coordinator(this.port())?; Ok(()) diff --git a/hyperactor_mesh/src/reference.rs b/hyperactor_mesh/src/reference.rs index a1b3b8d10..8f8c67bb0 100644 --- a/hyperactor_mesh/src/reference.rs +++ b/hyperactor_mesh/src/reference.rs @@ -289,7 +289,10 @@ mod tests { } #[async_trait] - impl Actor for MeshPingPongActor { + impl Actor for MeshPingPongActor {} + + #[async_trait] + impl hyperactor::RemoteSpawn for MeshPingPongActor { type Params = MeshPingPongActorParams; async fn new(params: Self::Params) -> Result { diff --git a/hyperactor_mesh/src/resource/mesh.rs b/hyperactor_mesh/src/resource/mesh.rs index 3f463ce54..2084d14d4 100644 --- a/hyperactor_mesh/src/resource/mesh.rs +++ b/hyperactor_mesh/src/resource/mesh.rs @@ -101,9 +101,11 @@ mod test { type State = (); } - #[derive(Actor, Debug, Default, Named, Serialize, Deserialize)] + #[derive(Debug, Default, Named, Serialize, Deserialize)] struct TestMeshController; + impl Actor for TestMeshController {} + // Ensure that TestMeshController conforms to the Controller behavior for TestMesh. handler! { TestMeshController, diff --git a/hyperactor_mesh/src/test_utils.rs b/hyperactor_mesh/src/test_utils.rs index f21be6a1f..78a30db82 100644 --- a/hyperactor_mesh/src/test_utils.rs +++ b/hyperactor_mesh/src/test_utils.rs @@ -12,6 +12,7 @@ use hyperactor::Bind; use hyperactor::Context; use hyperactor::Handler; use hyperactor::Named; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use serde::Deserialize; use serde::Serialize; @@ -20,18 +21,20 @@ use serde::Serialize; #[derive(Serialize, Deserialize, Debug, Named, Clone, Bind, Unbind)] pub struct EmptyMessage(); -#[derive(Debug, PartialEq, Default, Actor)] +#[derive(Debug, PartialEq, Default)] #[hyperactor::export( + spawn = true, handlers = [ EmptyMessage { cast = true }, ], )] pub struct EmptyActor(); +impl Actor for EmptyActor {} + #[async_trait] impl Handler for EmptyActor { async fn handle(&mut self, _: &Context, _: EmptyMessage) -> Result<(), anyhow::Error> { Ok(()) } } -hyperactor::remote!(EmptyActor); diff --git a/hyperactor_mesh/src/v1/host_mesh.rs b/hyperactor_mesh/src/v1/host_mesh.rs index 26f4bd0cd..fb5e15955 100644 --- a/hyperactor_mesh/src/v1/host_mesh.rs +++ b/hyperactor_mesh/src/v1/host_mesh.rs @@ -266,10 +266,9 @@ impl HostMesh { let manager = BootstrapProcManager::new(bootstrap_cmd)?; let (host, _handle) = Host::serve(manager, addr).await?; let addr = host.addr().clone(); - let host_mesh_agent = host - .system_proc() - .clone() - .spawn::("agent", HostAgentMode::Process(host)) + let system_proc = host.system_proc().clone(); + let host_mesh_agent = system_proc + .spawn::("agent", HostMeshAgent::new(HostAgentMode::Process(host))) .await .map_err(v1::Error::SingletonActorSpawnError)?; host_mesh_agent.bind::(); @@ -436,10 +435,11 @@ impl HostMesh { // Spawn a unique mesh controller for each proc mesh, so the type of the // mesh can be preserved. - let _controller: ActorHandle = - HostMeshController::spawn(cx, mesh.deref().clone()) - .await - .map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?; + let controller = HostMeshController::new(mesh.deref().clone()); + controller + .spawn(cx) + .await + .map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?; tracing::info!(name = "HostMeshStatus", status = "Allocate::Created"); Ok(mesh) @@ -949,10 +949,11 @@ impl HostMeshRef { if let Ok(ref mesh) = mesh { // Spawn a unique mesh controller for each proc mesh, so the type of the // mesh can be preserved. - let _controller: ActorHandle = - ProcMeshController::spawn(cx, mesh.deref().clone()) - .await - .map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?; + let controller = ProcMeshController::new(mesh.deref().clone()); + controller + .spawn(cx) + .await + .map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?; } mesh } diff --git a/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs b/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs index 9d89d64e0..d316b9369 100644 --- a/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs @@ -106,11 +106,22 @@ struct ProcCreationState { ShutdownHost ] )] +#[derive(Default)] pub struct HostMeshAgent { host: Option, created: HashMap, } +impl HostMeshAgent { + /// Create a new host mesh agent running in the provided mode. + pub fn new(mode: HostAgentMode) -> Self { + Self { + host: Some(mode), + created: HashMap::new(), + } + } +} + impl fmt::Debug for HostMeshAgent { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HostMeshAgent") @@ -121,27 +132,7 @@ impl fmt::Debug for HostMeshAgent { } #[async_trait] -impl Actor for HostMeshAgent { - type Params = HostAgentMode; - - async fn new(host: HostAgentMode) -> anyhow::Result { - if let HostAgentMode::Process(_) = host { - let (directory, file) = hyperactor_telemetry::log_file_path( - hyperactor_telemetry::env::Env::current(), - None, - ) - .unwrap(); - eprintln!( - "Monarch internal logs are being written to {}/{}.log", - directory, file - ); - } - Ok(Self { - host: Some(host), - created: HashMap::new(), - }) - } -} +impl Actor for HostMeshAgent {} #[async_trait] impl Handler> for HostMeshAgent { @@ -438,6 +429,14 @@ pub(crate) struct HostMeshAgentProcMeshTrampoline { #[async_trait] impl Actor for HostMeshAgentProcMeshTrampoline { + async fn init(&mut self, this: &Instance) -> anyhow::Result<()> { + self.reply_port.send(this, self.host_mesh_agent.bind())?; + Ok(()) + } +} + +#[async_trait] +impl hyperactor::RemoteSpawn for HostMeshAgentProcMeshTrampoline { type Params = ( ChannelTransport, PortRef>, @@ -462,22 +461,28 @@ impl Actor for HostMeshAgentProcMeshTrampoline { HostAgentMode::Process(host) }; - let host_mesh_agent = host - .system_proc() - .clone() - .spawn::("agent", host) - .await?; + if let HostAgentMode::Process(_) = &host { + let (directory, file) = hyperactor_telemetry::log_file_path( + hyperactor_telemetry::env::Env::current(), + None, + ) + .unwrap(); + eprintln!( + "Monarch internal logs are being written to {}/{}.log", + directory, file + ); + } + + let system_proc = host.system_proc().clone(); + let actor = HostMeshAgent::new(host); + + let host_mesh_agent = system_proc.spawn::("agent", actor).await?; Ok(Self { host_mesh_agent, reply_port, }) } - - async fn init(&mut self, this: &Instance) -> anyhow::Result<()> { - self.reply_port.send(this, self.host_mesh_agent.bind())?; - Ok(()) - } } #[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)] @@ -525,7 +530,13 @@ mod tests { let host_addr = host.addr().clone(); let system_proc = host.system_proc().clone(); let host_agent = system_proc - .spawn::("agent", HostAgentMode::Process(host)) + .spawn::( + "agent", + HostMeshAgent { + host: Some(HostAgentMode::Process(host)), + created: HashMap::new(), + }, + ) .await .unwrap(); diff --git a/hyperactor_mesh/src/v1/mesh_controller.rs b/hyperactor_mesh/src/v1/mesh_controller.rs index 7ab7341c2..b79f4ffd5 100644 --- a/hyperactor_mesh/src/v1/mesh_controller.rs +++ b/hyperactor_mesh/src/v1/mesh_controller.rs @@ -21,7 +21,7 @@ use crate::v1::actor_mesh::ActorMeshRef; use crate::v1::host_mesh::HostMeshRef; use crate::v1::proc_mesh::ProcMeshRef; -#[hyperactor::export(spawn = false)] +#[hyperactor::export] pub(crate) struct ActorMeshController where A: Referable, @@ -29,6 +29,13 @@ where mesh: ActorMeshRef, } +impl ActorMeshController { + /// Create a new mesh controller based on the provided reference. + pub(crate) fn new(mesh: ActorMeshRef) -> Self { + Self { mesh } + } +} + impl Debug for ActorMeshController { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MeshController") @@ -39,11 +46,6 @@ impl Debug for ActorMeshController { #[async_trait] impl Actor for ActorMeshController { - type Params = ActorMeshRef; - async fn new(params: Self::Params) -> Result { - Ok(Self { mesh: params }) - } - async fn cleanup( &mut self, this: &Instance, @@ -59,18 +61,20 @@ impl Actor for ActorMeshController { } #[derive(Debug)] -#[hyperactor::export(spawn = true)] +#[hyperactor::export] pub(crate) struct ProcMeshController { mesh: ProcMeshRef, } -#[async_trait] -impl Actor for ProcMeshController { - type Params = ProcMeshRef; - async fn new(params: Self::Params) -> Result { - Ok(Self { mesh: params }) +impl ProcMeshController { + /// Create a new proc controller based on the provided reference. + pub(crate) fn new(mesh: ProcMeshRef) -> Self { + Self { mesh } } +} +#[async_trait] +impl Actor for ProcMeshController { async fn cleanup( &mut self, this: &Instance, @@ -90,18 +94,20 @@ impl Actor for ProcMeshController { } #[derive(Debug)] -#[hyperactor::export(spawn = true)] +#[hyperactor::export] pub(crate) struct HostMeshController { mesh: HostMeshRef, } -#[async_trait] -impl Actor for HostMeshController { - type Params = HostMeshRef; - async fn new(params: Self::Params) -> Result { - Ok(Self { mesh: params }) +impl HostMeshController { + /// Create a new host controller based on the provided reference. + pub(crate) fn new(mesh: HostMeshRef) -> Self { + Self { mesh } } +} +#[async_trait] +impl Actor for HostMeshController { async fn cleanup( &mut self, this: &Instance, diff --git a/hyperactor_mesh/src/v1/proc_mesh.rs b/hyperactor_mesh/src/v1/proc_mesh.rs index d3a1d04e9..b7e7c5354 100644 --- a/hyperactor_mesh/src/v1/proc_mesh.rs +++ b/hyperactor_mesh/src/v1/proc_mesh.rs @@ -25,6 +25,7 @@ use hyperactor::ActorRef; use hyperactor::Named; use hyperactor::ProcId; use hyperactor::RemoteMessage; +use hyperactor::RemoteSpawn; use hyperactor::accum::ReducerOpts; use hyperactor::actor::ActorStatus; use hyperactor::actor::Referable; @@ -834,7 +835,7 @@ impl ProcMeshRef { /// inside the `ActorMesh`. /// - `A::Params: RemoteMessage` - spawn parameters must be /// serializable and routable. - pub async fn spawn( + pub async fn spawn( &self, cx: &impl context::Actor, name: &str, @@ -853,7 +854,7 @@ impl ProcMeshRef { /// /// Note: avoid using service actors if possible; the mechanism will /// be replaced by an actor registry. - pub async fn spawn_service( + pub async fn spawn_service( &self, cx: &impl context::Actor, name: &str, @@ -884,7 +885,7 @@ impl ProcMeshRef { proc_mesh=self.name.to_string(), actor_name=name.to_string(), ))] - pub(crate) async fn spawn_with_name( + pub(crate) async fn spawn_with_name( &self, cx: &impl context::Actor, name: Name, @@ -915,15 +916,12 @@ impl ProcMeshRef { result } - async fn spawn_with_name_inner( + async fn spawn_with_name_inner( &self, cx: &impl context::Actor, name: Name, params: &A::Params, - ) -> v1::Result> - where - A::Params: RemoteMessage, - { + ) -> v1::Result> { let remote = Remote::collect(); // `Referable` ensures the type `A` is registered with // `Remote`. @@ -1030,10 +1028,11 @@ impl ProcMeshRef { }?; // Spawn a unique mesh manager for each actor mesh, so the type of the // mesh can be preserved. - let _controller: ActorHandle> = - ActorMeshController::::spawn(cx, mesh.deref().clone()) - .await - .map_err(|e| Error::ControllerActorSpawnError(mesh.name().clone(), e))?; + let controller = ActorMeshController::::new(mesh.deref().clone()); + controller + .spawn(cx) + .await + .map_err(|e| Error::ControllerActorSpawnError(mesh.name().clone(), e))?; Ok(mesh) } diff --git a/hyperactor_mesh/src/v1/testactor.rs b/hyperactor_mesh/src/v1/testactor.rs index df56443d4..512fcac81 100644 --- a/hyperactor_mesh/src/v1/testactor.rs +++ b/hyperactor_mesh/src/v1/testactor.rs @@ -28,6 +28,7 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::PortRef; use hyperactor::RefClient; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; #[cfg(test)] use hyperactor::clock::Clock as _; @@ -53,7 +54,7 @@ use crate::v1::ActorMeshRef; use crate::v1::testing; /// A simple test actor used by various unit tests. -#[derive(Actor, Default, Debug)] +#[derive(Default, Debug)] #[hyperactor::export( spawn = true, handlers = [ @@ -67,6 +68,8 @@ use crate::v1::testing; )] pub struct TestActor; +impl Actor for TestActor {} + /// A message that returns the recipient actor's id. #[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)] pub struct GetActorId(#[binding(include)] pub PortRef); @@ -131,12 +134,6 @@ pub struct TestActorWithSupervisionHandling; #[async_trait] impl Actor for TestActorWithSupervisionHandling { - type Params = (); - - async fn new(_params: Self::Params) -> Result { - Ok(Self {}) - } - async fn handle_supervision_event( &mut self, _this: &Instance, @@ -219,12 +216,15 @@ impl Handler for TestActor { } } -#[derive(Default, Debug)] +#[derive(Debug)] #[hyperactor::export(spawn = true)] pub struct FailingCreateTestActor; #[async_trait] -impl Actor for FailingCreateTestActor { +impl Actor for FailingCreateTestActor {} + +#[async_trait] +impl hyperactor::RemoteSpawn for FailingCreateTestActor { type Params = (); async fn new(_params: Self::Params) -> Result { diff --git a/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs b/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs index 13f5b6737..bc3a9fd37 100644 --- a/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs +++ b/hyperactor_mesh/test/hyperactor_mesh_proxy_test.rs @@ -21,6 +21,7 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::PortRef; use hyperactor::Proc; +use hyperactor::RemoteSpawn; use hyperactor::channel::ChannelTransport; use hyperactor_mesh::Mesh; use hyperactor_mesh::ProcMesh; @@ -63,8 +64,10 @@ struct Args { )] pub struct TestActor; +impl Actor for TestActor {} + #[async_trait] -impl Actor for TestActor { +impl RemoteSpawn for TestActor { type Params = (); async fn new(_params: Self::Params) -> Result { @@ -108,6 +111,14 @@ impl fmt::Debug for ProxyActor { #[async_trait] impl Actor for ProxyActor { + async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { + self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await.unwrap()); + Ok(()) + } +} + +#[async_trait] +impl RemoteSpawn for ProxyActor { type Params = String; async fn new(exe_path: Self::Params) -> anyhow::Result { @@ -133,11 +144,6 @@ impl Actor for ProxyActor { actor_mesh: None, }) } - - async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { - self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await.unwrap()); - Ok(()) - } } #[async_trait] diff --git a/hyperactor_multiprocess/src/ping_pong.rs b/hyperactor_multiprocess/src/ping_pong.rs index 694a1397b..7d8f57235 100644 --- a/hyperactor_multiprocess/src/ping_pong.rs +++ b/hyperactor_multiprocess/src/ping_pong.rs @@ -21,7 +21,6 @@ mod tests { use hyperactor::reference::WorldId; use hyperactor::simnet; use hyperactor::test_utils::pingpong::PingPongActor; - use hyperactor::test_utils::pingpong::PingPongActorParams; use hyperactor::test_utils::pingpong::PingPongMessage; use crate::System; @@ -104,12 +103,11 @@ mod tests { .await .unwrap(); - let params = PingPongActorParams::new(None, None); spawn::( cx, &bootstrap.proc_actor.bind(), actor_index.to_string().as_str(), - ¶ms, + &(None, None, None), ) .await .unwrap() diff --git a/hyperactor_multiprocess/src/proc_actor.rs b/hyperactor_multiprocess/src/proc_actor.rs index 26b6b607a..6d70ee279 100644 --- a/hyperactor_multiprocess/src/proc_actor.rs +++ b/hyperactor_multiprocess/src/proc_actor.rs @@ -27,12 +27,11 @@ use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::PortRef; use hyperactor::RefClient; -use hyperactor::RemoteMessage; +use hyperactor::RemoteSpawn; use hyperactor::WorldId; use hyperactor::actor::ActorErrorKind; use hyperactor::actor::ActorHandle; use hyperactor::actor::ActorStatus; -use hyperactor::actor::Referable; use hyperactor::actor::remote::Remote; use hyperactor::channel; use hyperactor::channel::ChannelAddr; @@ -425,9 +424,9 @@ impl ProcActor { let handle = match proc .clone() - .spawn::( + .spawn( "proc", - ProcActorParams { + ProcActor::new(ProcActorParams { proc: proc.clone(), world_id: world_id.clone(), system_actor_ref: SYSTEM_ACTOR_REF.clone(), @@ -438,7 +437,7 @@ impl ProcActor { supervision_update_interval, labels, lifecycle_mode, - }, + }), ) .await { @@ -449,11 +448,7 @@ impl ProcActor { } }; - let comm_actor = match proc - .clone() - .spawn::("comm", Default::default()) - .await - { + let comm_actor = match proc.clone().spawn("comm", CommActor::default()).await { Ok(handle) => handle, Err(e) => { Self::failed_proc_bootstrap_cleanup(mailbox_handle).await; @@ -502,18 +497,6 @@ impl ProcActor { #[async_trait] impl Actor for ProcActor { - type Params = ProcActorParams; - - async fn new(params: ProcActorParams) -> Result { - let last_successful_supervision_update = params.proc.clock().system_time_now(); - Ok(Self { - params, - state: ProcState::AwaitingJoin, - remote: Remote::collect(), - last_successful_supervision_update, - }) - } - async fn init(&mut self, this: &Instance) -> anyhow::Result<()> { // Bind ports early so that when the proc actor joins, it can serve. this.bind::(); @@ -550,6 +533,16 @@ impl Actor for ProcActor { } impl ProcActor { + fn new(params: ProcActorParams) -> Self { + let last_successful_supervision_update = params.proc.clock().system_time_now(); + Self { + params, + state: ProcState::AwaitingJoin, + remote: Remote::collect(), + last_successful_supervision_update, + } + } + /// This proc's rank in the world. fn rank(&self) -> Index { self.params @@ -837,15 +830,12 @@ impl Handler for ProcActor { /// Convenience utility to spawn an actor on a proc. Spawn returns /// with the new ActorRef on success. -pub async fn spawn( +pub async fn spawn( cx: &impl context::Actor, proc_actor: &ActorRef, actor_name: &str, params: &A::Params, -) -> Result, anyhow::Error> -where - A::Params: RemoteMessage, -{ +) -> Result, anyhow::Error> { let remote = Remote::collect(); let (spawned_port, mut spawned_receiver) = open_port(cx); let ActorId(proc_id, _, _) = (*proc_actor).clone().into(); @@ -880,6 +870,7 @@ mod tests { use std::collections::HashSet; use std::time::Duration; + use hyperactor::RemoteSpawn; use hyperactor::actor::ActorStatus; use hyperactor::channel; use hyperactor::channel::ChannelAddr; @@ -891,7 +882,6 @@ mod tests { use hyperactor::id; use hyperactor::reference::ActorRef; use hyperactor::test_utils::pingpong::PingPongActor; - use hyperactor::test_utils::pingpong::PingPongActorParams; use hyperactor::test_utils::pingpong::PingPongMessage; use maplit::hashset; use rand::Rng; @@ -970,7 +960,7 @@ mod tests { server_handle.await; } - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] #[hyperactor::export( spawn = true, handlers = [ @@ -979,6 +969,8 @@ mod tests { )] struct TestActor; + impl Actor for TestActor {} + #[derive(Handler, HandleClient, RefClient, Serialize, Deserialize, Debug, Named)] enum TestActorMessage { Increment(u64, #[reply] OncePortRef), @@ -1031,7 +1023,7 @@ mod tests { } // Sleep - #[derive(Debug, Default, Actor)] + #[derive(Debug, Default)] #[hyperactor::export( spawn = true, handlers = [ @@ -1040,6 +1032,8 @@ mod tests { )] struct SleepActor {} + impl Actor for SleepActor {} + #[async_trait] impl Handler for SleepActor { async fn handle(&mut self, _cx: &Context, message: u64) -> anyhow::Result<()> { @@ -1444,17 +1438,21 @@ mod tests { let proc_1_client = proc_1.attach("client").unwrap(); let (proc_1_undeliverable_tx, mut _proc_1_undeliverable_rx) = proc_1_client.open_port(); - let ping_params = PingPongActorParams::new(Some(proc_0_undeliverable_tx.bind()), None); // Spawn two actors 'ping' and 'pong' where 'ping' runs on // 'world[0]' and 'pong' on 'world[1]' (that is, not on the // same proc). let ping_handle = proc_0 - .spawn::("ping", ping_params) + .spawn( + "ping", + PingPongActor::new(Some(proc_0_undeliverable_tx.bind()), None, None), + ) .await .unwrap(); - let pong_params = PingPongActorParams::new(Some(proc_1_undeliverable_tx.bind()), None); let pong_handle = proc_1 - .spawn::("pong", pong_params) + .spawn( + "pong", + PingPongActor::new(Some(proc_1_undeliverable_tx.bind()), None, None), + ) .await .unwrap(); @@ -1571,14 +1569,18 @@ mod tests { // Spawn two actors 'ping' and 'pong' where 'ping' runs on // 'world[0]' and 'pong' on 'world[1]' (that is, not on the // same proc). - let ping_params = PingPongActorParams::new(Some(proc_0_undeliverable_tx.bind()), None); let ping_handle = proc_0 - .spawn::("ping", ping_params) + .spawn( + "ping", + PingPongActor::new(Some(proc_0_undeliverable_tx.bind()), None, None), + ) .await .unwrap(); - let pong_params = PingPongActorParams::new(Some(proc_1_undeliverable_tx.bind()), None); let pong_handle = proc_1 - .spawn::("pong", pong_params) + .spawn( + "pong", + PingPongActor::new(Some(proc_1_undeliverable_tx.bind()), None, None), + ) .await .unwrap(); @@ -1733,12 +1735,11 @@ mod tests { .await .unwrap(); let (undeliverable_msg_tx, _) = cx.mailbox().open_port(); - let params = PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), None); let actor_ref = spawn::( cx, &bootstrap.proc_actor.bind(), &actor_id.to_string(), - ¶ms, + &(Some(undeliverable_msg_tx.bind()), None, None), ) .await .unwrap(); diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index 2712a3950..c790042aa 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -35,6 +35,7 @@ use hyperactor::PortHandle; use hyperactor::PortRef; use hyperactor::ProcId; use hyperactor::RefClient; +use hyperactor::RemoteSpawn; use hyperactor::WorldId; use hyperactor::actor::Handler; use hyperactor::channel::ChannelAddr; @@ -1135,6 +1136,17 @@ pub static SYSTEM_ACTOR_REF: LazyLock> = LazyLock::new(|| ActorRef::attest(id!(system[0].root))); impl SystemActor { + fn new(params: SystemActorParams) -> Self { + let supervision_update_timeout = params.supervision_update_timeout.clone(); + Self { + params, + supervision_state: SystemSupervisionState::new(supervision_update_timeout), + worlds: HashMap::new(), + worlds_to_stop: HashMap::new(), + shutting_down: false, + } + } + /// Adds a new world that's awaiting creation to the worlds. fn add_new_world(&mut self, world_id: WorldId) -> Result<(), anyhow::Error> { let world_state = WorldState { @@ -1180,7 +1192,7 @@ impl SystemActor { clock, ); let actor_handle = system_proc - .spawn::(SYSTEM_ACTOR_ID.name(), params) + .spawn(SYSTEM_ACTOR_ID.name(), SystemActor::new(params)) .await?; Ok((actor_handle, system_proc)) @@ -1200,19 +1212,6 @@ impl SystemActor { #[async_trait] impl Actor for SystemActor { - type Params = SystemActorParams; - - async fn new(params: SystemActorParams) -> Result { - let supervision_update_timeout = params.supervision_update_timeout.clone(); - Ok(Self { - params, - supervision_state: SystemSupervisionState::new(supervision_update_timeout), - worlds: HashMap::new(), - worlds_to_stop: HashMap::new(), - shutting_down: false, - }) - } - async fn init(&mut self, cx: &Instance) -> Result<(), anyhow::Error> { // Start to periodically check the unhealthy worlds. cx.self_message_with_delay(MaintainWorldHealth {}, Duration::from_secs(0))?; @@ -1859,7 +1858,6 @@ mod tests { use hyperactor::mailbox::PortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::simnet; - use hyperactor::test_utils::pingpong::PingPongActorParams; use super::*; use crate::System; @@ -2291,14 +2289,18 @@ mod tests { // Spawn two actors 'ping' and 'pong' where 'ping' runs on // 'world[0]' and 'pong' on 'world[1]' (that is, not on the // same proc). - let ping_params = PingPongActorParams::new(Some(proc_0_undeliverable_tx.bind()), None); let ping_handle = proc_0 - .spawn::("ping", ping_params) + .spawn( + "ping", + PingPongActor::new(Some(proc_0_undeliverable_tx.bind()), None, None), + ) .await .unwrap(); - let pong_params = PingPongActorParams::new(Some(proc_1_undeliverable_tx.bind()), None); let pong_handle = proc_1 - .spawn::("pong", pong_params) + .spawn( + "pong", + PingPongActor::new(Some(proc_1_undeliverable_tx.bind()), None, None), + ) .await .unwrap(); diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 11c5ab587..f2b5fe13f 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -90,7 +90,10 @@ impl LoggingMeshClient { fn spawn(instance: PyInstance, proc_mesh: &PyProcMesh) -> PyResult { let proc_mesh = proc_mesh.try_inner()?; PyPythonTask::new(async move { - let client_actor = proc_mesh.client_proc().spawn("log_client", ()).await?; + let client_actor = proc_mesh + .client_proc() + .spawn("log_client", LogClientActor::default()) + .await?; let client_actor_ref = client_actor.bind(); let forwarder_mesh = instance_dispatch!(instance, |cx| { proc_mesh diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 8d7ed8738..201679b67 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -145,11 +145,12 @@ impl _Controller { .proc() .spawn( &Name::new("mesh_controller").to_string(), - MeshControllerActorParams { + MeshControllerActor::new(MeshControllerActorParams { proc_mesh, id, rank_map, - }, + }) + .await, ) .await? }); @@ -683,7 +684,33 @@ struct MeshControllerActor { rank_map: Option>, } +struct MeshControllerActorParams { + proc_mesh: SharedCell, + id: usize, + rank_map: Option>, +} + impl MeshControllerActor { + async fn new( + MeshControllerActorParams { + proc_mesh, + id, + rank_map, + }: MeshControllerActorParams, + ) -> Self { + let world_size = proc_mesh.borrow().unwrap().shape().slice().len(); + MeshControllerActor { + proc_mesh, + workers: None, + brokers: None, + history: History::new(world_size), + id, + debugger_active: None, + debugger_paused: VecDeque::new(), + rank_map, + } + } + fn workers(&self) -> SharedCellRef> { self.workers.as_ref().unwrap().borrow().unwrap() } @@ -764,41 +791,8 @@ impl MeshControllerActor { } } -impl Debug for MeshControllerActor { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MeshControllerActor").finish() - } -} - -struct MeshControllerActorParams { - proc_mesh: SharedCell, - id: usize, - rank_map: Option>, -} - #[async_trait] impl Actor for MeshControllerActor { - type Params = MeshControllerActorParams; - async fn new( - MeshControllerActorParams { - proc_mesh, - id, - rank_map, - }: Self::Params, - ) -> Result { - let world_size = proc_mesh.borrow().unwrap().shape().slice().len(); - Ok(MeshControllerActor { - proc_mesh, - workers: None, - brokers: None, - history: History::new(world_size), - id, - debugger_active: None, - debugger_paused: VecDeque::new(), - rank_map, - }) - } - async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { let controller_actor_ref: ActorRef = this.bind(); let proc_mesh = self.proc_mesh.borrow().unwrap(); @@ -828,6 +822,12 @@ impl Actor for MeshControllerActor { } } +impl Debug for MeshControllerActor { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MeshControllerActor").finish() + } +} + impl MeshControllerActor { fn rank_of_worker(&self, actor_id: &ActorId) -> usize { if actor_id.proc_id().is_ranked() { diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index a5213c743..bc6aa34b3 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -22,6 +22,7 @@ use hyperactor::Named; use hyperactor::OncePortHandle; use hyperactor::PortHandle; use hyperactor::ProcId; +use hyperactor::RemoteSpawn; use hyperactor::actor::ActorError; use hyperactor::actor::ActorErrorKind; use hyperactor::actor::ActorStatus; @@ -556,25 +557,6 @@ fn update_undeliverable_envelope_for_casting( #[async_trait] impl Actor for PythonActor { - type Params = PickledPyObject; - - async fn new(actor_type: PickledPyObject) -> Result { - Ok(Python::with_gil(|py| -> Result { - let unpickled = actor_type.unpickle(py)?; - let class_type: &Bound<'_, PyType> = unpickled.downcast()?; - let actor: PyObject = class_type.call0()?.into_py_any(py)?; - - // Only create per-actor TaskLocals if not using shared runtime - let task_locals = (!hyperactor::config::global::get(SHARED_ASYNCIO_RUNTIME)) - .then(|| Python::allow_threads(py, create_task_locals)); - Ok(Self { - actor, - task_locals, - instance: None, - }) - })?) - } - async fn cleanup( &mut self, this: &Instance, @@ -709,6 +691,28 @@ impl Actor for PythonActor { } } +#[async_trait] +impl RemoteSpawn for PythonActor { + type Params = PickledPyObject; + + async fn new(actor_type: PickledPyObject) -> Result { + Ok(Python::with_gil(|py| -> Result { + let unpickled = actor_type.unpickle(py)?; + let class_type: &Bound<'_, PyType> = unpickled.downcast()?; + let actor: PyObject = class_type.call0()?.into_py_any(py)?; + + // Only create per-actor TaskLocals if not using shared runtime + let task_locals = (!hyperactor::config::global::get(SHARED_ASYNCIO_RUNTIME)) + .then(|| Python::allow_threads(py, create_task_locals)); + Ok(Self { + actor, + task_locals, + instance: None, + }) + })?) + } +} + /// Create a new TaskLocals with its own asyncio event loop in a dedicated thread. fn create_task_locals() -> pyo3_async_runtimes::TaskLocals { Python::with_gil(|py| { diff --git a/monarch_hyperactor/src/code_sync/auto_reload.rs b/monarch_hyperactor/src/code_sync/auto_reload.rs index bc8e34536..dceb66781 100644 --- a/monarch_hyperactor/src/code_sync/auto_reload.rs +++ b/monarch_hyperactor/src/code_sync/auto_reload.rs @@ -15,6 +15,7 @@ use hyperactor::Context; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use monarch_types::SerializablePyErr; use pyo3::prelude::*; use serde::Deserialize; @@ -37,11 +38,19 @@ pub struct AutoReloadActor { state: Result<(Arc, PyObject), SerializablePyErr>, } +impl Actor for AutoReloadActor {} + #[async_trait] -impl Actor for AutoReloadActor { +impl RemoteSpawn for AutoReloadActor { type Params = AutoReloadParams; async fn new(Self::Params {}: Self::Params) -> Result { + AutoReloadActor::new().await + } +} + +impl AutoReloadActor { + pub(crate) async fn new() -> Result { Ok(Self { state: tokio::task::spawn_blocking(move || { Python::with_gil(|py| { @@ -51,9 +60,7 @@ impl Actor for AutoReloadActor { .await?, }) } -} -impl AutoReloadActor { fn create_state(py: Python) -> PyResult<(Arc, PyObject)> { // Import the Python AutoReloader class let auto_reload_module = py.import("monarch._src.actor.code_sync.auto_reload")?; diff --git a/monarch_hyperactor/src/code_sync/conda_sync.rs b/monarch_hyperactor/src/code_sync/conda_sync.rs index 079e3163f..517f27716 100644 --- a/monarch_hyperactor/src/code_sync/conda_sync.rs +++ b/monarch_hyperactor/src/code_sync/conda_sync.rs @@ -19,6 +19,7 @@ use hyperactor::Bind; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::connect::Connect; @@ -61,18 +62,11 @@ pub struct CondaSyncMessage { #[derive(Debug, Named, Serialize, Deserialize)] pub struct CondaSyncParams {} -#[derive(Debug)] +#[derive(Debug, Default)] #[hyperactor::export(spawn = true, handlers = [CondaSyncMessage { cast = true }])] pub struct CondaSyncActor {} -#[async_trait] -impl Actor for CondaSyncActor { - type Params = CondaSyncParams; - - async fn new(CondaSyncParams {}: Self::Params) -> Result { - Ok(Self {}) - } -} +impl Actor for CondaSyncActor {} #[async_trait] impl Handler for CondaSyncActor { diff --git a/monarch_hyperactor/src/code_sync/manager.rs b/monarch_hyperactor/src/code_sync/manager.rs index bcad94185..4f14a3d62 100644 --- a/monarch_hyperactor/src/code_sync/manager.rs +++ b/monarch_hyperactor/src/code_sync/manager.rs @@ -30,6 +30,7 @@ use hyperactor::Context; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::context; use hyperactor::forward; @@ -62,15 +63,12 @@ use tokio::net::TcpStream; use crate::code_sync::WorkspaceLocation; use crate::code_sync::auto_reload::AutoReloadActor; use crate::code_sync::auto_reload::AutoReloadMessage; -use crate::code_sync::auto_reload::AutoReloadParams; use crate::code_sync::conda_sync::CondaSyncActor; use crate::code_sync::conda_sync::CondaSyncMessage; -use crate::code_sync::conda_sync::CondaSyncParams; use crate::code_sync::conda_sync::CondaSyncResult; use crate::code_sync::rsync::RsyncActor; use crate::code_sync::rsync::RsyncDaemon; use crate::code_sync::rsync::RsyncMessage; -use crate::code_sync::rsync::RsyncParams; use crate::code_sync::rsync::RsyncResult; #[derive(Clone, Serialize, Deserialize, Debug)] @@ -212,8 +210,10 @@ pub struct CodeSyncManager { rank: once_cell::sync::OnceCell, } +impl Actor for CodeSyncManager {} + #[async_trait] -impl Actor for CodeSyncManager { +impl RemoteSpawn for CodeSyncManager { type Params = CodeSyncManagerParams; async fn new(CodeSyncManagerParams {}: Self::Params) -> Result { @@ -233,7 +233,7 @@ impl CodeSyncManager { cx: &Context<'a, Self>, ) -> Result<&'a ActorHandle> { self.rsync - .get_or_try_init(RsyncActor::spawn(cx, RsyncParams {})) + .get_or_try_init(RsyncActor::default().spawn(cx)) .await } @@ -242,7 +242,7 @@ impl CodeSyncManager { cx: &Context<'a, Self>, ) -> Result<&'a ActorHandle> { self.auto_reload - .get_or_try_init(AutoReloadActor::spawn(cx, AutoReloadParams {})) + .get_or_try_init(async move { AutoReloadActor::new().await?.spawn(cx).await }) .await } @@ -251,7 +251,7 @@ impl CodeSyncManager { cx: &Context<'a, Self>, ) -> Result<&'a ActorHandle> { self.conda_sync - .get_or_try_init(CondaSyncActor::spawn(cx, CondaSyncParams {})) + .get_or_try_init(CondaSyncActor::default().spawn(cx)) .await } } diff --git a/monarch_hyperactor/src/code_sync/rsync.rs b/monarch_hyperactor/src/code_sync/rsync.rs index 31bd50504..82f19c935 100644 --- a/monarch_hyperactor/src/code_sync/rsync.rs +++ b/monarch_hyperactor/src/code_sync/rsync.rs @@ -28,6 +28,7 @@ use hyperactor::Bind; use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; @@ -336,25 +337,13 @@ pub struct RsyncMessage { pub workspace: WorkspaceLocation, } -#[derive(Debug, Named, Serialize, Deserialize)] -pub struct RsyncParams { - //pub workspace: WorkspaceLocation, -} - -#[derive(Debug)] +#[derive(Debug, Default)] #[hyperactor::export(spawn = true, handlers = [RsyncMessage { cast = true }])] pub struct RsyncActor { //workspace: WorkspaceLocation, } -#[async_trait] -impl Actor for RsyncActor { - type Params = RsyncParams; - - async fn new(RsyncParams {}: Self::Params) -> Result { - Ok(Self {}) - } -} +impl Actor for RsyncActor {} #[async_trait] impl Handler for RsyncActor { @@ -529,16 +518,13 @@ mod tests { let proc_mesh = ProcMesh::allocate(alloc).await?; - // Create RsyncParams - all actors will use the same target workspace for this test - let params = RsyncParams {}; - // TODO: thread through context, or access the actual python context; // for now this is basically equivalent (arguably better) to using the proc mesh client. let instance = global_root_client(); // Spawn actor mesh with RsyncActors let actor_mesh = proc_mesh - .spawn::(&instance, "rsync_test", ¶ms) + .spawn::(&instance, "rsync_test", &()) .await?; // Test rsync_mesh function - this coordinates rsync operations across the mesh diff --git a/monarch_hyperactor/src/local_state_broker.rs b/monarch_hyperactor/src/local_state_broker.rs index 6225a0fc9..c0b37fb99 100644 --- a/monarch_hyperactor/src/local_state_broker.rs +++ b/monarch_hyperactor/src/local_state_broker.rs @@ -16,6 +16,7 @@ use hyperactor::ActorRef; use hyperactor::Context; use hyperactor::Handler; use hyperactor::OncePortHandle; +use hyperactor::RemoteSpawn; use pyo3::prelude::*; #[derive(Debug)] @@ -30,13 +31,15 @@ pub enum LocalStateBrokerMessage { Get(usize, OncePortHandle), } -#[derive(Debug, Default, Actor)] +#[derive(Debug, Default)] #[hyperactor::export(spawn = true)] pub struct LocalStateBrokerActor { states: HashMap, ports: HashMap>, } +impl Actor for LocalStateBrokerActor {} + #[async_trait] impl Handler for LocalStateBrokerActor { async fn handle( diff --git a/monarch_hyperactor/src/logging.rs b/monarch_hyperactor/src/logging.rs index 8b99c09b8..5a9dd795c 100644 --- a/monarch_hyperactor/src/logging.rs +++ b/monarch_hyperactor/src/logging.rs @@ -19,6 +19,7 @@ use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Named; use hyperactor::RefClient; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use monarch_types::SerializablePyErr; use pyo3::prelude::*; @@ -63,9 +64,10 @@ impl LoggerRuntimeActor { Ok(()) } } +impl Actor for LoggerRuntimeActor {} #[async_trait] -impl Actor for LoggerRuntimeActor { +impl RemoteSpawn for LoggerRuntimeActor { type Params = (); async fn new(_: ()) -> Result { diff --git a/monarch_hyperactor/src/proc.rs b/monarch_hyperactor/src/proc.rs index 86bac31b7..cc2cb2f32 100644 --- a/monarch_hyperactor/src/proc.rs +++ b/monarch_hyperactor/src/proc.rs @@ -26,6 +26,7 @@ use std::time::SystemTime; use anyhow::Result; use hyperactor::ActorRef; use hyperactor::RemoteMessage; +use hyperactor::RemoteSpawn; use hyperactor::actor::Signal; use hyperactor::channel; use hyperactor::channel::ChannelAddr; @@ -61,6 +62,7 @@ use pyo3::types::PyType; use tokio::sync::OnceCell; use tokio::sync::watch; +use crate::actor::PythonActor; use crate::actor::PythonActorHandle; use crate::mailbox::PyMailbox; use crate::runtime::get_tokio_runtime; @@ -146,7 +148,10 @@ impl PyProc { crate::runtime::future_into_py(py, async move { Ok(PythonActorHandle { inner: proc - .spawn(name.as_deref().unwrap_or("anon"), pickled_type) + .spawn( + name.as_deref().unwrap_or("anon"), + PythonActor::new(pickled_type).await?, + ) .await?, }) }) @@ -163,8 +168,11 @@ impl PyProc { let pickled_type = PickledPyObject::pickle(actor.as_any())?; Ok(PythonActorHandle { inner: signal_safe_block_on(py, async move { - proc.spawn(name.as_deref().unwrap_or("anon"), pickled_type) - .await + proc.spawn( + name.as_deref().unwrap_or("anon"), + PythonActor::new(pickled_type).await?, + ) + .await }) .map_err(|e| PyRuntimeError::new_err(e.to_string()))??, }) diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 938495f62..5198ee2c9 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -11,10 +11,8 @@ use std::fmt::Display; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use hyperactor::Actor; -use hyperactor::RemoteMessage; +use hyperactor::RemoteSpawn; use hyperactor::WorldId; -use hyperactor::actor::Referable; use hyperactor::context; use hyperactor::context::Mailbox as _; use hyperactor::proc::Instance; @@ -89,15 +87,12 @@ impl From for TrackedProcMesh { } impl TrackedProcMesh { - pub async fn spawn( + pub async fn spawn( &self, cx: &impl context::Actor, actor_name: &str, params: &A::Params, - ) -> Result>, anyhow::Error> - where - A::Params: RemoteMessage, - { + ) -> Result>, anyhow::Error> { let mesh = self.cell.borrow()?; let actor = mesh.spawn(cx, actor_name, params).await?; Ok(self.children.insert(actor)) diff --git a/monarch_hyperactor/src/v1/actor_mesh.rs b/monarch_hyperactor/src/v1/actor_mesh.rs index fd773f11e..b1199c6ee 100644 --- a/monarch_hyperactor/src/v1/actor_mesh.rs +++ b/monarch_hyperactor/src/v1/actor_mesh.rs @@ -18,7 +18,7 @@ use hyperactor::RemoteMessage; use hyperactor::actor::ActorErrorKind; use hyperactor::actor::ActorStatus; use hyperactor::actor::Referable; -use hyperactor::actor::RemotableActor; +use hyperactor::actor::RemoteSpawn; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::context; @@ -529,7 +529,7 @@ async fn actor_states_monitor( canceled: CancellationToken, supervision_display_name: String, ) where - A: Actor + RemotableActor + Referable, + A: Actor + RemoteSpawn + Referable, A::Params: RemoteMessage, F: Fn(MeshFailure), { diff --git a/monarch_hyperactor/src/v1/logging.rs b/monarch_hyperactor/src/v1/logging.rs index be6e9462d..ceb1de624 100644 --- a/monarch_hyperactor/src/v1/logging.rs +++ b/monarch_hyperactor/src/v1/logging.rs @@ -205,7 +205,10 @@ impl LoggingMeshClient { instance_dispatch!(instance, async move |cx_instance| { cx_instance .proc() - .spawn(&Name::new("log_client").to_string(), ()) + .spawn( + &Name::new("log_client").to_string(), + LogClientActor::default(), + ) .await })?; let client_actor_ref = client_actor.bind(); diff --git a/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs b/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs index 27e41d79d..b65993b4b 100644 --- a/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs +++ b/monarch_rdma/examples/cuda_ping_pong/src/cuda_ping_pong.rs @@ -67,6 +67,7 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::Proc; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::channel::ChannelTransport; use hyperactor::supervision::ActorSupervisionEvent; @@ -258,6 +259,19 @@ pub struct CudaRdmaActor { #[async_trait] impl Actor for CudaRdmaActor { + async fn handle_supervision_event( + &mut self, + _cx: &Instance, + _event: &ActorSupervisionEvent, + ) -> Result { + tracing::error!("CudaRdmaActor supervision event: {:?}", _event); + tracing::error!("CudaRdmaActor error occurred, stop the worker process, exit code: 1"); + std::process::exit(1); + } +} + +#[async_trait] +impl RemoteSpawn for CudaRdmaActor { type Params = (ActorRef, usize, usize); async fn new(params: Self::Params) -> Result { @@ -354,16 +368,6 @@ impl Actor for CudaRdmaActor { }) } } - - async fn handle_supervision_event( - &mut self, - _cx: &Instance, - _event: &ActorSupervisionEvent, - ) -> Result { - tracing::error!("CudaRdmaActor supervision event: {:?}", _event); - tracing::error!("CudaRdmaActor error occurred, stop the worker process, exit code: 1"); - std::process::exit(1); - } } // Message to initialize the buffer with data diff --git a/monarch_rdma/examples/parameter_server/src/parameter_server.rs b/monarch_rdma/examples/parameter_server/src/parameter_server.rs index 2d694d5fe..6d4ca76ae 100644 --- a/monarch_rdma/examples/parameter_server/src/parameter_server.rs +++ b/monarch_rdma/examples/parameter_server/src/parameter_server.rs @@ -66,6 +66,7 @@ use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::PortRef; use hyperactor::Proc; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::channel::ChannelTransport; use hyperactor::context::Mailbox as _; @@ -111,6 +112,21 @@ pub struct ParameterServerActor { #[async_trait] impl Actor for ParameterServerActor { + async fn handle_supervision_event( + &mut self, + _cx: &Instance, + _event: &ActorSupervisionEvent, + ) -> Result { + tracing::error!("parameterServerActor supervision event: {:?}", _event); + tracing::error!( + "parameterServerActor error occurred, stop the worker process, exit code: 1" + ); + std::process::exit(1); + } +} + +#[async_trait] +impl RemoteSpawn for ParameterServerActor { type Params = (ActorRef, usize); async fn new(_params: Self::Params) -> Result { @@ -129,18 +145,6 @@ impl Actor for ParameterServerActor { owner_ref, }) } - - async fn handle_supervision_event( - &mut self, - _cx: &Instance, - _event: &ActorSupervisionEvent, - ) -> Result { - tracing::error!("parameterServerActor supervision event: {:?}", _event); - tracing::error!( - "parameterServerActor error occurred, stop the worker process, exit code: 1" - ); - std::process::exit(1); - } } // Message to get handles to the parameter server's weights and gradient buffers. @@ -246,6 +250,19 @@ pub struct WorkerActor { #[async_trait] impl Actor for WorkerActor { + async fn handle_supervision_event( + &mut self, + _cx: &Instance, + _event: &ActorSupervisionEvent, + ) -> Result { + tracing::error!("workerActor supervision event: {:?}", _event); + tracing::error!("workerActor error occurred, stop the worker process, exit code: 1"); + std::process::exit(1); + } +} + +#[async_trait] +impl RemoteSpawn for WorkerActor { type Params = (); async fn new(_params: Self::Params) -> Result { @@ -259,16 +276,6 @@ impl Actor for WorkerActor { rdma_manager: None, }) } - - async fn handle_supervision_event( - &mut self, - _cx: &Instance, - _event: &ActorSupervisionEvent, - ) -> Result { - tracing::error!("workerActor supervision event: {:?}", _event); - tracing::error!("workerActor error occurred, stop the worker process, exit code: 1"); - std::process::exit(1); - } } // Message to initialize the worker. diff --git a/monarch_rdma/src/rdma_manager_actor.rs b/monarch_rdma/src/rdma_manager_actor.rs index 936e2620f..290771232 100644 --- a/monarch_rdma/src/rdma_manager_actor.rs +++ b/monarch_rdma/src/rdma_manager_actor.rs @@ -45,6 +45,7 @@ use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::RemoteSpawn; use hyperactor::clock::Clock; use hyperactor::supervision::ActorSupervisionEvent; use serde::Deserialize; @@ -508,7 +509,7 @@ impl RdmaManagerActor { } #[async_trait] -impl Actor for RdmaManagerActor { +impl RemoteSpawn for RdmaManagerActor { type Params = Option; async fn new(params: Self::Params) -> Result { @@ -561,7 +562,10 @@ impl Actor for RdmaManagerActor { pci_to_device, }) } +} +#[async_trait] +impl Actor for RdmaManagerActor { async fn init(&mut self, _this: &Instance) -> Result<(), anyhow::Error> { tracing::debug!("RdmaManagerActor initialized with lazy domain/QP creation"); Ok(()) diff --git a/monarch_rdma/src/test_utils.rs b/monarch_rdma/src/test_utils.rs index 58d2a88ea..2f0370d59 100644 --- a/monarch_rdma/src/test_utils.rs +++ b/monarch_rdma/src/test_utils.rs @@ -85,6 +85,7 @@ pub mod test_utils { use hyperactor::Instance; use hyperactor::Proc; use hyperactor::RefClient; + use hyperactor::RemoteSpawn; use hyperactor::channel::ChannelTransport; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; @@ -124,8 +125,10 @@ pub mod test_utils { context: SendSyncCudaContext, } + impl Actor for CudaActor {} + #[async_trait::async_trait] - impl Actor for CudaActor { + impl RemoteSpawn for CudaActor { type Params = i32; async fn new(device_id: i32) -> Result { diff --git a/monarch_tensor_worker/src/borrow.rs b/monarch_tensor_worker/src/borrow.rs index 35957e6d7..9af653348 100644 --- a/monarch_tensor_worker/src/borrow.rs +++ b/monarch_tensor_worker/src/borrow.rs @@ -175,6 +175,7 @@ mod tests { use anyhow::Context; use anyhow::Result; + use hyperactor::RemoteSpawn; use hyperactor::proc::Proc; use monarch_messages::controller::ControllerMessage; use monarch_messages::worker::WorkerMessage; @@ -200,12 +201,13 @@ mod tests { let worker_handle = proc .spawn::( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await?, ) .await .unwrap(); @@ -337,12 +339,14 @@ mod tests { let worker_handle = proc .spawn::( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); diff --git a/monarch_tensor_worker/src/comm.rs b/monarch_tensor_worker/src/comm.rs index edaafed24..14831ab19 100644 --- a/monarch_tensor_worker/src/comm.rs +++ b/monarch_tensor_worker/src/comm.rs @@ -163,6 +163,8 @@ impl NcclCommActor { } } +impl Actor for NcclCommActor {} + /// Initialization parameters for `NcclCommActor`. #[derive(Debug, Clone)] pub enum CommParams { @@ -186,11 +188,8 @@ pub enum CommParams { FromComm(Arc>), } -#[async_trait] -impl Actor for NcclCommActor { - type Params = CommParams; - - async fn new(params: Self::Params) -> Result { +impl NcclCommActor { + pub async fn new(params: CommParams) -> Result { match params { CommParams::New { device, @@ -198,6 +197,7 @@ impl Actor for NcclCommActor { world_size, rank, } => { + // TODO: this should probalby be done in the actor's 'init' let comm = spawn_blocking(move || Communicator::new(device, world_size, unique_id, rank)) .await @@ -251,7 +251,10 @@ impl CommMessageHandler for NcclCommActor { .await .unwrap()?; - NcclCommActor::spawn(cx, CommParams::FromComm(Arc::new(Mutex::new(split_comm)))).await + NcclCommActor::new(CommParams::FromComm(Arc::new(Mutex::new(split_comm)))) + .await? + .spawn(cx) + .await } async fn split_from( @@ -268,7 +271,9 @@ impl CommMessageHandler for NcclCommActor { match split_comm { Some(split_comm) => Ok(Some( - NcclCommActor::spawn(cx, CommParams::FromComm(Arc::new(Mutex::new(split_comm)))) + NcclCommActor::new(CommParams::FromComm(Arc::new(Mutex::new(split_comm)))) + .await? + .spawn(cx) .await?, )), None => Ok(None), @@ -1029,6 +1034,7 @@ mod tests { use anyhow::Result; use futures::future::try_join_all; + use hyperactor::RemoteSpawn; use hyperactor::actor::ActorStatus; use hyperactor::proc::Proc; use monarch_messages::worker::WorkerMessageClient; @@ -1058,28 +1064,26 @@ mod tests { let unique_id = UniqueId::new().unwrap(); let device0 = CudaDevice::new(DeviceIndex(0)); - let handle0 = proc.spawn::( - "comm0", - CommParams::New { - device: device0, - unique_id: unique_id.clone(), - world_size: 2, - rank: 0, - }, - ); + let actor0 = NcclCommActor::new(CommParams::New { + device: device0, + unique_id: unique_id.clone(), + world_size: 2, + rank: 0, + }); let device1 = CudaDevice::new(DeviceIndex(1)); - let handle1 = proc.spawn::( - "comm1", - CommParams::New { - device: device1, - unique_id, - world_size: 2, - rank: 1, - }, - ); - let (handle0, handle1) = tokio::join!(handle0, handle1); - let (handle0, handle1) = (handle0.unwrap(), handle1.unwrap()); + let actor1 = NcclCommActor::new(CommParams::New { + device: device1, + unique_id, + world_size: 2, + rank: 1, + }); + + let (actor0, actor1) = tokio::join!(actor0, actor1); + let (actor0, actor1) = (actor0.unwrap(), actor1.unwrap()); + + let handle0 = actor0.spawn_detached().await.unwrap(); + let handle1 = actor1.spawn_detached().await.unwrap(); let cell0 = TensorCell::new(factory_float_tensor(&[1.0], device0.into())); @@ -1127,7 +1131,7 @@ mod tests { let unique_id = UniqueId::new().unwrap(); let device0 = CudaDevice::new(DeviceIndex(0)); - let handle0 = NcclCommActor::spawn_detached(CommParams::New { + let actor0 = NcclCommActor::new(CommParams::New { device: device0, unique_id: unique_id.clone(), world_size: 2, @@ -1135,14 +1139,18 @@ mod tests { }); let device1 = CudaDevice::new(DeviceIndex(1)); - let handle1 = NcclCommActor::spawn_detached(CommParams::New { + let actor1 = NcclCommActor::new(CommParams::New { device: device1, unique_id, world_size: 2, rank: 1, }); - let (handle0, handle1) = tokio::join!(handle0, handle1); - let (handle0, handle1) = (handle0.unwrap(), handle1.unwrap()); + + let (actor0, actor1) = tokio::join!(actor0, actor1); + let (actor0, actor1) = (actor0.unwrap(), actor1.unwrap()); + + let handle0 = actor0.spawn_detached().await.unwrap(); + let handle1 = actor1.spawn_detached().await.unwrap(); let cell0 = TensorCell::new(factory_float_tensor(&[1.0], device0.into())); @@ -1198,28 +1206,24 @@ mod tests { let unique_id = UniqueId::new()?; let device0 = CudaDevice::new(DeviceIndex(0)); - let handle0 = proc.spawn::( - "comm0", - CommParams::New { - device: device0, - unique_id: unique_id.clone(), - world_size: 2, - rank: 0, - }, - ); - + let actor0 = NcclCommActor::new(CommParams::New { + device: device0, + unique_id: unique_id.clone(), + world_size: 2, + rank: 0, + }); let device1 = CudaDevice::new(DeviceIndex(1)); - let handle1 = proc.spawn::( - "comm1", - CommParams::New { - device: device1, - unique_id, - world_size: 2, - rank: 1, - }, - ); - let (handle0, handle1) = tokio::join!(handle0, handle1); - let (handle0, handle1) = (handle0.unwrap(), handle1.unwrap()); + let actor1 = NcclCommActor::new(CommParams::New { + device: device1, + unique_id, + world_size: 2, + rank: 1, + }); + let (actor0, actor1) = tokio::join!(actor0, actor1); + let (actor0, actor1) = (actor0.unwrap(), actor1.unwrap()); + + let handle0 = proc.spawn("comm0", actor0).await.unwrap(); + let handle1 = proc.spawn("comm1", actor1).await.unwrap(); let cell0 = TensorCell::new(factory_float_tensor(&[1.0], device0.into())); let dest_rank = 0; @@ -1268,14 +1272,16 @@ mod tests { let world_size = 4; let workers = try_join_all((0..world_size).map(async |rank| { - proc.spawn::( + proc.spawn( &format!("worker{}", rank), - WorkerParams { + WorkerActor::new(WorkerParams { world_size, rank, device_index: Some(rank.try_into()?), controller_actor: controller_ref.clone(), - }, + }) + .await + .unwrap(), ) .await })) @@ -1442,26 +1448,30 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let handle1 = proc - .spawn::( + .spawn( "worker1", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 2, rank: 0, device_index: Some(0), controller_actor: controller_ref.clone(), - }, + }) + .await + .unwrap(), ) .await .unwrap(); let handle2 = proc - .spawn::( + .spawn( "worker2", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 2, rank: 1, device_index: Some(1), controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1610,14 +1620,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: Some(0), controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1736,7 +1748,7 @@ mod tests { let unique_id = UniqueId::new()?; let device0 = CudaDevice::new(DeviceIndex(0)); - let handle0 = NcclCommActor::spawn_detached(CommParams::New { + let actor0 = NcclCommActor::new(CommParams::New { device: device0, unique_id: unique_id.clone(), world_size: 2, @@ -1744,14 +1756,18 @@ mod tests { }); let device1 = CudaDevice::new(DeviceIndex(1)); - let handle1 = NcclCommActor::spawn_detached(CommParams::New { + let actor1 = NcclCommActor::new(CommParams::New { device: device1, unique_id, world_size: 2, rank: 1, }); - let (handle0, handle1) = tokio::join!(handle0, handle1); - let (handle0, handle1) = (handle0?, handle1?); + + let (actor0, actor1) = tokio::join!(actor0, actor1); + let (actor0, actor1) = (actor0?, actor1?); + + let handle0 = actor0.spawn_detached().await.unwrap(); + let handle1 = actor1.spawn_detached().await.unwrap(); let cell0 = TensorCell::new(factory_float_tensor(&[1.0], device0.into())); let port0 = client.open_once_port(); diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index bc8a5bbbb..cf7fe90b5 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -56,6 +56,7 @@ use hyperactor::ActorRef; use hyperactor::Bind; use hyperactor::Handler; use hyperactor::Named; +use hyperactor::RemoteSpawn; use hyperactor::Unbind; use hyperactor::actor::ActorHandle; use hyperactor::context; @@ -216,8 +217,10 @@ impl WorkerActor { } } +impl Actor for WorkerActor {} + #[async_trait] -impl Actor for WorkerActor { +impl RemoteSpawn for WorkerActor { type Params = WorkerParams; async fn new( @@ -297,15 +300,14 @@ impl WorkerMessageHandler for WorkerActor { let device = self .device .expect("tried to init backend network on a non-CUDA worker"); - let comm = NcclCommActor::spawn( - cx, - CommParams::New { - device, - unique_id, - world_size: self.world_size.try_into().unwrap(), - rank: self.rank.try_into().unwrap(), - }, - ) + let comm = NcclCommActor::new(CommParams::New { + device, + unique_id, + world_size: self.world_size.try_into().unwrap(), + rank: self.rank.try_into().unwrap(), + }) + .await? + .spawn(cx) .await?; let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into()); @@ -440,18 +442,16 @@ impl WorkerMessageHandler for WorkerActor { result: StreamRef, creation_mode: StreamCreationMode, ) -> Result<()> { - let handle: ActorHandle = StreamActor::spawn( - cx, - StreamParams { - world_size: self.world_size, - rank: self.rank, - creation_mode, - id: result, - device: self.device, - controller_actor: self.controller_actor.clone(), - respond_with_python_message: self.respond_with_python_message, - }, - ) + let handle: ActorHandle = StreamActor::new(StreamParams { + world_size: self.world_size, + rank: self.rank, + creation_mode, + id: result, + device: self.device, + controller_actor: self.controller_actor.clone(), + respond_with_python_message: self.respond_with_python_message, + }) + .spawn(cx) .await?; self.streams.insert(result, Arc::new(handle)); Ok(()) @@ -1115,6 +1115,7 @@ mod tests { use anyhow::Result; use hyperactor::Instance; + use hyperactor::RemoteSpawn; use hyperactor::WorldId; use hyperactor::actor::ActorStatus; use hyperactor::channel::ChannelAddr; @@ -1162,14 +1163,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1255,14 +1258,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1317,12 +1322,14 @@ mod tests { let worker_handle = proc .spawn::( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1386,14 +1393,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1459,14 +1468,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1747,14 +1758,16 @@ mod tests { let (client, controller_ref, _) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1821,14 +1834,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1902,26 +1917,30 @@ mod tests { let (client, controller_ref, _) = proc.attach_actor("controller").unwrap(); let worker_handle1 = proc - .spawn::( + .spawn( "worker0", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 2, rank: 0, device_index: Some(0), controller_actor: controller_ref.clone(), - }, + }) + .await + .unwrap(), ) .await .unwrap(); let worker_handle2 = proc - .spawn::( + .spawn( "worker1", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 2, rank: 1, device_index: Some(1), controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -1950,14 +1969,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); @@ -2039,14 +2060,16 @@ mod tests { let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap(); let worker_handle = proc - .spawn::( + .spawn( "worker", - WorkerParams { + WorkerActor::new(WorkerParams { world_size: 1, rank: 0, device_index: None, controller_actor: controller_ref, - }, + }) + .await + .unwrap(), ) .await .unwrap(); diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 605cda893..6cc08a193 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -465,10 +465,8 @@ pub struct StreamParams { pub respond_with_python_message: bool, } -#[async_trait] -impl Actor for StreamActor { - type Params = StreamParams; - async fn new( +impl StreamActor { + pub fn new( StreamParams { world_size, rank, @@ -477,9 +475,9 @@ impl Actor for StreamActor { controller_actor, creation_mode, respond_with_python_message, - }: Self::Params, - ) -> Result { - Ok(Self { + }: StreamParams, + ) -> Self { + Self { world_size, rank, env: HashMap::new(), @@ -493,9 +491,12 @@ impl Actor for StreamActor { active_recording: None, respond_with_python_message, last_seq_error: None, - }) + } } +} +#[async_trait] +impl Actor for StreamActor { async fn init(&mut self, cx: &Instance) -> Result<()> { // These thread locals are exposed via python functions, so we need to set them in the // same thread that python will run in. That means we need to initialize them here in @@ -2095,9 +2096,9 @@ mod tests { let (supervision_tx, supervision_rx) = client.open_port(); proc.set_supervision_coordinator(supervision_tx)?; let stream_actor = proc - .spawn::( + .spawn( "stream", - StreamParams { + StreamActor::new(StreamParams { world_size, rank: 0, creation_mode: StreamCreationMode::UseDefaultStream, @@ -2105,7 +2106,7 @@ mod tests { device: Some(CudaDevice::new(0.into())), controller_actor: controller_actor.clone(), respond_with_python_message: false, - }, + }), ) .await?; @@ -2578,14 +2579,16 @@ mod tests { let dummy_comm = test_setup .proc - .spawn::( + .spawn( "comm", - CommParams::New { + NcclCommActor::new(CommParams::New { device: CudaDevice::new(0.into()), unique_id: UniqueId::new()?, world_size: 1, rank: 0, - }, + }) + .await + .unwrap(), ) .await?; @@ -2971,9 +2974,9 @@ mod tests { let borrower_stream = test_setup .proc - .spawn::( + .spawn( "stream1", - StreamParams { + StreamActor::new(StreamParams { world_size: 1, rank: 0, creation_mode: StreamCreationMode::CreateNewStream, @@ -2981,7 +2984,7 @@ mod tests { device: Some(CudaDevice::new(0.into())), controller_actor: test_setup.controller_actor.clone(), respond_with_python_message: false, - }, + }), ) .await?; @@ -3261,14 +3264,16 @@ mod tests { let comm = Arc::new( test_setup .proc - .spawn::( + .spawn( "comm", - CommParams::New { + NcclCommActor::new(CommParams::New { device: CudaDevice::new(0.into()), unique_id: UniqueId::new()?, world_size: 1, rank: 0, - }, + }) + .await + .unwrap(), ) .await?, ); @@ -3567,25 +3572,25 @@ mod tests { let recording_ref = test_setup.next_ref(); let unique_id = UniqueId::new()?; - let comm0 = test_setup.proc.spawn::( - "comm0", - CommParams::New { - device: CudaDevice::new(0.into()), - unique_id: unique_id.clone(), - world_size: 2, - rank: 0, - }, - ); - let comm1 = test_setup.proc.spawn::( - "comm1", - CommParams::New { - device: CudaDevice::new(1.into()), - unique_id, - world_size: 2, - rank: 1, - }, - ); - let (comm0, comm1) = tokio::try_join!(comm0, comm1)?; + let device0 = CudaDevice::new(0.into()); + let actor0 = NcclCommActor::new(CommParams::New { + device: device0, + unique_id: unique_id.clone(), + world_size: 2, + rank: 0, + }); + let device1 = CudaDevice::new(1.into()); + let actor1 = NcclCommActor::new(CommParams::New { + device: device1, + unique_id, + world_size: 2, + rank: 1, + }); + let (actor0, actor1) = tokio::join!(actor0, actor1); + let (actor0, actor1) = (actor0.unwrap(), actor1.unwrap()); + + let comm0 = test_setup.proc.spawn("comm0", actor0).await.unwrap(); + let comm1 = test_setup.proc.spawn("comm1", actor1).await.unwrap(); let comm0 = Arc::new(comm0); let comm1 = Arc::new(comm1); @@ -3599,9 +3604,9 @@ mod tests { let send_stream = test_setup.stream_actor.clone(); let recv_stream = test_setup .proc - .spawn::( + .spawn( "recv_stream", - StreamParams { + StreamActor::new(StreamParams { world_size: 2, rank: 1, creation_mode: StreamCreationMode::CreateNewStream, @@ -3609,7 +3614,7 @@ mod tests { device: Some(CudaDevice::new(1.into())), controller_actor: test_setup.controller_actor.clone(), respond_with_python_message: false, - }, + }), ) .await?;