Skip to content

Commit b0b68af

Browse files
zdevitofacebook-github-bot
authored andcommitted
Context and Instance python APIs (meta-pytorch#918)
Summary: Pull Request resolved: meta-pytorch#918 Adding Context/Instance objects as proposed in https://docs.google.com/document/d/11ohUVih8yPmT1lZaV-OpL-Lawj49-RhJJ6tqxlW94S0/edit?tab=t.0 . This also replaces our global root mailbox with a global root client instance object. Context is created from the rust Context<Self> object, though it cannot directly contain Context because the lifetime of the rust thing is limited. (Is there some way to store a limited lifetime thing with a longer lifetime and dynamically assert it is still alvie)? ghstack-source-id: 304344057 Reviewed By: mariusae Differential Revision: D80505166 fbshipit-source-id: 2b0bfe1610d20199105a3fdc5a1ed37a1444be18
1 parent 50ef200 commit b0b68af

File tree

14 files changed

+257
-132
lines changed

14 files changed

+257
-132
lines changed

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ use async_trait::async_trait;
1616
use dashmap::DashMap;
1717
use futures::future::join_all;
1818
use hyperactor::Actor;
19+
use hyperactor::ActorHandle;
1920
use hyperactor::ActorId;
2021
use hyperactor::ActorRef;
22+
use hyperactor::Instance;
2123
use hyperactor::Mailbox;
2224
use hyperactor::Named;
2325
use hyperactor::RemoteMessage;
@@ -79,24 +81,24 @@ pub(crate) fn global_router() -> &'static MailboxRouter {
7981
GLOBAL_ROUTER.get_or_init(MailboxRouter::new)
8082
}
8183

82-
/// Global mailbox used by the root client to send messages.
84+
/// Context use by root client to send messages.
8385
/// This mailbox allows us to open ports before we know which proc the
8486
/// messages will be sent to.
85-
pub fn global_mailbox() -> Mailbox {
86-
static GLOBAL_MAILBOX: OnceLock<Mailbox> = OnceLock::new();
87-
GLOBAL_MAILBOX
88-
.get_or_init(|| {
89-
let world_id = WorldId(ShortUuid::generate().to_string());
90-
let client_proc_id = ProcId::Ranked(world_id.clone(), 0);
91-
let client_proc = Proc::new(
92-
client_proc_id.clone(),
93-
BoxedMailboxSender::new(global_router().clone()),
94-
);
95-
global_router().bind(world_id.clone().into(), client_proc.clone());
96-
97-
client_proc.attach("client").expect("root mailbox creation")
98-
})
99-
.clone()
87+
pub fn global_root_client() -> &'static Instance<()> {
88+
static GLOBAL_INSTANCE: OnceLock<(Instance<()>, ActorHandle<()>)> = OnceLock::new();
89+
let (instance, _) = GLOBAL_INSTANCE.get_or_init(|| {
90+
let world_id = WorldId(ShortUuid::generate().to_string());
91+
let client_proc_id = ProcId::Ranked(world_id.clone(), 0);
92+
let client_proc = Proc::new(
93+
client_proc_id.clone(),
94+
BoxedMailboxSender::new(global_router().clone()),
95+
);
96+
global_router().bind(world_id.clone().into(), client_proc.clone());
97+
client_proc
98+
.instance("client")
99+
.expect("root instance create")
100+
});
101+
instance
100102
}
101103

102104
type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;

monarch_hyperactor/src/actor.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -625,15 +625,11 @@ impl Handler<PythonMessage> for PythonActor {
625625
let (sender, receiver) = oneshot::channel();
626626

627627
let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> {
628-
let mailbox = mailbox(py, cx);
629-
let (rank, shape) = cx.cast_info();
630628
let awaitable = self.actor.call_method(
631629
py,
632630
"handle",
633631
(
634-
mailbox,
635-
rank,
636-
PyShape::from(shape),
632+
crate::mailbox::Context::new(py, cx),
637633
resolved.method,
638634
resolved.bytes,
639635
PanicFlag {

monarch_hyperactor/src/mailbox.rs

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
use std::hash::DefaultHasher;
1010
use std::hash::Hash;
1111
use std::hash::Hasher;
12+
use std::ops::Deref;
1213
use std::sync::Arc;
1314

15+
use hyperactor::ActorId;
1416
use hyperactor::Mailbox;
1517
use hyperactor::Named;
1618
use hyperactor::OncePortHandle;
@@ -33,8 +35,12 @@ use hyperactor::mailbox::monitored_return_handle;
3335
use hyperactor::message::Bind;
3436
use hyperactor::message::Bindings;
3537
use hyperactor::message::Unbind;
38+
use hyperactor_mesh::comm::multicast::CastInfo;
3639
use hyperactor_mesh::comm::multicast::set_cast_info_on_headers;
40+
use hyperactor_mesh::proc_mesh::global_root_client;
3741
use monarch_types::PickledPyObject;
42+
use monarch_types::py_global;
43+
use ndslice::shape::Shape;
3844
use pyo3::IntoPyObjectExt;
3945
use pyo3::exceptions::PyEOFError;
4046
use pyo3::exceptions::PyRuntimeError;
@@ -51,6 +57,7 @@ use crate::proc::PyActorId;
5157
use crate::pytokio::PyPythonTask;
5258
use crate::pytokio::PythonTask;
5359
use crate::runtime::signal_safe_block_on;
60+
use crate::shape::PyPoint;
5461
use crate::shape::PyShape;
5562
#[derive(Clone, Debug)]
5663
#[pyclass(
@@ -69,12 +76,6 @@ impl PyMailbox {
6976

7077
#[pymethods]
7178
impl PyMailbox {
72-
#[staticmethod]
73-
fn root_client_mailbox() -> PyMailbox {
74-
PyMailbox {
75-
inner: hyperactor_mesh::proc_mesh::global_mailbox(),
76-
}
77-
}
7879
fn open_port<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
7980
let (handle, receiver) = self.inner.open_port();
8081
let handle = Py::new(py, PythonPortHandle { inner: handle })?;
@@ -702,6 +703,86 @@ inventory::submit! {
702703
}
703704
}
704705

706+
#[pyclass(name = "Instance", module = "monarch._src.actor.actor_mesh")]
707+
struct Instance {
708+
mailbox: Mailbox,
709+
actor_id: ActorId,
710+
#[pyo3(get, set)]
711+
proc_mesh: Option<PyObject>,
712+
#[pyo3(get, set, name = "_controller_controller")]
713+
controller_controller: Option<PyObject>,
714+
}
715+
#[pymethods]
716+
impl Instance {
717+
#[getter]
718+
fn _mailbox(&self) -> PyMailbox {
719+
PyMailbox {
720+
inner: self.mailbox.clone(),
721+
}
722+
}
723+
#[getter]
724+
fn actor_id(&self) -> PyActorId {
725+
self.actor_id.clone().into()
726+
}
727+
}
728+
729+
impl<A: hyperactor::Actor> From<&hyperactor::proc::Instance<A>> for Instance {
730+
fn from(ins: &hyperactor::proc::Instance<A>) -> Self {
731+
Instance {
732+
mailbox: ins.mailbox_for_py().clone(),
733+
actor_id: ins.self_id().clone(),
734+
proc_mesh: None,
735+
controller_controller: None,
736+
}
737+
}
738+
}
739+
740+
#[pyclass(name = "Context", module = "monarch._src.actor.actor_mesh")]
741+
pub(crate) struct Context {
742+
instance: Py<Instance>,
743+
rank: usize,
744+
shape: Shape,
745+
}
746+
747+
py_global!(point, "monarch._src.actor.actor_mesh", "Point");
748+
749+
#[pymethods]
750+
impl Context {
751+
#[getter]
752+
fn actor_instance(&self) -> &Py<Instance> {
753+
&self.instance
754+
}
755+
#[getter]
756+
fn message_rank<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
757+
let shape = PyShape::from(self.shape.clone());
758+
point(py).call1((self.rank, shape.into_pyobject(py).unwrap().unbind()))
759+
}
760+
#[staticmethod]
761+
fn _root_client_context(py: Python<'_>) -> Context {
762+
let instance: Instance = global_root_client().into();
763+
Context {
764+
instance: instance.into_pyobject(py).unwrap().into(),
765+
rank: 0,
766+
shape: Shape::unity(),
767+
}
768+
}
769+
}
770+
771+
impl Context {
772+
pub(crate) fn new<T: hyperactor::actor::Actor>(
773+
py: Python<'_>,
774+
cx: &hyperactor::proc::Context<T>,
775+
) -> Context {
776+
let instance: Instance = cx.deref().into();
777+
let (rank, shape) = cx.cast_info();
778+
Context {
779+
instance: instance.into_pyobject(py).unwrap().into(),
780+
rank,
781+
shape,
782+
}
783+
}
784+
}
785+
705786
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
706787
hyperactor_mod.add_class::<PyMailbox>()?;
707788
hyperactor_mod.add_class::<PyPortId>()?;
@@ -713,5 +794,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
713794
hyperactor_mod.add_class::<PythonOncePortHandle>()?;
714795
hyperactor_mod.add_class::<PythonOncePortRef>()?;
715796
hyperactor_mod.add_class::<PythonOncePortReceiver>()?;
797+
hyperactor_mod.add_class::<Instance>()?;
798+
hyperactor_mod.add_class::<Context>()?;
716799
Ok(())
717800
}

monarch_hyperactor/src/pytokio.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,13 @@ impl PyPythonTask {
183183

184184
#[staticmethod]
185185
fn from_coroutine(py: Python<'_>, coro: PyObject) -> PyResult<PyPythonTask> {
186-
// MonarchContext.get() used inside a PythonTask should inherit the value of
187-
// MonarchContext from the context in which the PythonTask was constructed.
186+
// context() used inside a PythonTask should inherit the value of
187+
// context() from the context in which the PythonTask was constructed.
188188
// We need to do this manually because the value of the contextvar isn't
189189
// maintained inside the tokio runtime.
190190
let monarch_context = py
191191
.import("monarch._src.actor.actor_mesh")?
192-
.getattr("MonarchContext")?
193-
.call_method0("get")?
192+
.call_method0("context")?
194193
.unbind();
195194
PyPythonTask::new(async move {
196195
let (coroutine_iterator, none) = Python::with_gil(|py| {
@@ -206,7 +205,7 @@ impl PyPythonTask {
206205
loop {
207206
let action: PyResult<Action> = Python::with_gil(|py| {
208207
// We may be executing in a new thread at this point, so we need to set the value
209-
// of MonarchContext.
208+
// of context().
210209
let _context = py
211210
.import("monarch._src.actor.actor_mesh")?
212211
.getattr("_context")?;
@@ -220,7 +219,7 @@ impl PyPythonTask {
220219
.call_method1("throw", (pyerr.into_value(py),)),
221220
};
222221

223-
// Reset MonarchContext so that when this tokio thread yields, it has its original state.
222+
// Reset context() so that when this tokio thread yields, it has its original state.
224223
_context.call_method1("set", (old_context,))?;
225224
match result {
226225
Ok(task) => Ok(Action::Wait(

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,7 @@ class PortProtocol(Generic[R], Protocol):
266266
class Actor(Protocol):
267267
async def handle(
268268
self,
269-
mailbox: Mailbox,
270-
rank: int,
271-
shape: Shape,
269+
context: Any,
272270
method: MethodSpecifier,
273271
message: bytes,
274272
panic_flag: PanicFlag,

python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ class Mailbox:
137137
"""
138138
A mailbox from that can receive messages.
139139
"""
140-
@staticmethod
141-
def root_client_mailbox() -> "Mailbox": ...
142140
def open_port(self) -> tuple[PortHandle, PortReceiver]:
143141
"""Open a port to receive `PythonMessage` messages."""
144142
...

python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
# pyre-strict
88

9-
from typing import AsyncIterator, final, Literal, overload, Type
9+
from typing import Any, AsyncIterator, final, Literal, overload, Type, TYPE_CHECKING
1010

11-
from monarch._rust_bindings.monarch_hyperactor.actor import Actor
11+
if TYPE_CHECKING:
12+
from monarch._rust_bindings.monarch_hyperactor.actor import Actor
1213
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import (
1314
PythonActorMesh,
1415
PythonActorMeshImpl,
@@ -23,7 +24,7 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2324
@final
2425
class ProcMesh:
2526
@classmethod
26-
def allocate_nonblocking(self, alloc: Alloc) -> PythonTask[ProcMesh]:
27+
def allocate_nonblocking(self, alloc: Alloc) -> PythonTask["ProcMesh"]:
2728
"""
2829
Allocate a process mesh according to the provided alloc.
2930
Returns when the mesh is fully allocated.
@@ -34,7 +35,9 @@ class ProcMesh:
3435
...
3536

3637
def spawn_nonblocking(
37-
self, name: str, actor: Type[Actor]
38+
self,
39+
name: str,
40+
actor: Any,
3841
) -> PythonTask[PythonActorMesh]:
3942
"""
4043
Spawn a new actor on this mesh.
@@ -47,7 +50,7 @@ class ProcMesh:
4750

4851
@staticmethod
4952
def spawn_async(
50-
proc_mesh: Shared[ProcMesh], name: str, actor: Type[Actor], emulated: bool
53+
proc_mesh: Shared["ProcMesh"], name: str, actor: Type["Actor"], emulated: bool
5154
) -> PythonActorMesh: ...
5255
async def monitor(self) -> ProcMeshMonitor:
5356
"""

0 commit comments

Comments
 (0)