From d1b83d1cc6b34230c6f9c453b263278c21a9b4bd Mon Sep 17 00:00:00 2001 From: slurye Date: Tue, 12 Aug 2025 12:14:58 -0700 Subject: [PATCH 1/2] Make NetTx<->NetRx handshake two-way to prevent errant reconnections Summary: The purpose of this diff is to handle the following scenario: 1. Process A starts serving a NetRx. 2. Process B creates a NetTx that connects to process A's NetRx. 3. B sends a few messages to A, and the messages are acked. 4. Process A dies/is killed, while B stays alive. 5. A new Process C starts serving a NetRx on the same channel as from step 1. 6. B's NetTx connects to C's NetRx, *with no way of knowing it has connected to a different process than before*. 7. B sends messages to C, starting from where it left off with A. 8. C rejects all of B's messages because of invalid sequence numbers. 9. B's NetTx eventually times out after a long time with no acks. In order to distinguish among connections from different NetTx instances to the same NetRx instance, each NetTx generates a random unique session id. This session id gets sent as part of an initial handshake from NetTx -> NetRx before the NetTx starts sending normal messages. Currently, though, NetTx doesn't wait for any handshake before starting to send messages. To resolve the issue described above, this diff introduces a global (per-process) "rx session id". When a NetTx first connects to a NetRx, the NetRx responds with its rx session id as part of the handshake. The NetTx waits for the handshake response and extracts the rx session id. If this is the first time the NetTx is connecting, the NetTx stores the rx session id. On subsequent connection attempts, the NetTx will validate the rx session id it receives from the handshake against the rx session id it previously stored; if there is a mismatch, the NetTx returns the appropriate error to its caller. Differential Revision: D79607092 --- hyperactor/src/channel/net.rs | 319 +++++++++++++++++++++++++++++++--- 1 file changed, 292 insertions(+), 27 deletions(-) diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index 3557af98f..11e85127b 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -28,6 +28,7 @@ use std::mem::take; use std::net::ToSocketAddrs; use std::pin::Pin; use std::sync::Arc; +use std::sync::LazyLock; use std::task::Poll; use backoff::ExponentialBackoffBuilder; @@ -171,6 +172,65 @@ fn build_codec() -> LengthDelimitedCodec { .new_codec() } +#[derive(thiserror::Error, Debug)] +enum HandshakeError { + #[error(transparent)] + DeserializationFailed(bincode::Error), + #[error("sending init frame failed: {0}")] + SendFailed(String), + #[error("timed out waiting for handshake response")] + Timeout, + #[error("channel closed waiting for handshake response")] + ChannelClosed, + #[error(transparent)] + IoError(io::Error), + #[error("mismatched rx session ids: {0} (stored) vs. {1} (received)")] + MismatchedRxSessionIds(u64, u64), +} + +async fn do_handshake( + tx_session_id: u64, + rx_session_id: &mut Option, + sink: &mut SplitSink, Bytes>, + stream: &mut SplitStream>, +) -> Result<(), HandshakeError> { + let data = bincode::serialize(&Frame::::Init(tx_session_id)) + .expect("unexpected serialization error"); + + if let Err(err) = sink.send(data.into()).await { + return Err(HandshakeError::SendFailed(format!("{:?}", err))); + } + + let handshake_response = tokio::select! { + _ = RealClock.sleep(config::global::get(config::MESSAGE_DELIVERY_TIMEOUT)) => None, + response = tokio_stream::StreamExt::next(stream) => Some(response) + }; + + match handshake_response { + None => Err(HandshakeError::Timeout), + Some(None) => Err(HandshakeError::ChannelClosed), + Some(Some(handshake_response)) => match handshake_response { + Err(err) => Err(HandshakeError::IoError(err)), + Ok(handshake_response) => match bincode::deserialize::(&handshake_response) { + Err(err) => Err(HandshakeError::DeserializationFailed(err)), + Ok(received_rx_session_id) => { + if rx_session_id + .is_some_and(|rx_session_id| rx_session_id != received_rx_session_id) + { + Err(HandshakeError::MismatchedRxSessionIds( + rx_session_id.unwrap(), + received_rx_session_id, + )) + } else { + rx_session_id.replace(received_rx_session_id); + Ok(()) + } + } + }, + }, + } +} + /// A Tx implemented on top of a Link. The Tx manages the link state, /// reconnections, etc. #[derive(Debug)] @@ -499,6 +559,7 @@ impl NetTx { } let session_id = rand::random(); + let mut rx_session_id = None; let log_id = format!("session {}.{}", link.dest(), session_id); let mut state = State::init(&log_id); let mut conn = Conn::reconnect_with_default(); @@ -716,10 +777,7 @@ impl NetTx { match link.connect().await { Ok(stream) => { let framed = Framed::new(stream, build_codec()); - let (mut sink, stream) = futures::StreamExt::split(framed); - let data = bincode::serialize(&Frame::::Init(session_id)) - .expect("unexpected serialization error"); - let initialized = sink.send(data.into()).await.is_ok(); + let (mut sink, mut stream) = futures::StreamExt::split(framed); metrics::CHANNEL_CONNECTIONS.add( 1, @@ -732,21 +790,71 @@ impl NetTx { // Need to resend unacked after reconnecting. let largest_acked = unacked.largest_acked; outbox.requeue_unacked(unacked); - ( - State::Running(Deliveries { - outbox, - // unacked messages are put back to outbox. So they are not - // considered as "sent yet unacked" message anymore. But - // we still want to keep `largest_acked` to known Rx's watermark. - unacked: Unacked::new(largest_acked, &log_id), - }), - if initialized { - backoff.reset(); - Conn::Connected { sink, stream } - } else { - Conn::reconnect(backoff) - }, + + let handshake_result = do_handshake::<_, M>( + session_id, + &mut rx_session_id, + &mut sink, + &mut stream, ) + .await; + + match handshake_result { + Ok(_) => { + ( + State::Running(Deliveries { + outbox, + // unacked messages are put back to outbox. So they are not + // considered as "sent yet unacked" message anymore. But + // we still want to keep `largest_acked` to known Rx's watermark. + unacked: Unacked::new(largest_acked, &log_id), + }), + { + backoff.reset(); + Conn::Connected { sink, stream } + }, + ) + } + Err(err) => match err { + HandshakeError::SendFailed(msg) => { + tracing::debug!( + "session {}.{}: handshake send failed: {}", + link.dest(), + session_id, + msg + ); + ( + State::Running(Deliveries { + outbox, + unacked: Unacked::new(largest_acked, &log_id), + }), + Conn::reconnect(backoff), + ) + } + err => { + let error_msg = format!( + "session {}.{}: handshake failed: {}", + link.dest(), + session_id, + err + ); + tracing::error!(error_msg); + ( + State::Closing { + deliveries: Deliveries { + outbox, + unacked: Unacked::new( + largest_acked, + &log_id, + ), + }, + reason: error_msg, + }, + Conn::reconnect_with_default(), + ) + } + }, + } } Err(err) => { tracing::debug!( @@ -1152,6 +1260,13 @@ impl ServerConn { let serialized = serialize_ack(next_seq - 1); futures::SinkExt::send(sink, serialized).await } + + pub async fn send_handshake_response(&mut self, rx_session_id: u64) -> anyhow::Result<()> { + let serialized = bincode::serialize(&rx_session_id)?; + futures::SinkExt::send(&mut self.sink, serialized.into()) + .await + .map_err(anyhow::Error::from) + } } /// An MVar is a primitive that combines synchronization and the exchange @@ -1240,6 +1355,10 @@ struct Next { ack: u64, } +/// Unique NetRx session identifier for the process. Used during handshake +/// to determine whether or not to allow a reconnection from a NetTx. +static GLOBAL_RX_SESSION_ID: LazyLock = LazyLock::new(rand::random); + /// Manages persistent sessions, ensuring that only one connection can own /// a session at a time, and arranging for session handover. #[derive(Clone)] @@ -1264,9 +1383,10 @@ impl SessionManager { S: AsyncRead + AsyncWrite + Send + 'static + Unpin, M: RemoteMessage, { - let session_id = conn.handshake::().await?; + let tx_session_id = conn.handshake::().await?; + conn.send_handshake_response(*GLOBAL_RX_SESSION_ID).await?; - let session_var = match self.sessions.entry(session_id) { + let session_var = match self.sessions.entry(tx_session_id) { Entry::Occupied(entry) => entry.get().clone(), Entry::Vacant(entry) => { // We haven't seen this session before. We begin with seq=0 and ack=0. @@ -1277,7 +1397,7 @@ impl SessionManager { }; let next = session_var.take().await; - let (next, res) = conn.process(session_id, tx, cancel_token, next).await; + let (next, res) = conn.process(tx_session_id, tx, cancel_token, next).await; session_var.put(next).await; if let Err(ref err) = res { @@ -2561,6 +2681,16 @@ mod tests { ) .await .unwrap(); + + // Process handshake + bincode::deserialize::( + tokio_stream::StreamExt::next(framed) + .await + .unwrap() + .unwrap() + .as_ref(), + ) + .unwrap(); } for (seq, message) in messages { @@ -2669,6 +2799,31 @@ mod tests { }; } + #[async_timed_test(timeout_secs = 60)] + async fn test_server_session_handshake_response() { + let config = config::global::lock(); + let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1); + + let manager = SessionManager::new(); + let (_handle, mut framed, _rx, _cancel_token) = serve::(&manager).await; + + framed + .send(bincode::serialize(&Frame::::Init(123)).unwrap().into()) + .await + .unwrap(); + + let rx_session_id = bincode::deserialize::( + tokio_stream::StreamExt::next(&mut framed) + .await + .unwrap() + .unwrap() + .as_ref(), + ) + .unwrap(); + + assert_eq!(rx_session_id, *GLOBAL_RX_SESSION_ID); + } + #[async_timed_test(timeout_secs = 60)] async fn test_ack_from_server_session() { let config = config::global::lock(); @@ -2765,8 +2920,11 @@ mod tests { assert_eq!(frame, expected, "from ln={loc}"); } + static RX_SESSION_ID: u64 = 456; + async fn verify_stream( stream: &mut SplitStream>, + sink: &mut SplitSink, Bytes>, expects: &[(u64, M)], expect_session_id: Option, loc: u32, @@ -2783,6 +2941,8 @@ mod tests { assert_eq!(session_id, expected_id, "from ln={loc}"); } + send_handshake_response(sink, RX_SESSION_ID).await; + for expect in expects { verify_message(stream, expect.clone(), loc).await; } @@ -2790,6 +2950,15 @@ mod tests { session_id } + async fn send_handshake_response( + sink: &mut SplitSink, Bytes>, + handshake_response: u64, + ) { + sink.send(bincode::serialize(&handshake_response).unwrap().into()) + .await + .unwrap(); + } + async fn net_tx_send(tx: &NetTx, msgs: &[u64]) { for msg in msgs { tx.try_post(*msg, unused_return_channel()).unwrap(); @@ -2809,6 +2978,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; let id = verify_stream( &mut stream, + &mut sink, &[(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)], None, line!(), @@ -2827,8 +2997,15 @@ mod tests { // Sent a new message to verify all sent messages will not be resent. net_tx_send(&tx, &[105]).await; { - let (_sink, mut stream) = take_receiver(&receiver_storage).await; - verify_stream(&mut stream, &[(5, 105)], Some(session_id), line!()).await; + let (mut sink, mut stream) = take_receiver(&receiver_storage).await; + verify_stream( + &mut stream, + &mut sink, + &[(5, 105)], + Some(session_id), + line!(), + ) + .await; // client DuplexStream is dropped here. This breaks the connection. }; } @@ -2855,6 +3032,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; let id = verify_stream( &mut stream, + &mut sink, &[(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)], session_id, line!(), @@ -2881,9 +3059,10 @@ mod tests { { let client = receiver_storage.take().await; let framed = Framed::new(client, build_codec()); - let (_sink, mut stream) = futures::StreamExt::split(framed); + let (mut sink, mut stream) = futures::StreamExt::split(framed); verify_stream( &mut stream, + &mut sink, &[(2, 102), (3, 103), (4, 104)], session_id, line!(), @@ -2902,6 +3081,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 1st send. (2, 102), @@ -2939,6 +3119,7 @@ mod tests { let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 1st send. (4, 104), @@ -2966,9 +3147,10 @@ mod tests { for _ in 0..n { { - let (_sink, mut stream) = take_receiver(&receiver_storage).await; + let (mut sink, mut stream) = take_receiver(&receiver_storage).await; verify_stream( &mut stream, + &mut sink, &[ // From the 2nd send. (8, 108), @@ -2983,6 +3165,89 @@ mod tests { } } + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_timeout() { + let link = MockLink::::new(); + // Override the default (1m) for the purposes of this test. + let config = config::global::lock(); + let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1)); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + // Never send the handshake response + verify_tx_closed(&mut tx.status, "timed out waiting for handshake response").await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_incorrect_rx_session_id_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + { + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + send_handshake_response(&mut sink, 1234321).await; + // Wait for handshake response to be processed by NetTx before dropping sink/stream. Otherwise + // the channel will be closed and we will get the wrong error. + RealClock.sleep(tokio::time::Duration::from_secs(3)).await; + } + + { + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + send_handshake_response(&mut sink, 123).await; + // Ditto + RealClock.sleep(tokio::time::Duration::from_secs(3)).await; + } + + verify_tx_closed( + &mut tx.status, + format!( + "mismatched rx session ids: {} (stored) vs. {} (received)", + 1234321, 123 + ) + .as_str(), + ) + .await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_deserialization_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + let (mut sink, _stream) = take_receiver(&receiver_storage).await; + sink.send(Bytes::from_static(b"bad")).await.unwrap(); + + verify_tx_closed( + &mut tx.status, + "handshake failed: io error: unexpected end of file", + ) + .await; + } + + #[tracing_test::traced_test] + #[async_timed_test(timeout_secs = 60)] + async fn test_handshake_channel_closed_failure() { + let link = MockLink::::new(); + let receiver_storage = link.receiver_storage(); + let mut tx = NetTx::::new(link); + net_tx_send(&tx, &[100]).await; + + { + let _ = take_receiver(&receiver_storage).await; + } + + verify_tx_closed( + &mut tx.status, + "handshake failed: channel closed waiting for handshake response", + ) + .await; + } + #[async_timed_test(timeout_secs = 15)] async fn test_ack_before_redelivery_in_net_tx() { let link = MockLink::::new(); @@ -2994,7 +3259,7 @@ mod tests { let (return_channel_tx, return_channel_rx) = oneshot::channel(); net_tx.try_post(100, return_channel_tx).unwrap(); let (mut sink, mut stream) = take_receiver(&receiver_storage).await; - verify_stream(&mut stream, &[(0, 100)], None, line!()).await; + verify_stream(&mut stream, &mut sink, &[(0, 100)], None, line!()).await; // ack it sink.send(serialize_ack(0)).await.unwrap(); // confirm Tx received ack @@ -3037,7 +3302,7 @@ mod tests { tx.try_post(100, unused_return_channel()).unwrap(); let (mut sink, mut stream) = take_receiver(&receiver_storage).await; // Confirm message is sent to rx. - verify_stream(&mut stream, &[(0, 100)], None, line!()).await; + verify_stream(&mut stream, &mut sink, &[(0, 100)], None, line!()).await; // ack it sink.send(serialize_ack(0)).await.unwrap(); RealClock.sleep(Duration::from_secs(3)).await; From db21c844871830a23c835d24e92586e356d8aeeb Mon Sep 17 00:00:00 2001 From: Sam Lurye Date: Tue, 12 Aug 2025 12:44:12 -0700 Subject: [PATCH 2/2] Semi-private python API for overriding handle_undeliverable_message inside PythonActor (#797) Summary: Pull Request resolved: https://github.com/meta-pytorch/monarch/pull/797 This diff makes undeliverable message handling overridable for python actors, using the newly introduced `Actor._handle_undeliverable_message` method. Previously, the rust implementation of `PythonActor` simply used the default `Actor::handle_undeliverable_message` implementation. Now, `PythonActor` overrides `handle_undeliverable_message` to call into the corresponding method on the underlying python class. Differential Revision: D79841379 --- monarch_hyperactor/src/actor.rs | 23 ++++++ monarch_hyperactor/src/mailbox.rs | 31 +++++++- .../monarch_hyperactor/actor.pyi | 9 --- .../monarch_hyperactor/mailbox.pyi | 39 ++++++++-- python/monarch/_src/actor/actor_mesh.py | 17 +++++ python/tests/test_python_actors.py | 71 ++++++++++++++++++- 6 files changed, 172 insertions(+), 18 deletions(-) diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 84893fa37..b86cdaf0e 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -21,6 +21,8 @@ use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; use hyperactor::OncePortHandle; +use hyperactor::mailbox::MessageEnvelope; +use hyperactor::mailbox::Undeliverable; use hyperactor::message::Bind; use hyperactor::message::Bindings; use hyperactor::message::Unbind; @@ -50,6 +52,7 @@ use crate::local_state_broker::BrokerId; use crate::local_state_broker::LocalStateBrokerMessage; use crate::mailbox::EitherPortRef; use crate::mailbox::PyMailbox; +use crate::mailbox::PythonUndeliverableMessageEnvelope; use crate::proc::InstanceWrapper; use crate::proc::PyActorId; use crate::proc::PyProc; @@ -498,6 +501,26 @@ impl Actor for PythonActor { ); Ok(()) } + + async fn handle_undeliverable_message( + &mut self, + cx: &Instance, + envelope: Undeliverable, + ) -> Result<(), anyhow::Error> { + assert_eq!(envelope.0.sender(), cx.self_id()); + + Python::with_gil(|py| { + self.actor + .call_method( + py, + "_handle_undeliverable_message", + (PythonUndeliverableMessageEnvelope { inner: envelope },), + None, + ) + .map_err(|err| anyhow::Error::from(SerializablePyErr::from(py, &err))) + }) + .map(|_| ()) + } } /// Create a new TaskLocals with its own asyncio event loop in a dedicated thread. diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 114f64624..042de1f85 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -420,10 +420,38 @@ impl PythonPortReceiver { module = "monarch._rust_bindings.monarch_hyperactor.mailbox" )] pub(crate) struct PythonUndeliverableMessageEnvelope { - #[allow(dead_code)] // At this time, field `inner` isn't read. pub(crate) inner: Undeliverable, } +#[pymethods] +impl PythonUndeliverableMessageEnvelope { + fn __repr__(&self) -> String { + format!( + "UndeliverableMessageEnvelope(sender={}, dest={}, error={})", + self.inner.0.sender(), + self.inner.0.dest(), + self.error_msg() + ) + } + + fn sender(&self) -> PyActorId { + PyActorId { + inner: self.inner.0.sender().clone(), + } + } + + fn dest(&self) -> PyPortId { + self.inner.0.dest().clone().into() + } + + fn error_msg(&self) -> String { + self.inner + .0 + .error() + .map_or("None".to_string(), |e| e.to_string()) + } +} + #[derive(Debug)] #[pyclass( name = "UndeliverablePortReceiver", @@ -713,5 +741,6 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index 1fc4607b2..973dbe25e 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -210,15 +210,6 @@ class PythonMessage: @property def kind(self) -> PythonMessageKind: ... -class UndeliverableMessageEnvelope: - """ - An envelope representing a message that could not be delivered. - - This object is opaque; its contents are not accessible from Python. - """ - - ... - @final class PythonActorHandle: """ diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi index a1ea21473..cb258c3cd 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi @@ -8,10 +8,7 @@ from typing import final, Protocol -from monarch._rust_bindings.monarch_hyperactor.actor import ( - PythonMessage, - UndeliverableMessageEnvelope, -) +from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask @@ -20,9 +17,9 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Shape @final class PortId: - def __init__(self, actor_id: ActorId, index: int) -> None: + def __init__(self, *, actor_id: ActorId, port: int) -> None: """ - Create a new port id given an actor id and an index. + Create a new port id given an actor id and a port index. """ ... def __repr__(self) -> str: ... @@ -68,6 +65,12 @@ class PortRef: A reference to a remote port over which PythonMessages can be sent. """ + def __init__(self, port_id: PortId) -> None: + """ + Create a new port ref given a port id. + """ + ... + def send(self, mailbox: Mailbox, message: PythonMessage) -> None: """Send a single message to the port's receiver.""" ... @@ -220,3 +223,27 @@ class Reducer(Protocol): This method's Rust counterpart is `CommReducer::reduce`. """ + +class UndeliverableMessageEnvelope: + """ + An envelope representing a message that could not be delivered. + """ + + def __repr__(self) -> str: ... + def sender(self) -> ActorId: + """ + The actor id of the sender. + """ + ... + + def dest(self) -> PortId: + """ + The port id of the destination. + """ + ... + + def error_msg(self) -> str: + """ + The error message describing why the message could not be delivered. + """ + ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 0e659e2c2..60291f915 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -58,6 +58,7 @@ OncePortRef, PortReceiver as HyPortReceiver, PortRef, + UndeliverableMessageEnvelope, ) from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared @@ -941,6 +942,17 @@ def _post_mortem_debug(self, exc_tb) -> None: pdb_wrapper.post_mortem(exc_tb) self._maybe_exit_debugger(do_continue=False) + def _handle_undeliverable_message( + self, message: UndeliverableMessageEnvelope + ) -> None: + handle_undeliverable = getattr( + self.instance, "_handle_undeliverable_message", None + ) + if handle_undeliverable is not None: + handle_undeliverable(message) + else: + raise RuntimeError(f"a message was undeliverable and returned: {message}") + def _is_mailbox(x: object) -> bool: if hasattr(x, "__monarch_ref__"): @@ -983,6 +995,11 @@ def _new_with_shape(self, shape: Shape) -> Self: "actor implementations are not meshes, but we can't convince the typechecker of it..." ) + def _handle_undeliverable_message( + self, message: UndeliverableMessageEnvelope + ) -> None: + raise RuntimeError(f"a message was undeliverable and returned: {message}") + class ActorMesh(MeshTrait, Generic[T]): def __init__( diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 8c8014ab1..8cb46b102 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -18,14 +18,24 @@ import time import unittest from types import ModuleType -from typing import cast +from typing import cast, Tuple import pytest import torch +from monarch._rust_bindings.monarch_hyperactor.actor import ( + PythonMessage, + PythonMessageKind, +) +from monarch._rust_bindings.monarch_hyperactor.mailbox import ( + PortId, + PortRef, + UndeliverableMessageEnvelope, +) +from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask -from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port +from monarch._src.actor.actor_mesh import ActorMesh, Channel, MonarchContext, Port from monarch.actor import ( Accumulator, @@ -1010,3 +1020,60 @@ def test_mesh_len(): proc_mesh = local_proc_mesh(gpus=12).get() s = proc_mesh.spawn("sync_actor", SyncActor).get() assert 12 == len(s) + + +class UndeliverableMessageReceiver(Actor): + def __init__(self): + self._messages = asyncio.Queue() + + @endpoint + async def receive_undeliverable( + self, sender: ActorId, dest: PortId, error_msg: str + ) -> None: + await self._messages.put((sender, dest, error_msg)) + + @endpoint + async def get_messages(self) -> Tuple[ActorId, PortId, str]: + return await self._messages.get() + + +class UndeliverableMessageSender(Actor): + def __init__(self, receiver: UndeliverableMessageReceiver): + self._receiver = receiver + + @endpoint + def send_undeliverable(self) -> None: + mailbox = MonarchContext.get().mailbox + port_id = PortId( + actor_id=ActorId( + world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus" + ), + port=1234, + ) + port_ref = PortRef(port_id) + port_ref.send( + mailbox, + PythonMessage(PythonMessageKind.Result(None), b"123"), + ) + + def _handle_undeliverable_message( + self, message: UndeliverableMessageEnvelope + ) -> None: + self._receiver.receive_undeliverable.call_one( + message.sender(), message.dest(), message.error_msg() + ).get() + + +@pytest.mark.timeout(60) +async def test_undeliverable_message() -> None: + pm = proc_mesh(gpus=1) + receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver).get() + sender = pm.spawn( + "undeliverable_sender", UndeliverableMessageSender, receiver + ).get() + sender.send_undeliverable.call().get() + sender, dest, error_msg = receiver.get_messages.call_one().get() + assert sender.actor_name == "undeliverable_sender" + assert dest.actor_id.actor_name == "bogus" + assert error_msg is not None + pm.stop().get()