Skip to content

Commit 279eb61

Browse files
zdevitofacebook-github-bot
authored andcommitted
root_client_mailbox: open ports without a proc (#683)
Summary: Pull Request resolved: #683 To get ready to make proc_mesh/actor_mesh remote references, we need to be able to open ports on the root client without having to wait for the proc to exist. This diff creates a global Mailbox hooked up to the global router that can be used on the root client to open ports. On actors the mailbox is looked up via the monarch context set while the actor is running. On the root proc, we use the new global mailbox. This makes it possible to create a function to create a port without other arguments. This function is actually a class named Channel[T] because of python typing reasons: there is way to pass a type argument to a function in Python. Called channel because that is the rust equivalent. Channel[T].open() returns what used to be called PortTuple, but is now just a real tuple. ghstack-source-id: 300749961 exported-using-ghexport Reviewed By: mariusae Differential Revision: D79215096 fbshipit-source-id: 5ace368b0813467969cc0f43ef587ab4775cbf50
1 parent 694951b commit 279eb61

File tree

11 files changed

+96
-83
lines changed

11 files changed

+96
-83
lines changed

hyperactor/src/channel.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,18 @@ pub async fn serve<M: RemoteMessage>(
575575
.map(|(addr, inner)| (addr, ChannelRx { inner }))
576576
}
577577

578+
/// Serve on the local address. The server is turned down
579+
/// when the returned Rx is dropped.
580+
pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
581+
let (port, rx) = local::serve::<M>();
582+
(
583+
ChannelAddr::Local(port),
584+
ChannelRx {
585+
inner: ChannelRxKind::Local(rx),
586+
},
587+
)
588+
}
589+
578590
#[cfg(test)]
579591
mod tests {
580592
use std::assert_matches::assert_matches;

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use crate::proc_mesh::mesh_agent::MeshAgent;
6363
use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
6464
use crate::proc_mesh::mesh_agent::StopActorResult;
6565
use crate::reference::ProcMeshId;
66+
use crate::shortuuid::ShortUuid;
6667

6768
pub mod mesh_agent;
6869

@@ -79,6 +80,26 @@ pub(crate) fn global_router() -> &'static MailboxRouter {
7980
GLOBAL_ROUTER.get_or_init(MailboxRouter::new)
8081
}
8182

83+
/// Global mailbox used by the root client to send messages.
84+
/// This mailbox allows us to open ports before we know which proc the
85+
/// messages will be sent to.
86+
pub fn global_mailbox() -> Mailbox {
87+
static GLOBAL_MAILBOX: OnceLock<Mailbox> = OnceLock::new();
88+
GLOBAL_MAILBOX
89+
.get_or_init(|| {
90+
let world_id = WorldId(ShortUuid::generate().to_string());
91+
let client_proc_id = ProcId(world_id.clone(), 0);
92+
let client_proc = Proc::new(
93+
client_proc_id.clone(),
94+
BoxedMailboxSender::new(global_router().clone()),
95+
);
96+
global_router().bind(world_id.clone().into(), client_proc.clone());
97+
98+
client_proc.attach("client").expect("root mailbox creation")
99+
})
100+
.clone()
101+
}
102+
82103
type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;
83104
/// A ProcMesh maintains a mesh of procs whose lifecycles are managed by
84105
/// an allocator.

monarch_hyperactor/src/mailbox.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ impl PyMailbox {
6969

7070
#[pymethods]
7171
impl PyMailbox {
72+
#[staticmethod]
73+
fn root_client_mailbox() -> PyMailbox {
74+
PyMailbox {
75+
inner: hyperactor_mesh::proc_mesh::global_mailbox(),
76+
}
77+
}
7278
fn open_port<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
7379
let (handle, receiver) = self.inner.open_port();
7480
let handle = Py::new(py, PythonPortHandle { inner: handle })?;

python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ class Mailbox:
137137
"""
138138
A mailbox from that can receive messages.
139139
"""
140-
140+
@staticmethod
141+
def root_client_mailbox() -> "Mailbox": ...
141142
def open_port(self) -> tuple[PortHandle, PortReceiver]:
142143
"""Open a port to receive `PythonMessage` messages."""
143144
...

python/monarch/_src/actor/actor_mesh.py

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ class MonarchContext:
120120
def get() -> "MonarchContext":
121121
return _context.get()
122122

123+
@staticmethod
124+
def current_mailbox() -> "Mailbox":
125+
context = _context.get(None)
126+
if context is not None:
127+
return context.mailbox
128+
return Mailbox.root_client_mailbox()
129+
123130

124131
_context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
125132
"monarch.actor_mesh._context"
@@ -344,13 +351,13 @@ def _send(
344351
shape = self._actor_mesh._shape
345352
return Extent(shape.labels, shape.ndslice.sizes)
346353

347-
def _port(self, once: bool = False) -> "PortTuple[R]":
348-
p, r = PortTuple.create(self._mailbox, once)
354+
def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]":
355+
p, r = super()._port(once=once)
349356
if TYPE_CHECKING:
350357
assert isinstance(
351358
r._receiver, (HyPortReceiver | OncePortReceiver)
352359
), "unexpected receiver type"
353-
return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
360+
return (p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
354361

355362
def _rref(self, args, kwargs):
356363
self._check_arguments(args, kwargs)
@@ -526,49 +533,25 @@ def exception(self, obj: Exception) -> None:
526533

527534
T = TypeVar("T")
528535

529-
if TYPE_CHECKING:
530-
# Python <= 3.10 cannot inherit from Generic[R] and NamedTuple at the same time.
531-
# we only need it for type checking though, so copypasta it until 3.11.
532-
class PortTuple(NamedTuple, Generic[R]):
533-
sender: "Port[R]"
534-
receiver: "PortReceiver[R]"
535-
536-
@staticmethod
537-
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
538-
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
539-
port_ref = handle.bind()
540-
return PortTuple(
541-
Port(port_ref, mailbox, rank=None),
542-
PortReceiver(mailbox, receiver),
543-
)
544-
else:
545-
546-
class PortTuple(NamedTuple):
547-
sender: "Port[Any]"
548-
receiver: "PortReceiver[Any]"
549-
550-
@staticmethod
551-
def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
552-
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
553-
port_ref = handle.bind()
554-
return PortTuple(
555-
Port(port_ref, mailbox, rank=None),
556-
PortReceiver(mailbox, receiver),
557-
)
558-
559536

560537
# advance lower-level API for sending messages. This is intentially
561538
# not part of the Endpoint API because they way it accepts arguments
562539
# and handles concerns is different.
563-
def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]":
564-
return endpoint._port(once)
565-
540+
class Channel(Generic[R]):
541+
@staticmethod
542+
def open(once: bool = False) -> Tuple["Port[R]", "PortReceiver[R]"]:
543+
mailbox = MonarchContext.current_mailbox()
544+
handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
545+
port_ref = handle.bind()
546+
return (
547+
Port(port_ref, mailbox, rank=None),
548+
PortReceiver(mailbox, receiver),
549+
)
566550

567-
def ranked_port(
568-
endpoint: Endpoint[P, R], once: bool = False
569-
) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
570-
p, receiver = port(endpoint, once)
571-
return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
551+
@staticmethod
552+
def open_ranked(once: bool = False) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
553+
send, recv = Channel[R].open()
554+
return (send, recv.ranked())
572555

573556

574557
class PortReceiver(Generic[R]):
@@ -597,6 +580,9 @@ def _process(self, msg: PythonMessage) -> R:
597580
def recv(self) -> "Future[R]":
598581
return Future(coro=self._recv())
599582

583+
def ranked(self) -> "RankedPortReceiver[R]":
584+
return RankedPortReceiver[R](self._mailbox, self._receiver)
585+
600586

601587
class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
602588
def _process(self, msg: PythonMessage) -> Tuple[int, R]:

python/monarch/_src/actor/endpoint.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
HyPortReceiver,
4141
OncePortReceiver,
4242
Port,
43-
PortTuple,
43+
PortReceiver,
4444
ValueMesh,
4545
)
4646

@@ -90,9 +90,10 @@ def _send(
9090
"""
9191
pass
9292

93-
@abstractmethod
94-
def _port(self, once: bool = False) -> "PortTuple[R]":
95-
pass
93+
def _port(self, once: bool = False) -> "Tuple[Port[R], PortReceiver[R]]":
94+
from monarch._src.actor.actor_mesh import Channel
95+
96+
return Channel[R].open(once)
9697

9798
@abstractmethod
9899
def _call_name(self) -> Any:
@@ -115,17 +116,14 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
115116
116117
Load balanced RPC-style entrypoint for request/response messaging.
117118
"""
118-
from monarch._src.actor.actor_mesh import port
119119

120-
p, r = port(self, once=True)
120+
p, r = self._port(once=True)
121121
# pyre-ignore
122122
self._send(args, kwargs, port=p, selection="choose")
123123
return r.recv()
124124

125125
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
126-
from monarch._src.actor.actor_mesh import port
127-
128-
p, r = port(self, once=True)
126+
p, r = self._port(once=True)
129127
# pyre-ignore
130128
extent = self._send(args, kwargs, port=p, selection="choose")
131129
if extent.nelements != 1:
@@ -135,9 +133,10 @@ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
135133
return r.recv()
136134

137135
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
138-
from monarch._src.actor.actor_mesh import ranked_port, ValueMesh
136+
from monarch._src.actor.actor_mesh import ValueMesh
139137

140-
p, r = ranked_port(self)
138+
p, unranked = self._port()
139+
r = unranked.ranked()
141140
# pyre-ignore
142141
extent = self._send(args, kwargs, port=p)
143142

@@ -166,9 +165,8 @@ def _stream(
166165
This enables processing results from multiple actors incrementally as
167166
they become available. Returns an async generator of response values.
168167
"""
169-
from monarch._src.actor.actor_mesh import port
170168

171-
p, r = port(self)
169+
p, r = self._port()
172170
# pyre-ignore
173171
extent = self._send(args, kwargs, port=p)
174172
for _ in range(extent.nelements):

python/monarch/_src/actor/proc_mesh.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@
4141
)
4242
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
4343
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
44-
from monarch._src.actor.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
44+
from monarch._src.actor.actor_mesh import (
45+
_Actor,
46+
_ActorMeshRefImpl,
47+
Actor,
48+
ActorMeshRef,
49+
MonarchContext,
50+
)
4551

4652
from monarch._src.actor.allocator import (
4753
AllocateMixin,
@@ -317,7 +323,7 @@ async def _spawn_nonblocking_on(
317323
service = ActorMeshRef(
318324
Class,
319325
_ActorMeshRefImpl.from_hyperactor_mesh(pm.client, actor_mesh, self),
320-
pm.client,
326+
MonarchContext.current_mailbox(),
321327
)
322328
# useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by
323329
# doing `ActorMeshRef(Class, actor_handle)` but not calling _create.

python/monarch/actor/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
Actor,
1414
ActorError,
1515
as_endpoint,
16+
Channel,
1617
current_actor_name,
1718
current_rank,
1819
current_size,
1920
Point,
20-
port,
2121
send,
2222
ValueMesh,
2323
)
@@ -45,7 +45,7 @@
4545
"Point",
4646
"proc_mesh",
4747
"ProcMesh",
48-
"port",
48+
"Channel",
4949
"send",
5050
"sim_proc_mesh",
5151
"ValueMesh",

python/monarch/common/remote.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@
2828
import monarch.common.messages as messages
2929

3030
import torch
31-
from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
3231
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
33-
from monarch._src.actor.actor_mesh import Port, PortTuple
32+
from monarch._src.actor.actor_mesh import Port
3433
from monarch._src.actor.endpoint import Extent, Selection
3534

3635
from monarch.common import _coalescing, device_mesh, stream
@@ -135,20 +134,6 @@ def _send(
135134
client._request_status()
136135
return Extent(ambient_mesh._labels, ambient_mesh._ndslice.sizes)
137136

138-
def _port(self, once: bool = False) -> "PortTuple[R]":
139-
ambient_mesh = device_mesh._active
140-
if ambient_mesh is None:
141-
raise ValueError(
142-
"FIXME - cannot create a port without an active proc_mesh, because there is not way to create a port without a mailbox"
143-
)
144-
mesh_controller = getattr(ambient_mesh.client, "_mesh_controller", None)
145-
if mesh_controller is None:
146-
raise ValueError(
147-
"Cannot create raw port objects with an old-style tensor engine controller."
148-
)
149-
mailbox: Mailbox = mesh_controller._mailbox
150-
return PortTuple.create(mailbox, once)
151-
152137
@property
153138
def _resolvable(self):
154139
return resolvable_function(self._remote_impl)

python/monarch/mesh_controller.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
ActorId,
4444
)
4545
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
46-
from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple
46+
from monarch._src.actor.actor_mesh import ActorEndpoint, Channel, Port
4747
from monarch._src.actor.endpoint import Selection
4848
from monarch._src.actor.shape import NDSlice
4949
from monarch.common import device_mesh, messages, stream
@@ -156,7 +156,7 @@ def fetch(
156156
defs: Tuple["Tensor", ...],
157157
uses: Tuple["Tensor", ...],
158158
) -> "OldFuture": # the OldFuture is a lie
159-
sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
159+
sender, receiver = Channel.open(once=True)
160160

161161
ident = self.new_node(defs, uses, cast("OldFuture", sender))
162162
process = mesh._process(shard)
@@ -192,7 +192,7 @@ def shutdown(
192192
atexit.unregister(self._atexit)
193193
self._shutdown = True
194194

195-
sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
195+
sender, receiver = Channel.open(once=True)
196196
assert sender._port_ref is not None
197197
self._mesh_controller.sync_at_exit(sender._port_ref.port_id)
198198
receiver.recv().get(timeout=60)

0 commit comments

Comments
 (0)