diff --git a/python/monarch/_src/actor/bootstrap.py b/python/monarch/_src/actor/bootstrap.py index 60a209657..051762cc0 100644 --- a/python/monarch/_src/actor/bootstrap.py +++ b/python/monarch/_src/actor/bootstrap.py @@ -114,5 +114,9 @@ def attach_to_workers( host_mesh: PythonTask[HyHostMesh] = _attach_to_workers(workers_tasks, name=name) extent = Extent(["hosts"], [len(workers)]) return HostMesh( - host_mesh.spawn(), extent.region, stream_logs=True, is_fake_in_process=False + host_mesh.spawn(), + extent.region, + stream_logs=True, + is_fake_in_process=False, + _initialized_hy_host_mesh=None, ) diff --git a/python/monarch/_src/actor/v1/host_mesh.py b/python/monarch/_src/actor/v1/host_mesh.py index 0eb789cce..8929d01b7 100644 --- a/python/monarch/_src/actor/v1/host_mesh.py +++ b/python/monarch/_src/actor/v1/host_mesh.py @@ -96,7 +96,17 @@ def __init__( region: Region, stream_logs: bool, is_fake_in_process: bool, + _initialized_hy_host_mesh: Optional[HyHostMesh], ) -> None: + self._initialized_host_mesh = _initialized_hy_host_mesh + if not self._initialized_host_mesh: + + async def task(hy_host_mesh_task: Shared[HyHostMesh]) -> HyHostMesh: + self._initialized_host_mesh = await hy_host_mesh_task + return self._initialized_host_mesh + + hy_host_mesh = PythonTask.from_coroutine(task(hy_host_mesh)).spawn() + self._hy_host_mesh = hy_host_mesh self._region = region self._stream_logs = stream_logs @@ -127,6 +137,7 @@ async def task() -> HyHostMesh: extent.region, alloc.stream_logs, isinstance(allocator, LocalAllocator), + None, ) def spawn_procs( @@ -187,15 +198,25 @@ def _new_with_shape(self, shape: Shape) -> "HostMesh": if shape.region == self._region: return self + initialized_hm: Optional[HyHostMesh] = ( + None + if self._initialized_host_mesh is None + else self._initialized_host_mesh.sliced(shape.region) + ) + async def task() -> HyHostMesh: - hy_host_mesh = await self._hy_host_mesh - return hy_host_mesh.sliced(shape.region) + return ( + initialized_hm + if initialized_hm + else (await self._hy_host_mesh).sliced(shape.region) + ) return HostMesh( PythonTask.from_coroutine(task()).spawn(), shape.region, self.stream_logs, self.is_fake_in_process, + initialized_hm, ) @property @@ -222,11 +243,12 @@ async def task() -> HyHostMesh: region, stream_logs, is_fake_in_process, + hy_host_mesh, ) def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: return HostMesh._from_initialized_hy_host_mesh, ( - self._hy_host_mesh.block_on(), + self._initialized_mesh(), self._region, self.stream_logs, self.is_fake_in_process, @@ -238,12 +260,18 @@ def is_fake_in_process(self) -> bool: def __eq__(self, other: "HostMesh") -> bool: return ( - self._hy_host_mesh.block_on() == other._hy_host_mesh.block_on() + self._initialized_mesh() == other._initialized_mesh() and self._region == other._region and self.stream_logs == other.stream_logs and self.is_fake_in_process == other.is_fake_in_process ) + def _initialized_mesh(self) -> HyHostMesh: + if self._initialized_host_mesh is None: + self._hy_host_mesh.block_on() + assert self._initialized_host_mesh is not None + return self._initialized_host_mesh + def fake_in_process_host(name: str) -> "HostMesh": """ diff --git a/python/monarch/_src/actor/v1/proc_mesh.py b/python/monarch/_src/actor/v1/proc_mesh.py index e5836634f..f5fe83833 100644 --- a/python/monarch/_src/actor/v1/proc_mesh.py +++ b/python/monarch/_src/actor/v1/proc_mesh.py @@ -81,9 +81,20 @@ def __init__( host_mesh: "HostMesh", region: Region, root_region: Region, + _initialized_hy_proc_mesh: Optional[HyProcMesh], _device_mesh: Optional["DeviceMesh"] = None, ) -> None: _proc_mesh_registry.add(self) + + self._initialized_proc_mesh = _initialized_hy_proc_mesh + if not self._initialized_proc_mesh: + + async def task(hy_proc_mesh_task: Shared[HyProcMesh]) -> HyProcMesh: + self._initialized_proc_mesh = await hy_proc_mesh_task + return self._initialized_proc_mesh + + hy_proc_mesh = PythonTask.from_coroutine(task(hy_proc_mesh)).spawn() + self._proc_mesh = hy_proc_mesh self._host_mesh = host_mesh self._region = region @@ -138,14 +149,25 @@ def _new_with_shape(self, shape: Shape) -> "ProcMesh": else self._maybe_device_mesh._new_with_shape(shape) ) + initialized_pm: Optional[HyProcMesh] = ( + None + if self._initialized_proc_mesh is None + else self._initialized_proc_mesh.sliced(shape.region) + ) + async def task() -> HyProcMesh: - return (await self._proc_mesh).sliced(shape.region) + return ( + initialized_pm + if initialized_pm + else (await self._proc_mesh).sliced(shape.region) + ) return ProcMesh( PythonTask.from_coroutine(task()).spawn(), self._host_mesh, shape.region, self._root_region, + initialized_pm, _device_mesh=device_mesh, ) @@ -188,7 +210,7 @@ def from_host_mesh( setup: Callable[[], None] | None = None, _attach_controller_controller: bool = True, ) -> "ProcMesh": - pm = ProcMesh(hy_proc_mesh, host_mesh, region, region) + pm = ProcMesh(hy_proc_mesh, host_mesh, region, region, None) if _attach_controller_controller: instance = context().actor_instance @@ -367,11 +389,14 @@ async def task() -> HyProcMesh: host_mesh, region, root_region, + _initialized_hy_proc_mesh=hy_proc_mesh, ) def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: return ProcMesh._from_initialized_hy_proc_mesh, ( - self._proc_mesh.block_on(), + self._initialized_proc_mesh + if self._initialized_proc_mesh + else self._proc_mesh.block_on(), self._host_mesh, self._region, self._root_region, diff --git a/python/tests/test_proc_mesh.py b/python/tests/test_proc_mesh.py index 282b5d576..9aca7f26f 100644 --- a/python/tests/test_proc_mesh.py +++ b/python/tests/test_proc_mesh.py @@ -8,7 +8,10 @@ from typing import cast +import cloudpickle + import pytest +from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice from monarch._src.actor.actor_mesh import Actor, ActorMesh, context, ValueMesh from monarch._src.actor.endpoint import endpoint @@ -137,3 +140,20 @@ def test_nested_meshes() -> None: assert value == point.rank + 1 for point, value in res_1: assert value == point.rank + + +@pytest.mark.timeout(60) +async def test_pickle_initialized_proc_mesh_in_tokio_thread() -> None: + host = create_local_host_mesh("host", Extent(["hosts"], [2])) + proc = host.spawn_procs(per_host={"gpus": 2}) + + async def task(): + cloudpickle.dumps(proc) + + await proc.initialized + PythonTask.from_coroutine(task()).block_on() + + async def task(): + cloudpickle.dumps(proc.slice(gpus=0, hosts=0)) + + PythonTask.from_coroutine(task()).block_on()