Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hyperactor_mesh/src/v1/host_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ impl HostMeshRef {
for (host_rank, host) in self.ranks.iter().enumerate() {
for per_host_rank in 0..per_host.num_ranks() {
let create_rank = per_host.num_ranks() * host_rank + per_host_rank;
let proc_name = Name::new(format!("{}-{}", name, per_host_rank));
let proc_name = Name::new(format!("{}_{}", name, per_host_rank));
host.mesh_agent()
.create_or_update(cx, proc_name.clone(), resource::Rank::new(create_rank), ())
.await
Expand Down
20 changes: 20 additions & 0 deletions hyperactor_mesh/src/v1/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,26 @@ impl ProcMeshRef {
self.spawn_with_name(cx, Name::new(name), params).await
}

/// Spawn a 'service' actor. Service actors are *singletons*, using
/// reserved names. The provided name is used verbatim as the actor's
/// name, and thus it may be persistently looked up by constructing
/// the appropriate name.
///
/// Note: avoid using service actors if possible; the mechanism will
/// be replaced by an actor registry.
pub async fn spawn_service<A: Actor + Referable>(
&self,
cx: &impl context::Actor,
name: &str,
params: &A::Params,
) -> v1::Result<ActorMesh<A>>
where
A::Params: RemoteMessage,
{
self.spawn_with_name(cx, Name::new_reserved(name), params)
.await
}

/// Spawn an actor on all procs in this mesh under the given
/// [`Name`], returning a new `ActorMesh`.
///
Expand Down
8 changes: 4 additions & 4 deletions monarch_hyperactor/src/v1/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::v1::actor_mesh::PythonActorMeshImpl;
name = "ProcMesh",
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
)]
pub(crate) enum PyProcMesh {
pub enum PyProcMesh {
Owned(PyProcMeshImpl),
Ref(PyProcMeshRefImpl),
}
Expand All @@ -50,7 +50,7 @@ impl PyProcMesh {
Self::Ref(PyProcMeshRefImpl(inner))
}

pub(crate) fn mesh_ref(&self) -> Result<ProcMeshRef, anyhow::Error> {
pub fn mesh_ref(&self) -> Result<ProcMeshRef, anyhow::Error> {
match self {
PyProcMesh::Owned(inner) => Ok(inner.0.borrow()?.clone()),
PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
Expand Down Expand Up @@ -195,7 +195,7 @@ impl PyProcMesh {
name = "ProcMeshImpl",
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
)]
pub(crate) struct PyProcMeshImpl(SharedCell<ProcMesh>);
pub struct PyProcMeshImpl(SharedCell<ProcMesh>);

impl PyProcMeshImpl {
fn __repr__(&self) -> PyResult<String> {
Expand All @@ -211,7 +211,7 @@ impl PyProcMeshImpl {
name = "ProcMeshRefImpl",
module = "monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh"
)]
pub(crate) struct PyProcMeshRefImpl(ProcMeshRef);
pub struct PyProcMeshRefImpl(ProcMeshRef);

impl PyProcMeshRefImpl {
fn __repr__(&self) -> PyResult<String> {
Expand Down
64 changes: 45 additions & 19 deletions monarch_rdma/extension/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use monarch_hyperactor::instance_dispatch;
use monarch_hyperactor::proc_mesh::PyProcMesh;
use monarch_hyperactor::pytokio::PyPythonTask;
use monarch_hyperactor::runtime::signal_safe_block_on;
use monarch_hyperactor::v1::proc_mesh::PyProcMesh as PyProcMeshV1;
use monarch_rdma::RdmaBuffer;
use monarch_rdma::RdmaManagerActor;
use monarch_rdma::RdmaManagerMessageClient;
Expand All @@ -36,6 +37,7 @@ fn setup_rdma_context(
local_proc_id: String,
) -> (ActorRef<RdmaManagerActor>, RdmaBuffer) {
let proc_id: ProcId = local_proc_id.parse().unwrap();
// TODO: find some better way to look this up, or else formally define "service names"
let local_owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
let local_owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(local_owner_id);
let buffer = rdma_buffer.buffer.clone();
Expand All @@ -56,6 +58,7 @@ async fn create_rdma_buffer(
client: PyInstance,
) -> PyResult<PyRdmaBuffer> {
// Get the owning RdmaManagerActor's ActorRef
// TODO: find some better way to look this up, or else formally define "service names"
let owner_id = ActorId(proc_id, "rdma_manager".to_string(), 0);
let owner_ref: ActorRef<RdmaManagerActor> = ActorRef::attest(owner_id);

Expand Down Expand Up @@ -289,31 +292,54 @@ impl PyRdmaManager {
#[classmethod]
fn create_rdma_manager_nonblocking(
_cls: &Bound<'_, PyType>,
proc_mesh: &PyProcMesh,
proc_mesh: &Bound<'_, PyAny>,
client: PyInstance,
) -> PyResult<PyPythonTask> {
tracing::debug!("spawning RDMA manager on target proc_mesh nodes");

let tracked_proc_mesh = proc_mesh.try_inner()?;
if let Ok(v0) = proc_mesh.downcast::<PyProcMesh>() {
let tracked_proc_mesh = v0.borrow().try_inner()?;
PyPythonTask::new(async move {
// Spawns the `RdmaManagerActor` on the target proc_mesh.
// This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware.
let actor_mesh = instance_dispatch!(client, |cx| {
tracked_proc_mesh
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
// TODO - make IbverbsConfig configurable
.spawn::<RdmaManagerActor>(cx, "rdma_manager", &None)
.await
.map_err(|err| PyException::new_err(err.to_string()))?
});

PyPythonTask::new(async move {
// Spawns the `RdmaManagerActor` on the target proc_mesh.
// This allows the `RdmaController` to run on any node while real RDMA operations occur on appropriate hardware.
let actor_mesh = instance_dispatch!(client, |cx| {
tracked_proc_mesh
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
// TODO - make IbverbsConfig configurable
.spawn::<RdmaManagerActor>(cx, "rdma_manager", &None)
.await
.map_err(|err| PyException::new_err(err.to_string()))?
});
// Use placeholder device name since actual device is determined on remote node
Ok(Some(PyRdmaManager {
inner: actor_mesh,
device: "remote_rdma_device".to_string(),
}))
})
} else {
let proc_mesh = proc_mesh.downcast::<PyProcMeshV1>()?.borrow().mesh_ref()?;
PyPythonTask::new(async move {
let actor_mesh = instance_dispatch!(client, |cx| {
proc_mesh
// Pass None to use default config - RdmaManagerActor will use default IbverbsConfig
// TODO - make IbverbsConfig configurable
.spawn_service::<RdmaManagerActor>(cx, "rdma_manager", &None)
.await
.map_err(|err| PyException::new_err(err.to_string()))?
});

// Use placeholder device name since actual device is determined on remote node
Ok(Some(PyRdmaManager {
inner: actor_mesh,
device: "remote_rdma_device".to_string(),
}))
})
eprintln!("spawned rdma_manager: {:?}", actor_mesh);

let actor_mesh = RootActorMesh::from(actor_mesh);
let actor_mesh = SharedCell::from(actor_mesh);

Ok(Some(PyRdmaManager {
inner: actor_mesh,
device: "remote_rdma_device".to_string(),
}))
})
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion python/monarch/_src/actor/v1/host_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def spawn_procs(
per_host = {}

if not name:
name = ""
name = "anon"

return self._spawn_nonblocking(
name, Extent(list(per_host.keys()), list(per_host.values())), setup, True
Expand Down
37 changes: 25 additions & 12 deletions python/monarch/_src/rdma/rdma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import warnings
from collections import defaultdict
from typing import cast, List, Optional, Tuple
from typing import Any, cast, List, Optional, Tuple

import torch
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
Expand All @@ -27,7 +27,14 @@
from monarch._src.actor.actor_mesh import Actor, context
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.future import Future
from monarch._src.actor.v1 import get_or_spawn_controller, ProcMesh
from monarch._src.actor.v1 import (
get_or_spawn_controller as get_or_spawn_controller_v0,
ProcMesh as ProcMeshV0,
)
from monarch._src.actor.v1.proc_mesh import (
get_or_spawn_controller as get_or_spawn_controller_v1,
ProcMesh as ProcMeshV1,
)
from pyre_extensions import none_throws


Expand All @@ -54,13 +61,17 @@ def is_rdma_available():
@functools.cache
def _ensure_init_rdma_manager() -> Shared[None]:
async def task() -> None:
await (
await get_or_spawn_controller("rdma_controller", RdmaController)
).init_rdma_on_mesh.call_one(
# FIXME(slurye): Fix this once controller API is working properly
# for v1.
cast(ProcMesh, none_throws(context().actor_instance.proc_mesh))
)
proc_mesh = context().actor_instance.proc_mesh
if isinstance(proc_mesh, ProcMeshV1):
controller = await get_or_spawn_controller_v1(
"rdma_controller", RdmaController
)
else:
controller = await get_or_spawn_controller_v0(
"rdma_controller", RdmaController
)

await controller.init_rdma_on_mesh.call_one(proc_mesh)

return PythonTask.from_coroutine(task()).spawn()

Expand Down Expand Up @@ -120,17 +131,19 @@ def _get_addr_and_size(buf: torch.Tensor | memoryview) -> tuple[int, int]:

class RdmaController(Actor):
def __init__(self) -> None:
self._manager_futures: Dict[ProcMesh, Future[_RdmaManager]] = {}
self._manager_futures: Dict[ProcMeshV0 | ProcMeshV1, Future[_RdmaManager]] = {}

@endpoint
async def init_rdma_on_mesh(self, proc_mesh: ProcMesh) -> None:
async def init_rdma_on_mesh(self, proc_mesh: ProcMeshV0 | ProcMeshV1) -> None:
# Note: RdmaController acts as coordinator and can run on any node
# The RDMA support check should happen on the target proc_mesh nodes, not on RdmaController's node

if proc_mesh not in self._manager_futures:

async def create_manager() -> _RdmaManager:
proc_mesh_result = await Future(coro=proc_mesh._proc_mesh.task())
proc_mesh_result = await Future(
coro=cast("PythonTask[Any]", proc_mesh._proc_mesh.task())
)
return none_throws(
await _RdmaManager.create_rdma_manager_nonblocking(
proc_mesh_result, context().actor_instance
Expand Down