Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@

# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

from forge.env_constants import IS_MONARCH_HOSTMESH_V1
if IS_MONARCH_HOSTMESH_V1:
from monarch._rust_bindings.monarch_hyperactor.config import configure
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
configure(
default_transport=ChannelTransport.MetaTlsWithHostname,
)

from forge.env_constants import IS_MONARCH_HOSTMESH_V1

import asyncio
import time
import uuid
Expand All @@ -31,7 +41,8 @@
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.controller.provisioner import init_provisioner
from forge.env_constants import IS_MONARCH_HOSTMESH_V1
from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
Expand Down Expand Up @@ -314,14 +325,21 @@ async def main(cfg: DictConfig):
max_res_tokens = cfg.max_res_tokens

# ---- Global setups ---- #
provisioner_config = None
if cfg.get("provisioner", None) is not None:
await init_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
provisioner_config = ProvisionerConfig(
launcher_config=LauncherConfig(**cfg.provisioner)
)

provisioner = provisioner = await init_provisioner(provisioner_config)

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
await ts.initialize(strategy=ts.ControllerStorageVolumes())

assert provisioner is not None and IS_MONARCH_HOSTMESH_V1
if provisioner is None or not IS_MONARCH_HOSTMESH_V1:
await ts.initialize(strategy=ts.ControllerStorageVolumes())

# ---- Setup services ---- #

Expand All @@ -348,8 +366,18 @@ async def main(cfg: DictConfig):
reward_functions=[MathReward(), ThinkingReward()]
),
)

print("All services initialized successfully!")
print("Services initialized successfully!")

if provisioner is not None and IS_MONARCH_HOSTMESH_V1:
#TODO: support multiple host meshses
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
#TODO: support multiple host meshses
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"gpus": 8}),
strategy=ts.LocalRankStrategy()
)
print("Torchstore initialized successfully")

# ---- Core RL loops ---- #
async def continuous_rollouts():
Expand Down
8 changes: 8 additions & 0 deletions apps/mast/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from forge.env_constants import IS_MONARCH_HOSTMESH_V1
if IS_MONARCH_HOSTMESH_V1:
from monarch._rust_bindings.monarch_hyperactor.config import configure
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
configure(
default_transport=ChannelTransport.MetaTlsWithHostname,
)

import asyncio
import getpass
import uuid
Expand Down
2 changes: 1 addition & 1 deletion apps/mast/qwen3_1_7b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ max_res_tokens: 512
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default
launcher: mast
job_name: forge-qwen3-1_7b
job_name: forge-qwen3-1_7b-0
checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/

# Main loop configuration
Expand Down
25 changes: 13 additions & 12 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
import torch.distributed.checkpoint as dcp
import torchstore as ts

from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from monarch.actor import current_rank, current_size, endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
Expand All @@ -39,17 +50,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -112,7 +112,7 @@ class RLTrainer(ForgeActor):
# Non JobConfig-related fields
loss: Callable = lambda logits, **targets: logits
state_dict_key: str = "model_state_dict"
use_dcp: bool = True
use_dcp: bool = False
dcp_path: str = "forge_dcp_tmp"
vllm_tp_DEPRECATED: int = 1 # noqa: N815
use_vllm_builtin_load: bool = True
Expand Down Expand Up @@ -174,6 +174,7 @@ async def setup(self):
"use_dcp",
"use_vllm_builtin_load",
"dcp_path",
"job",
"vllm_tp_DEPRECATED",
}:
engine_config.pop(key) # Not part of job config
Expand Down
1 change: 0 additions & 1 deletion src/forge/controller/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ async def get_allocator(self, name: str, num_hosts: int) -> tuple[Any, Any, str]
async def remote_setup(self, procs: ProcMesh) -> tuple[str, int]:
setup = procs.spawn(f"setup-{uuid.uuid1()}", MastSetupActor)
await setup.mount.call(mount_dst="/mnt/wsfuse")
return await setup.get_info.choose()

async def launch_mast_job(self):
handle = self.create_server_handle()
Expand Down
53 changes: 40 additions & 13 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

"""Resource allocation and provisioning for both local and remote."""

import asyncio
import functools
import logging
Expand All @@ -13,20 +14,22 @@
import socket
import uuid

from monarch._src.actor.shape import NDSlice, Shape
from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host
from monarch._src.actor.shape import NDSlice, Shape, Extent
from monarch.tools import commands
from monarch.actor import Actor, endpoint, ProcMesh

from forge.controller.launcher import BaseLauncher, get_launcher

from forge.env_constants import FORGE_DISABLE_METRICS
from forge.env_constants import IS_MONARCH_HOSTMESH_V1, FORGE_DISABLE_METRICS
if IS_MONARCH_HOSTMESH_V1:
from monarch._src.actor.v1.host_mesh import HostMesh, this_host
else:
from monarch.actor import HostMesh, this_host

from forge.controller.launcher import BaseLauncher, get_launcher
from forge.types import ProcessConfig, ProvisionerConfig

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def _get_port() -> str:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", 0))
Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
self._this_host_id: GpuManager(available_local_devices),
}
self._proc_host_map = {}
self._host_mesh_map = {}
self.launcher: BaseLauncher | None = get_launcher(
cfg.launcher_config if cfg is not None else None
)
Expand All @@ -148,15 +152,32 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
alloc, alloc_constraints, server_name = await self.launcher.get_allocator(
name, num_hosts
)
return (
HostMesh(
if IS_MONARCH_HOSTMESH_V1:
# "procs" here is actually a dumby value, which Monarch requires but will ignore
# TODO: remove dummy dimension once Monarch supports HostMesh without it
host_mesh = HostMesh.allocate_nonblocking(
name=name,
extent=Extent(["hosts", "no_dim"], [num_hosts, 1]),
allocator=alloc,
alloc_constraints=alloc_constraints,
)
else:
host_mesh = HostMesh(
Shape(["hosts"], NDSlice.new_row_major([num_hosts])),
allocator=alloc,
alloc_constraints=alloc_constraints,
),
)
self._host_mesh_map[name] = host_mesh
return (
host_mesh
,
server_name,
)

def get_host_mesh(self, name: str) -> HostMesh:
"""Returns a HostMesh by name. Assumes the requested hostmesh already exists."""
return self._host_mesh_map[name]

async def get_proc_mesh(
self,
num_procs: int,
Expand Down Expand Up @@ -240,10 +261,16 @@ def bootstrap(env: dict[str, str]):
# Shows detailed logs for Monarch rust failures
env_vars["RUST_BACKTRACE"] = "1"

procs = host_mesh.spawn_procs(
per_host={"gpus": num_procs},
bootstrap=functools.partial(bootstrap, env=env_vars),
)
if IS_MONARCH_HOSTMESH_V1:
procs = host_mesh.spawn_procs(
per_host={"gpus": num_procs},
setup=functools.partial(bootstrap, env=env_vars),
)
else:
procs = host_mesh.spawn_procs(
per_host={"gpus": num_procs},
bootstrap=functools.partial(bootstrap, env=env_vars),
)

if is_remote:
await self.launcher.remote_setup(procs)
Expand Down
5 changes: 4 additions & 1 deletion src/forge/env_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
"""Centralized constants for environment variable names used in the project."""

# Performance metrics in forge.observability.perf_tracker.py becomes no-op
Expand All @@ -16,3 +16,6 @@
# Makes forge.observability.metrics.record_metric a no-op
# and disables spawning LocalFetcherActor in get_or_create_metric_logger
FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS"

# Experimental monarch features
IS_MONARCH_HOSTMESH_V1 = os.environ.get("MONARCH_HOSTMESH_V1", "1") == "1"
11 changes: 9 additions & 2 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
import os
from typing import Any, Union

from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
from forge.env_constants import IS_MONARCH_HOSTMESH_V1, FORGE_DISABLE_METRICS

from monarch.actor import Actor, endpoint, ProcMesh
if IS_MONARCH_HOSTMESH_V1:
from monarch._src.actor.v1.proc_mesh import get_or_spawn_controller
from monarch._src.actor.v1.host_mesh import this_proc
else:
from monarch.actor import get_or_spawn_controller, this_proc


from forge.env_constants import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
BackendRole,
get_logger_backend_class,
Expand Down