Skip to content

Commit c9c2cac

Browse files
committed
[monarch] mesh: support v1 actor meshes in RDMA buffer
Pull Request resolved: #1467 This uses the v0 shim to support v1 meshes in RDMA buffer. ghstack-source-id: 315017598 Differential Revision: [D84181912](https://our.internmc.facebook.com/intern/diff/D84181912/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84181912/)!
1 parent 1000ca1 commit c9c2cac

File tree

3 files changed

+66
-33
lines changed

3 files changed

+66
-33
lines changed

monarch_hyperactor/src/v1/proc_mesh.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use crate::v1::actor_mesh::PythonActorMeshImpl;
3636
name = "ProcMesh",
3737
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
3838
)]
39-
pub(crate) enum PyProcMesh {
39+
pub enum PyProcMesh {
4040
Owned(PyProcMeshImpl),
4141
Ref(PyProcMeshRefImpl),
4242
}
@@ -50,7 +50,7 @@ impl PyProcMesh {
5050
Self::Ref(PyProcMeshRefImpl(inner))
5151
}
5252

53-
pub(crate) fn mesh_ref(&self) -> Result<ProcMeshRef, anyhow::Error> {
53+
pub fn mesh_ref(&self) -> Result<ProcMeshRef, anyhow::Error> {
5454
match self {
5555
PyProcMesh::Owned(inner) => Ok(inner.0.borrow()?.clone()),
5656
PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
@@ -195,7 +195,7 @@ impl PyProcMesh {
195195
name = "ProcMeshImpl",
196196
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
197197
)]
198-
pub(crate) struct PyProcMeshImpl(SharedCell<ProcMesh>);
198+
pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
199199

200200
impl PyProcMeshImpl {
201201
fn __repr__(&self) -> PyResult<String> {
@@ -211,7 +211,7 @@ impl PyProcMeshImpl {
211211
name = "ProcMeshRefImpl",
212212
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
213213
)]
214-
pub(crate) struct PyProcMeshRefImpl(ProcMeshRef);
214+
pub struct PyProcMeshRefImpl(ProcMeshRef);
215215

216216
impl PyProcMeshRefImpl {
217217
fn __repr__(&self) -> PyResult<String> {

monarch_rdma/extension/lib.rs

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use monarch_hyperactor::instance_dispatch;
1818
use monarch_hyperactor::proc_mesh::PyProcMesh;
1919
use monarch_hyperactor::pytokio::PyPythonTask;
2020
use monarch_hyperactor::runtime::signal_safe_block_on;
21+
use monarch_hyperactor::v1::proc_mesh::PyProcMesh as PyProcMeshV1;
2122
use monarch_rdma::RdmaBuffer;
2223
use monarch_rdma::RdmaManagerActor;
2324
use monarch_rdma::RdmaManagerMessageClient;
@@ -289,31 +290,52 @@ impl PyRdmaManager {
289290
#[classmethod]
290291
fn create_rdma_manager_nonblocking(
291292
_cls: &Bound<'_, PyType>,
292-
proc_mesh: &PyProcMesh,
293+
proc_mesh: &Bound<'_, PyAny>,
293294
client: PyInstance,
294295
) -> PyResult<PyPythonTask> {
295296
tracing::debug!("spawning RDMA manager on target proc_mesh nodes");
296297

297-
let tracked_proc_mesh = proc_mesh.try_inner()?;
298+
if let Ok(v0) = proc_mesh.downcast::<PyProcMesh>() {
299+
let tracked_proc_mesh = v0.borrow().try_inner()?;
300+
PyPythonTask::new(async move {
301+
// Spawns the `RdmaManagerActor` on the target proc_mesh.
302+
// This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware.
303+
let actor_mesh = instance_dispatch!(client, |cx| {
304+
tracked_proc_mesh
305+
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
306+
// TODO - make IbverbsConfig configurable
307+
.spawn::<RdmaManagerActor>(cx, "rdma_manager", &None)
308+
.await
309+
.map_err(|err| PyException::new_err(err.to_string()))?
310+
});
298311

299-
PyPythonTask::new(async move {
300-
// Spawns the `RdmaManagerActor` on the target proc_mesh.
301-
// This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware.
302-
let actor_mesh = instance_dispatch!(client, |cx| {
303-
tracked_proc_mesh
304-
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
305-
// TODO - make IbverbsConfig configurable
306-
.spawn::<RdmaManagerActor>(cx, "rdma_manager", &None)
307-
.await
308-
.map_err(|err| PyException::new_err(err.to_string()))?
309-
});
312+
// Use placeholder device name since actual device is determined on remote node
313+
Ok(Some(PyRdmaManager {
314+
inner: actor_mesh,
315+
device: "remote_rdma_device".to_string(),
316+
}))
317+
})
318+
} else {
319+
let proc_mesh = proc_mesh.downcast::<PyProcMeshV1>()?.borrow().mesh_ref()?;
320+
PyPythonTask::new(async move {
321+
let actor_mesh = instance_dispatch!(client, |cx| {
322+
proc_mesh
323+
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
324+
// TODO - make IbverbsConfig configurable
325+
.spawn::<RdmaManagerActor>(cx, "rdma_manager", &None)
326+
.await
327+
.map_err(|err| PyException::new_err(err.to_string()))?
328+
});
310329

311-
// Use placeholder device name since actual device is determined on remote node
312-
Ok(Some(PyRdmaManager {
313-
inner: actor_mesh,
314-
device: "remote_rdma_device".to_string(),
315-
}))
316-
})
330+
let actor_mesh = RootActorMesh::from(actor_mesh);
331+
let actor_mesh = SharedCell::from(actor_mesh);
332+
333+
Ok(Some(PyRdmaManager {
334+
inner: actor_mesh,
335+
device: "remote_rdma_device".to_string(),
336+
}))
337+
})
338+
}
317339
}
318340
}
319341

python/monarch/_src/rdma/rdma.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@
2727
from monarch._src.actor.actor_mesh import Actor, context
2828
from monarch._src.actor.endpoint import endpoint
2929
from monarch._src.actor.future import Future
30-
from monarch._src.actor.proc_mesh import get_or_spawn_controller, ProcMesh
30+
from monarch._src.actor.proc_mesh import (
31+
get_or_spawn_controller as get_or_spawn_controller_v0,
32+
ProcMesh as ProcMeshV0,
33+
)
34+
from monarch._src.actor.v1.proc_mesh import (
35+
get_or_spawn_controller as get_or_spawn_controller_v1,
36+
ProcMesh as ProcMeshV1,
37+
)
3138
from pyre_extensions import none_throws
3239

3340

@@ -54,13 +61,17 @@ def is_rdma_available():
5461
@functools.cache
5562
def _ensure_init_rdma_manager() -> Shared[None]:
5663
async def task() -> None:
57-
await (
58-
await get_or_spawn_controller("rdma_controller", RdmaController)
59-
).init_rdma_on_mesh.call_one(
60-
# FIXME(slurye): Fix this once controller API is working properly
61-
# for v1.
62-
cast(ProcMesh, none_throws(context().actor_instance.proc_mesh))
63-
)
64+
proc_mesh = context().actor_instance.proc_mesh
65+
if isinstance(proc_mesh, ProcMeshV0):
66+
controller = await get_or_spawn_controller_v0(
67+
"rdma_controller", RdmaController
68+
)
69+
else:
70+
controller = await get_or_spawn_controller_v1(
71+
"rdma_controller", RdmaController
72+
)
73+
74+
await controller.init_rdma_on_mesh.call_one(proc_mesh)
6475

6576
return PythonTask.from_coroutine(task()).spawn()
6677

@@ -120,10 +131,10 @@ def _get_addr_and_size(buf: torch.Tensor | memoryview) -> tuple[int, int]:
120131

121132
class RdmaController(Actor):
122133
def __init__(self) -> None:
123-
self._manager_futures: Dict[ProcMesh, Future[_RdmaManager]] = {}
134+
self._manager_futures: Dict[ProcMeshV0 | ProcMeshV1, Future[_RdmaManager]] = {}
124135

125136
@endpoint
126-
async def init_rdma_on_mesh(self, proc_mesh: ProcMesh) -> None:
137+
async def init_rdma_on_mesh(self, proc_mesh: ProcMeshV0 | ProcMeshV1) -> None:
127138
# Note: RdmaController acts as coordinator and can run on any node
128139
# The RDMA support check should happen on the target proc_mesh nodes, not on RdmaController's node
129140

0 commit comments

Comments
 (0)