Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def context() -> Context:
_transport_lock = threading.Lock()


def enable_transport(transport: ChannelTransport) -> None:
def enable_transport(transport: "ChannelTransport | str") -> None:
"""
Allow monarch to communicate with transport type 'transport'
This must be called before any other calls in the monarch API.
Expand All @@ -247,6 +247,15 @@ def enable_transport(transport: ChannelTransport) -> None:
Currently only one transport type may be enabled at one time.
In the future we may allow multiple to be enabled.
"""
if isinstance(transport, str):
transport = {
"tcp": ChannelTransport.Tcp,
"ipc": ChannelTransport.Unix,
"metatls": ChannelTransport.MetaTlsWithIpV6,
}.get(transport)
if transport is None:
raise ValueError(f"unknown transport: {transport}")

if _context.get(None) is not None:
raise RuntimeError(
"`enable_transport()` must be called before any other calls in the monarch API. "
Expand Down
43 changes: 43 additions & 0 deletions python/monarch/_src/actor/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,46 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
import os
from typing import TYPE_CHECKING

from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice

from monarch._src.actor.allocator import AllocateMixin

from monarch._src.actor.endpoint import Extent
from monarch._src.actor.host_mesh import HostMesh as HostMeshV0
from monarch._src.actor.v1.host_mesh import HostMesh as HostMeshV1

enabled = os.environ.get("MONARCH_HOST_MESH_V1_REMOVE_ME_BEFORE_RELEASE", "0") != "0"

if TYPE_CHECKING or not enabled:
from monarch._src.actor.host_mesh import HostMesh, this_host, this_proc
from monarch._src.actor.proc_mesh import get_or_spawn_controller, ProcMesh
else:
from monarch._src.actor.v1.host_mesh import HostMesh, this_host, this_proc
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller, ProcMesh


def host_mesh_from_alloc(
name: str, extent: Extent, allocator: AllocateMixin, constraints: AllocConstraints
) -> "HostMeshV0 | HostMeshV1":
if enabled:
return HostMeshV1.allocate_nonblocking(name, extent, allocator, constraints)
else:
return HostMeshV0(
Shape(extent.labels, Slice.new_row_major(extent.sizes)),
allocator,
constraints,
)


__all__ = [
"HostMesh",
"this_host",
"this_proc",
"get_or_spawn_controller",
"ProcMesh",
"host_mesh_from_alloc",
]
40 changes: 2 additions & 38 deletions python/monarch/_src/job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
from typing import cast, Dict, List, Literal, NamedTuple, Optional, Sequence

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import configure

from monarch._src.actor.bootstrap import attach_to_workers

# note: the jobs api is intended as a library so it should
# only be importing _public_ monarch API functions.
from monarch._src.actor.host_mesh import HostMesh, this_host

from typing_extensions import Self
from monarch.actor import enable_transport, HostMesh, this_host


class JobState:
Expand Down Expand Up @@ -441,47 +438,14 @@ def _kill(self):
pass


class FakeLocalLoginJob(LoginJob):
"""
Fake it that we are logging in by just making a local process that runs the bootstrap.
"""

def __init__(self):
super().__init__()
configure(default_transport=ChannelTransport.Tcp)

self._next_port = 12345

def _start_host(self, host: str) -> ProcessState:
port = self._next_port
self._next_port += 1

env = {**os.environ}
if "FB_XAR_INVOKED_NAME" in os.environ:
env["PYTHONPATH"] = ":".join(sys.path)
addr = f"tcp://[::1]:{port}"
bind_addr = f"tcp://[::1]:{port}"
proc = subprocess.Popen(
[
sys.executable,
"-c",
f'from monarch.actor import run_worker_loop_forever; run_worker_loop_forever(address={repr(bind_addr)}, ca="trust_all_connections")',
],
env=env,
start_new_session=True,
)
return ProcessState(proc.pid, addr)


class SSHJob(LoginJob):
def __init__(
self,
python_exe: str = "python",
ssh_args: Sequence[str] = (),
monarch_port: int = 22222,
):
configure(default_transport=ChannelTransport.Tcp)
enable_transport("tcp")
self._python_exe = python_exe
self._ssh_args = ssh_args
self._port = monarch_port
Expand Down
10 changes: 4 additions & 6 deletions python/monarch/_src/job/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
)

from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
from monarch._src.actor.allocator import AllocateMixin
from monarch._src.actor.host_mesh import HostMesh
from monarch._src.actor.meta.allocator import (
MastAllocator,
MastAllocatorBase,
MastAllocatorConfig,
)
from monarch._src.actor.v1 import host_mesh_from_alloc

from monarch._src.job.job import BatchJob, JobState, JobTrait

Expand Down Expand Up @@ -173,10 +173,8 @@ def _state(self) -> JobState:
job_started,
)
constraints = AllocConstraints({MastAllocator.ALLOC_LABEL_TASK_GROUP: name})
host_meshes[name] = HostMesh(
Shape(["hosts"], Slice.new_row_major([num_host])),
allocator,
constraints,
host_meshes[name] = host_mesh_from_alloc(
name, Extent(["hosts"], [num_host]), allocator, constraints
)

return JobState(host_meshes)
Expand Down
2 changes: 1 addition & 1 deletion python/monarch/_src/rdma/rdma.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from monarch._src.actor.actor_mesh import Actor, context
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.future import Future
from monarch._src.actor.proc_mesh import get_or_spawn_controller, ProcMesh
from monarch._src.actor.v1 import get_or_spawn_controller, ProcMesh
from pyre_extensions import none_throws


Expand Down
17 changes: 7 additions & 10 deletions python/monarch/actor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Monarch Actor API - Public interface for actor functionality.
"""

from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.shape import Extent
from monarch._src.actor.actor_mesh import (
Accumulator,
Expand All @@ -35,19 +34,17 @@
from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.future import Future

from monarch._src.actor.host_mesh import (
from monarch._src.actor.host_mesh import hosts_from_config
from monarch._src.actor.proc_mesh import local_proc_mesh, proc_mesh, sim_proc_mesh

from monarch._src.actor.v1 import (
get_or_spawn_controller,
HostMesh,
hosts_from_config,
ProcMesh,
this_host,
this_proc,
)
from monarch._src.actor.proc_mesh import (
get_or_spawn_controller,
local_proc_mesh,
proc_mesh,
ProcMesh,
sim_proc_mesh,
)


__all__ = [
"Accumulator",
Expand Down
Loading