Skip to content

Commit a932201

Browse files
Sam Luryefacebook-github-bot
authored andcommitted
Semi-private python API for overriding handle_undeliverable_message inside PythonActor (#797)
Summary: This diff makes undeliverable message handling overridable for python actors, using the newly introduced `Actor._handle_undeliverable_message` method. Previously, the rust implementation of `PythonActor` simply used the default `Actor::handle_undeliverable_message` implementation. Now, `PythonActor` overrides `handle_undeliverable_message` to call into the corresponding method on the underlying python class. Differential Revision: D79841379
1 parent b9209cc commit a932201

File tree

6 files changed

+172
-18
lines changed

6 files changed

+172
-18
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use hyperactor::Handler;
2121
use hyperactor::Instance;
2222
use hyperactor::Named;
2323
use hyperactor::OncePortHandle;
24+
use hyperactor::mailbox::MessageEnvelope;
25+
use hyperactor::mailbox::Undeliverable;
2426
use hyperactor::message::Bind;
2527
use hyperactor::message::Bindings;
2628
use hyperactor::message::Unbind;
@@ -50,6 +52,7 @@ use crate::local_state_broker::BrokerId;
5052
use crate::local_state_broker::LocalStateBrokerMessage;
5153
use crate::mailbox::EitherPortRef;
5254
use crate::mailbox::PyMailbox;
55+
use crate::mailbox::PythonUndeliverableMessageEnvelope;
5356
use crate::proc::InstanceWrapper;
5457
use crate::proc::PyActorId;
5558
use crate::proc::PyProc;
@@ -498,6 +501,26 @@ impl Actor for PythonActor {
498501
);
499502
Ok(())
500503
}
504+
505+
async fn handle_undeliverable_message(
506+
&mut self,
507+
cx: &Instance<Self>,
508+
envelope: Undeliverable<MessageEnvelope>,
509+
) -> Result<(), anyhow::Error> {
510+
assert_eq!(envelope.0.sender(), cx.self_id());
511+
512+
Python::with_gil(|py| {
513+
self.actor
514+
.call_method(
515+
py,
516+
"_handle_undeliverable_message",
517+
(PythonUndeliverableMessageEnvelope { inner: envelope },),
518+
None,
519+
)
520+
.map_err(|err| anyhow::Error::from(SerializablePyErr::from(py, &err)))
521+
})
522+
.map(|_| ())
523+
}
501524
}
502525

503526
/// Create a new TaskLocals with its own asyncio event loop in a dedicated thread.

monarch_hyperactor/src/mailbox.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,38 @@ impl PythonPortReceiver {
420420
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
421421
)]
422422
pub(crate) struct PythonUndeliverableMessageEnvelope {
423-
#[allow(dead_code)] // At this time, field `inner` isn't read.
424423
pub(crate) inner: Undeliverable<MessageEnvelope>,
425424
}
426425

426+
#[pymethods]
427+
impl PythonUndeliverableMessageEnvelope {
428+
fn __repr__(&self) -> String {
429+
format!(
430+
"UndeliverableMessageEnvelope(sender={}, dest={}, error={})",
431+
self.inner.0.sender(),
432+
self.inner.0.dest(),
433+
self.error_msg()
434+
)
435+
}
436+
437+
fn sender(&self) -> PyActorId {
438+
PyActorId {
439+
inner: self.inner.0.sender().clone(),
440+
}
441+
}
442+
443+
fn dest(&self) -> PyPortId {
444+
self.inner.0.dest().clone().into()
445+
}
446+
447+
fn error_msg(&self) -> String {
448+
self.inner
449+
.0
450+
.error()
451+
.map_or("None".to_string(), |e| e.to_string())
452+
}
453+
}
454+
427455
#[derive(Debug)]
428456
#[pyclass(
429457
name = "UndeliverablePortReceiver",
@@ -713,5 +741,6 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
713741
hyperactor_mod.add_class::<PythonOncePortHandle>()?;
714742
hyperactor_mod.add_class::<PythonOncePortRef>()?;
715743
hyperactor_mod.add_class::<PythonOncePortReceiver>()?;
744+
hyperactor_mod.add_class::<PythonUndeliverableMessageEnvelope>()?;
716745
Ok(())
717746
}

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,6 @@ class PythonMessage:
210210
@property
211211
def kind(self) -> PythonMessageKind: ...
212212

213-
class UndeliverableMessageEnvelope:
214-
"""
215-
An envelope representing a message that could not be delivered.
216-
217-
This object is opaque; its contents are not accessible from Python.
218-
"""
219-
220-
...
221-
222213
@final
223214
class PythonActorHandle:
224215
"""

python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
from typing import final, Protocol
1010

11-
from monarch._rust_bindings.monarch_hyperactor.actor import (
12-
PythonMessage,
13-
UndeliverableMessageEnvelope,
14-
)
11+
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
1512

1613
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1714
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
@@ -20,9 +17,9 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2017

2118
@final
2219
class PortId:
23-
def __init__(self, actor_id: ActorId, index: int) -> None:
20+
def __init__(self, *, actor_id: ActorId, port: int) -> None:
2421
"""
25-
Create a new port id given an actor id and an index.
22+
Create a new port id given an actor id and a port index.
2623
"""
2724
...
2825
def __repr__(self) -> str: ...
@@ -68,6 +65,12 @@ class PortRef:
6865
A reference to a remote port over which PythonMessages can be sent.
6966
"""
7067

68+
def __init__(self, port_id: PortId) -> None:
69+
"""
70+
Create a new port ref given a port id.
71+
"""
72+
...
73+
7174
def send(self, mailbox: Mailbox, message: PythonMessage) -> None:
7275
"""Send a single message to the port's receiver."""
7376
...
@@ -220,3 +223,27 @@ class Reducer(Protocol):
220223
221224
This method's Rust counterpart is `CommReducer::reduce`.
222225
"""
226+
227+
class UndeliverableMessageEnvelope:
228+
"""
229+
An envelope representing a message that could not be delivered.
230+
"""
231+
232+
def __repr__(self) -> str: ...
233+
def sender(self) -> ActorId:
234+
"""
235+
The actor id of the sender.
236+
"""
237+
...
238+
239+
def dest(self) -> PortId:
240+
"""
241+
The port id of the destination.
242+
"""
243+
...
244+
245+
def error_msg(self) -> str:
246+
"""
247+
The error message describing why the message could not be delivered.
248+
"""
249+
...

python/monarch/_src/actor/actor_mesh.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
OncePortRef,
5959
PortReceiver as HyPortReceiver,
6060
PortRef,
61+
UndeliverableMessageEnvelope,
6162
)
6263
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
6364

@@ -941,6 +942,17 @@ def _post_mortem_debug(self, exc_tb) -> None:
941942
pdb_wrapper.post_mortem(exc_tb)
942943
self._maybe_exit_debugger(do_continue=False)
943944

945+
def _handle_undeliverable_message(
946+
self, message: UndeliverableMessageEnvelope
947+
) -> None:
948+
handle_undeliverable = getattr(
949+
self.instance, "_handle_undeliverable_message", None
950+
)
951+
if handle_undeliverable is not None:
952+
handle_undeliverable(message)
953+
else:
954+
raise RuntimeError(f"a message was undeliverable and returned: {message}")
955+
944956

945957
def _is_mailbox(x: object) -> bool:
946958
if hasattr(x, "__monarch_ref__"):
@@ -983,6 +995,11 @@ def _new_with_shape(self, shape: Shape) -> Self:
983995
"actor implementations are not meshes, but we can't convince the typechecker of it..."
984996
)
985997

998+
def _handle_undeliverable_message(
999+
self, message: UndeliverableMessageEnvelope
1000+
) -> None:
1001+
raise RuntimeError(f"a message was undeliverable and returned: {message}")
1002+
9861003

9871004
class ActorMesh(MeshTrait, Generic[T]):
9881005
def __init__(

python/tests/test_python_actors.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,24 @@
1818
import time
1919
import unittest
2020
from types import ModuleType
21-
from typing import cast
21+
from typing import cast, Tuple
2222

2323
import pytest
2424

2525
import torch
26+
from monarch._rust_bindings.monarch_hyperactor.actor import (
27+
PythonMessage,
28+
PythonMessageKind,
29+
)
30+
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
31+
PortId,
32+
PortRef,
33+
UndeliverableMessageEnvelope,
34+
)
35+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
2636
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
2737

28-
from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port
38+
from monarch._src.actor.actor_mesh import ActorMesh, Channel, MonarchContext, Port
2939

3040
from monarch.actor import (
3141
Accumulator,
@@ -1010,3 +1020,60 @@ def test_mesh_len():
10101020
proc_mesh = local_proc_mesh(gpus=12).get()
10111021
s = proc_mesh.spawn("sync_actor", SyncActor).get()
10121022
assert 12 == len(s)
1023+
1024+
1025+
class UndeliverableMessageReceiver(Actor):
1026+
def __init__(self):
1027+
self._messages = asyncio.Queue()
1028+
1029+
@endpoint
1030+
async def receive_undeliverable(
1031+
self, sender: ActorId, dest: PortId, error_msg: str
1032+
) -> None:
1033+
await self._messages.put((sender, dest, error_msg))
1034+
1035+
@endpoint
1036+
async def get_messages(self) -> Tuple[ActorId, PortId, str]:
1037+
return await self._messages.get()
1038+
1039+
1040+
class UndeliverableMessageSender(Actor):
1041+
def __init__(self, receiver: UndeliverableMessageReceiver):
1042+
self._receiver = receiver
1043+
1044+
@endpoint
1045+
def send_undeliverable(self) -> None:
1046+
mailbox = MonarchContext.get().mailbox
1047+
port_id = PortId(
1048+
actor_id=ActorId(
1049+
world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus"
1050+
),
1051+
port=1234,
1052+
)
1053+
port_ref = PortRef(port_id)
1054+
port_ref.send(
1055+
mailbox,
1056+
PythonMessage(PythonMessageKind.Result(None), b"123"),
1057+
)
1058+
1059+
def _handle_undeliverable_message(
1060+
self, message: UndeliverableMessageEnvelope
1061+
) -> None:
1062+
self._receiver.receive_undeliverable.call_one(
1063+
message.sender(), message.dest(), message.error_msg()
1064+
).get()
1065+
1066+
1067+
@pytest.mark.timeout(60)
1068+
async def test_undeliverable_message() -> None:
1069+
pm = proc_mesh(gpus=1)
1070+
receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver).get()
1071+
sender = pm.spawn(
1072+
"undeliverable_sender", UndeliverableMessageSender, receiver
1073+
).get()
1074+
sender.send_undeliverable.call().get()
1075+
sender, dest, error_msg = receiver.get_messages.call_one().get()
1076+
assert sender.actor_name == "undeliverable_sender"
1077+
assert dest.actor_id.actor_name == "bogus"
1078+
assert error_msg is not None
1079+
pm.stop().get()

0 commit comments

Comments
 (0)