diff --git a/apps/grpo/main.py b/apps/grpo/main.py index c64f00bc..4bb7d17d 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -6,6 +6,16 @@ # Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml +from forge.env_constants import IS_MONARCH_HOSTMESH_V1 +if IS_MONARCH_HOSTMESH_V1: + from monarch._rust_bindings.monarch_hyperactor.config import configure + from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport + configure( + default_transport=ChannelTransport.MetaTlsWithHostname, + ) + +from forge.env_constants import IS_MONARCH_HOSTMESH_V1 + import asyncio import time import uuid @@ -31,7 +41,8 @@ from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer - +from forge.controller.provisioner import init_provisioner +from forge.env_constants import IS_MONARCH_HOSTMESH_V1 from forge.types import LauncherConfig, ProvisionerConfig from forge.util.ops import compute_logprobs from monarch.actor import endpoint @@ -314,14 +325,21 @@ async def main(cfg: DictConfig): max_res_tokens = cfg.max_res_tokens # ---- Global setups ---- # + provisioner_config = None if cfg.get("provisioner", None) is not None: - await init_provisioner( - ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) + provisioner_config = ProvisionerConfig( + launcher_config=LauncherConfig(**cfg.provisioner) ) + + provisioner = provisioner = await init_provisioner(provisioner_config) + metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(metric_logging_cfg) - await ts.initialize(strategy=ts.ControllerStorageVolumes()) + + assert provisioner is not None and IS_MONARCH_HOSTMESH_V1 + if provisioner is None or not IS_MONARCH_HOSTMESH_V1: + await ts.initialize(strategy=ts.ControllerStorageVolumes()) # ---- Setup services ---- # @@ -348,8 +366,18 @@ async def main(cfg: DictConfig): reward_functions=[MathReward(), ThinkingReward()] ), ) - - print("All services initialized successfully!") + print("Services initialized successfully!") + + if provisioner is not None and IS_MONARCH_HOSTMESH_V1: + #TODO: support multiple host meshses + trainer_host_mesh_name = cfg.actors.trainer["mesh_name"] + #TODO: support multiple host meshses + trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name) + await ts.initialize( + mesh=trainer_hosts.spawn_procs(per_host={"gpus": 8}), + strategy=ts.LocalRankStrategy() + ) + print("Torchstore initialized successfully") # ---- Core RL loops ---- # async def continuous_rollouts(): diff --git a/apps/mast/main.py b/apps/mast/main.py index cd5de0be..8564381c 100644 --- a/apps/mast/main.py +++ b/apps/mast/main.py @@ -4,6 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from forge.env_constants import IS_MONARCH_HOSTMESH_V1 +if IS_MONARCH_HOSTMESH_V1: + from monarch._rust_bindings.monarch_hyperactor.config import configure + from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport + configure( + default_transport=ChannelTransport.MetaTlsWithHostname, + ) + import asyncio import getpass import uuid diff --git a/apps/mast/qwen3_1_7b_mast.yaml b/apps/mast/qwen3_1_7b_mast.yaml index 58d87957..ee7d0541 100644 --- a/apps/mast/qwen3_1_7b_mast.yaml +++ b/apps/mast/qwen3_1_7b_mast.yaml @@ -9,7 +9,7 @@ max_res_tokens: 512 model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default launcher: mast -job_name: forge-qwen3-1_7b +job_name: forge-qwen3-1_7b-0 checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/ # Main loop configuration diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 7a399e4f..426b81f7 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -18,6 +18,17 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors._torchstore_utils import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -39,17 +50,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, -) - -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -112,7 +112,7 @@ class RLTrainer(ForgeActor): # Non JobConfig-related fields loss: Callable = lambda logits, **targets: logits state_dict_key: str = "model_state_dict" - use_dcp: bool = True + use_dcp: bool = False dcp_path: str = "forge_dcp_tmp" vllm_tp_DEPRECATED: int = 1 # noqa: N815 use_vllm_builtin_load: bool = True @@ -174,6 +174,7 @@ async def setup(self): "use_dcp", "use_vllm_builtin_load", "dcp_path", + "job", "vllm_tp_DEPRECATED", }: engine_config.pop(key) # Not part of job config diff --git a/src/forge/controller/launcher.py b/src/forge/controller/launcher.py index f2fe5f0f..fb84403b 100644 --- a/src/forge/controller/launcher.py +++ b/src/forge/controller/launcher.py @@ -190,7 +190,6 @@ async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str] async def remote_setup(self, procs: ProcMesh) -> tuple[str, int]: setup = procs.spawn(f"setup-{uuid.uuid1()}", MastSetupActor) await setup.mount.call(mount_dst="/mnt/wsfuse") - return await setup.get_info.choose() async def launch_mast_job(self): handle = self.create_server_handle() diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 25884942..bb7f5e1a 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. """Resource allocation and provisioning for both local and remote.""" + import asyncio import functools import logging @@ -13,20 +14,22 @@ import socket import uuid -from monarch._src.actor.shape import NDSlice, Shape -from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host +from monarch._src.actor.shape import NDSlice, Shape, Extent from monarch.tools import commands +from monarch.actor import Actor, endpoint, ProcMesh -from forge.controller.launcher import BaseLauncher, get_launcher - -from forge.env_constants import FORGE_DISABLE_METRICS +from forge.env_constants import IS_MONARCH_HOSTMESH_V1, FORGE_DISABLE_METRICS +if IS_MONARCH_HOSTMESH_V1: + from monarch._src.actor.v1.host_mesh import HostMesh, this_host +else: + from monarch.actor import HostMesh, this_host +from forge.controller.launcher import BaseLauncher, get_launcher from forge.types import ProcessConfig, ProvisionerConfig logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - def _get_port() -> str: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("localhost", 0)) @@ -125,6 +128,7 @@ def __init__(self, cfg: ProvisionerConfig | None = None): self._this_host_id: GpuManager(available_local_devices), } self._proc_host_map = {} + self._host_mesh_map = {} self.launcher: BaseLauncher | None = get_launcher( cfg.launcher_config if cfg is not None else None ) @@ -148,15 +152,32 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh: alloc, alloc_constraints, server_name = await self.launcher.get_allocator( name, num_hosts ) - return ( - HostMesh( + if IS_MONARCH_HOSTMESH_V1: + # "procs" here is actually a dumby value, which Monarch requires but will ignore + # TODO: remove dummy dimension once Monarch supports HostMesh without it + host_mesh = HostMesh.allocate_nonblocking( + name=name, + extent=Extent(["hosts", "no_dim"], [num_hosts, 1]), + allocator=alloc, + alloc_constraints=alloc_constraints, + ) + else: + host_mesh = HostMesh( Shape(["hosts"], NDSlice.new_row_major([num_hosts])), allocator=alloc, alloc_constraints=alloc_constraints, - ), + ) + self._host_mesh_map[name] = host_mesh + return ( + host_mesh + , server_name, ) + def get_host_mesh(self, name: str) -> HostMesh: + """Returns a HostMesh by name. Assumes the requested hostmesh already exists.""" + return self._host_mesh_map[name] + async def get_proc_mesh( self, num_procs: int, @@ -240,10 +261,16 @@ def bootstrap(env: dict[str, str]): # Shows detailed logs for Monarch rust failures env_vars["RUST_BACKTRACE"] = "1" - procs = host_mesh.spawn_procs( - per_host={"gpus": num_procs}, - bootstrap=functools.partial(bootstrap, env=env_vars), - ) + if IS_MONARCH_HOSTMESH_V1: + procs = host_mesh.spawn_procs( + per_host={"gpus": num_procs}, + setup=functools.partial(bootstrap, env=env_vars), + ) + else: + procs = host_mesh.spawn_procs( + per_host={"gpus": num_procs}, + bootstrap=functools.partial(bootstrap, env=env_vars), + ) if is_remote: await self.launcher.remote_setup(procs) diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index 6e0fc30e..a9814b09 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - +import os """Centralized constants for environment variable names used in the project.""" # Performance metrics in forge.observability.perf_tracker.py becomes no-op @@ -16,3 +16,6 @@ # Makes forge.observability.metrics.record_metric a no-op # and disables spawning LocalFetcherActor in get_or_create_metric_logger FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS" + +# Experimental monarch features +IS_MONARCH_HOSTMESH_V1 = os.environ.get("MONARCH_HOSTMESH_V1", "1") == "1" diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e50cc3fd..a66cff42 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -9,9 +9,16 @@ import os from typing import Any, Union -from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc +from forge.env_constants import IS_MONARCH_HOSTMESH_V1, FORGE_DISABLE_METRICS + +from monarch.actor import Actor, endpoint, ProcMesh +if IS_MONARCH_HOSTMESH_V1: + from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller + from monarch._src.actor.v1.host_mesh import this_proc +else: + from monarch.actor import get_or_spawn_controller, this_proc + -from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metrics import ( BackendRole, get_logger_backend_class,