Skip to content

Commit a79fdd5

Browse files
pzhan9meta-codesync[bot]
authored andcommitted
Plumbing: Update some send APIs on the python side with PythonInstance (#1362)
Summary: Pull Request resolved: #1362 This diff is part of the effect to adding sequencing logic to sender actor. See D83371710 for details. This diff specifically updates the send method in `PythonPortRef`, `PythonOncePortRef`, `PythonActorRef`, and `PythonActorHandle` to have the `instance: &PyInstance`, so it can be plumbed to the Rust backend later. Reviewed By: mariusae, moonli Differential Revision: D83413791 fbshipit-source-id: 0163c8f3dfc859ea6be8e39da552f9a0281f3418
1 parent a5a1530 commit a79fdd5

File tree

7 files changed

+54
-30
lines changed

7 files changed

+54
-30
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use std::sync::Arc;
1313
use std::sync::OnceLock;
1414

1515
use async_trait::async_trait;
16-
use bytes::Bytes;
1716
use hyperactor::Actor;
1817
use hyperactor::ActorHandle;
1918
use hyperactor::ActorId;
@@ -50,9 +49,9 @@ use tokio::sync::mpsc::UnboundedSender;
5049
use tokio::sync::oneshot;
5150
use tracing::Instrument;
5251

53-
use crate::buffers::Buffer;
5452
use crate::buffers::FrozenBuffer;
5553
use crate::config::SHARED_ASYNCIO_RUNTIME;
54+
use crate::context::PyInstance;
5655
use crate::local_state_broker::BrokerId;
5756
use crate::local_state_broker::LocalStateBrokerMessage;
5857
use crate::mailbox::EitherPortRef;
@@ -275,10 +274,9 @@ impl PythonMessage {
275274
_ => panic!("PythonMessage is not a response but {:?}", self),
276275
}
277276
}
278-
279-
async fn resolve_indirect_call<T: Actor>(
277+
async fn resolve_indirect_call(
280278
self,
281-
cx: &Context<'_, T>,
279+
cx: &Context<'_, PythonActor>,
282280
) -> anyhow::Result<ResolvedCallMethod> {
283281
match self.kind {
284282
PythonMessageKind::CallMethodIndirect {
@@ -331,6 +329,7 @@ impl PythonMessage {
331329
.call_method1("repeat", (mailbox.clone(),))
332330
.unwrap()
333331
.unbind();
332+
let instance: PyInstance = cx.into();
334333
let response_port = response_port
335334
.map_or_else(
336335
|| {
@@ -343,7 +342,7 @@ impl PythonMessage {
343342
let point = cx.cast_point();
344343
py.import("monarch._src.actor.actor_mesh")
345344
.unwrap()
346-
.call_method1("Port", (x, mailbox, point.rank()))
345+
.call_method1("Port", (x, instance, point.rank()))
347346
.unwrap()
348347
},
349348
)
@@ -435,7 +434,8 @@ pub(super) struct PythonActorHandle {
435434
#[pymethods]
436435
impl PythonActorHandle {
437436
// TODO: do the pickling in rust
438-
fn send(&self, message: &PythonMessage) -> PyResult<()> {
437+
// TODO(pzhang) Use instance after its required by PortHandle.
438+
fn send(&self, _instance: &PyInstance, message: &PythonMessage) -> PyResult<()> {
439439
self.inner
440440
.send(message.clone())
441441
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
@@ -708,7 +708,11 @@ impl Handler<HandlePanic> for PythonActorPanicWatcher {
708708

709709
#[async_trait]
710710
impl Handler<PythonMessage> for PythonActor {
711-
async fn handle(&mut self, cx: &Context<Self>, message: PythonMessage) -> anyhow::Result<()> {
711+
async fn handle(
712+
&mut self,
713+
cx: &Context<PythonActor>,
714+
message: PythonMessage,
715+
) -> anyhow::Result<()> {
712716
let resolved = message.resolve_indirect_call(cx).await?;
713717

714718
// Create a channel for signaling panics in async endpoints.

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ impl ActorMeshProtocol for AsyncActorMesh {
618618
let port = py
619619
.import("monarch._src.actor.actor_mesh")
620620
.unwrap()
621-
.call_method1("Port", (port_ref, instance._mailbox(), 0))
621+
.call_method1("Port", (port_ref, instance, 0))
622622
.unwrap();
623623
port.call_method1("exception", (pyerr.value(py),)).unwrap();
624624
}),

monarch_hyperactor/src/mailbox.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use serde::Serialize;
4747

4848
use crate::actor::PythonMessage;
4949
use crate::actor::PythonMessageKind;
50+
use crate::context::PyInstance;
5051
use crate::proc::PyActorId;
5152
use crate::pytokio::PyPythonTask;
5253
use crate::pytokio::PythonTask;
@@ -264,7 +265,8 @@ pub(super) struct PythonPortHandle {
264265

265266
#[pymethods]
266267
impl PythonPortHandle {
267-
fn send(&self, message: PythonMessage) -> PyResult<()> {
268+
// TODO(pzhang) Use instance after its required by PortHandle.
269+
fn send(&self, _instance: &PyInstance, message: PythonMessage) -> PyResult<()> {
268270
self.inner
269271
.send(message)
270272
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))?;
@@ -318,9 +320,9 @@ impl PythonPortRef {
318320
Ok((slf.get_type(), (id,)))
319321
}
320322

321-
fn send(&self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> {
323+
fn send(&self, instance: &PyInstance, message: PythonMessage) -> PyResult<()> {
322324
self.inner
323-
.send(&mailbox.inner, message)
325+
.send(&instance._mailbox().inner, message)
324326
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))?;
325327
Ok(())
326328
}
@@ -530,13 +532,13 @@ impl PythonOncePortRef {
530532
Ok((slf.get_type(), (id,)))
531533
}
532534

533-
fn send(&mut self, mailbox: &PyMailbox, message: PythonMessage) -> PyResult<()> {
535+
fn send(&mut self, instance: &PyInstance, message: PythonMessage) -> PyResult<()> {
534536
let Some(port_ref) = self.inner.take() else {
535537
return Err(PyErr::new::<PyValueError, _>("OncePortRef is already used"));
536538
};
537539

538540
port_ref
539-
.send(&mailbox.inner, message)
541+
.send(&instance._mailbox().inner, message)
540542
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))?;
541543
Ok(())
542544
}

python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
from typing import final, Protocol
1010

1111
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
12-
12+
from monarch._rust_bindings.monarch_hyperactor.context import Instance
1313
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1414
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
1515

16-
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
17-
1816
@final
1917
class PortId:
2018
def __init__(self, *, actor_id: ActorId, port: int) -> None:
@@ -52,7 +50,7 @@ class PortHandle:
5250
A handle to a port over which PythonMessages can be sent.
5351
"""
5452

55-
def send(self, message: PythonMessage) -> None:
53+
def send(self, instance: Instance, message: PythonMessage) -> None:
5654
"""Send a message to the port's receiver."""
5755

5856
def bind(self) -> PortRef:
@@ -71,7 +69,7 @@ class PortRef:
7169
"""
7270
...
7371

74-
def send(self, mailbox: Mailbox, message: PythonMessage) -> None:
72+
def send(self, instance: Instance, message: PythonMessage) -> None:
7573
"""Send a single message to the port's receiver."""
7674
...
7775

@@ -123,7 +121,7 @@ class OncePortRef:
123121
A reference to a remote once port over which a single PythonMessages can be sent.
124122
"""
125123

126-
def send(self, mailbox: Mailbox, message: PythonMessage) -> None:
124+
def send(self, instance: Instance, message: PythonMessage) -> None:
127125
"""Send a single message to the port's receiver."""
128126
...
129127

python/monarch/_src/actor/actor_mesh.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,11 @@ class Port(Generic[R]):
550550
def __init__(
551551
self,
552552
port_ref: PortRef | OncePortRef,
553-
mailbox: Mailbox,
553+
instance: HyInstance,
554554
rank: Optional[int],
555555
) -> None:
556556
self._port_ref = port_ref
557-
self._mailbox = mailbox
557+
self._instance = instance
558558
self._rank = rank
559559

560560
def send(self, obj: R) -> None:
@@ -566,18 +566,35 @@ def send(self, obj: R) -> None:
566566
obj: R-typed object to send.
567567
"""
568568
self._port_ref.send(
569-
self._mailbox,
569+
self._instance,
570570
PythonMessage(PythonMessageKind.Result(self._rank), _pickle(obj)),
571571
)
572572

573573
def exception(self, obj: Exception) -> None:
574574
# we deliver each error exactly once, so if there is no port to respond to,
575575
# the error is sent to the current actor as an exception.
576576
self._port_ref.send(
577-
self._mailbox,
577+
self._instance,
578578
PythonMessage(PythonMessageKind.Exception(self._rank), _pickle(obj)),
579579
)
580580

581+
def __reduce__(self):
582+
"""
583+
When Port is sent over the wire, we do not want to send the actor instance
584+
from the current context. Instead, we want to reconstruct the Port with
585+
the receiver's context, since that is where the message will be sent
586+
from through this port.
587+
"""
588+
589+
def _reconstruct_port(port_ref, rank):
590+
instance = context().actor_instance._as_rust()
591+
return Port(port_ref, instance, rank)
592+
593+
return (
594+
_reconstruct_port,
595+
(self._port_ref, self._rank),
596+
)
597+
581598

582599
class DroppingPort:
583600
"""
@@ -617,11 +634,14 @@ class Channel(Generic[R]):
617634
@staticmethod
618635
def open(once: bool = False) -> Tuple["Port[R]", "PortReceiver[R]"]:
619636
""" """
620-
mailbox = context().actor_instance._mailbox
637+
actor_instance = context().actor_instance
638+
mailbox = actor_instance._mailbox
621639
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
622640
port_ref = handle.bind()
641+
hy_instance = actor_instance._as_rust()
642+
port = Port(port_ref, hy_instance, None)
623643
return (
624-
Port(port_ref, mailbox, rank=None),
644+
port,
625645
PortReceiver(mailbox, receiver),
626646
)
627647

python/tests/_monarch/test_mailbox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _python_task_test(
115115

116116
@_python_task_test
117117
async def test_accumulator() -> None:
118-
proc_mesh = await allocate()
118+
proc_mesh: ProcMesh = await allocate()
119119
mailbox: Mailbox = Instance._as_py(proc_mesh.client)._mailbox
120120

121121
def my_accumulate(state: str, update: int) -> str:
@@ -128,7 +128,7 @@ def my_accumulate(state: str, update: int) -> str:
128128

129129
def post_message(value: int) -> None:
130130
port_ref.send(
131-
mailbox,
131+
proc_mesh.client,
132132
PythonMessage(
133133
PythonMessageKind.CallMethod(
134134
MethodSpecifier.ReturnsResponse("test_accumulator"), None

python/tests/test_python_actors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,14 +1554,14 @@ async def get_messages(self) -> Tuple[str, str, str]:
15541554
class UndeliverableMessageSender(Actor):
15551555
@endpoint
15561556
def send_undeliverable(self) -> None:
1557-
mailbox = context().actor_instance._mailbox
1557+
actor_instance = context().actor_instance
15581558
port_id = PortId(
15591559
actor_id=ActorId(world_name="bogus", rank=0, actor_name="bogus"),
15601560
port=1234,
15611561
)
15621562
port_ref = PortRef(port_id)
15631563
port_ref.send(
1564-
mailbox,
1564+
actor_instance._as_rust(),
15651565
PythonMessage(PythonMessageKind.Result(None), b"123"),
15661566
)
15671567

0 commit comments

Comments
 (0)