diff --git a/controller/src/lib.rs b/controller/src/lib.rs index d703c291a..5828004df 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -634,7 +634,6 @@ mod tests { use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::context::Mailbox as _; - use hyperactor::data::Named; use hyperactor::id; use hyperactor::mailbox::BoxedMailboxSender; use hyperactor::mailbox::DialMailboxRouter; @@ -1129,8 +1128,7 @@ mod tests { // Build a supervisor. let sup_mail = system.attach().await.unwrap(); - let (sup_tx, _sup_rx) = sup_mail.open_port::(); - sup_tx.bind_to(ProcSupervisionMessage::port()); + let (_sup_tx, _sup_rx) = sup_mail.bind_actor_port::(); let sup_ref = ActorRef::::attest(sup_mail.self_id().clone()); // Construct a system sender. @@ -1360,8 +1358,7 @@ mod tests { // Build a supervisor. let sup_mail = system.attach().await.unwrap(); - let (sup_tx, _sup_rx) = sup_mail.open_port::(); - sup_tx.bind_to(ProcSupervisionMessage::port()); + let (_sup_tx, _sup_rx) = sup_mail.bind_actor_port::(); let sup_ref = ActorRef::::attest(sup_mail.self_id().clone()); // Construct a system sender. @@ -1665,9 +1662,8 @@ mod tests { .await .unwrap(); - let (client_supervision_tx, mut client_supervision_rx) = - client_mailbox.open_port::(); - client_supervision_tx.bind_to(ClientMessage::port()); + let (_client_supervision_tx, mut client_supervision_rx) = + client_mailbox.bind_actor_port::(); // mock a proc actor that doesn't update supervision state let ( @@ -1682,14 +1678,17 @@ mod tests { // Join the world. server_handle .system_actor_handle() - .send(SystemMessage::Join { - proc_id: local_proc_id.clone(), - world_id, - proc_message_port: local_proc_message_port.bind(), - proc_addr: local_proc_addr, - labels: HashMap::new(), - lifecycle_mode: ProcLifecycleMode::ManagedBySystem, - }) + .send( + &client_mailbox, + SystemMessage::Join { + proc_id: local_proc_id.clone(), + world_id, + proc_message_port: local_proc_message_port.bind(), + proc_addr: local_proc_addr, + labels: HashMap::new(), + lifecycle_mode: ProcLifecycleMode::ManagedBySystem, + }, + ) .unwrap(); assert_matches!( @@ -1726,9 +1725,8 @@ mod tests { // Client actor. let mut system = System::new(server_handle.local_addr().clone()); let client_mailbox = system.attach().await.unwrap(); - let (client_supervision_tx, mut client_supervision_rx) = - client_mailbox.open_port::(); - client_supervision_tx.bind_to(ClientMessage::port()); + let (_client_supervision_tx, mut client_supervision_rx) = + client_mailbox.bind_actor_port::(); // Bootstrap the controller let controller_id = id!(controller[0].root); @@ -1784,14 +1782,17 @@ mod tests { // Join the world. server_handle .system_actor_handle() - .send(SystemMessage::Join { - proc_id: local_proc_id.clone(), - world_id, - proc_message_port: local_proc_message_port.bind(), - proc_addr: local_proc_addr, - labels: HashMap::new(), - lifecycle_mode: ProcLifecycleMode::ManagedBySystem, - }) + .send( + &client_mailbox, + SystemMessage::Join { + proc_id: local_proc_id.clone(), + world_id, + proc_message_port: local_proc_message_port.bind(), + proc_addr: local_proc_addr, + labels: HashMap::new(), + lifecycle_mode: ProcLifecycleMode::ManagedBySystem, + }, + ) .unwrap(); assert_matches!( @@ -1865,9 +1866,8 @@ mod tests { // Client actor. let mut system = System::new(server_handle.local_addr().clone()); let client_mailbox = system.attach().await.unwrap(); - let (client_supervision_tx, mut client_supervision_rx) = + let (_client_supervision_tx, mut client_supervision_rx) = client_mailbox.open_port::(); - client_supervision_tx.bind_to(ClientMessage::port()); // Bootstrap the controller let controller_id = id!(controller[0].root); diff --git a/hyperactor/src/actor.rs b/hyperactor/src/actor.rs index 90dcdd96d..2b53e2493 100644 --- a/hyperactor/src/actor.rs +++ b/hyperactor/src/actor.rs @@ -622,11 +622,15 @@ impl ActorHandle { /// Send a message to the actor. Messages sent through the handle /// are always queued in process, and do not require serialization. - pub fn send(&self, message: M) -> Result<(), MailboxSenderError> + pub fn send( + &self, + cx: &impl context::Actor, + message: M, + ) -> Result<(), MailboxSenderError> where A: Handler, { - self.ports.get().send(message) + self.ports.get().send(cx, message) } /// Return a port for the provided message type handled by the actor. @@ -712,9 +716,12 @@ pub trait RemoteHandles: Referable {} #[cfg(test)] mod tests { + use std::collections::HashMap; use std::sync::Mutex; use std::time::Duration; + use timed_test::async_timed_test; + use tokio::sync::mpsc; use tokio::time::timeout; use super::*; @@ -724,6 +731,13 @@ mod tests { use crate::PortRef; use crate::checkpoint::CheckpointError; use crate::checkpoint::Checkpointable; + use crate::config; + use crate::id; + use crate::mailbox::BoxedMailboxSender; + use crate::mailbox::MailboxSender; + use crate::mailbox::monitored_return_handle; + use crate::proc::SEQ_INFO; + use crate::proc::SeqInfo; use crate::test_utils::pingpong::PingPongActor; use crate::test_utils::pingpong::PingPongActorParams; use crate::test_utils::pingpong::PingPongMessage; @@ -753,10 +767,10 @@ mod tests { #[tokio::test] async fn test_server_basic() { let proc = Proc::local(); - let client = proc.attach("client").unwrap(); + let (client, _) = proc.instance("client").unwrap(); let (tx, mut rx) = client.open_port(); let handle = proc.spawn::("echo", tx.bind()).await.unwrap(); - handle.send(123u64).unwrap(); + handle.send(&client, 123u64).unwrap(); handle.drain_and_stop().unwrap(); handle.await; @@ -766,7 +780,7 @@ mod tests { #[tokio::test] async fn test_ping_pong() { let proc = Proc::local(); - let client = proc.attach("client").unwrap(); + let (client, _) = proc.instance("client").unwrap(); let (undeliverable_msg_tx, _) = client.open_port(); let ping_pong_actor_params = @@ -783,7 +797,10 @@ mod tests { let (local_port, local_receiver) = client.open_once_port(); ping_handle - .send(PingPongMessage(10, pong_handle.bind(), local_port.bind())) + .send( + &client, + PingPongMessage(10, pong_handle.bind(), local_port.bind()), + ) .unwrap(); assert!(local_receiver.recv().await.unwrap()); @@ -792,7 +809,7 @@ mod tests { #[tokio::test] async fn test_ping_pong_on_handler_error() { let proc = Proc::local(); - let client = proc.attach("client").unwrap(); + let (client, _) = proc.instance("client").unwrap(); let (undeliverable_msg_tx, _) = client.open_port(); // Need to set a supervison coordinator for this Proc because there will @@ -814,11 +831,14 @@ mod tests { let (local_port, local_receiver) = client.open_once_port(); ping_handle - .send(PingPongMessage( - error_ttl + 1, // will encounter an error at TTL=66 - pong_handle.bind(), - local_port.bind(), - )) + .send( + &client, + PingPongMessage( + error_ttl + 1, // will encounter an error at TTL=66 + pong_handle.bind(), + local_port.bind(), + ), + ) .unwrap(); // TODO: Fix this receiver hanging issue in T200423722. @@ -861,10 +881,10 @@ mod tests { async fn test_init() { let proc = Proc::local(); let handle = proc.spawn::("init", ()).await.unwrap(); - let client = proc.attach("client").unwrap(); + let (client, _) = proc.instance("client").unwrap(); let (port, receiver) = client.open_once_port(); - handle.send(port).unwrap(); + handle.send(&client, port).unwrap(); assert!(receiver.recv().await.unwrap()); handle.drain_and_stop().unwrap(); @@ -946,12 +966,12 @@ mod tests { M: RemoteMessage, MultiActor: Handler, { - self.handle.send(message).unwrap() + self.handle.send(&self.client, message).unwrap() } async fn sync(&self) { let (port, done) = self.client.open_once_port::(); - self.handle.send(port).unwrap(); + self.handle.send(&self.client, port).unwrap(); assert!(done.recv().await.unwrap()); } @@ -1062,4 +1082,280 @@ mod tests { handle.drain_and_stop().unwrap(); handle.await; } + + // Returning the sequence number assigned to the message. + #[derive(Debug)] + #[hyperactor::export(handlers = [String, Callback])] + struct GetSeqActor(PortRef<(String, SeqInfo)>); + + #[async_trait] + impl Actor for GetSeqActor { + type Params = PortRef<(String, SeqInfo)>; + + async fn new(params: PortRef<(String, SeqInfo)>) -> Result { + Ok(Self(params)) + } + } + + #[async_trait] + impl Handler for GetSeqActor { + async fn handle( + &mut self, + cx: &Context, + message: String, + ) -> Result<(), anyhow::Error> { + let Self(port) = self; + let seq_info = cx.headers().get(SEQ_INFO).unwrap(); + port.send(cx, (message, seq_info.clone()))?; + Ok(()) + } + } + + // Unlike Handler, where the sender provides the string message + // directly, in Hanlder, sender needs to provide a port, and + // handler will reply that port with its own callback port. Then sender can + // send the string message through thsi callback port. + #[derive(Clone, Debug, Serialize, Deserialize, Named)] + struct Callback(PortRef>); + + #[async_trait] + impl Handler for GetSeqActor { + async fn handle( + &mut self, + cx: &Context, + message: Callback, + ) -> Result<(), anyhow::Error> { + let (handle, mut receiver) = cx.open_port::(); + let callback_ref = handle.bind(); + message.0.send(cx, callback_ref).unwrap(); + let msg = receiver.recv().await.unwrap(); + self.handle(cx, msg).await + } + } + + #[async_timed_test(timeout_secs = 30)] + async fn test_sequencing_actor_handle_basic() { + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); + let (tx, mut rx) = client.open_port(); + + let actor_handle = proc + .spawn::("get_seq", tx.bind()) + .await + .unwrap(); + + // Verify that unbound handle can send message. + actor_handle.send(&client, "unbound".to_string()).unwrap(); + assert_eq!( + rx.recv().await.unwrap(), + ("unbound".to_string(), SeqInfo::Unordered) + ); + + let actor_ref: ActorRef = actor_handle.bind(); + + let session_id = client.sequencer().session_id(); + let mut expected_seq = 0; + // Interleave messages sent through the handle and the reference. + for m in 0..10 { + actor_handle.send(&client, format!("{m}")).unwrap(); + expected_seq += 1; + assert_eq!( + rx.recv().await.unwrap(), + ( + format!("{m}"), + SeqInfo::Session { + session_id, + seq: expected_seq, + } + ) + ); + + for n in 0..2 { + actor_ref.port().send(&client, format!("{m}-{n}")).unwrap(); + expected_seq += 1; + assert_eq!( + rx.recv().await.unwrap(), + ( + format!("{m}-{n}"), + SeqInfo::Session { + session_id, + seq: expected_seq, + } + ) + ); + } + } + } + + // Verify that we can pass port refs between sender and destination actors + // back and forward, and send messages through them without being deadlocked. + #[async_timed_test(timeout_secs = 30)] + async fn test_sequencing_actor_handle_callback() { + let config = config::global::lock(); + let _guard = config.override_key(config::ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); + let (tx, mut rx) = client.open_port(); + + let actor_handle = proc + .spawn::("get_seq", tx.bind()) + .await + .unwrap(); + let actor_ref: ActorRef = actor_handle.bind(); + + let (callback_tx, mut callback_rx) = client.open_port(); + actor_ref + .send(&client, Callback(callback_tx.bind())) + .unwrap(); + let msg_port_ref = callback_rx.recv().await.unwrap(); + msg_port_ref.send(&client, "finally".to_string()).unwrap(); + + let session_id = client.sequencer().session_id(); + assert_eq!( + rx.recv().await.unwrap(), + ( + "finally".to_string(), + SeqInfo::Session { session_id, seq: 1 } + ) + ); + } + + // Adding a delay before sending the destination proc. Useful for tests + // requiring latency injection. + #[derive(Debug)] + struct DelayedMailboxSender { + relay_tx: mpsc::UnboundedSender, + } + + impl DelayedMailboxSender { + // Use a random latency between 0 and 1 second if the plan is empty. + fn boxed(dest_proc: Proc, latency_plan: HashMap) -> BoxedMailboxSender { + let (relay_tx, mut relay_rx) = mpsc::unbounded_channel(); + tokio::spawn(async move { + let mut count = 0; + while let Some(envelope) = relay_rx.recv().await { + count += 1; + + let latency = if latency_plan.is_empty() { + Duration::from_millis(1000) + } else { + latency_plan.get(&count).unwrap().clone() + }; + + let dest_proc_clone = dest_proc.clone(); + tokio::spawn(async move { + // Need Clock::sleep is an async function. + #[allow(clippy::disallowed_methods)] + tokio::time::sleep(latency).await; + dest_proc_clone.post(envelope, monitored_return_handle()); + }); + } + }); + + BoxedMailboxSender::new(Self { relay_tx }) + } + } + + impl MailboxSender for DelayedMailboxSender { + fn post_unchecked( + &self, + envelope: MessageEnvelope, + _return_handle: PortHandle>, + ) { + self.relay_tx.send(envelope).unwrap(); + } + } + + async fn assert_out_of_order_delivery( + expected: Vec<(String, u64)>, + latency_plan: HashMap, + ) { + let local_proc: Proc = Proc::local(); + let (client, _) = local_proc.instance("local").unwrap(); + let (tx, mut rx) = client.open_port(); + + let handle = local_proc + .spawn::("get_seq", tx.bind()) + .await + .unwrap(); + + let actor_ref: ActorRef = handle.bind(); + + let remote_proc = Proc::new( + id!(remote[0]), + DelayedMailboxSender::boxed(local_proc.clone(), latency_plan), + ); + let (remote_client, _) = remote_proc.instance("remote").unwrap(); + // Send the messages out in the order of their expected sequence numbers. + let mut messages = expected.clone(); + messages.sort_by_key(|v| v.1); + for (message, _seq) in messages { + actor_ref.send(&remote_client, message).unwrap(); + } + let session_id = remote_client.sequencer().session_id(); + for expect in expected { + let expected = ( + expect.0, + SeqInfo::Session { + session_id, + seq: expect.1, + }, + ); + assert_eq!(rx.recv().await.unwrap(), expected); + } + + handle.drain_and_stop().unwrap(); + handle.await; + } + + // Send several messages, use DelayedMailboxSender and the latency plan to + // ensure these messages will arrive at handler's workq in a determinstic + // out-of-order way. Then verify the actor handler will still process these + // messages based on their sending order if reordering buffer is enabled. + #[async_timed_test(timeout_secs = 30)] + async fn test_sequencing_actor_ref_out_of_order_deterministic() { + let config = config::global::lock(); + + let latency_plan = maplit::hashmap! { + 1 => Duration::from_millis(1000), + 2 => Duration::from_millis(0), + }; + + // By disabling the actor side re-ordering buffer, the mssages will + // be processed in the same order as they sent out. + let _guard = config.override_key(config::ENABLE_DEST_ACTOR_REORDERING_BUFFER, false); + assert_out_of_order_delivery( + vec![("second".to_string(), 2), ("first".to_string(), 1)], + latency_plan.clone(), + ) + .await; + + // By enabling the actor side re-ordering buffer, the mssages will + // be re-ordered before being processed. + let _guard = config.override_key(config::ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + assert_out_of_order_delivery( + vec![("first".to_string(), 1), ("second".to_string(), 2)], + latency_plan.clone(), + ) + .await; + } + + // Send a large nubmer of messages, use DelayedMailboxSender to ensure these + // messages will arrive at handler's workq in a random order. Then verify the + // actor handler will still process these messages based on their sending + // order with reordering buffer enabled. + #[async_timed_test(timeout_secs = 30)] + async fn test_sequencing_actor_ref_out_of_order_random() { + let config = config::global::lock(); + + // By enabling the actor side re-ordering buffer, the mssages will + // be re-ordered before being processed. + let _guard = config.override_key(config::ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + let expected = (1..10000) + .map(|i| (format!("msg{i}"), i)) + .collect::>(); + + assert_out_of_order_delivery(expected, HashMap::new()).await; + } } diff --git a/hyperactor/src/clock.rs b/hyperactor/src/clock.rs index 56746629a..f7b72ba96 100644 --- a/hyperactor/src/clock.rs +++ b/hyperactor/src/clock.rs @@ -22,7 +22,6 @@ use serde::Serialize; use crate::Mailbox; use crate::channel::ChannelAddr; -use crate::data::Named; use crate::id; use crate::mailbox::DeliveryError; use crate::mailbox::MailboxSender; @@ -261,9 +260,8 @@ impl SimClock { static SIMCLOCK_MAILBOX: OnceLock = OnceLock::new(); SIMCLOCK_MAILBOX.get_or_init(|| { let mailbox = Mailbox::new_detached(id!(proc[0].proc).clone()); - let (undeliverable_messages, mut rx) = - mailbox.open_port::>(); - undeliverable_messages.bind_to(Undeliverable::::port()); + let (_undeliverable_messages, mut rx) = + mailbox.bind_actor_port::>(); tokio::spawn(async move { while let Ok(Undeliverable(mut envelope)) = rx.recv().await { envelope.set_error(DeliveryError::BrokenLink( diff --git a/hyperactor/src/config.rs b/hyperactor/src/config.rs index 0214da7fa..44b7d6aff 100644 --- a/hyperactor/src/config.rs +++ b/hyperactor/src/config.rs @@ -110,7 +110,12 @@ declare_attrs! { pub attr MESSAGE_LATENCY_SAMPLING_RATE: f32 = 0.01; /// Whether to enable client sequence assignment. - pub attr ENABLE_CLIENT_SEQ_ASSIGNMENT: bool = false; + @meta(CONFIG_ENV_VAR = "HYPERACTOR_ENABLE_DEST_ACTOR_REORDERING_BUFFER".to_string()) + pub attr ENABLE_DEST_ACTOR_REORDERING_BUFFER: bool = false; + + /// Whether to use native v1 casting in v1 ActorMesh. + @meta(CONFIG_ENV_VAR = "HYPERACTOR_ENABLE_NATIVE_V1_CASTING".to_string()) + pub attr ENABLE_NATIVE_V1_CASTING: bool = false; /// Timeout for [`Host::spawn`] to await proc readiness. /// diff --git a/hyperactor/src/data.rs b/hyperactor/src/data.rs index 24b49a058..3fc161987 100644 --- a/hyperactor/src/data.rs +++ b/hyperactor/src/data.rs @@ -23,6 +23,9 @@ use serde::de::DeserializeOwned; use crate as hyperactor; use crate::config; +/// Actor handler port should have its most significant bit set to 1. +pub(crate) static ACTOR_PORT_BIT: u64 = 1 << 63; + /// A [`Named`] type is a type that has a globally unique name. pub trait Named: Sized + 'static { /// The globally unique type name for the type. @@ -46,7 +49,7 @@ pub trait Named: Sized + 'static { /// The globally unique port for this type. Typed ports are in the range /// of 1<<63..1<<64-1. fn port() -> u64 { - Self::typehash() | (1 << 63) + Self::typehash() | ACTOR_PORT_BIT } /// If the named type is an enum, this returns the name of the arm diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index c2756401b..57efc13dd 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -17,14 +17,17 @@ //! //! ``` //! # use hyperactor::mailbox::Mailbox; +//! # use hyperactor::Proc; //! # use hyperactor::reference::{ActorId, ProcId, WorldId}; //! # tokio_test::block_on(async { +//! # let proc = Proc::local(); +//! # let (client, _) = proc.instance("client").unwrap(); //! # let proc_id = ProcId::Ranked(WorldId("world".to_string()), 0); //! # let actor_id = ActorId(proc_id, "actor".to_string(), 0); //! let mbox = Mailbox::new_detached(actor_id); //! let (port, mut receiver) = mbox.open_port::(); //! -//! port.send(123).unwrap(); +//! port.send(&client, 123).unwrap(); //! assert_eq!(receiver.recv().await.unwrap(), 123u64); //! # }) //! ``` @@ -63,8 +66,6 @@ //! implementation to avoid a serialization roundtrip when passing //! messages locally. -#![allow(dead_code)] // Allow until this is used outside of tests. - use std::any::Any; use std::collections::BTreeMap; use std::collections::BTreeSet; @@ -118,6 +119,8 @@ use crate::context; use crate::data::Serialized; use crate::id; use crate::metrics; +use crate::proc::SEQ_INFO; +use crate::proc::SeqInfo; use crate::reference::ActorId; use crate::reference::PortId; use crate::reference::Reference; @@ -941,43 +944,6 @@ impl Future for MailboxServerHandle { } } -// A `MailboxServer` (such as a router) can receive a message -// that couldn't reach its destination. We can use the fact that -// servers are `MailboxSender`s to attempt to forward them back to -// their senders. -fn server_return_handle(server: T) -> PortHandle> { - let (return_handle, mut rx) = undeliverable::new_undeliverable_port(); - - tokio::task::spawn(async move { - while let Ok(Undeliverable(mut envelope)) = rx.recv().await { - if let Ok(Undeliverable(e)) = envelope.deserialized::>() - { - // A non-returnable undeliverable. - UndeliverableMailboxSender.post(e, monitored_return_handle()); - continue; - } - envelope.set_error(DeliveryError::BrokenLink( - "message was undeliverable".to_owned(), - )); - server.post( - MessageEnvelope::new( - envelope.sender().clone(), - PortRef::>::attest_message_port( - envelope.sender(), - ) - .port_id() - .clone(), - Serialized::serialize(&Undeliverable(envelope)).unwrap(), - Attrs::new(), - ), - monitored_return_handle(), - ); - } - }); - - return_handle -} - /// Serve a port on the provided [`channel::Rx`]. This dispatches all /// channel messages directly to the port. pub trait MailboxServer: MailboxSender + Clone + Sized + 'static { @@ -1006,6 +972,9 @@ pub trait MailboxServer: MailboxSender + Clone + Sized + 'static { envelope.set_error(DeliveryError::BrokenLink( "message was undeliverable".to_owned(), )); + let mut headers = Attrs::new(); + // Ordering is not required when returning Undeliverable. + headers.set(SEQ_INFO, SeqInfo::Unordered); server.post( MessageEnvelope::new( envelope.sender().clone(), @@ -1015,7 +984,7 @@ pub trait MailboxServer: MailboxSender + Clone + Sized + 'static { .port_id() .clone(), Serialized::serialize(&Undeliverable(envelope)).unwrap(), - Attrs::new(), + headers, ), monitored_return_handle(), ); @@ -1088,7 +1057,9 @@ impl MailboxClient { tokio::spawn(async move { let result = return_receiver.await; if let Ok(message) = result { - let _ = return_handle_0.send(Undeliverable(message)); + // When returning messages, we do not care whether the messages are delivered + // out of order. + let _ = return_handle_0.anon_send(Undeliverable(message)); } else { // Sender dropped, this task can end. } @@ -1244,6 +1215,18 @@ impl Mailbox { ) } + /// Bind this message's actor port to this actor's mailbox. This method is + /// normally used: + /// 1. when we need to intercept a message sent to a handler, and re-route + /// that message to the returned receiver; + /// 2. mock this message's handler when it is not implemented for this actor + /// type, with the returned receiver. + pub(crate) fn bind_actor_port(&self) -> (PortHandle, PortReceiver) { + let (handle, receiver) = self.open_port(); + handle.bind_actor_port(); + (handle, receiver) + } + /// Open a new port with an accumulator. This port accepts A::Update type /// messages, accumulate them into A::State with the given accumulator. /// The latest changed state can be received from the returned receiver as @@ -1372,13 +1355,14 @@ impl Mailbox { PortRef::attest(port_id) } - fn bind_to(&self, handle: &PortHandle, port_index: u64) { + fn bind_to_actor_port(&self, handle: &PortHandle) { assert_eq!( handle.mailbox.actor_id(), self.actor_id(), "port does not belong to mailbox" ); + let port_index = M::port(); let port_id = self.actor_id().port_id(port_index); match self.inner.ports.entry(port_index) { Entry::Vacant(entry) => { @@ -1557,7 +1541,7 @@ impl PortHandle { } } - fn location(&self) -> PortLocation { + pub(crate) fn location(&self) -> PortLocation { match self.bound.get() { Some(port_id) => PortLocation::Bound(port_id.clone()), None => PortLocation::new_unbound::(self.mailbox.actor_id().clone()), @@ -1565,11 +1549,52 @@ impl PortHandle { } /// Send a message to this port. - pub fn send(&self, message: M) -> Result<(), MailboxSenderError> { + pub fn send(&self, cx: &impl context::Actor, message: M) -> Result<(), MailboxSenderError> { let mut headers = Attrs::new(); crate::mailbox::headers::set_send_timestamp(&mut headers); + match self.bound.get() { + Some(bound_port) => { + // Message sent from handle is delivered immediately. It could + // race with messages from refs. So we need to assign seq to + // preserve the ordering. + if bound_port.is_actor_port() { + let sequencer = cx.instance().sequencer(); + let seq = sequencer.assign_seq(self.mailbox.actor_id()); + let seq_info = SeqInfo::Session { + session_id: sequencer.session_id(), + seq, + }; + headers.set(SEQ_INFO, seq_info); + } + } + None => { + // we do not have info to know whether this handle is used for + // enqueue port or not. Since enqueue port requires the SEQ_INFO + // header, we set it in for all messages sent from unbound handles. + headers.set(SEQ_INFO, SeqInfo::Unordered); + } + } + + // Encountering error means the port is closed. So we do not need to + // rollback the seq, because no message can be delivered to it, and + // subsequently do not need to worry about out-of-sequence for messages + // after this seq. + self.sender.send(headers, message).map_err(|err| { + MailboxSenderError::new_unbound::( + self.mailbox.actor_id().clone(), + MailboxSenderErrorKind::Other(err), + ) + }) + } + + /// Send a message to this port without a known sender. This method should + /// only be used if you do not care about out-of-ordering delivery. + pub fn anon_send(&self, message: M) -> Result<(), MailboxSenderError> { + let mut headers = Attrs::new(); + crate::mailbox::headers::set_send_timestamp(&mut headers); + headers.set(SEQ_INFO, SeqInfo::Unordered); self.sender.send(headers, message).map_err(|err| { MailboxSenderError::new_unbound::( self.mailbox.actor_id().clone(), @@ -1608,10 +1633,23 @@ impl PortHandle { ) } - /// Bind to a specific port index. This is used by [`actor::Binder`] implementations to - /// bind actor refs. This is not intended for general use. - pub fn bind_to(&self, port_index: u64) { - self.mailbox.bind_to(self, port_index); + /// Bind to this message's actor port. This method will panic if the handle + /// is already bound. + /// + /// This is used by [`actor::Binder`] implementations to bind actor refs. + /// This is not intended for general use. + pub(crate) fn bind_actor_port(&self) { + let port_id = self.mailbox.actor_id().port_id(M::port()); + self.bound + .set(port_id) + .map_err(|p| { + format!( + "could not bind port handle {} as {p}: already bound", + self.port_index + ) + }) + .unwrap(); + self.mailbox.bind_to_actor_port(self); } } @@ -2584,6 +2622,7 @@ mod tests { use crate::channel::sim::SimAddr; use crate::clock::Clock; use crate::clock::RealClock; + use crate::context::Mailbox as _; use crate::data::Serialized; use crate::id; use crate::proc::Proc; @@ -2627,32 +2666,33 @@ mod tests { #[tokio::test] async fn test_mailbox_accum() { - let mbox = Mailbox::new_detached(id!(test[0].test)); - let (port, mut receiver) = mbox.open_accum_port(accum::max::()); + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); + let (port, mut receiver) = client.mailbox().open_accum_port(accum::max::()); for i in -3..4 { - port.send(i).unwrap(); + port.send(&client, i).unwrap(); let received: accum::Max = receiver.recv().await.unwrap(); let msg = received.get(); assert_eq!(msg, &i); } // Send a smaller or same value. Should still receive the previous max. for i in -3..4 { - port.send(i).unwrap(); + port.send(&client, i).unwrap(); assert_eq!(receiver.recv().await.unwrap().get(), &3); } // send a larger value. Should receive the new max. - port.send(4).unwrap(); + port.send(&client, 4).unwrap(); assert_eq!(receiver.recv().await.unwrap().get(), &4); // Send multiple updates. Should only receive the final change. for i in 5..10 { - port.send(i).unwrap(); + port.send(&client, i).unwrap(); } assert_eq!(receiver.recv().await.unwrap().get(), &9); - port.send(1).unwrap(); - port.send(3).unwrap(); - port.send(2).unwrap(); + port.send(&client, 1).unwrap(); + port.send(&client, 3).unwrap(); + port.send(&client, 2).unwrap(); assert_eq!(receiver.recv().await.unwrap().get(), &9); } @@ -2680,9 +2720,10 @@ mod tests { #[tokio::test] #[ignore] // error behavior changed, but we will bring it back async fn test_mailbox_once() { - let mbox = Mailbox::new_detached(id!(test[0].test)); + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); - let (port, receiver) = mbox.open_once_port::(); + let (port, receiver) = client.open_once_port::(); // let port_id = port.port_id().clone(); @@ -2951,19 +2992,20 @@ mod tests { #[tokio::test] async fn test_enqueue_port() { - let mbox = Mailbox::new_detached(id!(test[0].test)); + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); let count = Arc::new(AtomicUsize::new(0)); let count_clone = count.clone(); - let port = mbox.open_enqueue_port(move |_, n| { + let port = client.mailbox().open_enqueue_port(move |_, n| { count_clone.fetch_add(n, Ordering::SeqCst); Ok(()) }); - port.send(10).unwrap(); - port.send(5).unwrap(); - port.send(1).unwrap(); - port.send(0).unwrap(); + port.send(&client, 10).unwrap(); + port.send(&client, 5).unwrap(); + port.send(&client, 1).unwrap(); + port.send(&client, 0).unwrap(); assert_eq!(count.load(Ordering::SeqCst), 16); } @@ -3028,6 +3070,7 @@ mod tests { let proc_id = id!(quux[0]); let mut proc = Proc::new(proc_id.clone(), proc_forwarder); ProcSupervisionCoordinator::set(&proc).await.unwrap(); + let (client, _) = proc.instance("client").unwrap(); let foo = proc.spawn::("foo", ()).await.unwrap(); let return_handle = foo.port::>(); @@ -3037,7 +3080,7 @@ mod tests { Serialized::serialize(&1u64).unwrap(), Attrs::new(), ); - return_handle.send(Undeliverable(message)).unwrap(); + return_handle.send(&client, Undeliverable(message)).unwrap(); RealClock .sleep(tokio::time::Duration::from_millis(100)) @@ -3069,7 +3112,9 @@ mod tests { Serialized::serialize(&1u64).unwrap(), Attrs::new(), ); - return_handle.send(Undeliverable(envelope.clone())).unwrap(); + return_handle + .anon_send(Undeliverable(envelope.clone())) + .unwrap(); // Check we receive the undelivered message. assert!( RealClock @@ -3631,8 +3676,7 @@ mod tests { actor_id.clone(), BoxedMailboxSender::new(AsyncLoopForwarder), ); - let (ret_port, mut ret_rx) = mailbox.open_port::>(); - ret_port.bind_to(Undeliverable::::port()); + let (ret_port, mut ret_rx) = mailbox.bind_actor_port::>(); // Create a destination not owned by this mailbox to force // forwarding. @@ -3679,9 +3723,8 @@ mod tests { actor_id.clone(), BoxedMailboxSender::new(PanickingMailboxSender), ); - let (undeliverable_tx, mut undeliverable_rx) = - mailbox.open_port::>(); - undeliverable_tx.bind_to(Undeliverable::::port()); + let (_undeliverable_tx, mut undeliverable_rx) = + mailbox.bind_actor_port::>(); // Open a local user u64 port. let (user_port, mut user_rx) = mailbox.open_port::(); @@ -3723,13 +3766,51 @@ mod tests { #[tokio::test] async fn test_port_contramap() { - let mbox = Mailbox::new_detached(id!(test[0].test)); - let (handle, mut rx) = mbox.open_port(); + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); + let (handle, mut rx) = client.open_port(); handle .contramap(|m| (1, m)) - .send("hello".to_string()) + .send(&client, "hello".to_string()) .unwrap(); assert_eq!(rx.recv().await.unwrap(), (1, "hello".to_string())); } + + #[test] + #[should_panic(expected = "already bound")] + fn test_bind_port_handle_to_actor_port_twice() { + let mbox = Mailbox::new_detached(id!(test[0].test)); + let (handle, _rx) = mbox.open_port::(); + handle.bind_actor_port(); + handle.bind_actor_port(); + } + + #[test] + fn test_bind_port_handle_to_actor_port() { + let mbox = Mailbox::new_detached(id!(test[0].test)); + let default_port = mbox.actor_id().port_id(String::port()); + let (handle, _rx) = mbox.open_port::(); + // Handle's port index is allocated by mailbox, not the actor port. + assert_ne!(default_port.index(), handle.port_index); + // Bind the handle to the actor port. + handle.bind_actor_port(); + assert_matches!(handle.location(), PortLocation::Bound(port) if port == default_port); + // bind() can still be used, just it will not change handle's state. + handle.bind(); + handle.bind(); + assert_matches!(handle.location(), PortLocation::Bound(port) if port == default_port); + } + + #[test] + #[should_panic(expected = "already bound")] + fn test_bind_port_handle_to_actor_port_when_already_bound() { + let mbox = Mailbox::new_detached(id!(test[0].test)); + let (handle, _rx) = mbox.open_port::(); + // Bound handle to the port allocated by mailbox. + handle.bind(); + assert_matches!(handle.location(), PortLocation::Bound(port) if port.index() == handle.port_index); + // Since handle is already bound, call bind_to() on it will cause panic. + handle.bind_actor_port(); + } } diff --git a/hyperactor/src/mailbox/undeliverable.rs b/hyperactor/src/mailbox/undeliverable.rs index cff25d173..af55e2c1c 100644 --- a/hyperactor/src/mailbox/undeliverable.rs +++ b/hyperactor/src/mailbox/undeliverable.rs @@ -17,6 +17,7 @@ use crate::ActorId; use crate::Message; use crate::Named; use crate::PortId; +use crate::Proc; use crate::actor::ActorStatus; use crate::id; use crate::mailbox::DeliveryError; @@ -94,7 +95,7 @@ pub(crate) fn return_undeliverable( envelope: MessageEnvelope, ) { let envelope_copy = envelope.clone(); - if (return_handle.send(Undeliverable(envelope))).is_err() { + if (return_handle.anon_send(Undeliverable(envelope))).is_err() { UndeliverableMailboxSender.post(envelope_copy, /*unsued*/ return_handle) } } @@ -157,6 +158,9 @@ pub fn supervise_undeliverable_messages_with( F: Fn(&MessageEnvelope) + Send + Sync + 'static, { crate::init::get_runtime().spawn(async move { + // Create a local client for this task. + let proc = Proc::local(); + let (client, _) = proc.instance("undeliverable_supervisor").unwrap(); while let Ok(Undeliverable(mut env)) = rx.recv().await { // Let caller log/trace before we mutate. on_undeliverable(&env); @@ -173,12 +177,15 @@ pub fn supervise_undeliverable_messages_with( let actor_id = env.dest().actor_id().clone(); let headers = env.headers().clone(); - if let Err(e) = sink.send(ActorSupervisionEvent::new( - actor_id, - ActorStatus::Failed(format!("message not delivered: {}", env)), - Some(headers), - None, - )) { + if let Err(e) = sink.send( + &client, + ActorSupervisionEvent::new( + actor_id, + ActorStatus::Failed(format!("message not delivered: {}", env)), + Some(headers), + None, + ), + ) { tracing::warn!( %e, actor=%env.dest().actor_id(), diff --git a/hyperactor/src/message.rs b/hyperactor/src/message.rs index 003d982a8..9f0f4eb31 100644 --- a/hyperactor/src/message.rs +++ b/hyperactor/src/message.rs @@ -249,7 +249,7 @@ impl IndexedErasedUnbound { Ok(()) } }); - port_handle.bind_to(IndexedErasedUnbound::::port()); + port_handle.bind_actor_port(); Ok(()) } } diff --git a/hyperactor/src/ordering.rs b/hyperactor/src/ordering.rs index 851f16c4f..ea3ccefd4 100644 --- a/hyperactor/src/ordering.rs +++ b/hyperactor/src/ordering.rs @@ -164,7 +164,6 @@ impl OrderedSender { } pub(crate) fn direct_send(&self, msg: T) -> Result<(), SendError> { - assert!(!self.enable_buffering); self.tx.send(msg) } } diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index 3990ccc00..e980072ce 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -13,6 +13,7 @@ use std::any::Any; use std::any::TypeId; +use std::any::type_name; use std::collections::HashMap; use std::fmt; use std::future::Future; @@ -389,9 +390,9 @@ impl Proc { .map_err(|existing| anyhow::anyhow!("coordinator port is already set to {existing}")) } - fn handle_supervision_event(&self, event: ActorSupervisionEvent) { + fn handle_supervision_event(&self, cx: &impl context::Actor, event: ActorSupervisionEvent) { let result = match self.state().supervision_coordinator_port.get() { - Some(port) => port.send(event).map_err(anyhow::Error::from), + Some(port) => port.send(cx, event).map_err(anyhow::Error::from), None => Err(anyhow::anyhow!( "coordinator port is not set for proc {}", self.proc_id() @@ -477,8 +478,7 @@ impl Proc { R: Referable + RemoteHandles, { let (instance, _handle) = self.instance(name)?; - let (handle, rx) = instance.open_port::(); - handle.bind_to(M::port()); + let (_handle, rx) = instance.bind_actor_port::(); let actor_ref = ActorRef::attest(instance.self_id().clone()); Ok((instance, actor_ref, rx)) } @@ -947,7 +947,7 @@ impl Instance { let mailbox = Mailbox::new(actor_id.clone(), BoxedMailboxSender::new(proc.downgrade())); let (work_tx, work_rx) = ordered_channel( actor_id.to_string(), - config::global::get(config::ENABLE_CLIENT_SEQ_ASSIGNMENT), + config::global::get(config::ENABLE_DEST_ACTOR_REORDERING_BUFFER), ); let ports: Arc> = Arc::new(Ports::new(mailbox.clone(), work_tx)); proc.state().proc_muxer.bind_mailbox(mailbox.clone()); @@ -1045,7 +1045,9 @@ impl Instance { let clock = self.proc.state().clock.clone(); tokio::spawn(async move { clock.non_advancing_sleep(delay).await; - if let Err(e) = port.send(message) { + // There is only one message from this context, so there is no need + // to worry out-of-order delivery. + if let Err(e) = port.anon_send(message) { // TODO: this is a fire-n-forget thread. We need to // handle errors in a better way. tracing::info!("{}: error sending delayed message: {}", self_id, e); @@ -1114,7 +1116,7 @@ impl Instance { if let Some(parent) = self.cell.maybe_unlink_parent() { if let Some(event) = event { // Parent exists, failure should be propagated to the parent. - parent.send_supervision_event_or_crash(event); + parent.send_supervision_event_or_crash(&self, event); } // TODO: we should get rid of this signal, and use *only* supervision events for // the purpose of conveying lifecycle changes @@ -1133,7 +1135,7 @@ impl Instance { // Note that orphaned actor is unexpected and would only happen if // there is a bug. if let Some(event) = event { - self.proc.handle_supervision_event(event); + self.proc.handle_supervision_event(&self, event); } } self.change_status(actor_status); @@ -1509,6 +1511,17 @@ impl context::Actor for &Context<'_, A> { } } +impl Instance<()> { + /// See [Mailbox::bind_actor_port] for details. + pub fn bind_actor_port(&self) -> (PortHandle, PortReceiver) { + assert!( + self.actor_task_handle().is_none(), + "can only bind actor port on instance with no running actor task" + ); + self.mailbox.bind_actor_port() + } +} + #[derive(Debug)] enum ActorType { Named(&'static TypeInfo), @@ -1657,7 +1670,9 @@ impl InstanceCell { #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ActorError`. pub fn signal(&self, signal: Signal) -> Result<(), ActorError> { if let Some((signal_port, _)) = &self.inner.actor_loop { - signal_port.send(signal).map_err(ActorError::from) + // The owner of InstanceCell would not use PortRef to send signal. + // So we do not need to worry about out-of-order delivery here. + signal_port.anon_send(signal).map_err(ActorError::from) } else { tracing::warn!( "{}: attempted to send signal {} to detached actor", @@ -1676,10 +1691,14 @@ impl InstanceCell { /// Note that "let it crash" is the default behavior when a supervision event /// cannot be delivered upstream. It is the upstream's responsibility to /// detect and handle crashes. - pub fn send_supervision_event_or_crash(&self, event: ActorSupervisionEvent) { + pub fn send_supervision_event_or_crash( + &self, + cx: &impl context::Actor, + event: ActorSupervisionEvent, + ) { match &self.inner.actor_loop { Some((_, supervision_port)) => { - if let Err(err) = supervision_port.send(event) { + if let Err(err) = supervision_port.send(cx, event) { tracing::error!( "{}: failed to send supervision event to actor: {:?}. Crash the process.", self.actor_id(), @@ -1828,17 +1847,25 @@ pub struct Ports { } /// A message's sequencer number infomation. -#[derive(Serialize, Deserialize, Clone, Named, AttrValue)] -pub struct SeqInfo { - /// Message's session ID - pub session_id: Uuid, - /// Message's sequence number in the given session. - pub seq: u64, +#[derive(Debug, Serialize, Deserialize, Clone, Named, AttrValue, PartialEq)] +pub enum SeqInfo { + /// Messages with the same session ID should be delivered in order. + Session { + /// Message's session ID + session_id: Uuid, + /// Message's sequence number in the given session. + seq: u64, + }, + /// This message does not require ordering and thus have no sequence number. + Unordered, } impl fmt::Display for SeqInfo { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}:{}", self.session_id, self.seq) + match self { + Self::Unordered => write!(f, "unordered"), + Self::Session { session_id, seq } => write!(f, "{}:{}", session_id, seq), + } } } @@ -1846,13 +1873,17 @@ impl std::str::FromStr for SeqInfo { type Err = anyhow::Error; fn from_str(s: &str) -> Result { + if s == "unordered" { + return Ok(SeqInfo::Unordered); + } + let parts: Vec<_> = s.split(':').collect(); if parts.len() != 2 { return Err(anyhow::anyhow!("invalid SeqInfo: {}", s)); } let session_id: Uuid = parts[0].parse()?; let seq: u64 = parts[1].parse()?; - Ok(SeqInfo { session_id, seq }) + Ok(SeqInfo::Session { session_id, seq }) } } @@ -1908,18 +1939,37 @@ impl Ports { hyperactor_telemetry::kv_pairs!("actor_id" => actor_id.clone()), ); if workq.enable_buffering { - let SeqInfo { session_id, seq } = - seq_info.expect("SEQ_INFO must be set when buffering is enabled"); - - // TODO: return the message contained in the error instead of dropping them when converting - // to anyhow::Error. In that way, the message can be picked up by mailbox and returned to sender. - workq.send(session_id, seq, work).map_err(|e| match e { - OrderedSenderError::InvalidZeroSeq(_) => { - anyhow::anyhow!("seq must be greater than 0") + match seq_info { + Some(SeqInfo::Session { session_id, seq }) => { + // TODO: return the message contained in the error instead of dropping them when converting + // to anyhow::Error. In that way, the message can be picked up by mailbox and returned to sender. + workq.send(session_id, seq, work).map_err(|e| match e { + OrderedSenderError::InvalidZeroSeq(_) => { + let error_msg = format!( + "in enqueue func for {}, got seq 0 for message type {}", + actor_id, + std::any::type_name::(), + ); + tracing::error!(error_msg); + anyhow::anyhow!(error_msg) + } + OrderedSenderError::SendError(e) => anyhow::Error::from(e), + OrderedSenderError::FlushError(e) => e, + }) } - OrderedSenderError::SendError(e) => anyhow::Error::from(e), - OrderedSenderError::FlushError(e) => e, - }) + Some(SeqInfo::Unordered) => { + workq.direct_send(work).map_err(anyhow::Error::from) + } + None => { + let error_msg = format!( + "in enqueue func for {}, buffering is enabled, but SEQ_INFO is not set for message type {}", + actor_id, + std::any::type_name::(), + ); + tracing::error!(error_msg); + anyhow::bail!(error_msg); + } + } } else { workq.direct_send(work).map_err(anyhow::Error::from) } @@ -1947,24 +1997,15 @@ impl Ports { } } - /// Bind the given message type to its default port. + /// Bind the given message type to its actor port. pub fn bind(&self) where A: Handler, { - self.bind_to::(M::port()); - } - - /// Bind the given message type to the provided port. - /// Ports cannot be rebound to different message types; - /// and attempting to do so will result in a panic. - pub fn bind_to(&self, port_index: u64) - where - A: Handler, - { + let port_index = M::port(); match self.bound.entry(port_index) { Entry::Vacant(entry) => { - self.get::().bind_to(port_index); + self.get::().bind_actor_port(); entry.insert(M::typename()); } Entry::Occupied(entry) => { @@ -2051,9 +2092,12 @@ mod tests { } impl TestActor { - async fn spawn_child(parent: &ActorHandle) -> ActorHandle { + async fn spawn_child( + cx: &impl context::Actor, + parent: &ActorHandle, + ) -> ActorHandle { let (tx, rx) = oneshot::channel(); - parent.send(TestActorMessage::Spawn(tx)).unwrap(); + parent.send(cx, TestActorMessage::Spawn(tx)).unwrap(); rx.await.unwrap() } } @@ -2083,12 +2127,12 @@ mod tests { async fn forward( &mut self, - _cx: &crate::Context, + cx: &crate::Context, destination: ActorHandle, message: Box, ) -> Result<(), anyhow::Error> { // TODO: this needn't be async - destination.send(*message)?; + destination.send(cx, *message)?; Ok(()) } @@ -2127,6 +2171,7 @@ mod tests { #[tokio::test] async fn test_spawn_actor() { let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); let handle = proc.spawn::("test", ()).await.unwrap(); // Check on the join handle. @@ -2144,7 +2189,7 @@ mod tests { // Send a ping-pong to the actor. Wait for the actor to become idle. let (tx, rx) = oneshot::channel::<()>(); - handle.send(TestActorMessage::Reply(tx)).unwrap(); + handle.send(&client, TestActorMessage::Reply(tx)).unwrap(); rx.await.unwrap(); state @@ -2157,7 +2202,7 @@ mod tests { let (exit_tx, exit_rx) = oneshot::channel::<()>(); handle - .send(TestActorMessage::Wait(enter_tx, exit_rx)) + .send(&client, TestActorMessage::Wait(enter_tx, exit_rx)) .unwrap(); enter_rx.await.unwrap(); assert_matches!(*state.borrow(), ActorStatus::Processing(instant, _) if instant <= RealClock.system_time_now()); @@ -2176,12 +2221,16 @@ mod tests { #[tokio::test] async fn test_proc_actors_messaging() { let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); let first = proc.spawn::("first", ()).await.unwrap(); let second = proc.spawn::("second", ()).await.unwrap(); let (tx, rx) = oneshot::channel::<()>(); let reply_message = TestActorMessage::Reply(tx); first - .send(TestActorMessage::Forward(second, Box::new(reply_message))) + .send( + &client, + TestActorMessage::Forward(second, Box::new(reply_message)), + ) .unwrap(); rx.await.unwrap(); } @@ -2270,10 +2319,11 @@ mod tests { #[tokio::test] async fn test_spawn_child() { let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); let first = proc.spawn::("first", ()).await.unwrap(); - let second = TestActor::spawn_child(&first).await; - let third = TestActor::spawn_child(&second).await; + let second = TestActor::spawn_child(&client, &first).await; + let third = TestActor::spawn_child(&client, &second).await; // Check we've got the join handles. assert!(logs_with_scope_contain( @@ -2328,17 +2378,18 @@ mod tests { #[tokio::test] async fn test_child_lifecycle() { let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); let root = proc.spawn::("root", ()).await.unwrap(); - let root_1 = TestActor::spawn_child(&root).await; - let root_2 = TestActor::spawn_child(&root).await; - let root_2_1 = TestActor::spawn_child(&root_2).await; + let root_1 = TestActor::spawn_child(&client, &root).await; + let root_2 = TestActor::spawn_child(&client, &root).await; + let root_2_1 = TestActor::spawn_child(&client, &root_2).await; root.drain_and_stop().unwrap(); root.await; for actor in [root_1, root_2, root_2_1] { - assert!(actor.send(TestActorMessage::Noop()).is_err()); + assert!(actor.send(&client, TestActorMessage::Noop()).is_err()); assert_matches!(actor.await, ActorStatus::Stopped); } } @@ -2346,19 +2397,21 @@ mod tests { #[tokio::test] async fn test_parent_failure() { let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); // Need to set a supervison coordinator for this Proc because there will // be actor failure(s) in this test which trigger supervision. ProcSupervisionCoordinator::set(&proc).await.unwrap(); let root = proc.spawn::("root", ()).await.unwrap(); - let root_1 = TestActor::spawn_child(&root).await; - let root_2 = TestActor::spawn_child(&root).await; - let root_2_1 = TestActor::spawn_child(&root_2).await; + let root_1 = TestActor::spawn_child(&client, &root).await; + let root_2 = TestActor::spawn_child(&client, &root).await; + let root_2_1 = TestActor::spawn_child(&client, &root_2).await; root_2 - .send(TestActorMessage::Fail(anyhow::anyhow!( - "some random failure" - ))) + .send( + &client, + TestActorMessage::Fail(anyhow::anyhow!("some random failure")), + ) .unwrap(); let root_2_actor_id = root_2.actor_id().clone(); assert_matches!( @@ -2387,6 +2440,7 @@ mod tests { } let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); // Add the 1st root. This root will remain active until the end of the test. let root: ActorHandle = proc.spawn::("root", ()).await.unwrap(); @@ -2438,7 +2492,7 @@ mod tests { // root -> root_1 -> root_1_1 // |-> root_2 - let root_1 = TestActor::spawn_child(&root).await; + let root_1 = TestActor::spawn_child(&client, &root).await; wait_until_idle(&root_1).await; { let snapshot = proc.state().ledger.snapshot(); @@ -2465,7 +2519,7 @@ mod tests { ); } - let root_1_1 = TestActor::spawn_child(&root_1).await; + let root_1_1 = TestActor::spawn_child(&client, &root_1).await; wait_until_idle(&root_1_1).await; { let snapshot = proc.state().ledger.snapshot(); @@ -2504,7 +2558,7 @@ mod tests { ); } - let root_2 = TestActor::spawn_child(&root).await; + let root_2 = TestActor::spawn_child(&client, &root).await; wait_until_idle(&root_2).await; { let snapshot = proc.state().ledger.snapshot(); @@ -2644,11 +2698,11 @@ mod tests { .spawn::("test", state.clone()) .await .unwrap(); - let client = proc.attach("client").unwrap(); + let (client, _) = proc.instance("client").unwrap(); let (tx, rx) = client.open_once_port(); - handle.send(tx).unwrap(); + handle.send(&client, tx).unwrap(); let usize_handle = rx.recv().await.unwrap(); - usize_handle.send(123).unwrap(); + usize_handle.send(&client, 123).unwrap(); handle.drain_and_stop().unwrap(); handle.await; @@ -2734,6 +2788,7 @@ mod tests { } let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); let reported_event = ProcSupervisionCoordinator::set(&proc).await.unwrap(); let root_state = Arc::new(AtomicBool::new(false)); @@ -2777,13 +2832,13 @@ mod tests { // fail `root_1_1_1`, the supervision msg should be propagated to // `root_1` because `root_1` has set `true` to `handle_supervision_event`. root_1_1_1 - .send::("some random failure".into()) + .send::(&client, "some random failure".into()) .unwrap(); // fail `root_2_1`, the supervision msg should be propagated to // ProcSupervisionCoordinator. root_2_1 - .send::("some random failure".into()) + .send::(&client, "some random failure".into()) .unwrap(); RealClock.sleep(Duration::from_secs(1)).await; @@ -2857,6 +2912,7 @@ mod tests { } let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); let (event_tx, mut event_rx) = tokio::sync::mpsc::unbounded_channel(); @@ -2875,7 +2931,9 @@ mod tests { // Grandchild fails, triggering failure up the tree, finally receiving // the event at the root. - grandchild.send("trigger failure".to_string()).unwrap(); + grandchild + .send(&client, "trigger failure".to_string()) + .unwrap(); assert!(grandchild.await.is_failed()); assert!(child.await.is_failed()); @@ -2929,7 +2987,7 @@ mod tests { let (port, mut receiver) = instance.open_port(); child_actor - .send(("hello".to_string(), port.bind())) + .send(&instance, ("hello".to_string(), port.bind())) .unwrap(); let message = receiver.recv().await.unwrap(); @@ -2990,9 +3048,9 @@ mod tests { struct LoggingActor; impl LoggingActor { - async fn wait(handle: &ActorHandle) { + async fn wait(cx: &impl context::Actor, handle: &ActorHandle) { let barrier = Arc::new(Barrier::new(2)); - handle.send(barrier.clone()).unwrap(); + handle.send(cx, barrier.clone()).unwrap(); barrier.wait().await; } } @@ -3049,12 +3107,16 @@ mod tests { } trace_and_block(async { - let handle = LoggingActor::spawn_detached(()).await.unwrap(); - handle.send("hello world".to_string()).unwrap(); - handle.send("hello world again".to_string()).unwrap(); - handle.send(123u64).unwrap(); + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); + let handle: ActorHandle = proc.spawn("logging", ()).await.unwrap(); + handle.send(&client, "hello world".to_string()).unwrap(); + handle + .send(&client, "hello world again".to_string()) + .unwrap(); + handle.send(&client, 123u64).unwrap(); - LoggingActor::wait(&handle).await; + LoggingActor::wait(&client, &handle).await; let events = handle.cell().inner.recording.tail(); assert_eq!(events.len(), 3); @@ -3067,7 +3129,7 @@ mod tests { let stacks = { let barriers = Arc::new((Barrier::new(2), Barrier::new(2))); - handle.send(Arc::clone(&barriers)).unwrap(); + handle.send(&client, Arc::clone(&barriers)).unwrap(); barriers.0.wait().await; let stacks = handle.cell().inner.recording.stacks(); barriers.1.wait().await; diff --git a/hyperactor/src/reference.rs b/hyperactor/src/reference.rs index f8e5b5ca8..15b9aeeee 100644 --- a/hyperactor/src/reference.rs +++ b/hyperactor/src/reference.rs @@ -54,6 +54,7 @@ use crate::attrs::Attrs; use crate::channel::ChannelAddr; use crate::context; use crate::context::MailboxExt; +use crate::data::ACTOR_PORT_BIT; use crate::data::Serialized; use crate::data::TypeInfo; use crate::mailbox::MailboxSenderError; @@ -62,6 +63,8 @@ use crate::mailbox::PortSink; use crate::message::Bind; use crate::message::Bindings; use crate::message::Unbind; +use crate::proc::SEQ_INFO; +use crate::proc::SeqInfo; pub mod lex; pub mod name; @@ -877,13 +880,15 @@ impl PortId { self.1 } + pub(crate) fn is_actor_port(&self) -> bool { + self.1 & ACTOR_PORT_BIT != 0 + } + /// Send a serialized message to this port, provided a sending capability, /// such as [`crate::actor::Instance`]. It is the sender's responsibility /// to ensure that the provided message is well-typed. pub fn send(&self, cx: &impl context::Actor, serialized: Serialized) { - let mut headers = Attrs::new(); - crate::mailbox::headers::set_send_timestamp(&mut headers); - cx.post(self.clone(), headers, serialized); + self.send_with_headers(cx, serialized, Attrs::new()); } /// Send a serialized message to this port, provided a sending capability, @@ -894,8 +899,31 @@ impl PortId { cx: &impl context::Actor, serialized: Serialized, mut headers: Attrs, + ) { + self.send_with_headers_with_option(cx, serialized, headers, true); + } + + /// Similar to [`PortId::send_with_headers`], but allows the caller to + /// decide whether to set the sequence info header with this method. + pub fn send_with_headers_with_option( + &self, + cx: &impl context::Actor, + serialized: Serialized, + mut headers: Attrs, + set_seq_info: bool, ) { crate::mailbox::headers::set_send_timestamp(&mut headers); + if set_seq_info && self.is_actor_port() { + // This method is infallible so is okay to assign the sequence number + // without worrying about rollback. + let sequencer = cx.instance().sequencer(); + let seq = sequencer.assign_seq(self.actor_id()); + let seq_info = SeqInfo::Session { + session_id: sequencer.session_id(), + seq, + }; + headers.set(SEQ_INFO, seq_info); + } cx.post(self.clone(), headers, serialized); } @@ -925,8 +953,8 @@ impl FromStr for PortId { impl fmt::Display for PortId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let PortId(actor_id, port) = self; - if port & (1 << 63) != 0 { - let type_info = TypeInfo::get(*port).or_else(|| TypeInfo::get(*port & !(1 << 63))); + if port & ACTOR_PORT_BIT != 0 { + let type_info = TypeInfo::get(*port).or_else(|| TypeInfo::get(*port & !ACTOR_PORT_BIT)); let typename = type_info.map_or("unknown", TypeInfo::typename); write!(f, "{}[{}<{}>]", actor_id, port, typename) } else { @@ -1036,14 +1064,8 @@ impl PortRef { /// Send a serialized message to this port, provided a sending capability, such as /// [`crate::actor::Instance`]. - pub fn send_serialized( - &self, - cx: &impl context::Actor, - mut headers: Attrs, - message: Serialized, - ) { - crate::mailbox::headers::set_send_timestamp(&mut headers); - cx.post(self.port_id.clone(), headers, message); + pub fn send_serialized(&self, cx: &impl context::Actor, headers: Attrs, message: Serialized) { + self.port_id.send_with_headers(cx, message, headers); } /// Convert this port into a sink that can be used to send messages using the given capability. @@ -1350,8 +1372,13 @@ impl<'a, A: Referable> From<&'a GangRef> for &'a GangId { mod tests { use rand::seq::SliceRandom; use rand::thread_rng; + use tokio::sync::mpsc; + use uuid::Uuid; use super::*; + use crate::Proc; + use crate::context::Mailbox as _; + use crate::mailbox::PortLocation; #[test] fn test_reference_parse() { @@ -1512,4 +1539,112 @@ mod tests { "test[234].testactor[1][17867850292987402005]" ); } + + #[tokio::test] + async fn test_sequencing_from_port_handle_ref_and_id() { + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let port_handle = client.mailbox().open_enqueue_port(move |headers, _m: ()| { + let seq_info = headers.get(SEQ_INFO); + tx.send(seq_info.cloned()).unwrap(); + Ok(()) + }); + port_handle.send(&client, ()).unwrap(); + // No seq will be assigned for unbound port handle. + assert!(rx.try_recv().unwrap().is_none()); + + port_handle.bind_actor_port(); + let port_id = match port_handle.location() { + PortLocation::Bound(port_id) => port_id, + _ => panic!("port_handle should be bound"), + }; + assert!(port_id.is_actor_port()); + let port_ref = PortRef::attest(port_id.clone()); + + port_handle.send(&client, ()).unwrap(); + let SeqInfo::Session { + session_id, + mut seq, + } = rx.try_recv().unwrap().unwrap() + else { + panic!("expected session info"); + }; + assert_eq!(session_id, client.sequencer().session_id()); + assert_eq!(seq, 1); + + fn assert_seq_info( + rx: &mut mpsc::UnboundedReceiver>, + session_id: Uuid, + seq: &mut u64, + ) { + *seq += 1; + let SeqInfo::Session { + session_id: rcved_session_id, + seq: rcved_seq, + } = rx.try_recv().unwrap().unwrap() + else { + panic!("expected session info"); + }; + assert_eq!(rcved_session_id, session_id); + assert_eq!(rcved_seq, *seq); + } + + // Interleave sends from port_handle, port_ref, and port_id + for _ in 0..10 { + // From port_handle + port_handle.send(&client, ()).unwrap(); + assert_seq_info(&mut rx, session_id, &mut seq); + + // From port_ref + for _ in 0..2 { + port_ref.send(&client, ()).unwrap(); + assert_seq_info(&mut rx, session_id, &mut seq); + } + + // From port_id + for _ in 0..3 { + port_id.send(&client, Serialized::serialize(&()).unwrap()); + assert_seq_info(&mut rx, session_id, &mut seq); + } + } + + assert_eq!(rx.try_recv().unwrap_err(), mpsc::error::TryRecvError::Empty); + } + + #[tokio::test] + async fn test_sequencing_from_port_handle_bound_to_allocated_port() { + let proc = Proc::local(); + let (client, _) = proc.instance("client").unwrap(); + let (tx, mut rx) = mpsc::unbounded_channel(); + let port_handle = client.mailbox().open_enqueue_port(move |headers, _m: ()| { + let seq_info = headers.get(SEQ_INFO); + tx.send(seq_info.cloned()).unwrap(); + Ok(()) + }); + port_handle.send(&client, ()).unwrap(); + // Unordered be set for unbound port handle since handler's ordered + // channel is expecting the SEQ_INFO header to be set. + assert_eq!(rx.try_recv().unwrap().unwrap(), SeqInfo::Unordered); + + // Bind to the allocated port. + port_handle.bind(); + let port_id = match port_handle.location() { + PortLocation::Bound(port_id) => port_id, + _ => panic!("port_handle should be bound"), + }; + // Since the port is not an actor port, no seq will be assigned no + // matter whether the message is sent from handle, ref or port id. + assert!(!port_id.is_actor_port()); + + port_handle.send(&client, ()).unwrap(); + assert!(rx.try_recv().unwrap().is_none()); + + let port_ref = PortRef::attest(port_id.clone()); + port_ref.send(&client, ()).unwrap(); + assert!(rx.try_recv().unwrap().is_none()); + + port_id.send(&client, Serialized::serialize(&()).unwrap()); + assert!(rx.try_recv().unwrap().is_none()); + } } diff --git a/hyperactor_macros/src/lib.rs b/hyperactor_macros/src/lib.rs index 6e4a909ad..e64f96679 100644 --- a/hyperactor_macros/src/lib.rs +++ b/hyperactor_macros/src/lib.rs @@ -912,11 +912,7 @@ fn derive_client(input: TokenStream, is_handle: bool) -> TokenStream { // The client implementation methods. let mut impl_methods = Vec::new(); - let send_message = if is_handle { - quote! { self.send(message)? } - } else { - quote! { self.send(cx, message)? } - }; + let send_message = quote! { self.send(cx, message)? }; let global_log_level = parse_log_level(&input.attrs).ok().unwrap_or(None); for message in &messages { diff --git a/hyperactor_macros/tests/export.rs b/hyperactor_macros/tests/export.rs index 82c4c5aef..ebaffeda4 100644 --- a/hyperactor_macros/tests/export.rs +++ b/hyperactor_macros/tests/export.rs @@ -199,8 +199,10 @@ mod tests { }; let actor_handle = proc.spawn::("actor", params).await.unwrap(); - actor_handle.send(123u64).unwrap(); - actor_handle.send(TestMessage("foo".to_string())).unwrap(); + actor_handle.send(&client, 123u64).unwrap(); + actor_handle + .send(&client, TestMessage("foo".to_string())) + .unwrap(); let myref: ActorRef = actor_handle.bind(); myref.port().send(&client, MyGeneric(())).unwrap(); diff --git a/hyperactor_mesh/Cargo.toml b/hyperactor_mesh/Cargo.toml index 775fe7636..3c0b097a2 100644 --- a/hyperactor_mesh/Cargo.toml +++ b/hyperactor_mesh/Cargo.toml @@ -81,9 +81,11 @@ tokio-stream = { version = "0.1.17", features = ["fs", "io-util", "net", "signal tokio-util = { version = "0.7.15", features = ["full"] } tracing = { version = "0.1.41", features = ["attributes", "valuable"] } tracing-subscriber = { version = "0.3.20", features = ["chrono", "env-filter", "json", "local-time", "parking_lot", "registry"] } +uuid = { version = "1.2", features = ["serde", "v4", "v5", "v6", "v7", "v8"] } [dev-dependencies] bytes = { version = "1.10", features = ["serde"] } +fastrand = "2.1.1" itertools = "0.14.0" maplit = "1.0" proptest = "1.5" diff --git a/hyperactor_mesh/examples/sieve.rs b/hyperactor_mesh/examples/sieve.rs index 59982d428..d52d4e4a8 100644 --- a/hyperactor_mesh/examples/sieve.rs +++ b/hyperactor_mesh/examples/sieve.rs @@ -77,7 +77,7 @@ impl Handler for SieveActor { if !msg.number.is_multiple_of(self.prime) { match &self.next { Some(next) => { - next.send(msg)?; + next.send(cx, msg)?; } None => { msg.prime_collector.send(cx, msg.number)?; diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 8cbf4dbc7..25e33bf7c 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -1092,17 +1092,17 @@ mod tests { let actor_mesh: RootActorMesh = mesh.spawn("test", &()).await.unwrap(); let actor_ref = actor_mesh.get(0).unwrap(); let mut headers = Attrs::new(); - set_cast_info_on_headers(&mut headers, extent.point_of_rank(0).unwrap(), mesh.client().self_id().clone()); + set_cast_info_on_headers(&mut headers, extent.point_of_rank(0).unwrap(), mesh.client().self_id().clone(), None); actor_ref.send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone())).unwrap(); assert_eq!(0, reply_port_receiver.recv().await.unwrap()); - set_cast_info_on_headers(&mut headers, extent.point_of_rank(1).unwrap(), mesh.client().self_id().clone()); + set_cast_info_on_headers(&mut headers, extent.point_of_rank(1).unwrap(), mesh.client().self_id().clone(), None); actor_ref.port() .send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone())) .unwrap(); assert_eq!(1, reply_port_receiver.recv().await.unwrap()); - set_cast_info_on_headers(&mut headers, extent.point_of_rank(2).unwrap(), mesh.client().self_id().clone()); + set_cast_info_on_headers(&mut headers, extent.point_of_rank(2).unwrap(), mesh.client().self_id().clone(), None); actor_ref.actor_id() .port_id(GetRank::port()) .send_with_headers( @@ -1358,21 +1358,27 @@ mod tests { //#[tracing_test::traced_test] #[async_timed_test(timeout_secs = 30)] async fn test_oversized_frames() { + use hyperactor::context::Mailbox as _; + use hyperactor::mailbox::MailboxSender; + // Reproduced from 'net.rs'. #[derive(Debug, Serialize, Deserialize, PartialEq)] enum Frame { Init(u64), Message(u64, M), } - // Calculate the frame length for the given message. - fn frame_length(src: &ActorId, dst: &PortId, pay: &Payload) -> usize { + // Build a message envelope and frame with empty header. + fn build_message( + src: &ActorId, + dst: &PortId, + pay: &Payload, + ) -> (MessageEnvelope, serde_multipart::Message) { let serialized = Serialized::serialize(pay).unwrap(); - let mut headers = Attrs::new(); - hyperactor::mailbox::headers::set_send_timestamp(&mut headers); + let headers = Attrs::new(); let envelope = MessageEnvelope::new(src.clone(), dst.clone(), serialized, headers); - let frame = Frame::Message(0u64, envelope); + let frame = Frame::Message(0u64, envelope.clone()); let message = serde_multipart::serialize_illegal_bincode(&frame).unwrap(); - message.frame_len() + (envelope, message) } // This process: short delivery timeout. @@ -1407,19 +1413,23 @@ mod tests { // Message sized to exactly max frame length. let payload = Payload { - part: Part::from(Bytes::from(vec![0u8; 698])), + part: Part::from(Bytes::from(vec![0u8; 762])), reply_port: reply_handle.bind(), }; - let frame_len = frame_length( + let (envelope, message) = build_message( proc_mesh.client().self_id(), dest.port::().port_id(), &payload, ); - assert_eq!(frame_len, 1024); - - // Send direct. A cast message is > 1024 bytes. - dest.send(proc_mesh.client(), payload).unwrap(); - #[allow(clippy::disallowed_methods)] + assert_eq!(message.frame_len(), 1024); + + // Send direct with envelope, so no extra header will be added to + // increase the frame size. + MailboxSender::post( + proc_mesh.client().mailbox(), + envelope, + hyperactor::mailbox::monitored_return_handle(), + ); let result = RealClock .timeout(Duration::from_secs(2), reply_receiver.recv()) .await; @@ -1427,18 +1437,19 @@ mod tests { // Message sized to max frame length + 1. let payload = Payload { - part: Part::from(Bytes::from(vec![0u8; 699])), + part: Part::from(Bytes::from(vec![0u8; 763])), reply_port: reply_handle.bind(), }; - let frame_len = frame_length( + let (_envelope, message) = build_message( proc_mesh.client().self_id(), dest.port::().port_id(), &payload, ); - assert_eq!(frame_len, 1025); // over the max frame len + assert_eq!(message.frame_len(), 1025); // over the max frame len - // Send direct or cast. Either are guaranteed over the - // limit and will fail. + // Send direct or cast. Either are guaranteed over the limit and + // will fail. The actual frame size is bigger than 1025 since extra + // headers will be added. if rand::thread_rng().gen_bool(0.5) { dest.send(proc_mesh.client(), payload).unwrap(); } else { diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index c6231d1a9..0b25b8628 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -3322,7 +3322,7 @@ mod tests { actor_mesh .cast(&instance, testactor::GetActorId(port.bind())) .unwrap(); - let got_id = rx.recv().await.unwrap(); + let (got_id, _seq) = rx.recv().await.unwrap(); assert_eq!( got_id, actor_mesh.values().next().unwrap().actor_id().clone() diff --git a/hyperactor_mesh/src/comm.rs b/hyperactor_mesh/src/comm.rs index 57cc2369f..197d842be 100644 --- a/hyperactor_mesh/src/comm.rs +++ b/hyperactor_mesh/src/comm.rs @@ -7,6 +7,8 @@ */ use crate::comm::multicast::CAST_ORIGINATING_SENDER; +use crate::comm::multicast::CastMessageV1; +use crate::comm::multicast::ForwardMessageV1; use crate::reference::ActorMeshId; use crate::resource; pub mod multicast; @@ -33,7 +35,13 @@ use hyperactor::mailbox::Undeliverable; use hyperactor::mailbox::UndeliverableMailboxSender; use hyperactor::mailbox::UndeliverableMessageError; use hyperactor::mailbox::monitored_return_handle; +use hyperactor::message::ErasedUnbound; +use hyperactor::proc::SeqInfo; use hyperactor::reference::UnboundPort; +use hyperactor_mesh_macros::sel; +use ndslice::Point; +use ndslice::Selection; +use ndslice::View; use ndslice::selection::routing::RoutingFrame; use serde::Deserialize; use serde::Serialize; @@ -82,6 +90,8 @@ struct ReceiveState { CommActorMode, CastMessage, ForwardMessage, + CastMessageV1, + ForwardMessageV1, ], )] pub struct CommActor { @@ -261,47 +271,29 @@ impl CommActor { seq: usize, last_seqs: &mut HashMap, ) -> Result<()> { - // Split ports, if any, and update message with new ports. In this - // way, children actors will reply to this comm actor's ports, instead - // of to the original ports provided by parent. - message.data_mut().visit_mut::( - |UnboundPort(port_id, reducer_spec, reducer_opts)| { - let split = port_id.split(cx, reducer_spec.clone(), reducer_opts.clone())?; - - #[cfg(test)] - tests::collect_split_port(port_id, &split, deliver_here); - - *port_id = split; - Ok(()) - }, - )?; + split_ports(cx, message.data_mut(), deliver_here)?; // Deliver message here, if necessary. if deliver_here { let rank_on_root_mesh = mode.self_rank(cx.self_id())?; let cast_rank = message.relative_rank(rank_on_root_mesh)?; - // Replace ranks with self ranks. - message - .data_mut() - .visit_mut::(|resource::Rank(rank)| { - *rank = Some(cast_rank); - Ok(()) - })?; let cast_shape = message.shape(); - let point = cast_shape + let cast_point = cast_shape .extent() .point_of_rank(cast_rank) .expect("rank out of bounds"); + + // Replace ranks with self ranks. + replace_with_self_ranks(&cast_point, message.data_mut())?; + let mut headers = cx.headers().clone(); - set_cast_info_on_headers(&mut headers, point, message.sender().clone()); - cx.post( - cx.self_id() - .proc_id() - .actor_id(message.dest_port().actor_name(), 0) - .port_id(message.dest_port().port()), - headers, - Serialized::serialize(message.data())?, - ); + set_cast_info_on_headers(&mut headers, cast_point, message.sender().clone(), None); + let dest_port_id = cx + .self_id() + .proc_id() + .actor_id(message.dest_port().actor_name(), 0) + .port_id(message.dest_port().port()); + dest_port_id.send_with_headers(cx, Serialized::serialize(message.data())?, headers); } // Forward to peers. @@ -456,6 +448,113 @@ impl Handler for CommActor { } } +// Split ports, if any, and update message with new ports. In this +// way, child actors will reply to this comm actor's ports, instead +// of to the original ports provided by parent. +fn split_ports( + cx: &Context, + data: &mut ErasedUnbound, + _deliver_here: bool, +) -> anyhow::Result<()> { + data.visit_mut::(|UnboundPort(port_id, reducer_spec, reducer_opts)| { + let split = port_id.split(cx, reducer_spec.clone(), reducer_opts.clone())?; + + #[cfg(test)] + tests::collect_split_port(port_id, &split, _deliver_here); + + *port_id = split; + Ok(()) + }) +} + +fn replace_with_self_ranks(cast_point: &Point, data: &mut ErasedUnbound) -> anyhow::Result<()> { + data.visit_mut::(|resource::Rank(rank)| { + *rank = Some(cast_point.rank()); + Ok(()) + }) +} + +#[async_trait] +impl Handler for CommActor { + async fn handle(&mut self, cx: &Context, cast_message: CastMessageV1) -> Result<()> { + let slice = cast_message.dest_region.slice().clone(); + let frame = RoutingFrame::root(sel!(*), slice); + let forward_message = ForwardMessageV1 { + dests: vec![frame], + message: cast_message, + }; + Handler::::handle(self, cx, forward_message).await + } +} + +#[async_trait] +impl Handler for CommActor { + async fn handle(&mut self, cx: &Context, fwd_message: ForwardMessageV1) -> Result<()> { + let ForwardMessageV1 { dests, mut message } = fwd_message; + // Resolve/dedup routing frames. + let rank_on_root_mesh = self.mode.self_rank(cx.self_id())?; + let (deliver_here, next_steps) = + ndslice::selection::routing::resolve_routing(rank_on_root_mesh, dests, &mut |_| { + panic!("Choice encountered in CommActor routing") + })?; + + split_ports(cx, &mut message.data, deliver_here)?; + + // Deliver message here, if necessary. + if deliver_here { + let cast_point = message.dest_region.point_of_base_rank(rank_on_root_mesh)?; + // Replace ranks with self ranks. + replace_with_self_ranks(&cast_point, &mut message.data)?; + + let seq = message + .cast_headers + .seqs + .get(cast_point.rank()) + .expect("mismatched seqs and dest_region"); + // headers should already contain a SEQ_INFO, which was set by the + // last comm actor for forwarding the message. We overwrite it here + // with the SEQ_INFO from original sender put in the CastMessageV1. + // In this way, the destination actor only sees the SEQ_INFO from + // the original sender. + let mut headers = cx.headers().clone(); + set_cast_info_on_headers( + &mut headers, + cast_point, + message.cast_headers.sender.clone(), + Some(SeqInfo::Session { + session_id: message.cast_headers.session_id, + seq, + }), + ); + let dest_port_id = cx + .self_id() + .proc_id() + .actor_id(message.dest_port.actor_name(), 0) + .port_id(message.dest_port.port()); + dest_port_id.send_with_headers_with_option( + cx, + Serialized::serialize(&message.data)?, + headers, + /*set_seq_info=*/ false, + ); + } + + // Forward to peers. + for (peer_rank_on_root_mesh, dests) in next_steps { + let forward_message = ForwardMessageV1 { + dests, + message: message.clone(), + }; + let child = self + .mode + .peer_for_rank(cx.self_id(), peer_rank_on_root_mesh)?; + child.send(cx, forward_message)?; + } + + Ok(()) + } +} + pub mod test_utils { use anyhow::Result; use async_trait::async_trait; @@ -552,6 +651,9 @@ mod tests { use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::config; + use hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER; + use hyperactor::config::ENABLE_NATIVE_V1_CASTING; + use hyperactor::config::global::ConfigLock; use hyperactor::context::Mailbox; use hyperactor::mailbox::PortReceiver; use hyperactor::mailbox::open_port; @@ -1131,8 +1233,7 @@ mod tests { } } - #[async_timed_test(timeout_secs = 30)] - async fn test_cast_and_reply_v1() { + async fn execute_cast_and_reply_v1() { let MeshSetupV1 { instance, actor_mesh_ref, @@ -1147,8 +1248,22 @@ mod tests { } #[async_timed_test(timeout_secs = 30)] - async fn test_cast_and_accum_v1() { - let config = config::global::lock(); + async fn test_cast_and_reply_v1_retrofit() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, false); + execute_cast_and_reply_v1().await + } + + #[async_timed_test(timeout_secs = 30)] + async fn test_cast_and_reply_v1_native() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + execute_cast_and_reply_v1().await + } + + async fn execute_cast_and_accum_v1(config: &ConfigLock) { // Use temporary config for this test let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1); @@ -1163,4 +1278,20 @@ mod tests { let ranks = actor_mesh_ref.values().collect::>(); execute_cast_and_accum(ranks, instance, reply1_rx, reply_tos).await; } + + #[async_timed_test(timeout_secs = 30)] + async fn test_cast_and_accum_v1_retrofit() { + let config = config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, false); + execute_cast_and_accum_v1(&config).await + } + + #[async_timed_test(timeout_secs = 30)] + async fn test_cast_and_accum_v1_native() { + let config = config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + execute_cast_and_accum_v1(&config).await + } } diff --git a/hyperactor_mesh/src/comm/multicast.rs b/hyperactor_mesh/src/comm/multicast.rs index e5de966ae..db8e18884 100644 --- a/hyperactor_mesh/src/comm/multicast.rs +++ b/hyperactor_mesh/src/comm/multicast.rs @@ -20,17 +20,22 @@ use hyperactor::declare_attrs; use hyperactor::message::Castable; use hyperactor::message::ErasedUnbound; use hyperactor::message::IndexedErasedUnbound; +use hyperactor::proc::SEQ_INFO; +use hyperactor::proc::SeqInfo; use hyperactor::reference::ActorId; use ndslice::Extent; use ndslice::Point; +use ndslice::Region; use ndslice::Shape; use ndslice::Slice; use ndslice::selection::Selection; use ndslice::selection::routing::RoutingFrame; use serde::Deserialize; use serde::Serialize; +use uuid::Uuid; use crate::reference::ActorMeshId; +use crate::v1; /// A union of slices that can be used to represent arbitrary subset of /// ranks in a gang. It is represented by a Slice together with a Selection. @@ -230,6 +235,72 @@ pub(crate) struct ForwardMessage { pub(crate) message: CastMessageEnvelope, } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub(super) struct CastMessageHeaders { + /// The client who sent this message. + pub(super) sender: ActorId, + /// The client-assigned session id of this message. + pub(super) session_id: Uuid, + /// The client-assigned sequence numbers of this message. + pub(super) seqs: v1::ValueMesh, +} + +/// The is used to start casting a message to a group of actors. +#[derive(Serialize, Deserialize, Debug, Clone, Named)] +pub(crate) struct CastMessageV1 { + /// The information used to set the headers of the messages sent to the + /// destination actors. These information are not used by the comm actors + /// for routing. + pub(super) cast_headers: CastMessageHeaders, + /// The destination mesh's region. + pub(super) dest_region: Region, + /// The destination port of the message. It could match multiple actors with + /// rank wildcard. + pub(super) dest_port: DestinationPort, + /// The serialized message. + pub(super) data: ErasedUnbound, +} + +impl CastMessageV1 { + /// Create a new CastMessageEnvelope. + pub(crate) fn new( + sender: ActorId, + dest_mesh: &v1::Name, + dest_region: Region, + message: M, + session_id: Uuid, + seqs: v1::ValueMesh, + ) -> Result + where + A: Referable + RemoteHandles>, + M: Castable + RemoteMessage, + { + let data = ErasedUnbound::try_from_message(message)?; + let cast_headers = CastMessageHeaders { + sender, + session_id, + seqs, + }; + Ok(Self { + cast_headers, + dest_region, + dest_port: DestinationPort::new::(dest_mesh.to_string()), + data, + }) + } +} + +/// Forward a message to procs of next hops. This is used by comm actor to +/// forward a message to other comm actors following the selection topology. +/// This message is not visible to the clients. +#[derive(Serialize, Deserialize, Debug, Clone, Named)] +pub(super) struct ForwardMessageV1 { + /// The destination of the message. + pub(super) dests: Vec, + /// The message to distribute. + pub(super) message: CastMessageV1, +} + declare_attrs! { /// Used inside headers to store the originating sender of a cast. pub attr CAST_ORIGINATING_SENDER: ActorId; @@ -238,9 +309,22 @@ declare_attrs! { pub attr CAST_POINT: Point; } -pub fn set_cast_info_on_headers(headers: &mut Attrs, cast_point: Point, sender: ActorId) { +pub(crate) fn set_cast_info_on_headers( + headers: &mut Attrs, + cast_point: Point, + sender: ActorId, + seq_info: Option, +) { headers.set(CAST_POINT, cast_point); headers.set(CAST_ORIGINATING_SENDER, sender); + match seq_info { + Some(i) => { + headers.set(SEQ_INFO, i); + } + None => { + headers.remove(SEQ_INFO); + } + } } pub trait CastInfo { diff --git a/hyperactor_mesh/src/connect.rs b/hyperactor_mesh/src/connect.rs index 455502229..87a7fead2 100644 --- a/hyperactor_mesh/src/connect.rs +++ b/hyperactor_mesh/src/connect.rs @@ -425,10 +425,10 @@ mod tests { #[tokio::test] async fn test_simple_connection() -> Result<()> { let proc = Proc::local(); - let (client, _client_handle) = proc.instance("client")?; + let (client, _) = proc.instance("client")?; let (connect, completer) = Connect::allocate(client.self_id().clone(), client); let actor = proc.spawn::("actor", ()).await?; - actor.send(connect)?; + actor.send(&completer.caps, connect)?; let (mut rd, mut wr) = completer.complete().await?.into_split(); let send = [3u8, 4u8, 5u8, 6u8]; try_join!( diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 72c688219..b566d558c 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -22,7 +22,6 @@ use hyperactor::ActorHandle; use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Instance; -use hyperactor::Named; use hyperactor::RemoteMessage; use hyperactor::WorldId; use hyperactor::actor::ActorStatus; @@ -188,9 +187,8 @@ pub fn global_root_client() -> &'static Instance<()> { // The hook logs each undeliverable, along with whether a sink // was present at the time of receipt, which helps diagnose // lost or misrouted events. - let (undeliverable_tx, undeliverable_rx) = - client.open_port::>(); - undeliverable_tx.bind_to(Undeliverable::::port()); + let (_undeliverable_tx, undeliverable_rx) = + client.bind_actor_port::>(); hyperactor::mailbox::supervise_undeliverable_messages_with( undeliverable_rx, crate::proc_mesh::get_global_supervision_sink, @@ -346,9 +344,8 @@ impl ProcMesh { // `global_root_client()`. let (client, _handle) = client_proc.instance("client")?; // Bind an undeliverable message port in the client. - let (undeliverable_messages, client_undeliverable_receiver) = - client.open_port::>(); - undeliverable_messages.bind_to(Undeliverable::::port()); + let (_undeliverable_messages, client_undeliverable_receiver) = + client.bind_actor_port::>(); hyperactor::mailbox::supervise_undeliverable_messages( supervision_port.clone(), client_undeliverable_receiver, diff --git a/hyperactor_mesh/src/v1/actor_mesh.rs b/hyperactor_mesh/src/v1/actor_mesh.rs index 78c175b07..2df96cb3b 100644 --- a/hyperactor_mesh/src/v1/actor_mesh.rs +++ b/hyperactor_mesh/src/v1/actor_mesh.rs @@ -19,6 +19,7 @@ use hyperactor::RemoteHandles; use hyperactor::RemoteMessage; use hyperactor::actor::Referable; use hyperactor::attrs::Attrs; +use hyperactor::config; use hyperactor::context; use hyperactor::message::Castable; use hyperactor::message::IndexedErasedUnbound; @@ -27,6 +28,7 @@ use hyperactor_mesh_macros::sel; use ndslice::Selection; use ndslice::ViewExt as _; use ndslice::view; +use ndslice::view::MapIntoExt; use ndslice::view::Region; use ndslice::view::View; use serde::Deserialize; @@ -36,7 +38,10 @@ use serde::Serializer; use crate::CommActor; use crate::actor_mesh as v0_actor_mesh; +use crate::actor_mesh::CAST_ACTOR_MESH_ID; use crate::comm::multicast; +use crate::comm::multicast::CastMessageV1; +use crate::metrics; use crate::proc_mesh::mesh_agent::ActorState; use crate::reference::ActorMeshId; use crate::resource; @@ -143,7 +148,16 @@ impl ActorMeshRef { M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor { if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() { - self.cast_v0(cx, message, root_comm_actor) + if config::global::get(config::ENABLE_NATIVE_V1_CASTING) { + assert!( + config::global::get(config::ENABLE_DEST_ACTOR_REORDERING_BUFFER), + "Native V1 casting requires ENABLE_DEST_ACTOR_REORDERING_BUFFER to be enabled", + ); + self.cast_v1(cx, message, root_comm_actor.actor_id().name()); + Ok(()) + } else { + self.cast_v0(cx, message, root_comm_actor) + } } else { for (point, actor) in self.iter() { let create_rank = point.rank(); @@ -214,6 +228,61 @@ impl ActorMeshRef { } } + fn cast_v1(&self, cx: &impl context::Actor, message: M, comm_actor_name: &str) + where + A: RemoteHandles + RemoteHandles>, + M: Castable + RemoteMessage, + { + let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!( + "message_type" => M::typename(), + "message_variant" => message.arm().unwrap_or_default(), + )); + + let actor_ids: ValueMesh<_> = self.proc_mesh.map_into(|proc| proc.actor_id(&self.name)); + // TODO: pick a random actor to send the cast to so we can have + // round-robin load balancing. + let comm_actor_id = actor_ids + .iter() + .next() + .expect("mesh should have at least one actor") + .1 + .proc_id() + .actor_id(comm_actor_name, 0); + let comm_actor_ref = ActorRef::::attest(comm_actor_id); + + // This block is infallible so is okay to assign the sequence numbers + // without worrying about rollback. + { + let sequencer = cx.instance().sequencer(); + let seqs = actor_ids.map_into(|actor_id| sequencer.assign_seq(actor_id)); + + let cast_message = CastMessageV1::new::( + cx.instance().self_id().clone(), + &self.name, + self.region(), + message, + sequencer.session_id(), + seqs, + ) + .expect("infallible because CastMessage should not fail for serialization"); + + let mut headers = Attrs::new(); + headers.set( + multicast::CAST_ORIGINATING_SENDER, + cx.instance().self_id().clone(), + ); + // Set CAST_ACTOR_MESH_ID temporarily to support supervision's + // v0 trnasition. Should be removed once supervision is migrated + // and ActorMeshId is deleted. + let actor_mesh_id = ActorMeshId::V1(self.name.clone()); + headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id); + + comm_actor_ref + .send_with_headers(cx, headers, cast_message) + .expect("infallible because CastMessage should not fail for serialization"); + } + } + pub async fn actor_states( &self, cx: &impl context::Actor, @@ -477,12 +546,12 @@ mod tests { .expect("rank 3 exists") .send(instance, testactor::GetActorId(port.bind())) .expect("send to rank 3 should succeed"); - let id_a = RealClock + let (id_a, _) = RealClock .timeout(Duration::from_secs(3), rx.recv()) .await .expect("timed out waiting for first reply") .expect("channel closed before first reply"); - let id_b = RealClock + let (id_b, _) = RealClock .timeout(Duration::from_secs(3), rx.recv()) .await .expect("timed out waiting for second reply") diff --git a/hyperactor_mesh/src/v1/host_mesh.rs b/hyperactor_mesh/src/v1/host_mesh.rs index f951df479..a1b6f42c0 100644 --- a/hyperactor_mesh/src/v1/host_mesh.rs +++ b/hyperactor_mesh/src/v1/host_mesh.rs @@ -662,10 +662,14 @@ mod tests { use std::collections::HashSet; use std::collections::VecDeque; + use hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER; + use hyperactor::config::ENABLE_NATIVE_V1_CASTING; + use hyperactor::config::global::ConfigLock; use hyperactor::context::Mailbox as _; use itertools::Itertools; use ndslice::ViewExt; use ndslice::extent; + use timed_test::async_timed_test; use tokio::process::Command; use super::*; @@ -704,9 +708,7 @@ mod tests { ); } - #[tokio::test] - async fn test_allocate() { - let config = hyperactor::config::global::lock(); + async fn execute_allocate(config: &ConfigLock) { let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false); let instance = testing::instance().await; @@ -763,7 +765,7 @@ mod tests { .collect(); while !expected_actor_ids.is_empty() { - let actor_id = rx.recv().await.unwrap(); + let (actor_id, _seq) = rx.recv().await.unwrap(); assert!( expected_actor_ids.remove(&actor_id), "got {actor_id}, expect {expected_actor_ids:?}" @@ -804,6 +806,21 @@ mod tests { } } + #[async_timed_test(timeout_secs = 180)] + async fn test_allocate() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false); + execute_allocate(&config).await; + } + + #[async_timed_test(timeout_secs = 180)] + async fn test_allocate_v1() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard1 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + execute_allocate(&config).await; + } + /// Allocate a new port on localhost. This drops the listener, releasing the socket, /// before returning. Hyperactor's channel::net applies SO_REUSEADDR, so we do not hav /// to wait out the socket's TIMED_WAIT state. @@ -814,9 +831,7 @@ mod tests { ChannelAddr::Tcp(listener.local_addr().unwrap()) } - #[tokio::test] - async fn test_extrinsic_allocation() { - let config = hyperactor::config::global::lock(); + async fn execute_extrinsic_allocation(config: &ConfigLock) { let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false); let program = crate::testresource::get("monarch/hyperactor_mesh/bootstrap"); @@ -857,6 +872,20 @@ mod tests { .expect("hosts shutdown"); } + async fn test_extrinsic_allocation() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false); + execute_extrinsic_allocation(&config).await; + } + + #[tokio::test] + async fn test_extrinsic_allocation_v1() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard1 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + execute_extrinsic_allocation(&config).await; + } + #[tokio::test] async fn test_failing_proc_allocation() { let program = buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap(); diff --git a/hyperactor_mesh/src/v1/proc_mesh.rs b/hyperactor_mesh/src/v1/proc_mesh.rs index b0d04a33e..8cf11eaf2 100644 --- a/hyperactor_mesh/src/v1/proc_mesh.rs +++ b/hyperactor_mesh/src/v1/proc_mesh.rs @@ -43,6 +43,7 @@ use crate::assign::Ranks; use crate::comm::CommActorMode; use crate::proc_mesh::mesh_agent; use crate::proc_mesh::mesh_agent::ActorState; +use crate::proc_mesh::mesh_agent::MeshAgentMessage; use crate::proc_mesh::mesh_agent::MeshAgentMessageClient; use crate::proc_mesh::mesh_agent::ProcMeshAgent; use crate::proc_mesh::mesh_agent::ReconfigurableMailboxSender; @@ -310,18 +311,17 @@ impl ProcMesh { let (config_handle, mut config_receiver) = cx.mailbox().open_port(); for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() { + let message = MeshAgentMessage::Configure { + rank, + forwarder: proc_channel_addr.clone(), + supervisor: None, // no supervisor; we just crash + address_book: address_book.clone(), + configured: config_handle.bind(), + record_supervision_events: true, + }; mesh_agent - .configure( - cx, - rank, - proc_channel_addr.clone(), - None, // no supervisor; we just crash - address_book.clone(), - config_handle.bind(), - true, - ) - .await - .map_err(Error::ConfigurationError)?; + .send(cx, message) + .map_err(|e| Error::ConfigurationError(e.into()))?; } let mut completed = Ranks::new(running.len()); while !completed.is_full() { @@ -671,11 +671,21 @@ impl view::RankedSliceable for ProcMeshRef { #[cfg(test)] mod tests { use std::assert_matches::assert_matches; - + use std::ops::Deref; + + use hyperactor::Proc; + use hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER; + use hyperactor::config::ENABLE_NATIVE_V1_CASTING; + use hyperactor::context; + use hyperactor::id; + use hyperactor::mailbox::BoxableMailboxSender; + use hyperactor::mailbox::DialMailboxRouter; use ndslice::ViewExt; use ndslice::extent; use timed_test::async_timed_test; + use uuid::Uuid; + use super::*; use crate::v1; use crate::v1::testactor; use crate::v1::testing; @@ -702,8 +712,7 @@ mod tests { ); } - #[async_timed_test(timeout_secs = 30)] - async fn test_spawn_actor() { + async fn execute_spawn_actor() { hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default()); let instance = testing::instance().await; @@ -714,6 +723,163 @@ mod tests { } } + #[async_timed_test(timeout_secs = 30)] + async fn test_spawn_actor_v1_casting() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + execute_spawn_actor().await; + } + + #[async_timed_test(timeout_secs = 30)] + async fn test_spawn_actor_v0_casting() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false); + execute_spawn_actor().await; + } + + // * Spawn an actor mesh, and then + // * do a random numbers of cast to it to bump the seq numbers for all + // actors participating in the cast. + // * This is to avoid the test mistakenly passing. + async fn spawm_for_seq_test( + cx: &impl context::Actor, + proc_mesh: &ProcMeshRef, + ) -> ActorMesh { + let actor_mesh = proc_mesh + .spawn::(cx, "test", &()) + .await + .unwrap(); + + let (instance, _) = cx + .instance() + .proc() + .instance(&format!("random_casts_{}", Uuid::now_v7())) + .unwrap(); + let n = fastrand::u64(3..10); + for _ in 0..n { + actor_mesh.cast(&instance, ()).unwrap(); + } + actor_mesh + } + + #[async_timed_test(timeout_secs = 30)] + async fn test_seq_from_same_sender_to_different_meshes() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + + hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default()); + let instance = testing::instance().await; + let session_id = instance.sequencer().session_id(); + + for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await { + let proc_mesh_ref = proc_mesh.deref(); + + // Sequence numbers are scoped based on the (client, dest) pair. + // So casts to different meshes from the same client instance would + // result in seq 1 for all casts. + let handles = (0..3) + .map(|_| { + let proc_mesh_ref_clone = proc_mesh_ref.clone(); + tokio::spawn(async move { + let actor_mesh = spawm_for_seq_test(instance, &proc_mesh_ref_clone).await; + let expected_seqs = vec![1; 8]; + testactor::assert_casting_correctness( + &actor_mesh, + instance, + Some((session_id, expected_seqs)), + ) + .await; + }) + }) + .collect::>(); + futures::future::join_all(handles).await; + } + } + + // Verify that the seq numbers are assigned correctly when we cast to + // different views of the same root mesh. + #[async_timed_test(timeout_secs = 30)] + async fn test_seq_from_same_sender_to_difference_views() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + + hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default()); + + let instance = testing::instance().await; + let session_id = instance.sequencer().session_id(); + + for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await { + let actor_mesh = spawm_for_seq_test(instance, &proc_mesh).await; + + // First cast. The seq should be 1 for all actors. + let expected_seqs = vec![1; 8]; + testactor::assert_casting_correctness( + &actor_mesh, + instance, + Some((session_id, expected_seqs)), + ) + .await; + + // Verify casting to the sliced actor mesh + let sliced_actor_mesh = actor_mesh.range("replicas", 1..3).unwrap(); + // Second cast. The seq should be 2 for actors in the sliced mesh. + let expected_seqs = vec![2; 4]; + testactor::assert_casting_correctness( + &sliced_actor_mesh, + instance, + Some((session_id, expected_seqs)), + ) + .await; + + // Verify casting to a different sliced actor mesh + let sliced_actor_mesh = actor_mesh.range("replicas", 0..2).unwrap(); + // For actors in the previous sliced mesh, the seq should be 3 since + // this is the third cast for them. For other actors, the seq should + // be 2. + let expected_seqs = vec![2, 2, 3, 3]; + testactor::assert_casting_correctness( + &sliced_actor_mesh, + instance, + Some((session_id, expected_seqs)), + ) + .await; + } + } + + #[async_timed_test(timeout_secs = 30)] + async fn test_seq_from_different_senders() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true); + let _guard2 = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true); + + hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default()); + let proc = Proc::new(id!(test[0]), DialMailboxRouter::new().boxed()); + let (instance, _) = proc.instance("test_client").unwrap(); + let (first_instance, _) = proc.instance("first_client").unwrap(); + let (second_instance, _) = proc.instance("second_client").unwrap(); + let (third_instance, _) = proc.instance("third_client").unwrap(); + + for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await { + let actor_mesh = spawm_for_seq_test(&instance, &proc_mesh).await; + + // Sequence numbers are calculated based on the sequencer, i.e. the + // client name. So three casts would result in seq 1 for all actors. + for inst in [&first_instance, &second_instance, &third_instance] { + let expected_seqs = vec![1; 8]; + let session_id = inst.sequencer().session_id(); + testactor::assert_casting_correctness( + &actor_mesh, + inst, + Some((session_id, expected_seqs)), + ) + .await; + } + } + } + #[tokio::test] async fn test_failing_spawn_actor() { hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default()); diff --git a/hyperactor_mesh/src/v1/testactor.rs b/hyperactor_mesh/src/v1/testactor.rs index b430c83a5..27a889e46 100644 --- a/hyperactor_mesh/src/v1/testactor.rs +++ b/hyperactor_mesh/src/v1/testactor.rs @@ -12,7 +12,7 @@ //! does not work across crate boundaries) #[cfg(test)] -use std::collections::HashSet; +use std::collections::HashMap; use std::collections::VecDeque; #[cfg(test)] use std::time::Duration; @@ -35,12 +35,16 @@ use hyperactor::clock::Clock as _; use hyperactor::clock::RealClock; #[cfg(test)] use hyperactor::mailbox; +use hyperactor::proc::SEQ_INFO; +use hyperactor::proc::SeqInfo; use hyperactor::supervision::ActorSupervisionEvent; use ndslice::Point; #[cfg(test)] use ndslice::ViewExt as _; use serde::Deserialize; use serde::Serialize; +#[cfg(test)] +use uuid::Uuid; use crate::comm::multicast::CastInfo; #[cfg(test)] @@ -55,6 +59,7 @@ use crate::v1::testing; #[hyperactor::export( spawn = true, handlers = [ + () { cast = true }, GetActorId { cast = true }, GetCastInfo { cast = true }, CauseSupervisionEvent { cast = true }, @@ -63,9 +68,9 @@ use crate::v1::testing; )] pub struct TestActor; -/// A message that returns the recipient actor's id. +/// A message that returns the recipient actor's id and cast message's seq info. #[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)] -pub struct GetActorId(#[binding(include)] pub PortRef); +pub struct GetActorId(#[binding(include)] pub PortRef<(ActorId, Option)>); #[derive(Debug, Clone, Serialize, Deserialize)] pub enum SupervisionEventType { @@ -79,6 +84,13 @@ pub enum SupervisionEventType { #[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)] pub struct CauseSupervisionEvent(pub SupervisionEventType); +#[async_trait] +impl Handler<()> for TestActor { + async fn handle(&mut self, cx: &Context, _: ()) -> Result<(), anyhow::Error> { + Ok(()) + } +} + #[async_trait] impl Handler for TestActor { async fn handle( @@ -86,7 +98,8 @@ impl Handler for TestActor { cx: &Context, GetActorId(reply): GetActorId, ) -> Result<(), anyhow::Error> { - reply.send(cx, cx.self_id().clone())?; + let seq_info = cx.headers().get(SEQ_INFO).cloned(); + reply.send(cx, (cx.self_id().clone(), seq_info))?; Ok(()) } } @@ -233,7 +246,7 @@ impl Actor for FailingCreateTestActor { pub async fn assert_mesh_shape(actor_mesh: ActorMesh) { let instance = testing::instance().await; // Verify casting to the root actor mesh - assert_casting_correctness(&actor_mesh, instance).await; + assert_casting_correctness(&actor_mesh, instance, None).await; // Just pick the first dimension. Slice half of it off. // actor_mesh.extent(). @@ -242,29 +255,48 @@ pub async fn assert_mesh_shape(actor_mesh: ActorMesh) { // Verify casting to the sliced actor mesh let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap(); - assert_casting_correctness(&sliced_actor_mesh, instance).await; + assert_casting_correctness(&sliced_actor_mesh, instance, None).await; } #[cfg(test)] -/// Cast to the actor mesh, and verify that all actors are reached. +/// Cast to the actor mesh, and verify that all actors are reached, and the +/// sequence numbers, if provided, are correct. pub async fn assert_casting_correctness( actor_mesh: &ActorMeshRef, instance: &Instance<()>, + expected_seqs: Option<(Uuid, Vec)>, ) { - let (port, mut rx) = mailbox::open_port(instance); - actor_mesh.cast(instance, GetActorId(port.bind())).unwrap(); - - let mut expected_actor_ids: HashSet<_> = actor_mesh + let (port, mut rx) = mailbox::open_port(&instance); + actor_mesh.cast(&instance, GetActorId(port.bind())).unwrap(); + let expected_actor_ids = actor_mesh .values() .map(|actor_ref| actor_ref.actor_id().clone()) - .collect(); + .collect::>(); + let mut expected: HashMap<&ActorId, Option> = match expected_seqs { + None => expected_actor_ids + .iter() + .map(|actor_id| (actor_id, None)) + .collect(), + Some((session_id, seqs)) => expected_actor_ids + .iter() + .zip( + seqs.into_iter() + .map(|seq| Some(SeqInfo::Session { session_id, seq })), + ) + .collect(), + }; - while !expected_actor_ids.is_empty() { - let actor_id = rx.recv().await.unwrap(); + while !expected.is_empty() { + let (actor_id, rcved) = rx.recv().await.unwrap(); + let rcv_seq_info = rcved.unwrap(); + let removed = expected.remove(&actor_id); assert!( - expected_actor_ids.remove(&actor_id), + removed.is_some(), "got {actor_id}, expect {expected_actor_ids:?}" ); + if let Some(expected) = removed.unwrap() { + assert_eq!(expected, rcv_seq_info, "got different seq for {actor_id}"); + } } // No more messages diff --git a/hyperactor_mesh/src/v1/value_mesh.rs b/hyperactor_mesh/src/v1/value_mesh.rs index 4c7b49a9c..2a6489be3 100644 --- a/hyperactor_mesh/src/v1/value_mesh.rs +++ b/hyperactor_mesh/src/v1/value_mesh.rs @@ -15,6 +15,8 @@ use futures::Future; use ndslice::view; use ndslice::view::Ranked; use ndslice::view::Region; +use serde::Deserialize; +use serde::Serialize; /// A mesh of values, where each value is associated with a rank. /// @@ -22,7 +24,7 @@ use ndslice::view::Region; /// The mesh is *complete*: `ranks.len()` always equals /// `region.num_ranks()`. Every rank in the region has exactly one /// associated value. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] // only if T implements +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] // only if T implements pub struct ValueMesh { region: Region, ranks: Vec, diff --git a/hyperactor_multiprocess/src/proc_actor.rs b/hyperactor_multiprocess/src/proc_actor.rs index 2640bb6a0..e70339c8e 100644 --- a/hyperactor_multiprocess/src/proc_actor.rs +++ b/hyperactor_multiprocess/src/proc_actor.rs @@ -878,7 +878,6 @@ mod tests { use hyperactor::channel::ChannelTransport; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; - use hyperactor::data::Named; use hyperactor::forward; use hyperactor::id; use hyperactor::reference::ActorRef; @@ -1205,9 +1204,8 @@ mod tests { // A test supervisor. let mut system = System::new(server_handle.local_addr().clone()); let supervisor = system.attach().await.unwrap(); - let (supervisor_supervision_tx, mut supervisor_supervision_receiver) = - supervisor.open_port::(); - supervisor_supervision_tx.bind_to(ProcSupervisionMessage::port()); + let (_supervisor_supervision_tx, mut supervisor_supervision_receiver) = + supervisor.bind_actor_port::(); let supervisor_actor_ref: ActorRef = ActorRef::attest(supervisor.self_id().clone()); @@ -1388,8 +1386,7 @@ mod tests { // Build a supervisor. let supervisor = system.attach().await.unwrap(); - let (sup_tx, _sup_rx) = supervisor.open_port::(); - sup_tx.bind_to(ProcSupervisionMessage::port()); + let (_sup_tx, _sup_rx) = supervisor.bind_actor_port::(); let sup_ref = ActorRef::::attest(supervisor.self_id().clone()); // Construct a system sender. @@ -1417,7 +1414,7 @@ mod tests { ) .await .unwrap(); - let proc_0_client = proc_0.attach("client").unwrap(); + let (proc_0_client, _) = proc_0.instance("client").unwrap(); let (proc_0_undeliverable_tx, mut proc_0_undeliverable_rx) = proc_0_client.open_port(); // Bootstrap a second proc 'world[1]', join the system. @@ -1462,7 +1459,10 @@ mod tests { let ttl = 66 + i as u64; // Avoid ttl = 66! let (once_handle, _) = proc_0_client.open_once_port::(); ping_handle - .send(PingPongMessage(ttl, pong_handle.bind(), once_handle.bind())) + .send( + &proc_0_client, + PingPongMessage(ttl, pong_handle.bind(), once_handle.bind()), + ) .unwrap(); } @@ -1512,8 +1512,7 @@ mod tests { // Build a supervisor. let supervisor = system.attach().await.unwrap(); - let (sup_tx, _sup_rx) = supervisor.open_port::(); - sup_tx.bind_to(ProcSupervisionMessage::port()); + let (_sup_tx, _sup_rx) = supervisor.bind_actor_port::(); let sup_ref = ActorRef::::attest(supervisor.self_id().clone()); // Construct a system sender. @@ -1579,7 +1578,10 @@ mod tests { let ttl = 10u64; // Avoid ttl = 66! let (once_tx, once_rx) = system_client.open_once_port::(); ping_handle - .send(PingPongMessage(ttl, pong_handle.bind(), once_tx.bind())) + .send( + &system_client, + PingPongMessage(ttl, pong_handle.bind(), once_tx.bind()), + ) .unwrap(); assert!(once_rx.recv().await.unwrap()); diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index 7f5929dd2..240b53ab8 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -1652,11 +1652,14 @@ impl Handler for SystemActor { // The proc has expired heartbeating and it manages the lifecycle of system, schedule system stop let (tx, _) = cx.open_once_port::<()>(); - cx.port().send(SystemMessage::Stop { - worlds: None, - proc_timeout: Duration::from_secs(5), - reply_port: tx.bind(), - })?; + cx.port().send( + &cx, + SystemMessage::Stop { + worlds: None, + proc_timeout: Duration::from_secs(5), + reply_port: tx.bind(), + }, + )?; } if world.state.status == WorldStatus::Live { @@ -1993,7 +1996,8 @@ mod tests { async fn test_host_join_before_world() { // Spins up a new world with 2 hosts, with 3 procs each. let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10)); - let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap(); + let (system_actor_handle, system_proc) = SystemActor::bootstrap(params).await.unwrap(); + let (client, _) = system_proc.instance("client").unwrap(); // Use a local proc actor to join the system. let mut host_actors: Vec = Vec::new(); @@ -2006,14 +2010,17 @@ mod tests { for host_actor in host_actors.iter_mut() { // Join the world. system_actor_handle - .send(SystemMessage::Join { - proc_id: host_actor.local_proc_id.clone(), - world_id: world_id.clone(), - proc_message_port: host_actor.local_proc_message_port.bind(), - proc_addr: host_actor.local_proc_addr.clone(), - labels: HashMap::new(), - lifecycle_mode: ProcLifecycleMode::ManagedBySystem, - }) + .send( + &client, + SystemMessage::Join { + proc_id: host_actor.local_proc_id.clone(), + world_id: world_id.clone(), + proc_message_port: host_actor.local_proc_message_port.bind(), + proc_addr: host_actor.local_proc_addr.clone(), + labels: HashMap::new(), + lifecycle_mode: ProcLifecycleMode::ManagedBySystem, + }, + ) .unwrap(); // We should get a joined message. @@ -2028,13 +2035,16 @@ mod tests { let num_procs = 6; let shape = Shape::Definite(vec![2, 3]); system_actor_handle - .send(SystemMessage::UpsertWorld { - world_id: world_id.clone(), - shape, - num_procs_per_host: 3, - env: Environment::Local, - labels: HashMap::new(), - }) + .send( + &client, + SystemMessage::UpsertWorld { + world_id: world_id.clone(), + shape, + num_procs_per_host: 3, + env: Environment::Local, + labels: HashMap::new(), + }, + ) .unwrap(); let mut all_procs: Vec = Vec::new(); @@ -2067,7 +2077,8 @@ mod tests { async fn test_host_join_after_world() { // Spins up a new world with 2 hosts, with 3 procs each. let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10)); - let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap(); + let (system_actor_handle, system_proc) = SystemActor::bootstrap(params).await.unwrap(); + let (client, _) = system_proc.instance("client").unwrap(); // Create a new world message and send to system actor let world_name = "test".to_string(); @@ -2075,13 +2086,16 @@ mod tests { let num_procs = 6; let shape = Shape::Definite(vec![2, 3]); system_actor_handle - .send(SystemMessage::UpsertWorld { - world_id: world_id.clone(), - shape, - num_procs_per_host: 3, - env: Environment::Local, - labels: HashMap::new(), - }) + .send( + &client, + SystemMessage::UpsertWorld { + world_id: world_id.clone(), + shape, + num_procs_per_host: 3, + env: Environment::Local, + labels: HashMap::new(), + }, + ) .unwrap(); // Use a local proc actor to join the system. @@ -2093,14 +2107,17 @@ mod tests { for host_actor in host_actors.iter_mut() { // Join the world. system_actor_handle - .send(SystemMessage::Join { - proc_id: host_actor.local_proc_id.clone(), - world_id: world_id.clone(), - proc_message_port: host_actor.local_proc_message_port.bind(), - proc_addr: host_actor.local_proc_addr.clone(), - labels: HashMap::new(), - lifecycle_mode: ProcLifecycleMode::ManagedBySystem, - }) + .send( + &client, + SystemMessage::Join { + proc_id: host_actor.local_proc_id.clone(), + world_id: world_id.clone(), + proc_message_port: host_actor.local_proc_message_port.bind(), + proc_addr: host_actor.local_proc_addr.clone(), + labels: HashMap::new(), + lifecycle_mode: ProcLifecycleMode::ManagedBySystem, + }, + ) .unwrap(); // We should get a joined message. @@ -2217,13 +2234,16 @@ mod tests { // Create one. let world_id = id!(world); system_actor_handle - .send(SystemMessage::UpsertWorld { - world_id: world_id.clone(), - shape: Shape::Definite(vec![1]), - num_procs_per_host: 1, - env: Environment::Local, - labels: HashMap::new(), - }) + .send( + &client, + SystemMessage::UpsertWorld { + world_id: world_id.clone(), + shape: Shape::Definite(vec![1]), + num_procs_per_host: 1, + env: Environment::Local, + labels: HashMap::new(), + }, + ) .unwrap(); // Now we should know a world. @@ -2242,8 +2262,7 @@ mod tests { // Build a supervisor. let supervisor = system.attach().await.unwrap(); - let (sup_tx, _sup_rx) = supervisor.open_port::(); - sup_tx.bind_to(ProcSupervisionMessage::port()); + let (_sup_tx, _sup_rx) = supervisor.bind_actor_port::(); let sup_ref = ActorRef::::attest(supervisor.self_id().clone()); // Construct a system sender. @@ -2268,7 +2287,7 @@ mod tests { ) .await .unwrap(); - let proc_0_client = proc_0.attach("client").unwrap(); + let (proc_0_client, _) = proc_0.instance("client").unwrap(); let (proc_0_undeliverable_tx, _proc_0_undeliverable_rx) = proc_0_client.open_port(); // Bootstrap a second proc 'world[1]', join the system. @@ -2324,7 +2343,10 @@ mod tests { let ttl = 1_u64; let (game_over, on_game_over) = proc_0_client.open_once_port::(); ping_handle - .send(PingPongMessage(ttl, pong_handle.bind(), game_over.bind())) + .send( + &proc_0_client, + PingPongMessage(ttl, pong_handle.bind(), game_over.bind()), + ) .unwrap(); // We expect message delivery failure prevents the game from @@ -2363,11 +2385,14 @@ mod tests { // Create a new world message and send to system actor let (client_tx, client_rx) = client.open_once_port::<()>(); - system_actor_handle.send(SystemMessage::Stop { - worlds: None, - proc_timeout: Duration::from_secs(5), - reply_port: client_tx.bind(), - })?; + system_actor_handle.send( + &client, + SystemMessage::Stop { + worlds: None, + proc_timeout: Duration::from_secs(5), + reply_port: client_tx.bind(), + }, + )?; client_rx.recv().await?; // Check that it has stopped. diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 758658874..129d0defb 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -13,6 +13,7 @@ use std::time::Duration; use hyperactor::ActorHandle; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; +use hyperactor::context; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::logging::LogClientActor; @@ -22,6 +23,7 @@ use hyperactor_mesh::logging::LogForwardMessage; use hyperactor_mesh::selection::Selection; use hyperactor_mesh::shared_cell::SharedCell; use monarch_hyperactor::context::PyInstance; +use monarch_hyperactor::instance_dispatch; use monarch_hyperactor::logging::LoggerRuntimeActor; use monarch_hyperactor::logging::LoggerRuntimeMessage; use monarch_hyperactor::proc_mesh::PyProcMesh; @@ -47,6 +49,7 @@ pub struct LoggingMeshClient { impl LoggingMeshClient { async fn flush_internal( + cx: &impl context::Actor, client_actor: ActorHandle, forwarder_mesh: SharedCell>, ) -> Result<(), anyhow::Error> { @@ -61,11 +64,14 @@ impl LoggingMeshClient { .open_once_port::(); // First initialize a sync flush. - client_actor.send(LogClientMessage::StartSyncFlush { - expected_procs: forwarder_inner_mesh.proc_mesh().shape().slice().len(), - reply: reply_tx.bind(), - version: version_tx.bind(), - })?; + client_actor.send( + cx, + LogClientMessage::StartSyncFlush { + expected_procs: forwarder_inner_mesh.proc_mesh().shape().slice().len(), + reply: reply_tx.bind(), + version: version_tx.bind(), + }, + )?; let version = version_rx.recv().await?; @@ -86,8 +92,9 @@ impl LoggingMeshClient { #[pymethods] impl LoggingMeshClient { #[staticmethod] - fn spawn(_instance: &PyInstance, proc_mesh: &PyProcMesh) -> PyResult { + fn spawn(instance: &PyInstance, proc_mesh: &PyProcMesh) -> PyResult { let proc_mesh = proc_mesh.try_inner()?; + let instance_for_task = instance.clone(); PyPythonTask::new(async move { let client_actor = proc_mesh.client_proc().spawn("log_client", ()).await?; let client_actor_ref = client_actor.bind(); @@ -99,16 +106,19 @@ impl LoggingMeshClient { let forwarder_mesh_for_callback = forwarder_mesh.clone(); proc_mesh .register_onstop_callback(|| async move { - match RealClock - .timeout( - FLUSH_TIMEOUT, - Self::flush_internal( - client_actor_for_callback, - forwarder_mesh_for_callback, - ), - ) - .await - { + let flush_result = instance_dispatch!(instance_for_task, |cx_instance| { + RealClock + .timeout( + FLUSH_TIMEOUT, + Self::flush_internal( + cx_instance, + client_actor_for_callback, + forwarder_mesh_for_callback, + ), + ) + .await + }); + match flush_result { Ok(Ok(())) => { tracing::debug!("flush completed successfully during shutdown"); } @@ -136,7 +146,7 @@ impl LoggingMeshClient { fn set_mode<'py>( &self, _py: Python<'py>, - _instance: &PyInstance, + instance: &PyInstance, stream_to_client: bool, aggregate_window_sec: Option, level: u8, @@ -167,24 +177,34 @@ impl LoggingMeshClient { ) .map_err(|e| PyErr::new::(e.to_string()))?; - self.client_actor - .send(LogClientMessage::SetAggregate { - aggregate_window_sec, - }) - .map_err(anyhow::Error::msg)?; - + instance_dispatch!(instance, |cx_instance| { + self.client_actor + .send( + cx_instance, + LogClientMessage::SetAggregate { + aggregate_window_sec, + }, + ) + .map_err(anyhow::Error::msg)?; + }); Ok(()) } // A sync flush mechanism for the client make sure all the stdout/stderr are streamed back and flushed. - fn flush(&self, _instance: &PyInstance) -> PyResult { + fn flush(&self, instance: &PyInstance) -> PyResult { let forwarder_mesh = self.forwarder_mesh.clone(); let client_actor = self.client_actor.clone(); + let instance_for_task = instance.clone(); PyPythonTask::new(async move { - Self::flush_internal(client_actor, forwarder_mesh) - .await - .map_err(|e| PyErr::new::(e.to_string())) + instance_dispatch!(instance_for_task, |cx_instance| { + Self::flush_internal(cx_instance, client_actor, forwarder_mesh) + .await + .map_err(|e| { + PyErr::new::(e.to_string()) + })?; + }); + Ok(()) }) } } diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 7ebdab54d..d15694b53 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -29,6 +29,7 @@ use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Instance; use hyperactor::PortRef; +use hyperactor::Proc; use hyperactor::context; use hyperactor::mailbox::MailboxSenderError; use hyperactor_mesh::Mesh; @@ -78,6 +79,7 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult module = "monarch._rust_bindings.monarch_extension.mesh_controller" )] struct _Controller { + instance: Instance<()>, // The actor that represents this object. controller_handle: Arc>>, all_ranks: Slice, broker_id: (String, usize), @@ -98,6 +100,7 @@ impl _Controller { fn new(py: Python, py_proc_mesh: &PyProcMesh) -> PyResult { let proc_mesh: SharedCell = py_proc_mesh.inner.clone(); let proc_mesh_ref = proc_mesh.borrow().unwrap(); + let instance = proc_mesh_ref.client().clone_for_py(); let shape = proc_mesh_ref.shape(); let slice = shape.slice(); let all_ranks = shape.slice().clone(); @@ -122,8 +125,8 @@ impl _Controller { Ok(Arc::new(Mutex::new(controller_handle))); r })??; - Ok(Self { + instance, controller_handle, all_ranks, // note that 0 is the _pid_ of the broker, which will be 0 for @@ -165,23 +168,26 @@ impl _Controller { }; self.controller_handle .blocking_lock() - .send(msg) + .send(&self.instance, msg) .map_err(to_py_error) } fn drop_refs(&mut self, refs: Vec) -> PyResult<()> { self.controller_handle .blocking_lock() - .send(ClientToControllerMessage::DropRefs { refs }) + .send(&self.instance, ClientToControllerMessage::DropRefs { refs }) .map_err(to_py_error) } fn sync_at_exit(&mut self, port: PyPortId) -> PyResult<()> { self.controller_handle .blocking_lock() - .send(ClientToControllerMessage::SyncAtExit { - port: PortRef::attest(port.into()), - }) + .send( + &self.instance, + ClientToControllerMessage::SyncAtExit { + port: PortRef::attest(port.into()), + }, + ) .map_err(to_py_error) } @@ -195,16 +201,22 @@ impl _Controller { let message: WorkerMessage = convert(message)?; self.controller_handle .blocking_lock() - .send(ClientToControllerMessage::Send { slices, message }) + .send( + &self.instance, + ClientToControllerMessage::Send { slices, message }, + ) .map_err(to_py_error) } fn _drain_and_stop(&mut self) -> PyResult<()> { self.controller_handle .blocking_lock() - .send(ClientToControllerMessage::Send { - slices: vec![self.all_ranks.clone()], - message: WorkerMessage::Exit { error: None }, - }) + .send( + &self.instance, + ClientToControllerMessage::Send { + slices: vec![self.all_ranks.clone()], + message: WorkerMessage::Exit { error: None }, + }, + ) .map_err(to_py_error)?; self.controller_handle .blocking_lock() diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index b354007c2..bae48b37c 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -52,6 +52,7 @@ use tracing::Instrument; use crate::buffers::FrozenBuffer; use crate::config::SHARED_ASYNCIO_RUNTIME; use crate::context::PyInstance; +use crate::instance_dispatch; use crate::local_state_broker::BrokerId; use crate::local_state_broker::LocalStateBrokerMessage; use crate::mailbox::EitherPortRef; @@ -287,7 +288,7 @@ impl PythonMessage { } => { let broker = BrokerId::new(local_state_broker).resolve(cx).unwrap(); let (send, recv) = cx.open_once_port(); - broker.send(LocalStateBrokerMessage::Get(id, send))?; + broker.send(cx, LocalStateBrokerMessage::Get(id, send))?; let state = recv.recv().await?; let mut state_it = state.state.into_iter(); Python::with_gil(|py| { @@ -434,11 +435,12 @@ pub(super) struct PythonActorHandle { #[pymethods] impl PythonActorHandle { // TODO: do the pickling in rust - // TODO(pzhang) Use instance after its required by PortHandle. - fn send(&self, _instance: &PyInstance, message: &PythonMessage) -> PyResult<()> { - self.inner - .send(message.clone()) - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + fn send(&self, instance: &PyInstance, message: &PythonMessage) -> PyResult<()> { + instance_dispatch!(instance, |cx_instance| { + self.inner + .send(cx_instance, message.clone()) + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + }); Ok(()) } @@ -676,7 +678,7 @@ impl Actor for PythonActorPanicWatcher { } async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { - this.handle().send(HandlePanic {})?; + this.handle().send(this, HandlePanic {})?; Ok(()) } } @@ -692,7 +694,7 @@ impl Handler for PythonActorPanicWatcher { // async endpoint executed successfully. // run again let h = cx.deref().handle(); - h.send(HandlePanic {})?; + h.send(cx, HandlePanic {})?; } Some(Err(err)) => { tracing::error!("caught error in async endpoint {}", err); diff --git a/monarch_hyperactor/src/code_sync/manager.rs b/monarch_hyperactor/src/code_sync/manager.rs index c15467ae4..086b34414 100644 --- a/monarch_hyperactor/src/code_sync/manager.rs +++ b/monarch_hyperactor/src/code_sync/manager.rs @@ -246,11 +246,14 @@ impl CodeSyncMessageHandler for CodeSyncManager { // Forward rsync connection port to the RsyncActor, which will do the actual // connection and run the client. let (tx, mut rx) = cx.open_port::>(); - self.get_rsync_actor(cx).await?.send(RsyncMessage { - connect, - result: tx.bind(), - workspace, - })?; + self.get_rsync_actor(cx).await?.send( + cx, + RsyncMessage { + connect, + result: tx.bind(), + workspace, + }, + )?; // Observe any errors. let _ = rx.recv().await?.map_err(anyhow::Error::msg)?; } @@ -261,14 +264,15 @@ impl CodeSyncMessageHandler for CodeSyncManager { // Forward rsync connection port to the RsyncActor, which will do the actual // connection and run the client. let (tx, mut rx) = cx.open_port::>(); - self.get_conda_sync_actor(cx) - .await? - .send(CondaSyncMessage { + self.get_conda_sync_actor(cx).await?.send( + cx, + CondaSyncMessage { connect, result: tx.bind(), workspace, path_prefix_replacements, - })?; + }, + )?; // Observe any errors. let _ = rx.recv().await?.map_err(anyhow::Error::msg)?; } @@ -322,7 +326,7 @@ impl CodeSyncMessageHandler for CodeSyncManager { let (tx, mut rx) = cx.open_port(); self.get_auto_reload_actor(cx) .await? - .send(AutoReloadMessage { result: tx.bind() })?; + .send(cx, AutoReloadMessage { result: tx.bind() })?; rx.recv().await?.map_err(anyhow::Error::msg)?; anyhow::Ok(()) } diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 97b7db63b..0496d5355 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -249,11 +249,12 @@ pub(super) struct PythonPortHandle { #[pymethods] impl PythonPortHandle { - // TODO(pzhang) Use instance after its required by PortHandle. - fn send(&self, _instance: &PyInstance, message: PythonMessage) -> PyResult<()> { - self.inner - .send(message) - .map_err(|err| PyErr::new::(format!("Port closed: {}", err)))?; + fn send(&self, instance: &PyInstance, message: PythonMessage) -> PyResult<()> { + instance_dispatch!(instance, |cx_instance| { + self.inner + .send(cx_instance, message) + .map_err(|err| PyErr::new::(format!("Port closed: {}", err)))?; + }); Ok(()) } diff --git a/monarch_hyperactor/src/proc.rs b/monarch_hyperactor/src/proc.rs index 97f3f1606..b05b69f9e 100644 --- a/monarch_hyperactor/src/proc.rs +++ b/monarch_hyperactor/src/proc.rs @@ -25,7 +25,6 @@ use std::time::SystemTime; use anyhow::Result; use hyperactor::ActorRef; -use hyperactor::Named; use hyperactor::RemoteMessage; use hyperactor::actor::Signal; use hyperactor::channel; @@ -453,11 +452,9 @@ impl InstanceWrapper { fn new_with_instance_and_clock(instance: Instance<()>, clock: ClockKind) -> Result { // TEMPORARY: remove after using fixed message ports. - let (message_port, message_receiver) = instance.open_port::(); - message_port.bind_to(M::port()); + let (_message_port, message_receiver) = instance.bind_actor_port::(); - let (signal_port, signal_receiver) = instance.open_port::(); - signal_port.bind_to(::port()); + let (signal_port, signal_receiver) = instance.bind_actor_port::(); let (controller_error_sender, controller_error_receiver) = watch::channel("".to_string()); let actor_id = instance.self_id().clone(); @@ -683,7 +680,7 @@ async fn check_actor_supervision_state( // TODO: should allow for multiple attempts tracing::error!("system actor is not alive, aborting!"); // Send a signal to the client to abort. - signal_port.send(Signal::Stop).unwrap(); + signal_port.send(&instance, Signal::Stop).unwrap(); } } Ok(()) diff --git a/monarch_hyperactor/src/v1/logging.rs b/monarch_hyperactor/src/v1/logging.rs index f4f9c363c..639fc7ff9 100644 --- a/monarch_hyperactor/src/v1/logging.rs +++ b/monarch_hyperactor/src/v1/logging.rs @@ -54,11 +54,14 @@ impl LoggingMeshClient { let (version_tx, version_rx) = cx.instance().open_once_port::(); // First initialize a sync flush. - client_actor.send(LogClientMessage::StartSyncFlush { - expected_procs: forwarder_mesh.region().num_ranks(), - reply: reply_tx.bind(), - version: version_tx.bind(), - })?; + client_actor.send( + cx, + LogClientMessage::StartSyncFlush { + expected_procs: forwarder_mesh.region().num_ranks(), + reply: reply_tx.bind(), + version: version_tx.bind(), + }, + )?; let version = version_rx.recv().await?; @@ -133,12 +136,15 @@ impl LoggingMeshClient { }) .map_err(|e| PyErr::new::(e.to_string()))?; - self.client_actor - .send(LogClientMessage::SetAggregate { - aggregate_window_sec, - }) - .map_err(anyhow::Error::msg)?; - + instance_dispatch!(instance, |cx_instance| { + self.client_actor.send( + cx_instance, + LogClientMessage::SetAggregate { + aggregate_window_sec, + }, + ) + }) + .map_err(|e| PyErr::new::(e.to_string()))?; Ok(()) } diff --git a/monarch_tensor_worker/Cargo.toml b/monarch_tensor_worker/Cargo.toml index c34d57fce..c622407f8 100644 --- a/monarch_tensor_worker/Cargo.toml +++ b/monarch_tensor_worker/Cargo.toml @@ -13,6 +13,7 @@ async-trait = "0.1.86" bincode = "1.3.3" clap = { version = "4.5.42", features = ["derive", "env", "string", "unicode", "wrap_help"] } cxx = "1.0.119" +derivative = "2.2" derive_more = { version = "1.0.0", features = ["full"] } futures = { version = "0.3.31", features = ["async-await", "compat"] } hyperactor = { version = "0.0.0", path = "../hyperactor" } diff --git a/monarch_tensor_worker/src/comm.rs b/monarch_tensor_worker/src/comm.rs index 51293829c..dc4d2d3e8 100644 --- a/monarch_tensor_worker/src/comm.rs +++ b/monarch_tensor_worker/src/comm.rs @@ -15,13 +15,14 @@ use anyhow::bail; use anyhow::ensure; use async_trait::async_trait; use cxx::CxxVector; +use derivative::Derivative; use hyperactor::Actor; use hyperactor::HandleClient; use hyperactor::Handler; +use hyperactor::Instance; use hyperactor::Named; use hyperactor::actor::ActorHandle; use hyperactor::forward; -use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::OncePortReceiver; use parking_lot::Mutex; @@ -463,10 +464,12 @@ impl Work for CommWork { } } -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] pub struct CommBackend { + #[derivative(Debug = "ignore")] + instance: Instance<()>, // The actor that represents this object. comm: Arc>, - mailbox: Mailbox, rank: usize, // Size of group. This is less than or equal to world_size. group_size: usize, @@ -477,8 +480,8 @@ pub struct CommBackend { impl CommBackend { pub fn new( + instance: Instance<()>, comm: Arc>, - mailbox: Mailbox, rank: usize, group_size: usize, world_size: usize, @@ -488,8 +491,8 @@ impl CommBackend { "Group must be smaller or equal to the world size" ); Self { + instance, comm, - mailbox, rank, group_size, world_size, @@ -570,13 +573,16 @@ impl Backend for CommBackend { let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() }); // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::AllReduce( - cell.clone(), - convert_reduce_op(opts.reduce_op)?, - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::AllReduce( + cell.clone(), + convert_reduce_op(opts.reduce_op)?, + Stream::get_current_stream(), + tx, + ), + )?; Ok(Box::new(CommWork::from(vec![cell], rx).await?)) } @@ -603,15 +609,18 @@ impl Backend for CommBackend { } // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); + let (tx, rx) = self.instance.open_once_port(); // This is not implemented in this function because the broadcasts we need // to create will change their behavior based on rank. - self.comm.send(CommMessage::AllGather( - output_cell.clone(), - input_cell.clone(), - Stream::get_current_stream(), - tx, - ))?; + self.comm.send( + &self.instance, + CommMessage::AllGather( + output_cell.clone(), + input_cell.clone(), + Stream::get_current_stream(), + tx, + ), + )?; let mut input_cells = vec![]; input_cells.extend(output_cell); input_cells.push(input_cell); @@ -631,13 +640,16 @@ impl Backend for CommBackend { let input_cell = TensorCell::new(unsafe { input.clone_unsafe() }); // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::AllGatherIntoTensor( - output_cell.clone(), - input_cell.clone(), - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::AllGatherIntoTensor( + output_cell.clone(), + input_cell.clone(), + Stream::get_current_stream(), + tx, + ), + )?; Ok(Box::new( CommWork::from(vec![output_cell, input_cell], rx).await?, )) @@ -645,10 +657,13 @@ impl Backend for CommBackend { async fn barrier(&self, _opts: BarrierOptions) -> Result>> { // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); + let (tx, rx) = self.instance.open_once_port(); self.comm // There's no native barrier op in nccl, so impl via all-reduce. - .send(CommMessage::Barrier(Stream::get_current_stream(), tx))?; + .send( + &self.instance, + CommMessage::Barrier(Stream::get_current_stream(), tx), + )?; Ok(Box::new(CommWork::from(vec![], rx).await?)) } @@ -663,14 +678,17 @@ impl Backend for CommBackend { let input_cell = TensorCell::new(unsafe { input.clone_unsafe() }); // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::Reduce( - input_cell.clone(), - convert_reduce_op(opts.reduce_op)?, - opts.root_rank, - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::Reduce( + input_cell.clone(), + convert_reduce_op(opts.reduce_op)?, + opts.root_rank, + Stream::get_current_stream(), + tx, + ), + )?; Ok(Box::new(CommWork::from(vec![input_cell], rx).await?)) } @@ -697,14 +715,17 @@ impl Backend for CommBackend { } // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::ReduceScatterTensor( - output_cell.clone(), - input_cell.clone(), - convert_reduce_op(opts.reduce_op)?, - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::ReduceScatterTensor( + output_cell.clone(), + input_cell.clone(), + convert_reduce_op(opts.reduce_op)?, + Stream::get_current_stream(), + tx, + ), + )?; Ok(Box::new( CommWork::from(vec![output_cell, input_cell], rx).await?, )) @@ -726,13 +747,11 @@ impl Backend for CommBackend { let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() }); // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::Send( - cell.clone(), - dst_rank, - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::Send(cell.clone(), dst_rank, Stream::get_current_stream(), tx), + )?; Ok(Box::new(CommWork::from(vec![cell], rx).await?)) } @@ -752,13 +771,11 @@ impl Backend for CommBackend { let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() }); // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::Recv( - cell.clone(), - src_rank, - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::Recv(cell.clone(), src_rank, Stream::get_current_stream(), tx), + )?; Ok(Box::new(CommWork::from(vec![cell], rx).await?)) } @@ -782,7 +799,7 @@ impl Backend for CommBackend { assert_type_and_sizes_match(outputs.as_slice(), input.scalar_type(), &input.sizes())?; // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); + let (tx, rx) = self.instance.open_once_port(); let mut messages = vec![]; // All ranks other than the root Recv, and the root rank calls Send. if self.rank == root { @@ -795,7 +812,7 @@ impl Backend for CommBackend { } for (r, output) in output_cells.clone().into_iter().enumerate() { if r != root { - let (tx_recv, _rx_recv) = self.mailbox.open_once_port(); + let (tx_recv, _rx_recv) = self.instance.open_once_port(); messages.push(CommMessage::Recv( output, r as i32, @@ -814,7 +831,7 @@ impl Backend for CommBackend { output_cells.len() )); } - let (tx_send, _rx_send) = self.mailbox.open_once_port(); + let (tx_send, _rx_send) = self.instance.open_once_port(); messages.push(CommMessage::Send( input_cell.clone(), root as i32, @@ -822,11 +839,10 @@ impl Backend for CommBackend { tx_send, )); } - self.comm.send(CommMessage::Group( - messages, - Stream::get_current_stream(), - tx, - ))?; + self.comm.send( + &self.instance, + CommMessage::Group(messages, Stream::get_current_stream(), tx), + )?; let mut inputs = vec![]; inputs.extend(output_cells); inputs.push(input_cell); @@ -853,7 +869,7 @@ impl Backend for CommBackend { assert_type_and_sizes_match(inputs.as_slice(), output.scalar_type(), &output.sizes())?; // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); + let (tx, rx) = self.instance.open_once_port(); let mut messages = vec![]; // Implementation is the inverse set of messages from gather, where all ranks // other than the root Send, and the root rank calls Recv. @@ -867,7 +883,7 @@ impl Backend for CommBackend { } for (r, input) in input_cells.clone().into_iter().enumerate() { if r != root { - let (tx_send, _rx_send) = self.mailbox.open_once_port(); + let (tx_send, _rx_send) = self.instance.open_once_port(); messages.push(CommMessage::Send( input, r as i32, @@ -886,7 +902,7 @@ impl Backend for CommBackend { input_cells.len() )); } - let (tx_recv, _rx_recv) = self.mailbox.open_once_port(); + let (tx_recv, _rx_recv) = self.instance.open_once_port(); messages.push(CommMessage::Recv( output_cell.clone(), root as i32, @@ -894,11 +910,10 @@ impl Backend for CommBackend { tx_recv, )); } - self.comm.send(CommMessage::Group( - messages, - Stream::get_current_stream(), - tx, - ))?; + self.comm.send( + &self.instance, + CommMessage::Group(messages, Stream::get_current_stream(), tx), + )?; let mut inputs = vec![]; inputs.push(output_cell); inputs.extend(input_cells); @@ -916,13 +931,16 @@ impl Backend for CommBackend { let cell = TensorCell::new(unsafe { as_singleton(tensors.as_slice())?.clone_unsafe() }); // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::Broadcast( - cell.clone(), - opts.root_rank, - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::Broadcast( + cell.clone(), + opts.root_rank, + Stream::get_current_stream(), + tx, + ), + )?; Ok(Box::new(CommWork::from(vec![cell], rx).await?)) } @@ -940,13 +958,16 @@ impl Backend for CommBackend { let input_cell = TensorCell::new(unsafe { input_buffer.clone_unsafe() }); // Call into `NcclCommActor`. - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::AllToAllSingle( - output_cell.clone(), - input_cell.clone(), - Stream::get_current_stream(), - tx, - ))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm.send( + &self.instance, + CommMessage::AllToAllSingle( + output_cell.clone(), + input_cell.clone(), + Stream::get_current_stream(), + tx, + ), + )?; Ok(Box::new( CommWork::from(vec![output_cell, input_cell], rx).await?, )) @@ -995,8 +1016,8 @@ impl Backend for CommBackend { for r in 0..output_tensors.len() { let output_cell = &output_cells[r]; let input_cell = &input_cells[r]; - let (tx_send, _rx_send) = self.mailbox.open_once_port(); - let (tx_recv, _rx_recv) = self.mailbox.open_once_port(); + let (tx_send, _rx_send) = self.instance.open_once_port(); + let (tx_recv, _rx_recv) = self.instance.open_once_port(); messages.push(CommMessage::Send( input_cell.clone(), r as i32, @@ -1010,8 +1031,9 @@ impl Backend for CommBackend { tx_recv, )); } - let (tx, rx) = self.mailbox.open_once_port(); - self.comm.send(CommMessage::Group(messages, stream, tx))?; + let (tx, rx) = self.instance.open_once_port(); + self.comm + .send(&self.instance, CommMessage::Group(messages, stream, tx))?; let mut all_cells = vec![]; all_cells.extend(output_cells); all_cells.extend(input_cells); @@ -1752,21 +1774,27 @@ mod tests { let cell0 = TensorCell::new(factory_float_tensor(&[1.0], device0.into())); let port0 = client.open_once_port(); - handle0.send(CommMessage::Send( - cell0.clone(), - 1, - Stream::get_current_stream_on_device(device0), - port0.0, - ))?; + handle0.send( + &client, + CommMessage::Send( + cell0.clone(), + 1, + Stream::get_current_stream_on_device(device0), + port0.0, + ), + )?; let cell1 = TensorCell::new(factory_float_tensor(&[1.0], device1.into())); let port1 = client.open_once_port(); - handle1.send(CommMessage::Recv( - cell1.clone(), - 0, - Stream::get_current_stream_on_device(device1), - port1.0, - ))?; + handle1.send( + &client, + CommMessage::Recv( + cell1.clone(), + 0, + Stream::get_current_stream_on_device(device1), + port1.0, + ), + )?; let (work0, work1) = tokio::join!( CommWork::from(vec![cell0.clone()], port0.1), CommWork::from(vec![cell1.clone()], port1.1) diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 956d7ca3e..8a5890873 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -1174,7 +1174,6 @@ mod tests { use anyhow::Result; use hyperactor::Instance; - use hyperactor::Named; use hyperactor::WorldId; use hyperactor::actor::ActorStatus; use hyperactor::channel::ChannelAddr; @@ -2348,8 +2347,7 @@ mod tests { // Create a fake controller for the workers to talk to. let client = System::new(system_addr.clone()).attach().await?; - let (handle, mut controller_rx) = client.open_port::(); - handle.bind_to(ControllerMessage::port()); + let (_handle, mut controller_rx) = client.bind_actor_port::(); let controller_ref: ActorRef = ActorRef::attest(client.self_id().clone()); // Create the worker world diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 590482e72..7ef065c56 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -11,6 +11,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::collections::hash_map::Entry; use std::future::Future; +use std::ops::Deref; use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; @@ -32,7 +33,6 @@ use hyperactor::PortHandle; use hyperactor::actor::ActorHandle; use hyperactor::data::Serialized; use hyperactor::forward; -use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; @@ -829,9 +829,10 @@ impl StreamActor { // it to create a new torch group. let ranks = mesh.get_ranks_for_dim_slice(&dims)?; let group_size = ranks.len(); + let (child_instance, _) = cx.child()?; let backend = CommBackend::new( + child_instance, comm, - Mailbox::new_detached(cx.self_id().clone()), self.rank, group_size, self.world_size, @@ -1095,7 +1096,7 @@ impl StreamActor { let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap(); broker - .send(message) + .send(cx, message) .map_err(|e| CallFunctionError::Error(e.into()))?; let result = recv .recv() @@ -1159,7 +1160,7 @@ impl StreamMessageHandler for StreamActor { async fn borrow_create( &mut self, - _cx: &Context, + cx: &Context, borrow: u64, tensor: Ref, first_use_sender: PortHandle<(Option, TensorCellResult)>, @@ -1188,7 +1189,7 @@ impl StreamMessageHandler for StreamActor { }; let event = self.cuda_stream().map(|stream| stream.record_event(None)); - first_use_sender.send((event, result)).map_err(|err| { + first_use_sender.send(cx, (event, result)).map_err(|err| { anyhow!( "failed sending first use event for borrow {:?}: {:?}", borrow, @@ -1245,7 +1246,7 @@ impl StreamMessageHandler for StreamActor { async fn borrow_last_use( &mut self, - _cx: &Context, + cx: &Context, borrow: u64, result: Ref, last_use_sender: PortHandle<(Option, TensorCellResult)>, @@ -1269,7 +1270,7 @@ impl StreamMessageHandler for StreamActor { _ => bail!("invalid rvalue type for borrow_last_use"), }; - last_use_sender.send((event, tensor)).map_err(|err| { + last_use_sender.send(cx, (event, tensor)).map_err(|err| { anyhow!( "failed sending last use event for borrow {:?}: {:?}", borrow, @@ -1693,7 +1694,7 @@ impl StreamMessageHandler for StreamActor { // Actually send the value. if let Some(pipe) = pipe { - pipe.send(PipeMessage::SendValue(value))?; + pipe.send(cx, PipeMessage::SendValue(value))?; } else { let result = match value { Ok(value) => Ok(Serialized::serialize(&value).map_err(anyhow::Error::from)?), @@ -1761,7 +1762,7 @@ impl StreamMessageHandler for StreamActor { self.try_define(cx, seq, results, &vec![], async |self| { let (tx, rx) = cx.open_once_port(); - pipe.send(PipeMessage::RecvValue(tx)) + pipe.send(cx, PipeMessage::RecvValue(tx)) .map_err(anyhow::Error::from) .map_err(CallFunctionError::from)?; let value = rx.recv().await.map_err(anyhow::Error::from)?; diff --git a/ndslice/src/view.rs b/ndslice/src/view.rs index 4c2555f68..c3ebd10ed 100644 --- a/ndslice/src/view.rs +++ b/ndslice/src/view.rs @@ -1515,7 +1515,7 @@ pub trait MapIntoExt: Ranked { M::build_dense_unchecked(region, values) } - fn try_map_into(self, f: impl Fn(&Self::Item) -> Result) -> Result + fn try_map_into(&self, f: impl Fn(&Self::Item) -> Result) -> Result where Self: Sized, M: BuildFromRegion, diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index cc44b591f..0d3f2c3c6 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -8,7 +8,6 @@ import asyncio import ctypes -import enum import importlib.resources import logging import operator @@ -21,6 +20,7 @@ import time import unittest import unittest.mock +from enum import auto, Enum from tempfile import TemporaryDirectory from types import ModuleType from typing import cast, Dict, Tuple @@ -79,9 +79,10 @@ from typing_extensions import assert_type -class ApiVersion(enum.Enum): - V0 = "v0" - V1 = "v1" +class ApiVersion(Enum): + V0 = auto() + V1 = auto() + V1Native = auto() needs_cuda = pytest.mark.skipif( @@ -122,11 +123,22 @@ def spawn_procs_on_host( return host.spawn_procs(per_host) +def setup_env_vars(api_ver: ApiVersion): + match api_ver: + case ApiVersion.V1Native: + os.environ["HYPERACTOR_ENABLE_NATIVE_V1_CASTING"] = "true" + os.environ["HYPERACTOR_ENABLE_DEST_ACTOR_REORDERING_BUFFER"] = "true" + case _: + pass + + def spawn_procs_on_fake_host( api_ver: ApiVersion, per_host: Dict[str, int] ) -> ProcMesh | ProcMeshV1: + setup_env_vars(api_ver) + match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: return spawn_procs_on_host(fake_in_process_host_v1("fake_host"), per_host) case ApiVersion.V0: return spawn_procs_on_host(fake_in_process_host(), per_host) @@ -137,8 +149,10 @@ def spawn_procs_on_fake_host( def spawn_procs_on_this_host( api_ver: ApiVersion, per_host: Dict[str, int] ) -> ProcMesh | ProcMeshV1: + setup_env_vars(api_ver) + match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: return spawn_procs_on_host(this_host_v1(), per_host) case ApiVersion.V0: return spawn_procs_on_host(this_host(), per_host) @@ -148,7 +162,7 @@ def spawn_procs_on_this_host( def get_this_proc(api_ver: ApiVersion): match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: return this_proc_v1() case ApiVersion.V0: return this_proc() @@ -156,7 +170,7 @@ def get_this_proc(api_ver: ApiVersion): raise ValueError(f"Unknown API version: {api_ver}") -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_choose(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -176,7 +190,7 @@ async def test_choose(api_ver: ApiVersion): assert result2 == result3 -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_stream(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -198,7 +212,7 @@ async def fetch(self, to: To): return [await x for x in to.whoami.stream()] -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_mesh_passed_to_mesh(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -209,7 +223,7 @@ async def test_mesh_passed_to_mesh(api_ver: ApiVersion): assert all[0] != all[1] -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_mesh_passed_to_mesh_on_different_proc_mesh(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -221,7 +235,7 @@ async def test_mesh_passed_to_mesh_on_different_proc_mesh(api_ver: ApiVersion): assert all[0] != all[1] -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_actor_slicing(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -238,7 +252,7 @@ def test_actor_slicing(api_ver: ApiVersion): assert result[0] == result[1] -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_aggregate(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -259,7 +273,7 @@ async def return_current_rank_str(self): return str(current_rank()) -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_rank_size(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -271,11 +285,11 @@ async def test_rank_size(api_ver: ApiVersion): assert 4 == await acc.accumulate(lambda: current_size()["gpus"]) -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_rank_string(api_ver: ApiVersion): match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: per_host = {"gpus": 2} case ApiVersion.V0: per_host = {"hosts": 1, "gpus": 2} @@ -296,7 +310,7 @@ def sync_endpoint(self, a_counter: Counter): return a_counter.value.choose().get() -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_sync_actor(api_ver: ApiVersion): proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -306,7 +320,7 @@ async def test_sync_actor(api_ver: ApiVersion): assert r == 5 -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_sync_actor_sync_client(api_ver: ApiVersion) -> None: proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -316,14 +330,14 @@ def test_sync_actor_sync_client(api_ver: ApiVersion) -> None: assert r == 5 -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_proc_mesh_size(api_ver: ApiVersion) -> None: proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) assert 2 == proc.size("gpus") -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_rank_size_sync(api_ver: ApiVersion) -> None: proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -334,7 +348,7 @@ def test_rank_size_sync(api_ver: ApiVersion) -> None: assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get() -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_accumulate_sync(api_ver: ApiVersion) -> None: proc = spawn_procs_on_fake_host(api_ver, {"gpus": 2}) @@ -351,11 +365,11 @@ def doit(self, c: Counter): return list(c.value.call().get()) -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_value_mesh(api_ver: ApiVersion) -> None: match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: per_host = {"gpus": 2} case ApiVersion.V0: per_host = {"hosts": 1, "gpus": 2} @@ -400,7 +414,7 @@ def check(module, path): check(bindings, "monarch._rust_bindings") -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_proc_mesh_liveness(api_ver: ApiVersion) -> None: mesh = spawn_procs_on_this_host(api_ver, {"gpus": 2}) @@ -436,7 +450,7 @@ async def get_async(self): return self.local.value -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_actor_tls(api_ver: ApiVersion) -> None: """Test that thread-local state is respected.""" @@ -467,7 +481,7 @@ def get_value(self): return self.local.value -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_actor_tls_full_sync(api_ver: ApiVersion) -> None: """Test that thread-local state is respected.""" @@ -495,7 +509,7 @@ async def no_more(self) -> None: self.should_exit = True -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_async_concurrency(api_ver: ApiVersion): """Test that async endpoints will be processed concurrently.""" @@ -633,7 +647,7 @@ def _handle_undeliverable_message( return True -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_actor_log_streaming(api_ver: ApiVersion) -> None: # Save original file descriptors @@ -703,10 +717,12 @@ async def test_actor_log_streaming(api_ver: ApiVersion) -> None: await am.log.call("has log streaming as level matched") match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: await asyncio.sleep(1) case ApiVersion.V0: await pm.stop() + case _: + raise ValueError(f"Unknown API version: {api_ver}") # Flush all outputs stdout_file.flush() @@ -787,7 +803,7 @@ async def test_actor_log_streaming(api_ver: ApiVersion) -> None: pass -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(120) async def test_alloc_based_log_streaming(api_ver: ApiVersion) -> None: """Test both AllocHandle.stream_logs = False and True cases.""" @@ -899,7 +915,7 @@ def _stream_logs(self) -> bool: await test_stream_logs_case(True, "stream_logs_true") -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_logging_option_defaults(api_ver: ApiVersion) -> None: # Save original file descriptors @@ -936,11 +952,13 @@ async def test_logging_option_defaults(api_ver: ApiVersion) -> None: await am.log.call("log streaming") match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: # Wait for > default aggregation window (3 seconds) await asyncio.sleep(5) case ApiVersion.V0: await pm.stop() + case _: + raise ValueError(f"Unknown API version: {api_ver}") # Flush all outputs stdout_file.flush() @@ -1018,7 +1036,7 @@ def __init__(self): # oss_skip: pytest keeps complaining about mocking get_ipython module @pytest.mark.oss_skip -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) async def test_flush_called_only_once(api_ver: ApiVersion) -> None: """Test that flush is called only once when ending an ipython cell""" mock_ipython = MockIPython() @@ -1046,7 +1064,7 @@ async def test_flush_called_only_once(api_ver: ApiVersion) -> None: # oss_skip: pytest keeps complaining about mocking get_ipython module @pytest.mark.oss_skip -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(180) async def test_flush_logs_ipython(api_ver: ApiVersion) -> None: """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" @@ -1179,7 +1197,7 @@ async def test_flush_logs_fast_exit() -> None: ), process.stdout -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_flush_on_disable_aggregation(api_ver: ApiVersion) -> None: """Test that logs are flushed when disabling aggregation. @@ -1224,11 +1242,13 @@ async def test_flush_on_disable_aggregation(api_ver: ApiVersion) -> None: await am.print.call("single log line") match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: # Wait for > default aggregation window (3 secs) await asyncio.sleep(5) case ApiVersion.V0: await pm.stop() + case _: + raise ValueError(f"Unknown API version: {api_ver}") # Flush all outputs stdout_file.flush() @@ -1274,7 +1294,7 @@ async def test_flush_on_disable_aggregation(api_ver: ApiVersion) -> None: pass -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(120) async def test_multiple_ongoing_flushes_no_deadlock(api_ver: ApiVersion) -> None: """ @@ -1305,7 +1325,7 @@ async def test_multiple_ongoing_flushes_no_deadlock(api_ver: ApiVersion) -> None futures[-1].get() -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_adjust_aggregation_window(api_ver: ApiVersion) -> None: """Test that the flush deadline is updated when the aggregation window is adjusted. @@ -1347,11 +1367,13 @@ async def test_adjust_aggregation_window(api_ver: ApiVersion) -> None: await am.print.call("second batch of logs") match api_ver: - case ApiVersion.V1: + case ApiVersion.V1 | ApiVersion.V1Native: # Wait for > aggregation window (2 secs) await asyncio.sleep(4) case ApiVersion.V0: await pm.stop() + case _: + raise ValueError(f"Unknown API version: {api_ver}") # Flush all outputs stdout_file.flush() @@ -1397,7 +1419,7 @@ async def send(self, port: Port[int]): port.send(i) -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_port_as_argument(api_ver: ApiVersion) -> None: proc_mesh = spawn_procs_on_fake_host(api_ver, {"gpus": 1}) @@ -1509,7 +1531,7 @@ def add(self, port: "Port[int]", b: int) -> None: port.send(3 + b) -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) def test_ported_actor(api_ver: ApiVersion): proc_mesh = spawn_procs_on_fake_host(api_ver, {"gpus": 1}) @@ -1551,7 +1573,7 @@ async def sleep(self, t: float) -> None: await asyncio.sleep(t) -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) def test_mesh_len(api_ver: ApiVersion): proc_mesh = spawn_procs_on_fake_host(api_ver, {"gpus": 12}) s = proc_mesh.spawn("sleep_actor", SleepActor) @@ -1608,7 +1630,7 @@ def _handle_undeliverable_message( return True -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_undeliverable_message_with_override(api_ver: ApiVersion) -> None: pm = spawn_procs_on_this_host(api_ver, {"gpus": 1}) @@ -1624,7 +1646,7 @@ async def test_undeliverable_message_with_override(api_ver: ApiVersion) -> None: pm.stop().get() -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) @pytest.mark.timeout(60) async def test_undeliverable_message_without_override(api_ver: ApiVersion) -> None: pm = spawn_procs_on_this_host(api_ver, {"gpus": 1}) @@ -1636,7 +1658,7 @@ async def test_undeliverable_message_without_override(api_ver: ApiVersion) -> No pm.stop().get() -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) def test_this_and_that(api_ver: ApiVersion): proc = get_this_proc(api_ver) counter = proc.spawn("counter", Counter, 7) @@ -1649,7 +1671,7 @@ def status(self): return 1 -@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1]) +@pytest.mark.parametrize("api_ver", [ApiVersion.V0, ApiVersion.V1, ApiVersion.V1Native]) async def test_things_survive_losing_python_reference(api_ver: ApiVersion) -> None: """Test the slice_receptor_mesh function in LOCAL mode, verifying that setup methods are called."""