Skip to content

Commit c3e0fb0

Browse files
samluryefacebook-github-bot
authored andcommitted
Additional dims per host when spawning proc mesh on host mesh (#1345)
Summary: Pull Request resolved: #1345 Add an API to specify an `Extent` for adding additional dims on each host when spawning a proc mesh on a v1 host mesh. The extent of the proc mesh will be the extent of the host mesh + the extent per host. ghstack-source-id: 312970198 exported-using-ghexport Reviewed By: mariusae Differential Revision: D83274881 fbshipit-source-id: 4309309a151c1169d0d4f2f434062512ee0b228a
1 parent fa64d1b commit c3e0fb0

File tree

8 files changed

+145
-62
lines changed

8 files changed

+145
-62
lines changed

hyperactor_mesh/src/bootstrap.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1709,6 +1709,7 @@ mod tests {
17091709
use hyperactor::context::Mailbox as _;
17101710
use hyperactor::host::ProcHandle;
17111711
use hyperactor::id;
1712+
use ndslice::Extent;
17121713
use ndslice::ViewExt;
17131714
use ndslice::extent;
17141715
use tokio::process::Command;
@@ -2863,7 +2864,10 @@ mod tests {
28632864
//
28642865
// (4) We collect the per-host procs into a `ProcMesh` and
28652866
// return it.
2866-
let proc_mesh = host_mesh.spawn(&instance, "p0").await.unwrap();
2867+
let proc_mesh = host_mesh
2868+
.spawn(&instance, "p0", Extent::unity())
2869+
.await
2870+
.unwrap();
28672871

28682872
// Note: There is no support for status() in v1.
28692873
// assert!(proc_mesh.status(&instance).await.is_err());

hyperactor_mesh/src/v1/actor_mesh.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ mod tests {
341341
use hyperactor::clock::RealClock;
342342
use hyperactor::context::Mailbox as _;
343343
use hyperactor::mailbox;
344+
use ndslice::Extent;
344345
use ndslice::ViewExt;
345346
use ndslice::extent;
346347
use ndslice::view::Ranked;
@@ -530,7 +531,10 @@ mod tests {
530531

531532
let instance = testing::instance().await;
532533
let host_mesh = testing::host_mesh(extent!(host = 4)).await;
533-
let proc_mesh = host_mesh.spawn(instance, "test").await.unwrap();
534+
let proc_mesh = host_mesh
535+
.spawn(instance, "test", Extent::unity())
536+
.await
537+
.unwrap();
534538
let actor_mesh = proc_mesh
535539
.spawn::<testactor::TestActor>(instance, "test", &())
536540
.await

hyperactor_mesh/src/v1/host_mesh.rs

Lines changed: 101 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use hyperactor::channel::ChannelTransport;
1010
pub mod mesh_agent;
1111

12+
use std::collections::HashSet;
1213
use std::ops::Deref;
1314
use std::str::FromStr;
1415
use std::sync::Arc;
@@ -460,31 +461,69 @@ impl HostMeshRef {
460461
}
461462
}
462463

463-
/// Spawn a ProcMesh onto this host mesh.
464-
// TODO: add an "additional dims" API
465-
pub async fn spawn(&self, cx: &impl context::Actor, name: &str) -> v1::Result<ProcMesh> {
466-
let name = Name::new(name);
464+
/// Spawn a ProcMesh onto this host mesh. The per_host extent specifies the shape
465+
/// of the procs to spawn on each host.
466+
pub async fn spawn(
467+
&self,
468+
cx: &impl context::Actor,
469+
name: &str,
470+
per_host: Extent,
471+
) -> v1::Result<ProcMesh> {
472+
let per_host_labels = per_host.labels().iter().collect::<HashSet<_>>();
473+
let host_labels = self.region.labels().iter().collect::<HashSet<_>>();
474+
if !per_host_labels
475+
.intersection(&host_labels)
476+
.collect::<Vec<_>>()
477+
.is_empty()
478+
{
479+
return Err(v1::Error::ConfigurationError(anyhow::anyhow!(
480+
"per_host dims overlap with existing dims when spawning proc mesh"
481+
)));
482+
}
483+
484+
let labels = self
485+
.region
486+
.labels()
487+
.to_vec()
488+
.into_iter()
489+
.chain(per_host.labels().to_vec().into_iter())
490+
.collect();
491+
let sizes = self
492+
.region
493+
.extent()
494+
.sizes()
495+
.to_vec()
496+
.into_iter()
497+
.chain(per_host.sizes().to_vec().into_iter())
498+
.collect();
499+
let extent =
500+
Extent::new(labels, sizes).map_err(|err| v1::Error::ConfigurationError(err.into()))?;
501+
502+
let mesh_name = Name::new(name);
467503
let mut procs = Vec::new();
468-
for (rank, host) in self.ranks.iter().enumerate() {
469-
let _ok = host
470-
.mesh_agent()
471-
.create_or_update(cx, name.clone(), ())
472-
.await
473-
.map_err(|e| {
474-
v1::Error::HostMeshAgentConfigurationError(
475-
host.mesh_agent().actor_id().clone(),
476-
format!("failed while creating proc: {}", e),
477-
)
478-
})?;
479-
procs.push(ProcRef::new(
480-
host.named_proc(&name),
481-
rank,
482-
// TODO: specify or retrieve from state instead, to avoid attestation.
483-
ActorRef::attest(host.named_proc(&name).actor_id("agent", 0)),
484-
));
504+
for (host_rank, host) in self.ranks.iter().enumerate() {
505+
for per_host_rank in 0..per_host.num_ranks() {
506+
let proc_name = Name::new(format!("{}-{}", name, per_host_rank));
507+
let _ok = host
508+
.mesh_agent()
509+
.create_or_update(cx, proc_name.clone(), ())
510+
.await
511+
.map_err(|e| {
512+
v1::Error::HostMeshAgentConfigurationError(
513+
host.mesh_agent().actor_id().clone(),
514+
format!("failed while creating proc: {}", e),
515+
)
516+
})?;
517+
procs.push(ProcRef::new(
518+
host.named_proc(&proc_name),
519+
per_host.num_ranks() * host_rank + per_host_rank,
520+
// TODO: specify or retrieve from state instead, to avoid attestation.
521+
ActorRef::attest(host.named_proc(&proc_name).actor_id("agent", 0)),
522+
));
523+
}
485524
}
486525

487-
ProcMesh::create_owned_unchecked(cx, name, self.clone(), procs).await
526+
ProcMesh::create_owned_unchecked(cx, mesh_name, extent, self.clone(), procs).await
488527
}
489528
}
490529

@@ -621,12 +660,28 @@ mod tests {
621660
.await
622661
.unwrap();
623662

624-
let proc_mesh1 = host_mesh.spawn(instance, "test_1").await.unwrap();
663+
let proc_mesh1 = host_mesh
664+
.spawn(instance, "test_1", Extent::unity())
665+
.await
666+
.unwrap();
625667
let actor_mesh1: ActorMesh<testactor::TestActor> =
626668
proc_mesh1.spawn(instance, "test", &()).await.unwrap();
627-
let proc_mesh2 = host_mesh.spawn(instance, "test_2").await.unwrap();
669+
let proc_mesh2 = host_mesh
670+
.spawn(instance, "test_2", extent!(gpus = 3, extra = 2))
671+
.await
672+
.unwrap();
673+
assert_eq!(
674+
proc_mesh2.extent(),
675+
extent!(replicas = 4, gpus = 3, extra = 2)
676+
);
677+
assert_eq!(proc_mesh2.values().count(), 24);
628678
let actor_mesh2: ActorMesh<testactor::TestActor> =
629679
proc_mesh2.spawn(instance, "test", &()).await.unwrap();
680+
assert_eq!(
681+
actor_mesh2.extent(),
682+
extent!(replicas = 4, gpus = 3, extra = 2)
683+
);
684+
assert_eq!(actor_mesh2.values().count(), 24);
630685

631686
// Host meshes can be dereferenced to produce a concrete ref.
632687
let host_mesh_ref: HostMeshRef = host_mesh.clone();
@@ -637,23 +692,24 @@ mod tests {
637692
);
638693

639694
// Validate we can cast:
640-
641-
let (port, mut rx) = instance.mailbox().open_port();
642-
actor_mesh1
643-
.cast(instance, testactor::GetActorId(port.bind()))
644-
.unwrap();
645-
646-
let mut expected_actor_ids: HashSet<_> = actor_mesh1
647-
.values()
648-
.map(|actor_ref| actor_ref.actor_id().clone())
649-
.collect();
650-
651-
while !expected_actor_ids.is_empty() {
652-
let actor_id = rx.recv().await.unwrap();
653-
assert!(
654-
expected_actor_ids.remove(&actor_id),
655-
"got {actor_id}, expect {expected_actor_ids:?}"
656-
);
695+
for actor_mesh in [&actor_mesh1, &actor_mesh2] {
696+
let (port, mut rx) = instance.mailbox().open_port();
697+
actor_mesh
698+
.cast(instance, testactor::GetActorId(port.bind()))
699+
.unwrap();
700+
701+
let mut expected_actor_ids: HashSet<_> = actor_mesh
702+
.values()
703+
.map(|actor_ref| actor_ref.actor_id().clone())
704+
.collect();
705+
706+
while !expected_actor_ids.is_empty() {
707+
let actor_id = rx.recv().await.unwrap();
708+
assert!(
709+
expected_actor_ids.remove(&actor_id),
710+
"got {actor_id}, expect {expected_actor_ids:?}"
711+
);
712+
}
657713
}
658714

659715
// Now forward a message through all directed edges across the two meshes.
@@ -719,7 +775,10 @@ mod tests {
719775

720776
let instance = testing::instance().await;
721777
let host_mesh = HostMeshRef::from_hosts(hosts);
722-
let proc_mesh = host_mesh.spawn(&instance, "test").await.unwrap();
778+
let proc_mesh = host_mesh
779+
.spawn(&testing::instance().await, "test", Extent::unity())
780+
.await
781+
.unwrap();
723782
let actor_mesh: ActorMesh<testactor::TestActor> = proc_mesh
724783
.spawn(&testing::instance().await, "test", &())
725784
.await

hyperactor_mesh/src/v1/proc_mesh.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,10 @@ impl ProcMesh {
217217
pub(crate) async fn create_owned_unchecked(
218218
cx: &impl context::Actor,
219219
name: Name,
220+
extent: Extent,
220221
hosts: HostMeshRef,
221222
ranks: Vec<ProcRef>,
222223
) -> v1::Result<Self> {
223-
let extent = hosts.extent();
224224
Self::create(
225225
cx,
226226
name,

monarch_hyperactor/src/shape.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ impl From<Extent> for PyExtent {
102102
}
103103
}
104104

105+
impl From<PyExtent> for Extent {
106+
fn from(py_extent: PyExtent) -> Self {
107+
py_extent.inner
108+
}
109+
}
110+
105111
#[derive(Serialize, Deserialize, Clone)]
106112
#[pyclass(
107113
name = "Region",

monarch_hyperactor/src/v1/host_mesh.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::alloc::PyAlloc;
2727
use crate::context::PyInstance;
2828
use crate::instance_dispatch;
2929
use crate::pytokio::PyPythonTask;
30+
use crate::shape::PyExtent;
3031
use crate::shape::PyRegion;
3132
use crate::v1::proc_mesh::PyProcMesh;
3233

@@ -142,12 +143,18 @@ impl PyHostMesh {
142143
})
143144
}
144145

145-
fn spawn_nonblocking(&self, instance: &PyInstance, name: String) -> PyResult<PyPythonTask> {
146+
fn spawn_nonblocking(
147+
&self,
148+
instance: &PyInstance,
149+
name: String,
150+
per_host: &PyExtent,
151+
) -> PyResult<PyPythonTask> {
146152
let host_mesh = self.mesh_ref()?.clone();
147153
let instance = instance.clone();
154+
let per_host = per_host.clone().into();
148155
let mesh_impl = async move {
149156
let proc_mesh = instance_dispatch!(instance, async move |cx_instance| {
150-
host_mesh.spawn(cx_instance, &name).await
157+
host_mesh.spawn(cx_instance, &name, per_host).await
151158
})
152159
.map_err(to_py_error)?;
153160
Ok(PyProcMesh::new_owned(proc_mesh))

python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc
1212
from monarch._rust_bindings.monarch_hyperactor.context import Instance
1313
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
1414

15-
from monarch._rust_bindings.monarch_hyperactor.shape import Region
15+
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Region
1616
from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import ProcMesh
1717

1818
@final
@@ -40,13 +40,15 @@ class HostMesh:
4040
self,
4141
instance: Instance,
4242
name: str,
43+
per_host: Extent,
4344
) -> PythonTask[ProcMesh]:
4445
"""
4546
Spawn a new actor on this mesh.
4647
4748
Arguments:
4849
- `instance`: The instance to use to spawn the mesh.
4950
- `name`: Name of the proc mesh
51+
- `per_host`: Extent describing the shape of the proc mesh on each host.
5052
"""
5153
...
5254

python/tests/_monarch/test_actor_mesh.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
AllocConstraints,
2626
AllocSpec,
2727
)
28-
from monarch._rust_bindings.monarch_hyperactor.shape import Region, Slice
28+
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Region, Slice
2929
from monarch._src.actor.proc_mesh import _get_bootstrap_args, ProcessAllocator
3030

3131
if TYPE_CHECKING:
@@ -256,7 +256,7 @@ async def test_host_mesh() -> None:
256256
async def run() -> None:
257257
cmd, args, bootstrap_env = _get_bootstrap_args()
258258
allocator = ProcessAllocator(cmd, args, bootstrap_env)
259-
spec: AllocSpec = AllocSpec(AllocConstraints(), replicas=2, hosts=2, gpus=4)
259+
spec: AllocSpec = AllocSpec(AllocConstraints(), hosts=2)
260260
alloc = allocator.allocate(spec)
261261

262262
host_mesh = await HostMesh.allocate_nonblocking(
@@ -271,34 +271,35 @@ async def run() -> None:
271271
),
272272
).spawn()
273273

274-
assert host_mesh.region.labels() == ["replicas", "hosts", "gpus"]
275-
assert host_mesh.region.slice() == Slice(
276-
offset=0, sizes=[2, 2, 4], strides=[8, 4, 1]
277-
)
274+
assert host_mesh.region.labels() == ["hosts"]
275+
assert host_mesh.region.slice() == Slice(offset=0, sizes=[2], strides=[1])
278276

279277
proc_mesh = await host_mesh.spawn_nonblocking(
280-
context().actor_instance._as_rust(), "proc_mesh"
278+
context().actor_instance._as_rust(),
279+
"proc_mesh",
280+
Extent(["gpus", "replicas"], [2, 4]),
281281
).spawn()
282282
actor_mesh = await spawn_actor_mesh(proc_mesh)
283283

284284
await verify_cast_to_call(actor_mesh, context().actor_instance, list(range(16)))
285285

286-
# Ranks 4, 6, 12, 14 (gpus 0 and 2 on host 1 on both replicas)
287286
sliced_hm = host_mesh.sliced(
288287
Region(
289-
labels=["replicas", "gpus"],
290-
slice=Slice(offset=4, sizes=[2, 2], strides=[8, 2]),
288+
labels=["hosts"],
289+
slice=Slice(offset=1, sizes=[1], strides=[1]),
291290
)
292291
)
293292

294-
assert sliced_hm.region.labels() == ["replicas", "gpus"]
295-
assert sliced_hm.region.slice() == Slice(offset=4, sizes=[2, 2], strides=[8, 2])
293+
assert sliced_hm.region.labels() == ["hosts"]
294+
assert sliced_hm.region.slice() == Slice(offset=1, sizes=[1], strides=[1])
296295

297296
sliced_pm = await sliced_hm.spawn_nonblocking(
298-
context().actor_instance._as_rust(), "sliced_pm"
297+
context().actor_instance._as_rust(),
298+
"sliced_pm",
299+
Extent(["gpus", "replicas"], [2, 3]),
299300
)
300301
sliced_am = await spawn_actor_mesh(sliced_pm)
301302

302-
await verify_cast_to_call(sliced_am, context().actor_instance, list(range(4)))
303+
await verify_cast_to_call(sliced_am, context().actor_instance, list(range(6)))
303304

304305
run()

0 commit comments

Comments
 (0)