Skip to content

Commit 5a40b91

Browse files
committed
[monarch] mesh: support v1 actor meshes in RDMA buffer
Pull Request resolved: #1467 This uses the v0 shim to support v1 meshes in RDMA buffer. ghstack-source-id: 315126536 Differential Revision: [D84181912](https://our.internmc.facebook.com/intern/diff/D84181912/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84181912/)!
1 parent d7899c9 commit 5a40b91

File tree

6 files changed

+96
-37
lines changed

6 files changed

+96
-37
lines changed

hyperactor_mesh/src/v1/host_mesh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ impl HostMeshRef {
550550
for (host_rank, host) in self.ranks.iter().enumerate() {
551551
for per_host_rank in 0..per_host.num_ranks() {
552552
let create_rank = per_host.num_ranks() * host_rank + per_host_rank;
553-
let proc_name = Name::new(format!("{}-{}", name, per_host_rank));
553+
let proc_name = Name::new(format!("{}_{}", name, per_host_rank));
554554
host.mesh_agent()
555555
.create_or_update(cx, proc_name.clone(), resource::Rank::new(create_rank), ())
556556
.await

hyperactor_mesh/src/v1/proc_mesh.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,26 @@ impl ProcMeshRef {
630630
self.spawn_with_name(cx, Name::new(name), params).await
631631
}
632632

633+
/// Spawn a 'service' actor. Service actors are *singletons*, using
634+
/// reserved names. The provided name is used verbatim as the actor's
635+
/// name, and thus it may be persistently looked up by constructing
636+
/// the appropriate name.
637+
///
638+
/// Note: avoid using service actors if possible; the mechanism will
639+
/// be replaced by an actor registry.
640+
pub async fn spawn_service<A: Actor + Referable>(
641+
&self,
642+
cx: &impl context::Actor,
643+
name: &str,
644+
params: &A::Params,
645+
) -> v1::Result<ActorMesh<A>>
646+
where
647+
A::Params: RemoteMessage,
648+
{
649+
self.spawn_with_name(cx, Name::new_reserved(name), params)
650+
.await
651+
}
652+
633653
/// Spawn an actor on all procs in this mesh under the given
634654
/// [`Name`], returning a new `ActorMesh`.
635655
///

monarch_hyperactor/src/v1/proc_mesh.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use crate::v1::actor_mesh::PythonActorMeshImpl;
3636
name = "ProcMesh",
3737
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
3838
)]
39-
pub(crate) enum PyProcMesh {
39+
pub enum PyProcMesh {
4040
Owned(PyProcMeshImpl),
4141
Ref(PyProcMeshRefImpl),
4242
}
@@ -50,7 +50,7 @@ impl PyProcMesh {
5050
Self::Ref(PyProcMeshRefImpl(inner))
5151
}
5252

53-
pub(crate) fn mesh_ref(&self) -> Result<ProcMeshRef, anyhow::Error> {
53+
pub fn mesh_ref(&self) -> Result<ProcMeshRef, anyhow::Error> {
5454
match self {
5555
PyProcMesh::Owned(inner) => Ok(inner.0.borrow()?.clone()),
5656
PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
@@ -195,7 +195,7 @@ impl PyProcMesh {
195195
name = "ProcMeshImpl",
196196
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
197197
)]
198-
pub(crate) struct PyProcMeshImpl(SharedCell<ProcMesh>);
198+
pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
199199

200200
impl PyProcMeshImpl {
201201
fn __repr__(&self) -> PyResult<String> {
@@ -211,7 +211,7 @@ impl PyProcMeshImpl {
211211
name = "ProcMeshRefImpl",
212212
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
213213
)]
214-
pub(crate) struct PyProcMeshRefImpl(ProcMeshRef);
214+
pub struct PyProcMeshRefImpl(ProcMeshRef);
215215

216216
impl PyProcMeshRefImpl {
217217
fn __repr__(&self) -> PyResult<String> {

monarch_rdma/extension/lib.rs

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use monarch_hyperactor::instance_dispatch;
1818
use monarch_hyperactor::proc_mesh::PyProcMesh;
1919
use monarch_hyperactor::pytokio::PyPythonTask;
2020
use monarch_hyperactor::runtime::signal_safe_block_on;
21+
use monarch_hyperactor::v1::proc_mesh::PyProcMesh as PyProcMeshV1;
2122
use monarch_rdma::RdmaBuffer;
2223
use monarch_rdma::RdmaManagerActor;
2324
use monarch_rdma::RdmaManagerMessageClient;
@@ -36,6 +37,7 @@ fn setup_rdma_context(
3637
local_proc_id: String,
3738
) -> (ActorRef<RdmaManagerActor>, RdmaBuffer) {
3839
let proc_id: ProcId = local_proc_id.parse().unwrap();
40+
// TODO: find some better way to look this up, or else formally define "service names"
3941
let local_owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
4042
let local_owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(local_owner_id);
4143
let buffer = rdma_buffer.buffer.clone();
@@ -56,6 +58,7 @@ async fn create_rdma_buffer(
5658
client: PyInstance,
5759
) -> PyResult<PyRdmaBuffer> {
5860
// Get the owning RdmaManagerActor's ActorRef
61+
// TODO: find some better way to look this up, or else formally define "service names"
5962
let owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
6063
let owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(owner_id);
6164

@@ -289,31 +292,54 @@ impl PyRdmaManager {
289292
#[classmethod]
290293
fn create_rdma_manager_nonblocking(
291294
_cls: &Bound<'_, PyType>,
292-
proc_mesh: &PyProcMesh,
295+
proc_mesh: &Bound<'_, PyAny>,
293296
client: PyInstance,
294297
) -> PyResult<PyPythonTask> {
295298
tracing::debug!("spawning RDMA manager on target proc_mesh nodes");
296299

297-
let tracked_proc_mesh = proc_mesh.try_inner()?;
300+
if let Ok(v0) = proc_mesh.downcast::<PyProcMesh>() {
301+
let tracked_proc_mesh = v0.borrow().try_inner()?;
302+
PyPythonTask::new(async move {
303+
// Spawns the `RdmaManagerActor` on the target proc_mesh.
304+
// This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware.
305+
let actor_mesh = instance_dispatch!(client, |cx| {
306+
tracked_proc_mesh
307+
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
308+
// TODO - make IbverbsConfig configurable
309+
.spawn::<RdmaManagerActor>(cx, "rdma_manager", &None)
310+
.await
311+
.map_err(|err| PyException::new_err(err.to_string()))?
312+
});
298313

299-
PyPythonTask::new(async move {
300-
// Spawns the `RdmaManagerActor` on the target proc_mesh.
301-
// This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware.
302-
let actor_mesh = instance_dispatch!(client, |cx| {
303-
tracked_proc_mesh
304-
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
305-
// TODO - make IbverbsConfig configurable
306-
.spawn::<RdmaManagerActor>(cx, "rdma_manager", &None)
307-
.await
308-
.map_err(|err| PyException::new_err(err.to_string()))?
309-
});
314+
// Use placeholder device name since actual device is determined on remote node
315+
Ok(Some(PyRdmaManager {
316+
inner: actor_mesh,
317+
device: "remote_rdma_device".to_string(),
318+
}))
319+
})
320+
} else {
321+
let proc_mesh = proc_mesh.downcast::<PyProcMeshV1>()?.borrow().mesh_ref()?;
322+
PyPythonTask::new(async move {
323+
let actor_mesh = instance_dispatch!(client, |cx| {
324+
proc_mesh
325+
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
326+
// TODO - make IbverbsConfig configurable
327+
.spawn_service::<RdmaManagerActor>(cx, "rdma_manager", &None)
328+
.await
329+
.map_err(|err| PyException::new_err(err.to_string()))?
330+
});
310331

311-
// Use placeholder device name since actual device is determined on remote node
312-
Ok(Some(PyRdmaManager {
313-
inner: actor_mesh,
314-
device: "remote_rdma_device".to_string(),
315-
}))
316-
})
332+
eprintln!("spawned rdma_manager: {:?}", actor_mesh);
333+
334+
let actor_mesh = RootActorMesh::from(actor_mesh);
335+
let actor_mesh = SharedCell::from(actor_mesh);
336+
337+
Ok(Some(PyRdmaManager {
338+
inner: actor_mesh,
339+
device: "remote_rdma_device".to_string(),
340+
}))
341+
})
342+
}
317343
}
318344
}
319345

python/monarch/_src/actor/v1/host_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def spawn_procs(
139139
per_host = {}
140140

141141
if not name:
142-
name = ""
142+
name = "anon"
143143

144144
return self._spawn_nonblocking(
145145
name, Extent(list(per_host.keys()), list(per_host.values())), setup, True

python/monarch/_src/rdma/rdma.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
import warnings
1212
from collections import defaultdict
13-
from typing import cast, List, Optional, Tuple
13+
from typing import Any, cast, List, Optional, Tuple
1414

1515
import torch
1616
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
@@ -27,7 +27,14 @@
2727
from monarch._src.actor.actor_mesh import Actor, context
2828
from monarch._src.actor.endpoint import endpoint
2929
from monarch._src.actor.future import Future
30-
from monarch._src.actor.v1 import get_or_spawn_controller, ProcMesh
30+
from monarch._src.actor.v1 import (
31+
get_or_spawn_controller as get_or_spawn_controller_v0,
32+
ProcMesh as ProcMeshV0,
33+
)
34+
from monarch._src.actor.v1.proc_mesh import (
35+
get_or_spawn_controller as get_or_spawn_controller_v1,
36+
ProcMesh as ProcMeshV1,
37+
)
3138
from pyre_extensions import none_throws
3239

3340

@@ -54,13 +61,17 @@ def is_rdma_available():
5461
@functools.cache
5562
def _ensure_init_rdma_manager() -> Shared[None]:
5663
async def task() -> None:
57-
await (
58-
await get_or_spawn_controller("rdma_controller", RdmaController)
59-
).init_rdma_on_mesh.call_one(
60-
# FIXME(slurye): Fix this once controller API is working properly
61-
# for v1.
62-
cast(ProcMesh, none_throws(context().actor_instance.proc_mesh))
63-
)
64+
proc_mesh = context().actor_instance.proc_mesh
65+
if isinstance(proc_mesh, ProcMeshV1):
66+
controller = await get_or_spawn_controller_v1(
67+
"rdma_controller", RdmaController
68+
)
69+
else:
70+
controller = await get_or_spawn_controller_v0(
71+
"rdma_controller", RdmaController
72+
)
73+
74+
await controller.init_rdma_on_mesh.call_one(proc_mesh)
6475

6576
return PythonTask.from_coroutine(task()).spawn()
6677

@@ -120,17 +131,19 @@ def _get_addr_and_size(buf: torch.Tensor | memoryview) -> tuple[int, int]:
120131

121132
class RdmaController(Actor):
122133
def __init__(self) -> None:
123-
self._manager_futures: Dict[ProcMesh, Future[_RdmaManager]] = {}
134+
self._manager_futures: Dict[ProcMeshV0 | ProcMeshV1, Future[_RdmaManager]] = {}
124135

125136
@endpoint
126-
async def init_rdma_on_mesh(self, proc_mesh: ProcMesh) -> None:
137+
async def init_rdma_on_mesh(self, proc_mesh: ProcMeshV0 | ProcMeshV1) -> None:
127138
# Note: RdmaController acts as coordinator and can run on any node
128139
# The RDMA support check should happen on the target proc_mesh nodes, not on RdmaController's node
129140

130141
if proc_mesh not in self._manager_futures:
131142

132143
async def create_manager() -> _RdmaManager:
133-
proc_mesh_result = await Future(coro=proc_mesh._proc_mesh.task())
144+
proc_mesh_result = await Future(
145+
coro=cast("PythonTask[Any]", proc_mesh._proc_mesh.task())
146+
)
134147
return none_throws(
135148
await _RdmaManager.create_rdma_manager_nonblocking(
136149
proc_mesh_result, context().actor_instance

0 commit comments

Comments
 (0)