Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.env import MONARCH_HOSTMESH_V1
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
Expand Down Expand Up @@ -327,12 +326,6 @@ async def main(cfg: DictConfig):
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)

# In the host mesh v0 case, actors on remote hosts are not able to communicate
# with one another. Therefore we use the controller as our storage volume.
if not MONARCH_HOSTMESH_V1.get_value():
await ts.initialize(strategy=ts.ControllerStorageVolumes())
print("Torchstore successfully initialized with controller storage strategy")

# ---- Setup services ---- #

(
Expand Down Expand Up @@ -364,21 +357,19 @@ async def main(cfg: DictConfig):

print("All services initialized successfully!")
shutdown_event = asyncio.Event()
# In the HostMesh v1 case, we spawn a torchstore storage volume
# per trainer process.
# Here we spawn a torchstore storage volume per trainer process.
# We initialize after service initialization because torchstore currently
# requires access to the underlying proc meshes in the local rank strategy.
# We should be able to hide this in the future.
if MONARCH_HOSTMESH_V1.get_value():
# TODO: support multiple host meshes
trainer_num_procs = cfg.actors.trainer["procs"]
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
strategy=ts.LocalRankStrategy(),
)
print("Torchstore successfully initialized with local rank strategy")
# TODO: support multiple host meshes
trainer_num_procs = cfg.actors.trainer["procs"]
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
strategy=ts.LocalRankStrategy(),
)
print("Torchstore successfully initialized with local rank strategy")

# ---- Core RL loops ---- #
async def continuous_rollouts():
Expand Down
20 changes: 8 additions & 12 deletions src/forge/controller/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from monarch.tools.components import hyperactor
from monarch.tools.config import Config, Workspace

from forge.env import MONARCH_HOSTMESH_V1

from forge.types import Launcher, LauncherConfig

_MAST_AVAILABLE = False
Expand Down Expand Up @@ -120,11 +118,10 @@ async def remote_setup(self, procs: ProcMesh) -> None:

class Slurmlauncher(BaseLauncher):
async def initialize(self) -> None:
if MONARCH_HOSTMESH_V1.get_value():
# HostMeshV1 currently requires explicit configuration
# of the underlying transport from client to mesh.
# This can be removed in the future once this has been removed.
configure(default_transport=ChannelTransport.Tcp)
# HostMesh currently requires explicit configuration
# of the underlying transport from client to mesh.
# This can be removed in the future once this has been removed.
configure(default_transport=ChannelTransport.Tcp)

async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]:
appdef = hyperactor.host_mesh(
Expand Down Expand Up @@ -180,11 +177,10 @@ def __init__(self, cfg: LauncherConfig | None = None):
self.job_name = self.cfg.job_name or self.create_job_name()

async def initialize(self) -> None:
if MONARCH_HOSTMESH_V1.get_value():
# HostMeshV1 currently requires explicit configuration
# of the underlying transport from client to mesh.
# This can be removed in the future once this has been removed.
configure(default_transport=ChannelTransport.MetaTlsWithHostname)
# HostMesh currently requires explicit configuration
# of the underlying transport from client to mesh.
# This can be removed in the future once this has been removed.
configure(default_transport=ChannelTransport.MetaTlsWithHostname)

await self.launch_mast_job()

Expand Down
47 changes: 16 additions & 31 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,21 @@
import socket
import uuid

from monarch._src.actor.shape import Extent, NDSlice, Shape
from monarch.actor import Actor, endpoint, ProcMesh
from monarch._src.actor.shape import Extent

from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host

from monarch.tools import commands

from forge.controller.launcher import BaseLauncher, get_launcher

from forge.env import all_env_vars, FORGE_DISABLE_METRICS, MONARCH_HOSTMESH_V1

from forge.env import all_env_vars, FORGE_DISABLE_METRICS
from forge.types import ProcessConfig, ProvisionerConfig

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


if MONARCH_HOSTMESH_V1.get_value():
from monarch._src.actor.v1.host_mesh import HostMesh, this_host

logger.info("Using Monarch HostMesh v1...")
else:
from monarch.actor import HostMesh, this_host


def _get_port() -> str:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", 0))
Expand Down Expand Up @@ -159,27 +151,20 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
name, num_hosts
)

if MONARCH_HOSTMESH_V1.get_value():
# We are asking Monarch to allocate a single process on
# every host, reflected in the Extent we provide below.
# We are asking Monarch to allocate a single process on
# every host, reflected in the Extent we provide below.

# Technically, this is ["hosts", "procs"] but to reduce
# confusion on its relationship with procs elsewhere,
# we call it "no_dim".
# Technically, this is ["hosts", "procs"] but to reduce
# confusion on its relationship with procs elsewhere,
# we call it "no_dim".

# TODO - remove this once Monarch supports HostMesh without it.
host_mesh = HostMesh.allocate_nonblocking(
name=name,
extent=Extent(["hosts", "no_dim"], [num_hosts, 1]),
allocator=alloc,
alloc_constraints=alloc_constraints,
)
else:
host_mesh = HostMesh(
Shape(["hosts"], NDSlice.new_row_major([num_hosts])),
allocator=alloc,
alloc_constraints=alloc_constraints,
)
# TODO - remove this once Monarch supports HostMesh without it.
host_mesh = HostMesh.allocate_nonblocking(
name=name,
extent=Extent(["hosts", "no_dim"], [num_hosts, 1]),
allocator=alloc,
alloc_constraints=alloc_constraints,
)
return host_mesh, server_name

def get_host_mesh(self, name: str) -> HostMesh:
Expand Down
6 changes: 0 additions & 6 deletions src/forge/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,6 @@ def get_value(self) -> Any:
description="Sets the maximum frame length for Monarch's actor message delivery in bytes.",
)

MONARCH_HOSTMESH_V1 = EnvVar(
name="MONARCH_HOST_MESH_V1_REMOVE_ME_BEFORE_RELEASE",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(y)

default=False,
description="Whether or not to use Monarch's experimental hostmesh v1 APIs",
)

TORCHSTORE_USE_RDMA = EnvVar(
name="TORCHSTORE_RDMA_ENABLED",
default=0,
Expand Down
10 changes: 2 additions & 8 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import logging
from typing import Any, Union

from monarch.actor import Actor, endpoint, ProcMesh
from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc

from forge.env import FORGE_DISABLE_METRICS, MONARCH_HOSTMESH_V1
from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
BackendRole,
get_logger_backend_class,
Expand All @@ -19,12 +19,6 @@
reduce_metrics_states,
)

if MONARCH_HOSTMESH_V1.get_value():
from monarch._src.actor.v1.host_mesh import this_proc
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller
else:
from monarch.actor import get_or_spawn_controller, this_proc


logger = logging.getLogger(__name__)

Expand Down
Loading