Skip to content

Commit ea4e5c8

Browse files
Sam Luryefacebook-github-bot
authored andcommitted
Semi-private python API for overriding handle_undeliverable_message inside PythonActor (#797)
Summary: This diff makes undeliverable message handling overridable for python actors, using the newly introduced `Actor._handle_undeliverable_message` method. Previously, the rust implementation of `PythonActor` simply used the default `Actor::handle_undeliverable_message` implementation. Now, `PythonActor` overrides `handle_undeliverable_message` to call into the corresponding method on the underlying python class. Differential Revision: D79841379
1 parent a99e0cc commit ea4e5c8

File tree

6 files changed

+172
-18
lines changed

6 files changed

+172
-18
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use hyperactor::Handler;
2121
use hyperactor::Instance;
2222
use hyperactor::Named;
2323
use hyperactor::OncePortHandle;
24+
use hyperactor::mailbox::MessageEnvelope;
25+
use hyperactor::mailbox::Undeliverable;
2426
use hyperactor::message::Bind;
2527
use hyperactor::message::Bindings;
2628
use hyperactor::message::Unbind;
@@ -50,6 +52,7 @@ use crate::local_state_broker::BrokerId;
5052
use crate::local_state_broker::LocalStateBrokerMessage;
5153
use crate::mailbox::EitherPortRef;
5254
use crate::mailbox::PyMailbox;
55+
use crate::mailbox::PythonUndeliverableMessageEnvelope;
5356
use crate::proc::InstanceWrapper;
5457
use crate::proc::PyActorId;
5558
use crate::proc::PyProc;
@@ -498,6 +501,26 @@ impl Actor for PythonActor {
498501
);
499502
Ok(())
500503
}
504+
505+
async fn handle_undeliverable_message(
506+
&mut self,
507+
cx: &Instance<Self>,
508+
envelope: Undeliverable<MessageEnvelope>,
509+
) -> Result<(), anyhow::Error> {
510+
assert_eq!(envelope.0.sender(), cx.self_id());
511+
512+
Python::with_gil(|py| {
513+
self.actor
514+
.call_method(
515+
py,
516+
"_handle_undeliverable_message",
517+
(PythonUndeliverableMessageEnvelope { inner: envelope },),
518+
None,
519+
)
520+
.map_err(|err| anyhow::Error::from(SerializablePyErr::from(py, &err)))
521+
})
522+
.map(|_| ())
523+
}
501524
}
502525

503526
/// Create a new TaskLocals with its own asyncio event loop in a dedicated thread.

monarch_hyperactor/src/mailbox.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,38 @@ impl PythonPortReceiver {
420420
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
421421
)]
422422
pub(crate) struct PythonUndeliverableMessageEnvelope {
423-
#[allow(dead_code)] // At this time, field `inner` isn't read.
424423
pub(crate) inner: Undeliverable<MessageEnvelope>,
425424
}
426425

426+
#[pymethods]
427+
impl PythonUndeliverableMessageEnvelope {
428+
fn __repr__(&self) -> String {
429+
format!(
430+
"UndeliverableMessageEnvelope(sender={}, dest={}, error={})",
431+
self.inner.0.sender(),
432+
self.inner.0.dest(),
433+
self.error_msg()
434+
)
435+
}
436+
437+
fn sender(&self) -> PyActorId {
438+
PyActorId {
439+
inner: self.inner.0.sender().clone(),
440+
}
441+
}
442+
443+
fn dest(&self) -> PyPortId {
444+
self.inner.0.dest().clone().into()
445+
}
446+
447+
fn error_msg(&self) -> String {
448+
self.inner
449+
.0
450+
.error()
451+
.map_or("None".to_string(), |e| e.to_string())
452+
}
453+
}
454+
427455
#[derive(Debug)]
428456
#[pyclass(
429457
name = "UndeliverablePortReceiver",
@@ -713,5 +741,6 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
713741
hyperactor_mod.add_class::<PythonOncePortHandle>()?;
714742
hyperactor_mod.add_class::<PythonOncePortRef>()?;
715743
hyperactor_mod.add_class::<PythonOncePortReceiver>()?;
744+
hyperactor_mod.add_class::<PythonUndeliverableMessageEnvelope>()?;
716745
Ok(())
717746
}

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,6 @@ class PythonMessage:
210210
@property
211211
def kind(self) -> PythonMessageKind: ...
212212

213-
class UndeliverableMessageEnvelope:
214-
"""
215-
An envelope representing a message that could not be delivered.
216-
217-
This object is opaque; its contents are not accessible from Python.
218-
"""
219-
220-
...
221-
222213
@final
223214
class PythonActorHandle:
224215
"""

python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
from typing import final, Protocol
1010

11-
from monarch._rust_bindings.monarch_hyperactor.actor import (
12-
PythonMessage,
13-
UndeliverableMessageEnvelope,
14-
)
11+
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
1512

1613
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1714
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
@@ -20,9 +17,9 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2017

2118
@final
2219
class PortId:
23-
def __init__(self, actor_id: ActorId, index: int) -> None:
20+
def __init__(self, *, actor_id: ActorId, port: int) -> None:
2421
"""
25-
Create a new port id given an actor id and an index.
22+
Create a new port id given an actor id and a port index.
2623
"""
2724
...
2825
def __repr__(self) -> str: ...
@@ -68,6 +65,12 @@ class PortRef:
6865
A reference to a remote port over which PythonMessages can be sent.
6966
"""
7067

68+
def __init__(self, port_id: PortId) -> None:
69+
"""
70+
Create a new port ref given a port id.
71+
"""
72+
...
73+
7174
def send(self, mailbox: Mailbox, message: PythonMessage) -> None:
7275
"""Send a single message to the port's receiver."""
7376
...
@@ -220,3 +223,27 @@ class Reducer(Protocol):
220223
221224
This method's Rust counterpart is `CommReducer::reduce`.
222225
"""
226+
227+
class UndeliverableMessageEnvelope:
228+
"""
229+
An envelope representing a message that could not be delivered.
230+
"""
231+
232+
def __repr__(self) -> str: ...
233+
def sender(self) -> ActorId:
234+
"""
235+
The actor id of the sender.
236+
"""
237+
...
238+
239+
def dest(self) -> PortId:
240+
"""
241+
The port id of the destination.
242+
"""
243+
...
244+
245+
def error_msg(self) -> str:
246+
"""
247+
The error message describing why the message could not be delivered.
248+
"""
249+
...

python/monarch/_src/actor/actor_mesh.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
OncePortRef,
5959
PortReceiver as HyPortReceiver,
6060
PortRef,
61+
UndeliverableMessageEnvelope,
6162
)
6263
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
6364

@@ -969,6 +970,17 @@ def _post_mortem_debug(self, exc_tb) -> None:
969970
pdb_wrapper.post_mortem(exc_tb)
970971
self._maybe_exit_debugger(do_continue=False)
971972

973+
def _handle_undeliverable_message(
974+
self, message: UndeliverableMessageEnvelope
975+
) -> None:
976+
handle_undeliverable = getattr(
977+
self.instance, "_handle_undeliverable_message", None
978+
)
979+
if handle_undeliverable is not None:
980+
handle_undeliverable(message)
981+
else:
982+
raise RuntimeError(f"a message was undeliverable and returned: {message}")
983+
972984

973985
def _is_mailbox(x: object) -> bool:
974986
if hasattr(x, "__monarch_ref__"):
@@ -1011,6 +1023,11 @@ def _new_with_shape(self, shape: Shape) -> Self:
10111023
"actor implementations are not meshes, but we can't convince the typechecker of it..."
10121024
)
10131025

1026+
def _handle_undeliverable_message(
1027+
self, message: UndeliverableMessageEnvelope
1028+
) -> None:
1029+
raise RuntimeError(f"a message was undeliverable and returned: {message}")
1030+
10141031

10151032
class ActorMesh(MeshTrait, Generic[T]):
10161033
def __init__(

python/tests/test_python_actors.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,24 @@
1818
import time
1919
import unittest
2020
from types import ModuleType
21-
from typing import cast
21+
from typing import cast, Tuple
2222

2323
import pytest
2424

2525
import torch
26+
from monarch._rust_bindings.monarch_hyperactor.actor import (
27+
PythonMessage,
28+
PythonMessageKind,
29+
)
30+
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
31+
PortId,
32+
PortRef,
33+
UndeliverableMessageEnvelope,
34+
)
35+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
2636
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
2737

28-
from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port
38+
from monarch._src.actor.actor_mesh import ActorMesh, Channel, MonarchContext, Port
2939

3040
from monarch.actor import (
3141
Accumulator,
@@ -832,3 +842,60 @@ def s(t):
832842
b = PythonTask.spawn_blocking(lambda: s(0))
833843
r = PythonTask.select_one([a.task(), b.task()]).block_on()
834844
assert r == (0, 1)
845+
846+
847+
class UndeliverableMessageReceiver(Actor):
848+
def __init__(self):
849+
self._messages = asyncio.Queue()
850+
851+
@endpoint
852+
async def receive_undeliverable(
853+
self, sender: ActorId, dest: PortId, error_msg: str
854+
) -> None:
855+
await self._messages.put((sender, dest, error_msg))
856+
857+
@endpoint
858+
async def get_messages(self) -> Tuple[ActorId, PortId, str]:
859+
return await self._messages.get()
860+
861+
862+
class UndeliverableMessageSender(Actor):
863+
def __init__(self, receiver: UndeliverableMessageReceiver):
864+
self._receiver = receiver
865+
866+
@endpoint
867+
def send_undeliverable(self) -> None:
868+
mailbox = MonarchContext.get().mailbox
869+
port_id = PortId(
870+
actor_id=ActorId(
871+
world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus"
872+
),
873+
port=1234,
874+
)
875+
port_ref = PortRef(port_id)
876+
port_ref.send(
877+
mailbox,
878+
PythonMessage(PythonMessageKind.Result(None), b"123"),
879+
)
880+
881+
def _handle_undeliverable_message(
882+
self, message: UndeliverableMessageEnvelope
883+
) -> None:
884+
self._receiver.receive_undeliverable.call_one(
885+
message.sender(), message.dest(), message.error_msg()
886+
).get()
887+
888+
889+
@pytest.mark.timeout(60)
890+
async def test_undeliverable_message() -> None:
891+
pm = proc_mesh(gpus=1)
892+
receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver).get()
893+
sender = pm.spawn(
894+
"undeliverable_sender", UndeliverableMessageSender, receiver
895+
).get()
896+
sender.send_undeliverable.call().get()
897+
sender, dest, error_msg = receiver.get_messages.call_one().get()
898+
assert sender.actor_name == "undeliverable_sender"
899+
assert dest.actor_id.actor_name == "bogus"
900+
assert error_msg is not None
901+
pm.stop().get()

0 commit comments

Comments
 (0)