Skip to content

Commit 9532ba1

Browse files
samluryefacebook-github-bot
authored andcommitted
v1 host mesh and proc mesh python integration (#1346)
Summary: Pull Request resolved: #1346 Write the python API for v1 host mesh and v1 proc mesh to (relatively closely) match the python APIs for v0. ghstack-source-id: 313115893 exported-using-ghexport Reviewed By: colin2328, zdevito Differential Revision: D83223400 fbshipit-source-id: 50e5f47c613ed350efa1a9a617a418acc0b42da3
1 parent 96e1fd8 commit 9532ba1

File tree

20 files changed

+1233
-116
lines changed

20 files changed

+1233
-116
lines changed

monarch_hyperactor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ async-trait = "0.1.86"
2222
bincode = "1.3.3"
2323
clap = { version = "4.5.42", features = ["derive", "env", "string", "unicode", "wrap_help"] }
2424
erased-serde = "0.3.27"
25+
fastrand = "2.1.1"
2526
fbinit = { version = "0.2.0", git = "https://github.com/facebookexperimental/rust-shed.git", branch = "main" }
2627
futures = { version = "0.3.31", features = ["async-await", "compat"] }
2728
hyperactor = { version = "0.0.0", path = "../hyperactor" }

monarch_hyperactor/src/shape.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use ndslice::Point;
1212
use ndslice::Region;
1313
use ndslice::Shape;
1414
use ndslice::Slice;
15+
use ndslice::View;
1516
use pyo3::IntoPyObjectExt;
1617
use pyo3::exceptions::PyValueError;
1718
use pyo3::prelude::*;
@@ -94,6 +95,13 @@ impl PyExtent {
9495
fn keys<'py>(&self, py: Python<'py>) -> PyResult<PyObject> {
9596
Ok(self.inner.labels().into_bound_py_any(py)?.into())
9697
}
98+
99+
#[getter]
100+
fn region(&self) -> PyRegion {
101+
PyRegion {
102+
inner: self.inner.region(),
103+
}
104+
}
97105
}
98106

99107
impl From<Extent> for PyExtent {
@@ -139,13 +147,32 @@ impl PyRegion {
139147
}
140148
}
141149

150+
#[getter]
142151
fn labels(&self) -> Vec<String> {
143152
self.inner.labels().to_vec()
144153
}
145154

146155
fn slice(&self) -> PySlice {
147156
self.inner.slice().clone().into()
148157
}
158+
159+
fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
160+
let bytes = bincode::serialize(&self.inner)
161+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
162+
let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
163+
let from_bytes = py
164+
.import("monarch._rust_bindings.monarch_hyperactor.shape")?
165+
.getattr("Region")?
166+
.getattr("from_bytes")?;
167+
Ok((from_bytes, py_bytes))
168+
}
169+
170+
#[staticmethod]
171+
fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
172+
Ok(bincode::deserialize::<Region>(bytes.as_bytes())
173+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?
174+
.into())
175+
}
149176
}
150177

151178
impl From<Region> for PyRegion {

monarch_hyperactor/src/v1/actor_mesh.rs

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
use hyperactor::ActorRef;
1010
use hyperactor_mesh::v1::actor_mesh::ActorMesh;
1111
use hyperactor_mesh::v1::actor_mesh::ActorMeshRef;
12+
use ndslice::Region;
1213
use ndslice::Selection;
14+
use ndslice::Slice;
15+
use ndslice::selection::structurally_equal;
1316
use ndslice::view::Ranked;
1417
use ndslice::view::RankedSliceable;
1518
use pyo3::IntoPyObjectExt;
1619
use pyo3::exceptions::PyException;
1720
use pyo3::exceptions::PyNotImplementedError;
21+
use pyo3::exceptions::PyRuntimeError;
1822
use pyo3::exceptions::PyValueError;
1923
use pyo3::prelude::*;
2024
use pyo3::types::PyBytes;
@@ -89,9 +93,8 @@ impl ActorMeshProtocol for PythonActorMeshImpl {
8993
}
9094

9195
fn supervision_event(&self) -> PyResult<Option<PyShared>> {
92-
Err(PyErr::new::<PyNotImplementedError, _>(
93-
"supervision_event is not implemented yet for v1::PythonActorMeshImpl",
94-
))
96+
// FIXME: implement supervision events for v1 actor mesh.
97+
Ok(None)
9598
}
9699

97100
fn new_with_region(&self, region: &PyRegion) -> PyResult<Box<dyn ActorMeshProtocol>> {
@@ -113,13 +116,37 @@ impl ActorMeshProtocol for ActorMeshRef<PythonActor> {
113116
fn cast(
114117
&self,
115118
message: PythonMessage,
116-
_selection: Selection,
119+
selection: Selection,
117120
instance: &PyInstance,
118121
) -> PyResult<()> {
119-
instance_dispatch!(instance, |cx_instance| {
120-
self.cast(cx_instance, message.clone())
121-
.map_err(|err| PyException::new_err(err.to_string()))?;
122-
});
122+
if structurally_equal(&selection, &Selection::All(Box::new(Selection::True))) {
123+
instance_dispatch!(instance, |cx_instance| {
124+
self.cast(cx_instance, message.clone())
125+
.map_err(|err| PyException::new_err(err.to_string()))?;
126+
});
127+
} else if structurally_equal(&selection, &Selection::Any(Box::new(Selection::True))) {
128+
let region = Ranked::region(self);
129+
let random_rank = fastrand::usize(0..region.num_ranks());
130+
let offset = region
131+
.slice()
132+
.get(random_rank)
133+
.map_err(anyhow::Error::from)?;
134+
let singleton_region = Region::new(
135+
Vec::new(),
136+
Slice::new(offset, Vec::new(), Vec::new()).map_err(anyhow::Error::from)?,
137+
);
138+
instance_dispatch!(instance, |cx_instance| {
139+
self.sliced(singleton_region)
140+
.cast(cx_instance, message.clone())
141+
.map_err(|err| PyException::new_err(err.to_string()))?;
142+
});
143+
} else {
144+
return Err(PyRuntimeError::new_err(format!(
145+
"invalid selection: {:?}",
146+
selection
147+
)));
148+
}
149+
123150
Ok(())
124151
}
125152

monarch_hyperactor/src/v1/host_mesh.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ impl PyHostMesh {
177177
let bytes = bincode::serialize(&self.mesh_ref()?)
178178
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
179179
let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
180-
let from_bytes = wrap_pyfunction!(py_host_mesh_from_bytes, py)?.into_any();
180+
let from_bytes =
181+
PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh")?
182+
.getattr("py_host_mesh_from_bytes")?;
181183
Ok((from_bytes, py_bytes))
182184
}
183185
}

monarch_hyperactor/src/v1/proc_mesh.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use hyperactor_mesh::v1::proc_mesh::ProcMesh;
1111
use hyperactor_mesh::v1::proc_mesh::ProcMeshRef;
1212
use monarch_types::PickledPyObject;
1313
use ndslice::View;
14+
use ndslice::view::RankedSliceable;
1415
use pyo3::IntoPyObjectExt;
1516
use pyo3::exceptions::PyException;
1617
use pyo3::exceptions::PyNotImplementedError;
@@ -181,6 +182,12 @@ impl PyProcMesh {
181182
"v1::PyProcMesh::stop not implemented",
182183
))
183184
}
185+
186+
fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
187+
Ok(Self::new_ref(
188+
self.mesh_ref()?.sliced(region.as_inner().clone()),
189+
))
190+
}
184191
}
185192

186193
#[derive(Clone)]

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class PortProtocol(Generic[R], Protocol):
257257
class Actor(Protocol):
258258
async def handle(
259259
self,
260-
context: Any,
260+
ctx: Any,
261261
method: MethodSpecifier,
262262
message: bytes,
263263
panic_flag: PanicFlag,

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ class Extent(collections.abc.Mapping):
173173
def __iter__(self) -> "Iterator[str]": ...
174174
def __getitem__(self, label: str) -> int: ...
175175
def __len__(self) -> int: ...
176+
@property
177+
def region(self) -> "Region": ...
176178

177179
class Point(collections.abc.Mapping):
178180
"""
@@ -204,6 +206,7 @@ class Region:
204206
"""
205207
def __init__(self, labels: Sequence[str], slice: Slice) -> None: ...
206208
def as_shape(self) -> "Shape": ...
209+
@property
207210
def labels(self) -> List[str]:
208211
"""
209212
The labels for each dimension of the region.
@@ -215,3 +218,5 @@ class Region:
215218
The slice of the region.
216219
"""
217220
...
221+
222+
def __reduce__(self) -> Any: ...

python/monarch/_rust_bindings/monarch_hyperactor/v1/proc_mesh.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,8 @@ class ProcMesh:
8080
...
8181

8282
def __repr__(self) -> str: ...
83+
def sliced(self, region: Region) -> "ProcMesh":
84+
"""
85+
Returns a new mesh that is a slice of this mesh with the given region.
86+
"""
87+
...

python/monarch/_src/actor/actor_mesh.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@
9292
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ActorMeshProtocol
9393
from monarch._rust_bindings.monarch_hyperactor.mailbox import PortReceiverBase
9494
from monarch._src.actor.proc_mesh import _ControllerController, ProcMesh
95+
from monarch._src.actor.v1.proc_mesh import (
96+
_ControllerController as _ControllerControllerV1,
97+
ProcMesh as ProcMeshV1,
98+
)
9599
from monarch._src.actor.telemetry import get_monarch_tracer
96100

97101
CallMethod = PythonMessageKind.CallMethod
@@ -142,7 +146,7 @@ def actor_id(self) -> ActorId:
142146
...
143147

144148
@property
145-
def proc(self) -> "ProcMesh":
149+
def proc(self) -> "ProcMesh | ProcMeshV1":
146150
"""
147151
The singleton proc mesh that corresponds to just this actor.
148152
"""
@@ -155,14 +159,14 @@ def proc(self) -> "ProcMesh":
155159
The actors __init__ message.
156160
"""
157161
rank: Point
158-
proc_mesh: "ProcMesh"
159-
_controller_controller: "_ControllerController"
162+
proc_mesh: "ProcMesh | ProcMeshV1"
163+
_controller_controller: "_ControllerController | _ControllerControllerV1"
160164

161165
# this property is used to hold the handles to actors and processes launched by this actor
162166
# in order to keep them alive until this actor exits.
163-
_children: "Optional[List[ActorMesh | ProcMesh]]"
167+
_children: "Optional[List[ActorMesh | ProcMesh | ProcMeshV1]]"
164168

165-
def _add_child(self, child: "ActorMesh | ProcMesh") -> None:
169+
def _add_child(self, child: "ActorMesh | ProcMesh | ProcMeshV1") -> None:
166170
if self._children is None:
167171
self._children = [child]
168172
else:
@@ -212,13 +216,16 @@ def context() -> Context:
212216
if c is None:
213217
c = Context._root_client_context()
214218
_context.set(c)
219+
220+
# FIXME: Switch to the v1 APIs when it becomes the default.
215221
from monarch._src.actor.host_mesh import create_local_host_mesh
216222
from monarch._src.actor.proc_mesh import _get_controller_controller
217223

218224
c.actor_instance.proc_mesh, c.actor_instance._controller_controller = (
219225
_get_controller_controller()
220226
)
221-
c.actor_instance.proc_mesh._host_mesh = create_local_host_mesh()
227+
228+
c.actor_instance.proc_mesh._host_mesh = create_local_host_mesh() # type: ignore
222229
return c
223230

224231

@@ -281,7 +288,7 @@ def __init__(
281288
self,
282289
actor_mesh: "ActorMeshProtocol",
283290
shape: Shape,
284-
proc_mesh: "Optional[ProcMesh]",
291+
proc_mesh: "Optional[ProcMesh] | Optional[ProcMeshV1]",
285292
name: MethodSpecifier,
286293
impl: Callable[Concatenate[Any, P], Awaitable[R]],
287294
propagator: Propagator,
@@ -931,7 +938,7 @@ def __init__(
931938
Class: Type[T],
932939
inner: "ActorMeshProtocol",
933940
shape: Shape,
934-
proc_mesh: "Optional[ProcMesh]",
941+
proc_mesh: "Optional[ProcMesh] | Optional[ProcMeshV1]",
935942
) -> None:
936943
self.__name__: str = Class.__name__
937944
self._class: Type[T] = Class
@@ -986,8 +993,9 @@ def _create(
986993
Class: Type[T],
987994
actor_mesh: "PythonActorMesh",
988995
shape: Shape,
989-
proc_mesh: "ProcMesh",
990-
controller_controller: Optional["_ControllerController"],
996+
proc_mesh: "ProcMesh | ProcMeshV1",
997+
controller_controller: Optional["_ControllerController"]
998+
| Optional["_ControllerControllerV1"],
991999
# args and kwargs are passed to the __init__ method of the user defined
9921000
# python actor object.
9931001
*args: Any,
@@ -1019,7 +1027,7 @@ def from_actor_id(
10191027
return cls(Class, _SingletonActorAdapator(actor_id), singleton_shape, None)
10201028

10211029
def __reduce_ex__(self, protocol: ...) -> "Tuple[Type[ActorMesh], Tuple[Any, ...]]":
1022-
return ActorMesh, (self._class, self._inner, self._shape, None)
1030+
return ActorMesh, (self._class, self._inner, self._shape, self._proc_mesh)
10231031

10241032
@property
10251033
def _ndslice(self) -> NDSlice:

python/monarch/_src/actor/host_mesh.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@ def this_host() -> "HostMesh":
2525
2626
This is just shorthand for looking it up via the context
2727
"""
28-
return context().actor_instance.proc.host_mesh
28+
host_mesh = context().actor_instance.proc.host_mesh
29+
assert isinstance(host_mesh, HostMesh), "expected v0 HostMesh, got v1 HostMesh"
30+
return host_mesh
2931

3032

3133
def this_proc() -> "ProcMesh":
3234
"""
3335
The current singleton process that this specific actor is
3436
running on
3537
"""
36-
return context().actor_instance.proc
38+
proc = context().actor_instance.proc
39+
assert isinstance(proc, ProcMesh), "expected v1 ProcMesh, got v0 ProcMesh"
40+
return proc
3741

3842

3943
def create_local_host_mesh() -> "HostMesh":

0 commit comments

Comments
 (0)