Skip to content

Commit 1090b0f

Browse files
samluryemeta-codesync[bot]
authored andcommitted
Fix this_host for v1 to only return a single host instead of the whole host mesh (meta-pytorch#1459)
Summary: Pull Request resolved: meta-pytorch#1459 As titled. ghstack-source-id: 314719294 exported-using-ghexport Reviewed By: zdevito Differential Revision: D84091757 fbshipit-source-id: d487b2debde10a9b7827743082c91fe875418f69
1 parent cce0cd4 commit 1090b0f

File tree

11 files changed

+159
-15
lines changed

11 files changed

+159
-15
lines changed

monarch_hyperactor/src/ndslice.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ impl PySlice {
201201
fn new_row_major(sizes: Vec<usize>) -> PySlice {
202202
ndslice::Slice::new_row_major(sizes).into()
203203
}
204+
205+
fn get(&self, index: usize) -> PyResult<usize> {
206+
self.inner
207+
.get(index)
208+
.map_err(|err| PyValueError::new_err(err.to_string()))
209+
}
204210
}
205211

206212
impl From<&PySlice> for ndslice::Slice {

monarch_hyperactor/src/shape.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ impl PyExtent {
102102
inner: self.inner.region(),
103103
}
104104
}
105+
106+
fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
107+
if let Ok(other) = other.extract::<PyExtent>() {
108+
Ok(self.inner == other.inner)
109+
} else {
110+
Ok(false)
111+
}
112+
}
105113
}
106114

107115
impl From<Extent> for PyExtent {
@@ -173,6 +181,21 @@ impl PyRegion {
173181
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?
174182
.into())
175183
}
184+
185+
fn point_of_base_rank(&self, rank: usize) -> PyResult<PyPoint> {
186+
self.inner
187+
.point_of_base_rank(rank)
188+
.map_pyerr()
189+
.map(PyPoint::from)
190+
}
191+
192+
fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
193+
if let Ok(other) = other.extract::<PyRegion>() {
194+
Ok(self.inner == other.inner)
195+
} else {
196+
Ok(false)
197+
}
198+
}
176199
}
177200

178201
impl From<Region> for PyRegion {

monarch_hyperactor/src/v1/host_mesh.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,21 +182,25 @@ impl PyHostMesh {
182182
.getattr("py_host_mesh_from_bytes")?;
183183
Ok((from_bytes, py_bytes))
184184
}
185+
186+
fn __eq__(&self, other: &PyHostMesh) -> PyResult<bool> {
187+
Ok(self.mesh_ref()? == other.mesh_ref()?)
188+
}
185189
}
186190

187191
#[derive(Clone)]
188192
#[pyclass(
189193
name = "HostMeshImpl",
190194
module = "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh"
191195
)]
192-
struct PyHostMeshImpl(SharedCell<HostMesh>);
196+
pub(crate) struct PyHostMeshImpl(SharedCell<HostMesh>);
193197

194198
#[derive(Debug, Clone)]
195199
#[pyclass(
196200
name = "HostMeshRefImpl",
197201
module = "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh"
198202
)]
199-
struct PyHostMeshRefImpl(HostMeshRef);
203+
pub(crate) struct PyHostMeshRefImpl(HostMeshRef);
200204

201205
impl PyHostMeshRefImpl {
202206
fn __repr__(&self) -> PyResult<String> {

python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ class Slice:
9797
...
9898

9999
def __repr__(self) -> str: ...
100+
def get(self, index: int) -> int:
101+
"""
102+
Given a logical index in row-major order, compute the physical
103+
memory offset according to the slice layout. Inverse of `index`.
104+
"""
105+
...
106+
107+
def index(self, value: int) -> int:
108+
"""
109+
Given a physical memory offset, compute the logical index in
110+
row-major order. Inverse of `get`.
111+
"""
112+
...
100113

101114
@final
102115
class Shape:
@@ -175,6 +188,7 @@ class Extent(collections.abc.Mapping):
175188
def __len__(self) -> int: ...
176189
@property
177190
def region(self) -> "Region": ...
191+
def __eq__(self, other: "Extent") -> bool: ...
178192

179193
class Point(collections.abc.Mapping):
180194
"""
@@ -220,3 +234,10 @@ class Region:
220234
...
221235

222236
def __reduce__(self) -> Any: ...
237+
def point_of_base_rank(self, rank: int) -> "Point":
238+
"""
239+
Get the point in this region that corresponds to the given base rank
240+
in the super-region that this region is a subset of.
241+
"""
242+
...
243+
def __eq__(self, other: "Region") -> bool: ...

python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class HostMesh:
6969
...
7070

7171
def __reduce__(self) -> Any: ...
72+
def __eq__(self, other: "HostMesh") -> bool: ...
7273

7374
@final
7475
class BootstrapCommand:

python/monarch/_src/actor/actor_mesh.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,16 @@ def items(self) -> Iterable[Tuple[Point, R]]:
494494
for i, _global_rank in enumerate(self._shape.ranks()):
495495
yield Point(i, extent), self._hy.get(i)
496496

497+
def values(self) -> Iterable[R]:
498+
"""
499+
Generator that iterates over just the values in the mesh.
500+
501+
Returns:
502+
Values at all coordinates.
503+
"""
504+
for _, value in self.items():
505+
yield value
506+
497507
def __iter__(self) -> Iterator[Tuple[Point, R]]:
498508
return iter(self.items())
499509

python/monarch/_src/actor/v1/host_mesh.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,7 @@ def this_host() -> "HostMesh":
4747
4848
This is just shorthand for looking it up via the context
4949
"""
50-
proc = this_proc()
51-
if proc.host_mesh.is_fake_in_process:
52-
return create_local_host_mesh("root_host")
53-
host_mesh = proc.host_mesh
54-
assert isinstance(host_mesh, HostMesh), "expected v1 HostMesh, got v0 HostMesh"
55-
return host_mesh
50+
return this_proc().host_mesh
5651

5752

5853
def this_proc() -> "ProcMesh":
@@ -241,6 +236,14 @@ def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
241236
def is_fake_in_process(self) -> bool:
242237
return self._is_fake_in_process
243238

239+
def __eq__(self, other: "HostMesh") -> bool:
240+
return (
241+
self._hy_host_mesh.block_on() == other._hy_host_mesh.block_on()
242+
and self._region == other._region
243+
and self.stream_logs == other.stream_logs
244+
and self.is_fake_in_process == other.is_fake_in_process
245+
)
246+
244247

245248
def fake_in_process_host(name: str) -> "HostMesh":
246249
"""

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ def __init__(
8080
hy_proc_mesh: "Shared[HyProcMesh]",
8181
host_mesh: "HostMesh",
8282
region: Region,
83+
root_region: Region,
8384
_device_mesh: Optional["DeviceMesh"] = None,
8485
) -> None:
8586
_proc_mesh_registry.add(self)
8687
self._proc_mesh = hy_proc_mesh
8788
self._host_mesh = host_mesh
8889
self._region = region
90+
self._root_region = root_region
8991
self._maybe_device_mesh = _device_mesh
9092
self._logging_manager = LoggingManager()
9193
self._controller_controller: Optional["_ControllerController"] = None
@@ -107,7 +109,16 @@ async def task() -> Literal[True]:
107109

108110
@property
109111
def host_mesh(self) -> "HostMesh":
110-
return self._host_mesh
112+
if self.extent.nelements != 1:
113+
raise NotImplementedError(
114+
"`ProcMesh.host_mesh` is not yet supported for non-singleton proc meshes."
115+
)
116+
elif self._host_mesh.is_fake_in_process:
117+
from monarch._src.actor.v1.host_mesh import create_local_host_mesh
118+
119+
return create_local_host_mesh("root_host")
120+
else:
121+
return self._host(0)
111122

112123
@property
113124
def _ndslice(self) -> Slice:
@@ -134,6 +145,7 @@ async def task() -> HyProcMesh:
134145
PythonTask.from_coroutine(task()).spawn(),
135146
self._host_mesh,
136147
shape.region,
148+
self._root_region,
137149
_device_mesh=device_mesh,
138150
)
139151

@@ -176,7 +188,7 @@ def from_host_mesh(
176188
setup: Callable[[], None] | None = None,
177189
_attach_controller_controller: bool = True,
178190
) -> "ProcMesh":
179-
pm = ProcMesh(hy_proc_mesh, host_mesh, region)
191+
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region)
180192

181193
if _attach_controller_controller:
182194
instance = context().actor_instance
@@ -341,7 +353,11 @@ async def __aexit__(
341353

342354
@classmethod
343355
def _from_initialized_hy_proc_mesh(
344-
cls, hy_proc_mesh: HyProcMesh, host_mesh: "HostMesh", region: Region
356+
cls,
357+
hy_proc_mesh: HyProcMesh,
358+
host_mesh: "HostMesh",
359+
region: Region,
360+
root_region: Region,
345361
) -> "ProcMesh":
346362
async def task() -> HyProcMesh:
347363
return hy_proc_mesh
@@ -350,13 +366,25 @@ async def task() -> HyProcMesh:
350366
PythonTask.from_coroutine(task()).spawn(),
351367
host_mesh,
352368
region,
369+
root_region,
353370
)
354371

355372
def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
356373
return ProcMesh._from_initialized_hy_proc_mesh, (
357374
self._proc_mesh.block_on(),
358-
self.host_mesh,
375+
self._host_mesh,
359376
self._region,
377+
self._root_region,
378+
)
379+
380+
def _host(self, proc_rank: int) -> "HostMesh":
381+
base_proc_rank = self._region.slice().get(proc_rank)
382+
n_procs = len(self._root_region.slice())
383+
procs_per_host = n_procs // len(self._host_mesh.region.slice())
384+
host_rank = base_proc_rank // procs_per_host
385+
base_host_rank = self._host_mesh.region.slice().get(host_rank)
386+
return self._host_mesh.slice(
387+
**self._host_mesh.region.point_of_base_rank(base_host_rank)
360388
)
361389

362390

python/tests/test_host_mesh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_spawn_proc_mesh() -> None:
7474
Extent(["replicas", "hosts"], [2, 4]),
7575
)
7676
proc_mesh = host.spawn_procs(name="proc")
77-
assert proc_mesh.host_mesh is host
77+
assert proc_mesh._host_mesh is host
7878
assert proc_mesh._ndslice == host._ndslice
7979
assert tuple(proc_mesh._labels) == host._labels
8080
hy_proc_mesh = proc_mesh._proc_mesh.block_on()
@@ -89,7 +89,7 @@ def test_spawn_proc_mesh() -> None:
8989
name="proc_sliced", per_host={"gpus": 3, "just_for_fun": 4}
9090
)
9191
hy_sliced_proc = sliced_proc._proc_mesh.block_on()
92-
assert sliced_proc.host_mesh is sliced_host
92+
assert sliced_proc._host_mesh is sliced_host
9393
assert sliced_proc._ndslice == Slice(offset=0, sizes=[2, 3, 4], strides=[12, 4, 1])
9494
assert sliced_proc._labels == ["hosts", "gpus", "just_for_fun"]
9595
assert hy_sliced_proc.region.labels == sliced_proc._labels

python/tests/test_proc_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_nested_meshes() -> None:
128128
for i, nested in enumerate([nested_0, nested_1]):
129129
region = cast(
130130
ProcMesh, cast(ActorMesh[TestActor], nested)._proc_mesh
131-
).host_mesh.region
131+
)._host_mesh.region
132132
assert region.labels == ["hosts"]
133133
assert region.slice() == Slice(offset=i, sizes=[1], strides=[1])
134134
res_0 = nested_0.slice(gpus=0).call_on_other_mesh.call_one(nested_1).get()

0 commit comments

Comments
 (0)