Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/monarch/_src/actor/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,9 @@ def attach_to_workers(
host_mesh: PythonTask[HyHostMesh] = _attach_to_workers(workers_tasks, name=name)
extent = Extent(["hosts"], [len(workers)])
return HostMesh(
host_mesh.spawn(), extent.region, stream_logs=True, is_fake_in_process=False
host_mesh.spawn(),
extent.region,
stream_logs=True,
is_fake_in_process=False,
_initialized_hy_host_mesh=None,
)
36 changes: 32 additions & 4 deletions python/monarch/_src/actor/v1/host_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,17 @@ def __init__(
region: Region,
stream_logs: bool,
is_fake_in_process: bool,
_initialized_hy_host_mesh: Optional[HyHostMesh],
) -> None:
self._initialized_host_mesh = _initialized_hy_host_mesh
if not self._initialized_host_mesh:

async def task(hy_host_mesh_task: Shared[HyHostMesh]) -> HyHostMesh:
self._initialized_host_mesh = await hy_host_mesh_task
return self._initialized_host_mesh

hy_host_mesh = PythonTask.from_coroutine(task(hy_host_mesh)).spawn()

self._hy_host_mesh = hy_host_mesh
self._region = region
self._stream_logs = stream_logs
Expand Down Expand Up @@ -127,6 +137,7 @@ async def task() -> HyHostMesh:
extent.region,
alloc.stream_logs,
isinstance(allocator, LocalAllocator),
None,
)

def spawn_procs(
Expand Down Expand Up @@ -187,15 +198,25 @@ def _new_with_shape(self, shape: Shape) -> "HostMesh":
if shape.region == self._region:
return self

initialized_hm: Optional[HyHostMesh] = (
None
if self._initialized_host_mesh is None
else self._initialized_host_mesh.sliced(shape.region)
)

async def task() -> HyHostMesh:
hy_host_mesh = await self._hy_host_mesh
return hy_host_mesh.sliced(shape.region)
return (
initialized_hm
if initialized_hm
else (await self._hy_host_mesh).sliced(shape.region)
)

return HostMesh(
PythonTask.from_coroutine(task()).spawn(),
shape.region,
self.stream_logs,
self.is_fake_in_process,
initialized_hm,
)

@property
Expand All @@ -222,11 +243,12 @@ async def task() -> HyHostMesh:
region,
stream_logs,
is_fake_in_process,
hy_host_mesh,
)

def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
return HostMesh._from_initialized_hy_host_mesh, (
self._hy_host_mesh.block_on(),
self._initialized_mesh(),
self._region,
self.stream_logs,
self.is_fake_in_process,
Expand All @@ -238,12 +260,18 @@ def is_fake_in_process(self) -> bool:

def __eq__(self, other: "HostMesh") -> bool:
return (
self._hy_host_mesh.block_on() == other._hy_host_mesh.block_on()
self._initialized_mesh() == other._initialized_mesh()
and self._region == other._region
and self.stream_logs == other.stream_logs
and self.is_fake_in_process == other.is_fake_in_process
)

def _initialized_mesh(self) -> HyHostMesh:
if self._initialized_host_mesh is None:
self._hy_host_mesh.block_on()
assert self._initialized_host_mesh is not None
return self._initialized_host_mesh


def fake_in_process_host(name: str) -> "HostMesh":
"""
Expand Down
31 changes: 28 additions & 3 deletions python/monarch/_src/actor/v1/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,20 @@ def __init__(
host_mesh: "HostMesh",
region: Region,
root_region: Region,
_initialized_hy_proc_mesh: Optional[HyProcMesh],
_device_mesh: Optional["DeviceMesh"] = None,
) -> None:
_proc_mesh_registry.add(self)

self._initialized_proc_mesh = _initialized_hy_proc_mesh
if not self._initialized_proc_mesh:

async def task(hy_proc_mesh_task: Shared[HyProcMesh]) -> HyProcMesh:
self._initialized_proc_mesh = await hy_proc_mesh_task
return self._initialized_proc_mesh

hy_proc_mesh = PythonTask.from_coroutine(task(hy_proc_mesh)).spawn()

self._proc_mesh = hy_proc_mesh
self._host_mesh = host_mesh
self._region = region
Expand Down Expand Up @@ -138,14 +149,25 @@ def _new_with_shape(self, shape: Shape) -> "ProcMesh":
else self._maybe_device_mesh._new_with_shape(shape)
)

initialized_pm: Optional[HyProcMesh] = (
None
if self._initialized_proc_mesh is None
else self._initialized_proc_mesh.sliced(shape.region)
)

async def task() -> HyProcMesh:
return (await self._proc_mesh).sliced(shape.region)
return (
initialized_pm
if initialized_pm
else (await self._proc_mesh).sliced(shape.region)
)

return ProcMesh(
PythonTask.from_coroutine(task()).spawn(),
self._host_mesh,
shape.region,
self._root_region,
initialized_pm,
_device_mesh=device_mesh,
)

Expand Down Expand Up @@ -188,7 +210,7 @@ def from_host_mesh(
setup: Callable[[], None] | None = None,
_attach_controller_controller: bool = True,
) -> "ProcMesh":
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region)
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region, None)

if _attach_controller_controller:
instance = context().actor_instance
Expand Down Expand Up @@ -367,11 +389,14 @@ async def task() -> HyProcMesh:
host_mesh,
region,
root_region,
_initialized_hy_proc_mesh=hy_proc_mesh,
)

def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]:
return ProcMesh._from_initialized_hy_proc_mesh, (
self._proc_mesh.block_on(),
self._initialized_proc_mesh
if self._initialized_proc_mesh
else self._proc_mesh.block_on(),
self._host_mesh,
self._region,
self._root_region,
Expand Down
20 changes: 20 additions & 0 deletions python/tests/test_proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

from typing import cast

import cloudpickle

import pytest
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice
from monarch._src.actor.actor_mesh import Actor, ActorMesh, context, ValueMesh
from monarch._src.actor.endpoint import endpoint
Expand Down Expand Up @@ -137,3 +140,20 @@ def test_nested_meshes() -> None:
assert value == point.rank + 1
for point, value in res_1:
assert value == point.rank


@pytest.mark.timeout(60)
async def test_pickle_initialized_proc_mesh_in_tokio_thread() -> None:
host = create_local_host_mesh("host", Extent(["hosts"], [2]))
proc = host.spawn_procs(per_host={"gpus": 2})

async def task():
cloudpickle.dumps(proc)

await proc.initialized
PythonTask.from_coroutine(task()).block_on()

async def task():
cloudpickle.dumps(proc.slice(gpus=0, hosts=0))

PythonTask.from_coroutine(task()).block_on()